mlpack  2.2.5
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
hoeffding_tree.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_HOEFFDING_TREES_HOEFFDING_TREE_HPP
14 #define MLPACK_METHODS_HOEFFDING_TREES_HOEFFDING_TREE_HPP
15 
16 #include <mlpack/prereqs.hpp>
18 #include "gini_impurity.hpp"
21 
22 namespace mlpack {
23 namespace tree {
24 
55 template<typename FitnessFunction = GiniImpurity,
56  template<typename> class NumericSplitType =
58  template<typename> class CategoricalSplitType =
59  HoeffdingCategoricalSplit
60 >
62 {
63  public:
65  typedef NumericSplitType<FitnessFunction> NumericSplit;
67  typedef CategoricalSplitType<FitnessFunction> CategoricalSplit;
68 
91  template<typename MatType>
92  HoeffdingTree(const MatType& data,
93  const data::DatasetInfo& datasetInfo,
94  const arma::Row<size_t>& labels,
95  const size_t numClasses,
96  const bool batchTraining = true,
97  const double successProbability = 0.95,
98  const size_t maxSamples = 0,
99  const size_t checkInterval = 100,
100  const size_t minSamples = 100,
101  const CategoricalSplitType<FitnessFunction>& categoricalSplitIn
102  = CategoricalSplitType<FitnessFunction>(0, 0),
103  const NumericSplitType<FitnessFunction>& numericSplitIn =
104  NumericSplitType<FitnessFunction>(0));
105 
125  HoeffdingTree(const data::DatasetInfo& datasetInfo,
126  const size_t numClasses,
127  const double successProbability = 0.95,
128  const size_t maxSamples = 0,
129  const size_t checkInterval = 100,
130  const size_t minSamples = 100,
131  const CategoricalSplitType<FitnessFunction>& categoricalSplitIn
132  = CategoricalSplitType<FitnessFunction>(0, 0),
133  const NumericSplitType<FitnessFunction>& numericSplitIn =
134  NumericSplitType<FitnessFunction>(0),
135  std::unordered_map<size_t, std::pair<size_t, size_t>>*
136  dimensionMappings = NULL);
137 
144  HoeffdingTree(const HoeffdingTree& other);
145 
149  ~HoeffdingTree();
150 
159  template<typename MatType>
160  void Train(const MatType& data,
161  const arma::Row<size_t>& labels,
162  const bool batchTraining = true);
163 
170  template<typename VecType>
171  void Train(const VecType& point, const size_t label);
172 
178  size_t SplitCheck();
179 
181  size_t SplitDimension() const { return splitDimension; }
182 
184  size_t MajorityClass() const { return majorityClass; }
186  size_t& MajorityClass() { return majorityClass; }
187 
189  double MajorityProbability() const { return majorityProbability; }
191  double& MajorityProbability() { return majorityProbability; }
192 
194  size_t NumChildren() const { return children.size(); }
195 
197  const HoeffdingTree& Child(const size_t i) const { return *children[i]; }
199  HoeffdingTree& Child(const size_t i) { return *children[i]; }
200 
202  double SuccessProbability() const { return successProbability; }
204  void SuccessProbability(const double successProbability);
205 
207  size_t MinSamples() const { return minSamples; }
209  void MinSamples(const size_t minSamples);
210 
212  size_t MaxSamples() const { return maxSamples; }
214  void MaxSamples(const size_t maxSamples);
215 
217  size_t CheckInterval() const { return checkInterval; }
219  void CheckInterval(const size_t checkInterval);
220 
228  template<typename VecType>
229  size_t CalculateDirection(const VecType& point) const;
230 
238  template<typename VecType>
239  size_t Classify(const VecType& point) const;
240 
252  template<typename VecType>
253  void Classify(const VecType& point, size_t& prediction, double& probability)
254  const;
255 
263  template<typename MatType>
264  void Classify(const MatType& data, arma::Row<size_t>& predictions) const;
265 
277  template<typename MatType>
278  void Classify(const MatType& data,
279  arma::Row<size_t>& predictions,
280  arma::rowvec& probabilities) const;
281 
285  void CreateChildren();
286 
288  template<typename Archive>
289  void Serialize(Archive& ar, const unsigned int /* version */);
290 
291  private:
292  // We need to keep some information for before we have split.
293 
295  std::vector<NumericSplitType<FitnessFunction>> numericSplits;
297  std::vector<CategoricalSplitType<FitnessFunction>> categoricalSplits;
298 
300  std::unordered_map<size_t, std::pair<size_t, size_t>>* dimensionMappings;
302  bool ownsMappings;
303 
305  size_t numSamples;
307  size_t numClasses;
309  size_t maxSamples;
311  size_t checkInterval;
313  size_t minSamples;
315  const data::DatasetInfo* datasetInfo;
317  bool ownsInfo;
319  double successProbability;
320 
321  // And we need to keep some information for after we have split.
322 
324  size_t splitDimension;
326  size_t majorityClass;
329  double majorityProbability;
331  typename CategoricalSplitType<FitnessFunction>::SplitInfo categoricalSplit;
333  typename NumericSplitType<FitnessFunction>::SplitInfo numericSplit;
335  std::vector<HoeffdingTree*> children;
336 };
337 
338 } // namespace tree
339 } // namespace mlpack
340 
341 #include "hoeffding_tree_impl.hpp"
342 
343 #endif
Auxiliary information for a dataset, including mappings to/from strings and the datatype of each dime...
double SuccessProbability() const
Get the confidence required for a split.
size_t SplitDimension() const
Get the splitting dimension (size_t(-1) if no split).
size_t CheckInterval() const
Get the number of samples before a split check is performed.
~HoeffdingTree()
Clean up memory.
The HoeffdingTree object represents all of the necessary information for a Hoeffding-bound-based deci...
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Train(const MatType &data, const arma::Row< size_t > &labels, const bool batchTraining=true)
Train on a set of points, either in streaming mode or in batch mode, with the given labels...
const HoeffdingTree & Child(const size_t i) const
Get a child.
void CreateChildren()
Given that this node should split, create the children.
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
size_t MajorityClass() const
Get the majority class.
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
void Serialize(Archive &ar, const unsigned int)
Serialize the split.
size_t NumChildren() const
Get the number of children.
HoeffdingNumericSplit< FitnessFunction, double > HoeffdingDoubleNumericSplit
Convenience typedef.
size_t & MajorityClass()
Modify the majority class.
HoeffdingTree(const MatType &data, const data::DatasetInfo &datasetInfo, const arma::Row< size_t > &labels, const size_t numClasses, const bool batchTraining=true, const double successProbability=0.95, const size_t maxSamples=0, const size_t checkInterval=100, const size_t minSamples=100, const CategoricalSplitType< FitnessFunction > &categoricalSplitIn=CategoricalSplitType< FitnessFunction >(0, 0), const NumericSplitType< FitnessFunction > &numericSplitIn=NumericSplitType< FitnessFunction >(0))
Construct the Hoeffding tree with the given parameters and given training data.
size_t SplitCheck()
Check if a split would satisfy the conditions of the Hoeffding bound with the node&#39;s specified succes...
HoeffdingTree & Child(const size_t i)
Modify a child.
size_t MaxSamples() const
Get the maximum number of samples before a split is forced.
size_t MinSamples() const
Get the minimum number of samples for a split.
double MajorityProbability() const
Get the probability of the majority class (based on training samples).
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.
double & MajorityProbability()
Modify the probability of the majority class.
size_t Classify(const VecType &point) const
Classify the given point, using this node and the entire (sub)tree beneath it.