mlpack  3.4.2
normal_distribution.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_DISTRIBUTIONS_NORMAL_DISTRIBUTION_HPP
14 #define MLPACK_METHODS_ANN_DISTRIBUTIONS_NORMAL_DISTRIBUTION_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include "../activation_functions/logistic_function.hpp"
18 
19 namespace mlpack {
20 namespace ann {
21 
31 template <typename DataType = arma::mat>
33 {
34  public:
40 
47  NormalDistribution(const DataType& mean, const DataType& sigma);
48 
54  DataType Probability(const DataType& observation) const
55  {
56  return arma::exp(LogProbability(observation));
57  }
58 
64  DataType LogProbability(const DataType& observation) const;
65 
74  void ProbBackward(const DataType& observation,
75  DataType& dmu,
76  DataType& dsigma) const;
77 
85  void Probability(const DataType& x, DataType& probabilities) const
86  {
87  probabilities = Probability(x);
88  }
89 
97  void LogProbability(const DataType& x, DataType& probabilities) const
98  {
99  probabilities = LogProbability(x);
100  }
101 
108  DataType Sample() const;
109 
111  const DataType& Mean() const { return mean; }
112 
114  DataType& Mean() { return mean; }
115 
117  const DataType& StandardDeviation() const { return sigma; }
118 
120  DataType& StandardDeviation() { return sigma; }
121 
123  size_t Dimensionality() const { return mean.n_elem; }
124 
128  template<typename Archive>
129  void serialize(Archive& ar, const unsigned int /* version */);
130 
131  private:
133  DataType mean;
134 
136  DataType sigma;
137 }; // class NormalDistribution
138 
139 } // namespace ann
140 } // namespace mlpack
141 
142 // Include implementation.
143 #include "normal_distribution_impl.hpp"
144 
145 #endif
mlpack::ann::NormalDistribution::NormalDistribution
NormalDistribution(const DataType &mean, const DataType &sigma)
Create a Normal distribution with the given mean and sigma.
mlpack::ann::NormalDistribution::LogProbability
void LogProbability(const DataType &x, DataType &probabilities) const
Calculates the log of normal probability density function for each data point (column) in the given m...
Definition: normal_distribution.hpp:97
mlpack::ann::NormalDistribution::Dimensionality
size_t Dimensionality() const
Return the dimensionality of this distribution.
Definition: normal_distribution.hpp:123
prereqs.hpp
The core includes that mlpack expects; standard C++ includes and Armadillo.
mlpack::ann::NormalDistribution::LogProbability
DataType LogProbability(const DataType &observation) const
Return the log probabilities of the given matrix of observations.
mlpack::ann::NormalDistribution::Probability
void Probability(const DataType &x, DataType &probabilities) const
Calculates the normal probability density function for each data point (column) in the given matrix.
Definition: normal_distribution.hpp:85
mlpack::ann::NormalDistribution
Implementation of the Normal Distribution function.
Definition: normal_distribution.hpp:33
mlpack::ann::NormalDistribution::Sample
DataType Sample() const
Return a randomly generated observation according to the probability distribution defined by this obj...
mlpack::ann::NormalDistribution::StandardDeviation
const DataType & StandardDeviation() const
Get the standard deviation.
Definition: normal_distribution.hpp:117
mlpack::ann::NormalDistribution::ProbBackward
void ProbBackward(const DataType &observation, DataType &dmu, DataType &dsigma) const
Stores the gradient of the probabilities of the observations with respect to mean and standard deviat...
mlpack::ann::NormalDistribution::Mean
DataType & Mean()
Modify the mean.
Definition: normal_distribution.hpp:114
mlpack
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: add_to_cli11.hpp:21
mlpack::ann::NormalDistribution::Mean
const DataType & Mean() const
Get the mean.
Definition: normal_distribution.hpp:111
mlpack::ann::NormalDistribution::NormalDistribution
NormalDistribution()
Default constructor, which creates a Normal distribution with zero dimension.
mlpack::ann::NormalDistribution::Probability
DataType Probability(const DataType &observation) const
Return the probabilities of the given matrix of observations.
Definition: normal_distribution.hpp:54
mlpack::ann::NormalDistribution::StandardDeviation
DataType & StandardDeviation()
Modify the standard deviation.
Definition: normal_distribution.hpp:120
mlpack::ann::NormalDistribution::serialize
void serialize(Archive &ar, const unsigned int)
Serialize the distribution.