mlpack  2.2.5
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
cosine_tree.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_TREE_COSINE_TREE_COSINE_TREE_HPP
13 #define MLPACK_CORE_TREE_COSINE_TREE_COSINE_TREE_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 #include <boost/heap/priority_queue.hpp>
17 
18 namespace mlpack {
19 namespace tree {
20 
21 // Predeclare classes for CosineNodeQueue typedef.
22 class CompareCosineNode;
23 class CosineTree;
24 
25 // CosineNodeQueue typedef.
26 typedef boost::heap::priority_queue<CosineTree*,
27  boost::heap::compare<CompareCosineNode> > CosineNodeQueue;
28 
30 {
31  public:
40  CosineTree(const arma::mat& dataset);
41 
51  CosineTree(CosineTree& parentNode, const std::vector<size_t>& subIndices);
52 
67  CosineTree(const arma::mat& dataset,
68  const double epsilon,
69  const double delta);
70 
74  ~CosineTree();
75 
85  void ModifiedGramSchmidt(CosineNodeQueue& treeQueue,
86  arma::vec& centroid,
87  arma::vec& newBasisVector,
88  arma::vec* addBasisVector = NULL);
89 
102  double MonteCarloError(CosineTree* node,
103  CosineNodeQueue& treeQueue,
104  arma::vec* addBasisVector1 = NULL,
105  arma::vec* addBasisVector2 = NULL);
106 
112  void ConstructBasis(CosineNodeQueue& treeQueue);
113 
119  void CosineNodeSplit();
120 
127  void ColumnSamplesLS(std::vector<size_t>& sampledIndices,
128  arma::vec& probabilities, size_t numSamples);
129 
136  size_t ColumnSampleLS();
137 
150  size_t BinarySearch(arma::vec& cDistribution, double value, size_t start,
151  size_t end);
152 
160  void CalculateCosines(arma::vec& cosines);
161 
166  void CalculateCentroid();
167 
169  void GetFinalBasis(arma::mat& finalBasis) { finalBasis = basis; }
170 
172  const arma::mat& GetDataset() const { return dataset; }
173 
175  std::vector<size_t>& VectorIndices() { return indices; }
176 
178  void L2Error(const double error) { this->l2Error = error; }
180  double L2Error() const { return l2Error; }
181 
183  arma::vec& Centroid() { return centroid; }
184 
186  void BasisVector(arma::vec& bVector) { this->basisVector = bVector; }
187 
189  arma::vec& BasisVector() { return basisVector; }
190 
192  CosineTree* Parent() const { return parent; }
194  CosineTree*& Parent() { return parent; }
195 
197  CosineTree* Left() const { return left; }
199  CosineTree*& Left() { return left; }
200 
202  CosineTree* Right() const { return right; }
204  CosineTree*& Right() { return right; }
205 
207  size_t NumColumns() const { return numColumns; }
208 
210  double FrobNormSquared() const { return frobNormSquared; }
211 
213  size_t SplitPointIndex() const { return indices[splitPointIndex]; }
214 
215  private:
217  const arma::mat& dataset;
219  double delta;
221  arma::mat basis;
223  CosineTree* parent;
225  CosineTree* left;
227  CosineTree* right;
229  std::vector<size_t> indices;
231  arma::vec l2NormsSquared;
233  arma::vec centroid;
235  arma::vec basisVector;
237  size_t splitPointIndex;
239  size_t numColumns;
241  double l2Error;
243  double frobNormSquared;
244 };
245 
247 {
248  public:
249 
250  // Comparison function for construction of priority queue.
251  bool operator() (const CosineTree* a, const CosineTree* b) const
252  {
253  return a->L2Error() < b->L2Error();
254  }
255 };
256 
257 } // namespace tree
258 } // namespace mlpack
259 
260 #endif
bool operator()(const CosineTree *a, const CosineTree *b) const
double FrobNormSquared() const
Get the Frobenius norm squared of columns in the node.
void ModifiedGramSchmidt(CosineNodeQueue &treeQueue, arma::vec &centroid, arma::vec &newBasisVector, arma::vec *addBasisVector=NULL)
Calculates the orthonormalization of the passed centroid, with respect to the current vector subspace...
arma::vec & Centroid()
Get pointer to the centroid vector.
void GetFinalBasis(arma::mat &finalBasis)
Returns the basis of the constructed subspace.
CosineTree *& Left()
Modify the pointer to the left child of the node.
double MonteCarloError(CosineTree *node, CosineNodeQueue &treeQueue, arma::vec *addBasisVector1=NULL, arma::vec *addBasisVector2=NULL)
Estimates the squared error of the projection of the input node&#39;s matrix onto the current vector subs...
void ConstructBasis(CosineNodeQueue &treeQueue)
Constructs the final basis matrix, after the cosine tree construction.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void L2Error(const double error)
Set the Monte Carlo error.
const arma::mat & GetDataset() const
Get pointer to the dataset matrix.
CosineTree * Left() const
Get pointer to the left child of the node.
void ColumnSamplesLS(std::vector< size_t > &sampledIndices, arma::vec &probabilities, size_t numSamples)
Sample &#39;numSamples&#39; points from the Length-Squared distribution of the cosine node.
size_t SplitPointIndex() const
Get the column index of split point of the node.
size_t ColumnSampleLS()
Sample a point from the Length-Squared distribution of the cosine node.
CosineTree(const arma::mat &dataset)
CosineTree constructor for the root node of the tree.
void CosineNodeSplit()
This function splits the cosine node into two children based on the cosines of the columns contained ...
double L2Error() const
Get the Monte Carlo error.
CosineTree *& Right()
Modify the pointer to the left child of the node.
CosineTree *& Parent()
Modify the pointer to the parent node.
std::vector< size_t > & VectorIndices()
Get the indices of columns in the node.
CosineTree * Parent() const
Get pointer to the parent node.
~CosineTree()
Clean up the CosineTree: release allocated memory (including children).
CosineTree * Right() const
Get pointer to the right child of the node.
size_t BinarySearch(arma::vec &cDistribution, double value, size_t start, size_t end)
Sample a column based on the cumulative Length-Squared distribution of the cosine node...
void BasisVector(arma::vec &bVector)
Set the basis vector of the node.
void CalculateCentroid()
Calculate centroid of the columns present in the node.
size_t NumColumns() const
Get number of columns of input matrix in the node.
arma::vec & BasisVector()
Get the basis vector of the node.
void CalculateCosines(arma::vec &cosines)
Calculate cosines of the columns present in the node, with respect to the sampled splitting point...
boost::heap::priority_queue< CosineTree *, boost::heap::compare< CompareCosineNode > > CosineNodeQueue
Definition: cosine_tree.hpp:23