mlpack  2.2.5
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
decision_tree.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
14 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include "gini_gain.hpp"
20 
21 namespace mlpack {
22 namespace tree {
23 
31 template<typename FitnessFunction = GiniGain,
32  template<typename> class NumericSplitType = BestBinaryNumericSplit,
33  template<typename> class CategoricalSplitType = AllCategoricalSplit,
34  typename ElemType = double,
35  bool NoRecursion = false>
36 class DecisionTree :
37  public NumericSplitType<FitnessFunction>::template
38  AuxiliarySplitInfo<ElemType>,
39  public CategoricalSplitType<FitnessFunction>::template
40  AuxiliarySplitInfo<ElemType>
41 {
42  public:
44  typedef NumericSplitType<FitnessFunction> NumericSplit;
46  typedef CategoricalSplitType<FitnessFunction> CategoricalSplit;
47 
60  template<typename MatType>
61  DecisionTree(const MatType& data,
62  const data::DatasetInfo& datasetInfo,
63  const arma::Row<size_t>& labels,
64  const size_t numClasses,
65  const size_t minimumLeafSize = 10);
66 
78  template<typename MatType>
79  DecisionTree(const MatType& data,
80  const arma::Row<size_t>& labels,
81  const size_t numClasses,
82  const size_t minimumLeafSize = 10);
83 
90  DecisionTree(const size_t numClasses = 1);
91 
98  DecisionTree(const DecisionTree& other);
99 
105  DecisionTree(DecisionTree&& other);
106 
113  DecisionTree& operator=(const DecisionTree& other);
114 
121 
125  ~DecisionTree();
126 
139  template<typename MatType>
140  void Train(const MatType& data,
141  const data::DatasetInfo& datasetInfo,
142  const arma::Row<size_t>& labels,
143  const size_t numClasses,
144  const size_t minimumLeafSize = 10);
145 
157  template<typename MatType>
158  void Train(const MatType& data,
159  const arma::Row<size_t>& labels,
160  const size_t numClasses,
161  const size_t minimumLeafSize = 10);
162 
169  template<typename VecType>
170  size_t Classify(const VecType& point) const;
171 
181  template<typename VecType>
182  void Classify(const VecType& point,
183  size_t& prediction,
184  arma::vec& probabilities) const;
185 
193  template<typename MatType>
194  void Classify(const MatType& data,
195  arma::Row<size_t>& predictions) const;
196 
207  template<typename MatType>
208  void Classify(const MatType& data,
209  arma::Row<size_t>& predictions,
210  arma::mat& probabilities) const;
211 
215  template<typename Archive>
216  void Serialize(Archive& ar, const unsigned int /* version */);
217 
219  size_t NumChildren() const { return children.size(); }
220 
222  const DecisionTree& Child(const size_t i) const { return *children[i]; }
224  DecisionTree& Child(const size_t i) { return *children[i]; }
225 
233  template<typename VecType>
234  size_t CalculateDirection(const VecType& point) const;
235 
236  private:
238  std::vector<DecisionTree*> children;
240  size_t splitDimension;
243  size_t dimensionTypeOrMajorityClass;
251  arma::vec classProbabilities;
252 
256  typedef typename NumericSplit::template AuxiliarySplitInfo<ElemType>
257  NumericAuxiliarySplitInfo;
258  typedef typename CategoricalSplit::template AuxiliarySplitInfo<ElemType>
259  CategoricalAuxiliarySplitInfo;
260 
264  template<typename RowType>
265  void CalculateClassProbabilities(const RowType& labels,
266  const size_t numClasses);
267 };
268 
272 template<typename FitnessFunction = GiniGain,
273  template<typename> class NumericSplitType = BestBinaryNumericSplit,
274  template<typename> class CategoricalSplitType = AllCategoricalSplit,
275  typename ElemType = double>
276 using DecisionStump = DecisionTree<FitnessFunction,
277  NumericSplitType,
278  CategoricalSplitType,
279  ElemType,
280  false>;
281 
282 } // namespace tree
283 } // namespace mlpack
284 
285 // Include implementation.
286 #include "decision_tree_impl.hpp"
287 
288 #endif
DecisionTree(const MatType &data, const data::DatasetInfo &datasetInfo, const arma::Row< size_t > &labels, const size_t numClasses, const size_t minimumLeafSize=10)
Construct the decision tree on the given data and labels, where the data can be both numeric and cate...
Auxiliary information for a dataset, including mappings to/from strings and the datatype of each dime...
This class implements a generic decision tree learner.
The core includes that mlpack expects; standard C++ includes and Armadillo.
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.
DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, false > DecisionStump
Convenience typedef for decision stumps (single level decision trees).
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
const DecisionTree & Child(const size_t i) const
Get the child of the given index.
DecisionTree & operator=(const DecisionTree &other)
Copy another tree.
size_t NumChildren() const
Get the number of children.
DecisionTree & Child(const size_t i)
Modify the child of the given index (be careful!).
void Train(const MatType &data, const data::DatasetInfo &datasetInfo, const arma::Row< size_t > &labels, const size_t numClasses, const size_t minimumLeafSize=10)
Train the decision tree on the given data.
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
void Serialize(Archive &ar, const unsigned int)
Serialize the tree.
~DecisionTree()
Clean up memory.
size_t Classify(const VecType &point) const
Classify the given point, using the entire tree.