Go to the documentation of this file.
13 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
14 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
22 #include <type_traits>
34 template<
typename FitnessFunction = GiniGain,
35 template<
typename>
class NumericSplitType = BestBinaryNumericSplit,
36 template<
typename>
class CategoricalSplitType = AllCategoricalSplit,
37 typename DimensionSelectionType = AllDimensionSelect,
38 typename ElemType = double,
39 bool NoRecursion =
false>
41 public NumericSplitType<FitnessFunction>::template
42 AuxiliarySplitInfo<ElemType>,
43 public CategoricalSplitType<FitnessFunction>::template
44 AuxiliarySplitInfo<ElemType>
71 template<
typename MatType,
typename LabelsType>
75 const size_t numClasses,
76 const size_t minimumLeafSize = 10,
77 const double minimumGainSplit = 1e-7,
78 const size_t maximumDepth = 0,
79 DimensionSelectionType dimensionSelector =
80 DimensionSelectionType());
98 template<
typename MatType,
typename LabelsType>
101 const size_t numClasses,
102 const size_t minimumLeafSize = 10,
103 const double minimumGainSplit = 1e-7,
104 const size_t maximumDepth = 0,
105 DimensionSelectionType dimensionSelector =
106 DimensionSelectionType());
127 template<
typename MatType,
typename LabelsType,
typename WeightsType>
132 const size_t numClasses,
134 const size_t minimumLeafSize = 10,
135 const double minimumGainSplit = 1e-7,
136 const size_t maximumDepth = 0,
137 DimensionSelectionType dimensionSelector = DimensionSelectionType(),
139 typename std::remove_reference<WeightsType>::type>::value>* = 0);
159 template<
typename MatType,
typename LabelsType,
typename WeightsType>
165 const size_t numClasses,
167 const size_t minimumLeafSize = 10,
168 const double minimumGainSplit = 1e-7,
170 typename std::remove_reference<WeightsType>::type>::value>* = 0);
189 template<
typename MatType,
typename LabelsType,
typename WeightsType>
193 const size_t numClasses,
195 const size_t minimumLeafSize = 10,
196 const double minimumGainSplit = 1e-7,
197 const size_t maximumDepth = 0,
198 DimensionSelectionType dimensionSelector = DimensionSelectionType(),
200 typename std::remove_reference<WeightsType>::type>::value>* = 0);
220 template<
typename MatType,
typename LabelsType,
typename WeightsType>
225 const size_t numClasses,
227 const size_t minimumLeafSize = 10,
228 const double minimumGainSplit = 1e-7,
229 const size_t maximumDepth = 0,
230 DimensionSelectionType dimensionSelector = DimensionSelectionType(),
232 typename std::remove_reference<WeightsType>::type>::value>* = 0);
296 template<
typename MatType,
typename LabelsType>
300 const size_t numClasses,
301 const size_t minimumLeafSize = 10,
302 const double minimumGainSplit = 1e-7,
303 const size_t maximumDepth = 0,
304 DimensionSelectionType dimensionSelector =
305 DimensionSelectionType());
324 template<
typename MatType,
typename LabelsType>
327 const size_t numClasses,
328 const size_t minimumLeafSize = 10,
329 const double minimumGainSplit = 1e-7,
330 const size_t maximumDepth = 0,
331 DimensionSelectionType dimensionSelector =
332 DimensionSelectionType());
355 template<
typename MatType,
typename LabelsType,
typename WeightsType>
359 const size_t numClasses,
361 const size_t minimumLeafSize = 10,
362 const double minimumGainSplit = 1e-7,
363 const size_t maximumDepth = 0,
364 DimensionSelectionType dimensionSelector =
365 DimensionSelectionType(),
367 std::remove_reference<WeightsType>::type>::value>* = 0);
388 template<
typename MatType,
typename LabelsType,
typename WeightsType>
391 const size_t numClasses,
393 const size_t minimumLeafSize = 10,
394 const double minimumGainSplit = 1e-7,
395 const size_t maximumDepth = 0,
396 DimensionSelectionType dimensionSelector =
397 DimensionSelectionType(),
399 std::remove_reference<WeightsType>::type>::value>* = 0);
407 template<
typename VecType>
419 template<
typename VecType>
422 arma::vec& probabilities)
const;
431 template<
typename MatType>
433 arma::Row<size_t>& predictions)
const;
445 template<
typename MatType>
447 arma::Row<size_t>& predictions,
448 arma::mat& probabilities)
const;
453 template<
typename Archive>
475 template<
typename VecType>
485 std::vector<DecisionTree*> children;
487 size_t splitDimension;
490 size_t dimensionTypeOrMajorityClass;
498 arma::vec classProbabilities;
503 typedef typename NumericSplit::template AuxiliarySplitInfo<ElemType>
504 NumericAuxiliarySplitInfo;
505 typedef typename CategoricalSplit::template AuxiliarySplitInfo<ElemType>
506 CategoricalAuxiliarySplitInfo;
511 template<
bool UseWeights,
typename RowType,
typename WeightsRowType>
512 void CalculateClassProbabilities(
const RowType& labels,
513 const size_t numClasses,
514 const WeightsRowType& weights);
533 template<
bool UseWeights,
typename MatType>
534 double Train(MatType& data,
538 arma::Row<size_t>& labels,
539 const size_t numClasses,
540 arma::rowvec& weights,
541 const size_t minimumLeafSize,
542 const double minimumGainSplit,
543 const size_t maximumDepth,
544 DimensionSelectionType& dimensionSelector);
562 template<
bool UseWeights,
typename MatType>
563 double Train(MatType& data,
566 arma::Row<size_t>& labels,
567 const size_t numClasses,
568 arma::rowvec& weights,
569 const size_t minimumLeafSize,
570 const double minimumGainSplit,
571 const size_t maximumDepth,
572 DimensionSelectionType& dimensionSelector);
578 template<
typename FitnessFunction =
GiniGain,
582 typename ElemType =
double>
585 CategoricalSplitType,
604 #include "decision_tree_impl.hpp"
double Train(MatType data, LabelsType labels, const size_t numClasses, WeightsType weights, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType(), const std::enable_if_t< arma::is_arma_type< typename std::remove_reference< WeightsType >::type >::value > *=0)
Train the decision tree on the given weighted data, assuming that all dimensions are numeric.
DecisionTree(const DecisionTree &other, MatType data, LabelsType labels, const size_t numClasses, WeightsType weights, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType(), const std::enable_if_t< arma::is_arma_type< typename std::remove_reference< WeightsType >::type >::value > *=0)
Take ownership of another decision tree and train on the given data and labels with weights,...
DecisionTree(MatType data, LabelsType labels, const size_t numClasses, WeightsType weights, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType(), const std::enable_if_t< arma::is_arma_type< typename std::remove_reference< WeightsType >::type >::value > *=0)
Construct the decision tree on the given data and labels with weights, assuming that the data is all ...
The core includes that mlpack expects; standard C++ includes and Armadillo.
The AllCategoricalSplit is a splitting function that will split categorical features into many childr...
DecisionTree(const DecisionTree &other)
Copy another tree.
DecisionTree(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, WeightsType weights, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType(), const std::enable_if_t< arma::is_arma_type< typename std::remove_reference< WeightsType >::type >::value > *=0)
Construct the decision tree on the given data and labels with weights, where the data can be both num...
size_t NumClasses() const
Get the number of classes in the tree.
This dimension selection policy allows any dimension to be selected for splitting.
DimensionSelectionType DimensionSelection
Allow access to the dimension selection type.
size_t Classify(const VecType &point) const
Classify the given point, using the entire tree.
DecisionTree< InformationGain, BestBinaryNumericSplit, AllCategoricalSplit, AllDimensionSelect, double, true > ID3DecisionStump
Convenience typedef for ID3 decision stumps (single level decision trees made with the ID3 algorithm)...
Linear algebra utility functions, generally performed on matrices or vectors.
typename enable_if< B, T >::type enable_if_t
double Train(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the decision tree on the given data.
DecisionTree(const size_t numClasses=1)
Construct a decision tree without training it.
void Classify(const VecType &point, size_t &prediction, arma::vec &probabilities) const
Classify the given point and also return estimates of the probability for each class in the given vec...
double Train(MatType data, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the decision tree on the given data, assuming that all dimensions are numeric.
size_t NumChildren() const
Get the number of children.
DecisionTree(DecisionTree &&other)
Take ownership of another tree.
DecisionTree & Child(const size_t i)
Modify the child of the given index (be careful!).
const DecisionTree & Child(const size_t i) const
Get the child of the given index.
DecisionTree(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Construct the decision tree on the given data and labels, where the data can be both numeric and cate...
The Gini gain, a measure of set purity usable as a fitness function (FitnessFunction) for decision tr...
DecisionTree & operator=(const DecisionTree &other)
Copy another tree.
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
DecisionTree & operator=(DecisionTree &&other)
Take ownership of another tree.
DecisionTree(MatType data, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Construct the decision tree on the given data and labels, assuming that the data is all of the numeri...
void serialize(Archive &ar, const unsigned int)
Serialize the tree.
void Classify(const MatType &data, arma::Row< size_t > &predictions, arma::mat &probabilities) const
Classify the given points and also return estimates of the probabilities for each class in the given ...
~DecisionTree()
Clean up memory.
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
The BestBinaryNumericSplit is a splitting function for decision trees that will exhaustively search a...
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.
size_t SplitDimension() const
Get the split dimension (only meaningful if this is a non-leaf in a trained tree).
void Classify(const MatType &data, arma::Row< size_t > &predictions) const
Classify the given points, using the entire tree.
This class implements a generic decision tree learner.
DecisionTree(const DecisionTree &other, MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, WeightsType weights, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const std::enable_if_t< arma::is_arma_type< typename std::remove_reference< WeightsType >::type >::value > *=0)
Take ownership of another decision tree and train on the given data and labels with weights,...
double Train(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, WeightsType weights, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType(), const std::enable_if_t< arma::is_arma_type< typename std::remove_reference< WeightsType >::type >::value > *=0)
Train the decision tree on the given weighted data.
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...