13 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
14 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
31 template<
typename FitnessFunction = GiniGain,
32 template<
typename>
class NumericSplitType = BestBinaryNumericSplit,
33 template<
typename>
class CategoricalSplitType = AllCategoricalSplit,
34 typename ElemType = double,
35 bool NoRecursion =
false>
37 public NumericSplitType<FitnessFunction>::template
38 AuxiliarySplitInfo<ElemType>,
39 public CategoricalSplitType<FitnessFunction>::template
40 AuxiliarySplitInfo<ElemType>
60 template<
typename MatType>
63 const arma::Row<size_t>& labels,
64 const size_t numClasses,
65 const size_t minimumLeafSize = 10);
78 template<
typename MatType>
80 const arma::Row<size_t>& labels,
81 const size_t numClasses,
82 const size_t minimumLeafSize = 10);
139 template<
typename MatType>
140 void Train(
const MatType& data,
142 const arma::Row<size_t>& labels,
143 const size_t numClasses,
144 const size_t minimumLeafSize = 10);
157 template<
typename MatType>
158 void Train(
const MatType& data,
159 const arma::Row<size_t>& labels,
160 const size_t numClasses,
161 const size_t minimumLeafSize = 10);
169 template<
typename VecType>
170 size_t Classify(
const VecType& point)
const;
181 template<
typename VecType>
184 arma::vec& probabilities)
const;
193 template<
typename MatType>
195 arma::Row<size_t>& predictions)
const;
207 template<
typename MatType>
209 arma::Row<size_t>& predictions,
210 arma::mat& probabilities)
const;
215 template<
typename Archive>
216 void Serialize(Archive& ar,
const unsigned int );
233 template<
typename VecType>
238 std::vector<DecisionTree*> children;
240 size_t splitDimension;
243 size_t dimensionTypeOrMajorityClass;
251 arma::vec classProbabilities;
256 typedef typename NumericSplit::template AuxiliarySplitInfo<ElemType>
257 NumericAuxiliarySplitInfo;
258 typedef typename CategoricalSplit::template AuxiliarySplitInfo<ElemType>
259 CategoricalAuxiliarySplitInfo;
264 template<
typename RowType>
265 void CalculateClassProbabilities(
const RowType& labels,
266 const size_t numClasses);
272 template<
typename FitnessFunction = GiniGain,
273 template<
typename>
class NumericSplitType = BestBinaryNumericSplit,
274 template<
typename>
class CategoricalSplitType = AllCategoricalSplit,
275 typename ElemType =
double>
278 CategoricalSplitType,
286 #include "decision_tree_impl.hpp"
DecisionTree(const MatType &data, const data::DatasetInfo &datasetInfo, const arma::Row< size_t > &labels, const size_t numClasses, const size_t minimumLeafSize=10)
Construct the decision tree on the given data and labels, where the data can be both numeric and cate...
Auxiliary information for a dataset, including mappings to/from strings and the datatype of each dime...
This class implements a generic decision tree learner.
The core includes that mlpack expects; standard C++ includes and Armadillo.
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.
DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, false > DecisionStump
Convenience typedef for decision stumps (single level decision trees).
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
const DecisionTree & Child(const size_t i) const
Get the child of the given index.
DecisionTree & operator=(const DecisionTree &other)
Copy another tree.
size_t NumChildren() const
Get the number of children.
DecisionTree & Child(const size_t i)
Modify the child of the given index (be careful!).
void Train(const MatType &data, const data::DatasetInfo &datasetInfo, const arma::Row< size_t > &labels, const size_t numClasses, const size_t minimumLeafSize=10)
Train the decision tree on the given data.
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
void Serialize(Archive &ar, const unsigned int)
Serialize the tree.
~DecisionTree()
Clean up memory.
size_t Classify(const VecType &point) const
Classify the given point, using the entire tree.