mlpack  2.2.5
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
dtree.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_DET_DTREE_HPP
14 #define MLPACK_METHODS_DET_DTREE_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace det {
20 
44 class DTree
45 {
46  public:
50  DTree();
51 
60  DTree(const arma::vec& maxVals,
61  const arma::vec& minVals,
62  const size_t totalPoints);
63 
72  DTree(arma::mat& data);
73 
86  DTree(const arma::vec& maxVals,
87  const arma::vec& minVals,
88  const size_t start,
89  const size_t end,
90  const double logNegError);
91 
103  DTree(const arma::vec& maxVals,
104  const arma::vec& minVals,
105  const size_t totalPoints,
106  const size_t start,
107  const size_t end);
108 
110  ~DTree();
111 
122  double Grow(arma::mat& data,
123  arma::Col<size_t>& oldFromNew,
124  const bool useVolReg = false,
125  const size_t maxLeafSize = 10,
126  const size_t minLeafSize = 5);
127 
136  double PruneAndUpdate(const double oldAlpha,
137  const size_t points,
138  const bool useVolReg = false);
139 
145  double ComputeValue(const arma::vec& query) const;
146 
154  void WriteTree(FILE *fp, const size_t level = 0) const;
155 
163  int TagTree(const int tag = 0);
164 
171  int FindBucket(const arma::vec& query) const;
172 
178  void ComputeVariableImportance(arma::vec& importances) const;
179 
186  double LogNegativeError(const size_t totalPoints) const;
187 
191  bool WithinRange(const arma::vec& query) const;
192 
193  private:
194  // The indices in the complete set of points
195  // (after all forms of swapping in the original data
196  // matrix to align all the points in a node
197  // consecutively in the matrix. The 'old_from_new' array
198  // maps the points back to their original indices.
199 
202  size_t start;
205  size_t end;
206 
208  arma::vec maxVals;
210  arma::vec minVals;
211 
213  size_t splitDim;
214 
216  double splitValue;
217 
219  double logNegError;
220 
222  double subtreeLeavesLogNegError;
223 
225  size_t subtreeLeaves;
226 
228  bool root;
229 
231  double ratio;
232 
234  double logVolume;
235 
237  int bucketTag;
238 
240  double alphaUpper;
241 
243  DTree* left;
245  DTree* right;
246 
247  public:
249  size_t Start() const { return start; }
251  size_t End() const { return end; }
253  size_t SplitDim() const { return splitDim; }
255  double SplitValue() const { return splitValue; }
257  double LogNegError() const { return logNegError; }
259  double SubtreeLeavesLogNegError() const { return subtreeLeavesLogNegError; }
261  size_t SubtreeLeaves() const { return subtreeLeaves; }
264  double Ratio() const { return ratio; }
266  double LogVolume() const { return logVolume; }
268  DTree* Left() const { return left; }
270  DTree* Right() const { return right; }
272  bool Root() const { return root; }
274  double AlphaUpper() const { return alphaUpper; }
275 
277  const arma::vec& MaxVals() const { return maxVals; }
279  arma::vec& MaxVals() { return maxVals; }
280 
282  const arma::vec& MinVals() const { return minVals; }
284  arma::vec& MinVals() { return minVals; }
285 
289  template<typename Archive>
290  void Serialize(Archive& ar, const unsigned int /* version */)
291  {
292  using data::CreateNVP;
293 
294  ar & CreateNVP(start, "start");
295  ar & CreateNVP(end, "end");
296  ar & CreateNVP(maxVals, "maxVals");
297  ar & CreateNVP(minVals, "minVals");
298  ar & CreateNVP(splitDim, "splitDim");
299  ar & CreateNVP(splitValue, "splitValue");
300  ar & CreateNVP(logNegError, "logNegError");
301  ar & CreateNVP(subtreeLeavesLogNegError, "subtreeLeavesLogNegError");
302  ar & CreateNVP(subtreeLeaves, "subtreeLeaves");
303  ar & CreateNVP(root, "root");
304  ar & CreateNVP(ratio, "ratio");
305  ar & CreateNVP(logVolume, "logVolume");
306  ar & CreateNVP(bucketTag, "bucketTag");
307  ar & CreateNVP(alphaUpper, "alphaUpper");
308 
309  if (Archive::is_loading::value)
310  {
311  if (left)
312  delete left;
313  if (right)
314  delete right;
315  }
316 
317  ar & CreateNVP(left, "left");
318  ar & CreateNVP(right, "right");
319  }
320 
321  private:
322 
323  // Utility methods.
324 
328  bool FindSplit(const arma::mat& data,
329  size_t& splitDim,
330  double& splitValue,
331  double& leftError,
332  double& rightError,
333  const size_t minLeafSize = 5) const;
334 
338  size_t SplitData(arma::mat& data,
339  const size_t splitDim,
340  const double splitValue,
341  arma::Col<size_t>& oldFromNew) const;
342 
343 };
344 
345 } // namespace det
346 } // namespace mlpack
347 
348 #endif // MLPACK_METHODS_DET_DTREE_HPP
double SplitValue() const
Return the split value of this node.
Definition: dtree.hpp:255
size_t Start() const
Return the starting index of points contained in this node.
Definition: dtree.hpp:249
double AlphaUpper() const
Return the upper part of the alpha sum.
Definition: dtree.hpp:274
~DTree()
Clean up memory allocated by the tree.
FirstShim< T > CreateNVP(T &t, const std::string &name, typename boost::enable_if< HasSerialize< T >>::type *=0)
Call this function to produce a name-value pair; this is similar to BOOST_SERIALIZATION_NVP(), but should be used for types that have a Serialize() function (or contain a type that has a Serialize() function) instead of a serialize() function.
double Ratio() const
Return the ratio of points in this node to the points in the whole dataset.
Definition: dtree.hpp:264
The core includes that mlpack expects; standard C++ includes and Armadillo.
arma::vec & MaxVals()
Modify the maximum values.
Definition: dtree.hpp:279
const arma::vec & MinVals() const
Return the minimum values.
Definition: dtree.hpp:282
size_t End() const
Return the first index of a point not contained in this node.
Definition: dtree.hpp:251
void Serialize(Archive &ar, const unsigned int)
Serialize the density estimation tree.
Definition: dtree.hpp:290
void WriteTree(FILE *fp, const size_t level=0) const
Print the tree in a depth-first manner (this function is called recursively).
bool WithinRange(const arma::vec &query) const
Return whether a query point is within the range of this node.
double Grow(arma::mat &data, arma::Col< size_t > &oldFromNew, const bool useVolReg=false, const size_t maxLeafSize=10, const size_t minLeafSize=5)
Greedily expand the tree.
bool Root() const
Return whether or not this is the root of the tree.
Definition: dtree.hpp:272
double SubtreeLeavesLogNegError() const
Return the log negative error of all descendants of this node.
Definition: dtree.hpp:259
size_t SplitDim() const
Return the split dimension of this node.
Definition: dtree.hpp:253
DTree * Right() const
Return the right child.
Definition: dtree.hpp:270
int TagTree(const int tag=0)
Index the buckets for possible usage later; this results in every leaf in the tree having a specific ...
double ComputeValue(const arma::vec &query) const
Compute the logarithm of the density estimate of a given query point.
size_t SubtreeLeaves() const
Return the number of leaves which are descendants of this node.
Definition: dtree.hpp:261
double PruneAndUpdate(const double oldAlpha, const size_t points, const bool useVolReg=false)
Perform alpha pruning on a tree.
A density estimation tree is similar to both a decision tree and a space partitioning tree (like a kd...
Definition: dtree.hpp:44
double LogNegError() const
Return the log negative error of this node.
Definition: dtree.hpp:257
void ComputeVariableImportance(arma::vec &importances) const
Compute the variable importance of each dimension in the learned tree.
double LogVolume() const
Return the inverse of the volume of this node.
Definition: dtree.hpp:266
DTree * Left() const
Return the left child.
Definition: dtree.hpp:268
const arma::vec & MaxVals() const
Return the maximum values.
Definition: dtree.hpp:277
double LogNegativeError(const size_t totalPoints) const
Compute the log-negative-error for this point, given the total number of points in the dataset...
int FindBucket(const arma::vec &query) const
Return the tag of the leaf containing the query.
arma::vec & MinVals()
Modify the minimum values.
Definition: dtree.hpp:284
DTree()
Create an empty density estimation tree.