mlpack  3.4.2
hoeffding_tree.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_HOEFFDING_TREES_HOEFFDING_TREE_HPP
14 #define MLPACK_METHODS_HOEFFDING_TREES_HOEFFDING_TREE_HPP
15 
16 #include <mlpack/prereqs.hpp>
18 #include "gini_impurity.hpp"
21 
22 namespace mlpack {
23 namespace tree {
24 
55 template<typename FitnessFunction = GiniImpurity,
56  template<typename> class NumericSplitType =
58  template<typename> class CategoricalSplitType =
59  HoeffdingCategoricalSplit
60 >
62 {
63  public:
65  typedef NumericSplitType<FitnessFunction> NumericSplit;
67  typedef CategoricalSplitType<FitnessFunction> CategoricalSplit;
68 
93  template<typename MatType>
94  HoeffdingTree(const MatType& data,
95  const data::DatasetInfo& datasetInfo,
96  const arma::Row<size_t>& labels,
97  const size_t numClasses,
98  const bool batchTraining = true,
99  const double successProbability = 0.95,
100  const size_t maxSamples = 0,
101  const size_t checkInterval = 100,
102  const size_t minSamples = 100,
103  const CategoricalSplitType<FitnessFunction>& categoricalSplitIn
104  = CategoricalSplitType<FitnessFunction>(0, 0),
105  const NumericSplitType<FitnessFunction>& numericSplitIn =
106  NumericSplitType<FitnessFunction>(0));
107 
130  HoeffdingTree(const data::DatasetInfo& datasetInfo,
131  const size_t numClasses,
132  const double successProbability = 0.95,
133  const size_t maxSamples = 0,
134  const size_t checkInterval = 100,
135  const size_t minSamples = 100,
136  const CategoricalSplitType<FitnessFunction>& categoricalSplitIn
137  = CategoricalSplitType<FitnessFunction>(0, 0),
138  const NumericSplitType<FitnessFunction>& numericSplitIn =
139  NumericSplitType<FitnessFunction>(0),
140  std::unordered_map<size_t, std::pair<size_t, size_t>>*
141  dimensionMappings = NULL,
142  const bool copyDatasetInfo = true);
143 
149 
157 
162 
171  template<typename MatType>
172  void Train(const MatType& data,
173  const arma::Row<size_t>& labels,
174  const bool batchTraining = true);
175 
180  template<typename MatType>
181  void Train(const MatType& data,
182  const data::DatasetInfo& info,
183  const arma::Row<size_t>& labels,
184  const bool batchTraining = true);
185 
192  template<typename VecType>
193  void Train(const VecType& point, const size_t label);
194 
200  size_t SplitCheck();
201 
203  size_t SplitDimension() const { return splitDimension; }
204 
206  size_t MajorityClass() const { return majorityClass; }
208  size_t& MajorityClass() { return majorityClass; }
209 
211  double MajorityProbability() const { return majorityProbability; }
213  double& MajorityProbability() { return majorityProbability; }
214 
216  size_t NumChildren() const { return children.size(); }
217 
219  const HoeffdingTree& Child(const size_t i) const { return *children[i]; }
221  HoeffdingTree& Child(const size_t i) { return *children[i]; }
222 
224  double SuccessProbability() const { return successProbability; }
226  void SuccessProbability(const double successProbability);
227 
229  size_t MinSamples() const { return minSamples; }
231  void MinSamples(const size_t minSamples);
232 
234  size_t MaxSamples() const { return maxSamples; }
236  void MaxSamples(const size_t maxSamples);
237 
239  size_t CheckInterval() const { return checkInterval; }
241  void CheckInterval(const size_t checkInterval);
242 
250  template<typename VecType>
251  size_t CalculateDirection(const VecType& point) const;
252 
260  template<typename VecType>
261  size_t Classify(const VecType& point) const;
262 
264  size_t NumDescendants() const;
265 
277  template<typename VecType>
278  void Classify(const VecType& point, size_t& prediction, double& probability)
279  const;
280 
288  template<typename MatType>
289  void Classify(const MatType& data, arma::Row<size_t>& predictions) const;
290 
302  template<typename MatType>
303  void Classify(const MatType& data,
304  arma::Row<size_t>& predictions,
305  arma::rowvec& probabilities) const;
306 
311 
313  template<typename Archive>
314  void serialize(Archive& ar, const unsigned int /* version */);
315 
316  private:
317  // We need to keep some information for before we have split.
318 
320  std::vector<NumericSplitType<FitnessFunction>> numericSplits;
322  std::vector<CategoricalSplitType<FitnessFunction>> categoricalSplits;
323 
325  std::unordered_map<size_t, std::pair<size_t, size_t>>* dimensionMappings;
327  bool ownsMappings;
328 
330  size_t numSamples;
332  size_t numClasses;
334  size_t maxSamples;
336  size_t checkInterval;
338  size_t minSamples;
340  const data::DatasetInfo* datasetInfo;
342  bool ownsInfo;
344  double successProbability;
345 
346  // And we need to keep some information for after we have split.
347 
349  size_t splitDimension;
351  size_t majorityClass;
354  double majorityProbability;
356  typename CategoricalSplitType<FitnessFunction>::SplitInfo categoricalSplit;
358  typename NumericSplitType<FitnessFunction>::SplitInfo numericSplit;
360  std::vector<HoeffdingTree*> children;
361 };
362 
363 } // namespace tree
364 } // namespace mlpack
365 
366 #include "hoeffding_tree_impl.hpp"
367 
368 #endif
mlpack::tree::HoeffdingTree::MaxSamples
size_t MaxSamples() const
Get the maximum number of samples before a split is forced.
Definition: hoeffding_tree.hpp:234
mlpack::tree::HoeffdingTree::NumDescendants
size_t NumDescendants() const
Get the size of the Hoeffding Tree.
mlpack::tree::HoeffdingTree::Train
void Train(const MatType &data, const arma::Row< size_t > &labels, const bool batchTraining=true)
Train on a set of points, either in streaming mode or in batch mode, with the given labels.
prereqs.hpp
The core includes that mlpack expects; standard C++ includes and Armadillo.
mlpack::tree::HoeffdingTree::CalculateDirection
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
mlpack::tree::HoeffdingTree::HoeffdingTree
HoeffdingTree()
Construct a Hoeffding tree with no data and no information.
mlpack::tree::HoeffdingTree::MajorityClass
size_t MajorityClass() const
Get the majority class.
Definition: hoeffding_tree.hpp:206
mlpack::tree::HoeffdingTree::MinSamples
void MinSamples(const size_t minSamples)
Modify the minimum number of samples for a split.
mlpack::tree::HoeffdingDoubleNumericSplit
HoeffdingNumericSplit< FitnessFunction, double > HoeffdingDoubleNumericSplit
Convenience typedef.
Definition: hoeffding_numeric_split.hpp:148
mlpack::tree::HoeffdingTree::MaxSamples
void MaxSamples(const size_t maxSamples)
Modify the maximum number of samples before a split is forced.
mlpack::tree::HoeffdingTree::~HoeffdingTree
~HoeffdingTree()
Clean up memory.
mlpack::tree::HoeffdingTree::Train
void Train(const MatType &data, const data::DatasetInfo &info, const arma::Row< size_t > &labels, const bool batchTraining=true)
Train on a set of points, either in streaming mode or in batch mode, with the given labels and the gi...
mlpack::tree::HoeffdingTree::Classify
size_t Classify(const VecType &point) const
Classify the given point, using this node and the entire (sub)tree beneath it.
hoeffding_categorical_split.hpp
mlpack::tree::HoeffdingTree::SplitDimension
size_t SplitDimension() const
Get the splitting dimension (size_t(-1) if no split).
Definition: hoeffding_tree.hpp:203
mlpack::tree::HoeffdingTree::NumChildren
size_t NumChildren() const
Get the number of children.
Definition: hoeffding_tree.hpp:216
mlpack::tree::HoeffdingTree::SplitCheck
size_t SplitCheck()
Check if a split would satisfy the conditions of the Hoeffding bound with the node's specified succes...
mlpack::tree::HoeffdingTree::Classify
void Classify(const VecType &point, size_t &prediction, double &probability) const
Classify the given point and also return an estimate of the probability that the prediction is correc...
mlpack
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: add_to_cli11.hpp:21
mlpack::tree::HoeffdingTree::CheckInterval
size_t CheckInterval() const
Get the number of samples before a split check is performed.
Definition: hoeffding_tree.hpp:239
hoeffding_numeric_split.hpp
mlpack::tree::HoeffdingTree::Child
HoeffdingTree & Child(const size_t i)
Modify a child.
Definition: hoeffding_tree.hpp:221
mlpack::tree::HoeffdingTree::MinSamples
size_t MinSamples() const
Get the minimum number of samples for a split.
Definition: hoeffding_tree.hpp:229
gini_impurity.hpp
mlpack::tree::HoeffdingTree::Classify
void Classify(const MatType &data, arma::Row< size_t > &predictions, arma::rowvec &probabilities) const
Classify the given points, using this node and the entire (sub)tree beneath it.
mlpack::tree::HoeffdingTree::serialize
void serialize(Archive &ar, const unsigned int)
Serialize the split.
mlpack::tree::HoeffdingTree::SuccessProbability
double SuccessProbability() const
Get the confidence required for a split.
Definition: hoeffding_tree.hpp:224
mlpack::tree::HoeffdingTree::CategoricalSplit
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
Definition: hoeffding_tree.hpp:67
mlpack::tree::HoeffdingTree::CheckInterval
void CheckInterval(const size_t checkInterval)
Modify the number of samples before a split check is performed.
mlpack::tree::HoeffdingTree::SuccessProbability
void SuccessProbability(const double successProbability)
Modify the confidence required for a split.
mlpack::tree::HoeffdingTree::MajorityClass
size_t & MajorityClass()
Modify the majority class.
Definition: hoeffding_tree.hpp:208
mlpack::tree::HoeffdingTree::Classify
void Classify(const MatType &data, arma::Row< size_t > &predictions) const
Classify the given points, using this node and the entire (sub)tree beneath it.
mlpack::tree::HoeffdingTree::HoeffdingTree
HoeffdingTree(const HoeffdingTree &other)
Copy another tree (warning: this will duplicate the tree entirely, and may use a lot of memory.
mlpack::tree::HoeffdingTree::Train
void Train(const VecType &point, const size_t label)
Train on a single point in streaming mode, with the given label.
mlpack::tree::HoeffdingTree::HoeffdingTree
HoeffdingTree(const data::DatasetInfo &datasetInfo, const size_t numClasses, const double successProbability=0.95, const size_t maxSamples=0, const size_t checkInterval=100, const size_t minSamples=100, const CategoricalSplitType< FitnessFunction > &categoricalSplitIn=CategoricalSplitType< FitnessFunction >(0, 0), const NumericSplitType< FitnessFunction > &numericSplitIn=NumericSplitType< FitnessFunction >(0), std::unordered_map< size_t, std::pair< size_t, size_t >> *dimensionMappings=NULL, const bool copyDatasetInfo=true)
Construct the Hoeffding tree with the given parameters, but training on no data.
mlpack::tree::HoeffdingTree::MajorityProbability
double MajorityProbability() const
Get the probability of the majority class (based on training samples).
Definition: hoeffding_tree.hpp:211
dataset_mapper.hpp
mlpack::tree::HoeffdingTree::Child
const HoeffdingTree & Child(const size_t i) const
Get a child.
Definition: hoeffding_tree.hpp:219
mlpack::tree::HoeffdingTree
The HoeffdingTree object represents all of the necessary information for a Hoeffding-bound-based deci...
Definition: hoeffding_tree.hpp:62
mlpack::tree::HoeffdingTree::CreateChildren
void CreateChildren()
Given that this node should split, create the children.
mlpack::tree::HoeffdingTree::HoeffdingTree
HoeffdingTree(const MatType &data, const data::DatasetInfo &datasetInfo, const arma::Row< size_t > &labels, const size_t numClasses, const bool batchTraining=true, const double successProbability=0.95, const size_t maxSamples=0, const size_t checkInterval=100, const size_t minSamples=100, const CategoricalSplitType< FitnessFunction > &categoricalSplitIn=CategoricalSplitType< FitnessFunction >(0, 0), const NumericSplitType< FitnessFunction > &numericSplitIn=NumericSplitType< FitnessFunction >(0))
Construct the Hoeffding tree with the given parameters and given training data.
mlpack::tree::HoeffdingTree::MajorityProbability
double & MajorityProbability()
Modify the probability of the majority class.
Definition: hoeffding_tree.hpp:213
mlpack::tree::HoeffdingTree::NumericSplit
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.
Definition: hoeffding_tree.hpp:65
mlpack::data::DatasetMapper
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
Definition: dataset_mapper.hpp:42