mlpack  3.4.2
reparametrization.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_REPARAMETRIZATION_HPP
14 #define MLPACK_METHODS_ANN_LAYER_REPARAMETRIZATION_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 #include "layer_types.hpp"
19 #include "../activation_functions/softplus_function.hpp"
20 
21 namespace mlpack {
22 namespace ann {
23 
52 template <
53  typename InputDataType = arma::mat,
54  typename OutputDataType = arma::mat
55 >
57 {
58  public:
61 
70  Reparametrization(const size_t latentSize,
71  const bool stochastic = true,
72  const bool includeKl = true,
73  const double beta = 1);
74 
82  template<typename eT>
83  void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output);
84 
94  template<typename eT>
95  void Backward(const arma::Mat<eT>& input,
96  const arma::Mat<eT>& gy,
97  arma::Mat<eT>& g);
98 
100  OutputDataType const& OutputParameter() const { return outputParameter; }
102  OutputDataType& OutputParameter() { return outputParameter; }
103 
105  OutputDataType const& Delta() const { return delta; }
107  OutputDataType& Delta() { return delta; }
108 
110  size_t const& OutputSize() const { return latentSize; }
112  size_t& OutputSize() { return latentSize; }
113 
115  double Loss()
116  {
117  if (!includeKl)
118  return 0;
119 
120  return -0.5 * beta * arma::accu(2 * arma::log(stdDev) - arma::pow(stdDev, 2)
121  - arma::pow(mean, 2) + 1) / mean.n_cols;
122  }
123 
125  bool Stochastic() const { return stochastic; }
126 
128  bool IncludeKL() const { return includeKl; }
129 
131  double Beta() const { return beta; }
132 
136  template<typename Archive>
137  void serialize(Archive& ar, const unsigned int /* version */);
138 
139  private:
141  size_t latentSize;
142 
144  bool stochastic;
145 
147  bool includeKl;
148 
150  double beta;
151 
153  OutputDataType delta;
154 
156  OutputDataType gaussianSample;
157 
159  OutputDataType mean;
160 
163  OutputDataType preStdDev;
164 
166  OutputDataType stdDev;
167 
169  OutputDataType outputParameter;
170 }; // class Reparametrization
171 
172 } // namespace ann
173 } // namespace mlpack
174 
175 // Include implementation.
176 #include "reparametrization_impl.hpp"
177 
178 #endif
mlpack::ann::Reparametrization::serialize
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
mlpack::ann::Reparametrization::OutputSize
size_t & OutputSize()
Modify the output size.
Definition: reparametrization.hpp:112
prereqs.hpp
The core includes that mlpack expects; standard C++ includes and Armadillo.
mlpack::ann::Reparametrization::OutputParameter
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: reparametrization.hpp:102
mlpack::ann::Reparametrization::Reparametrization
Reparametrization(const size_t latentSize, const bool stochastic=true, const bool includeKl=true, const double beta=1)
Create the Reparametrization layer object using the specified sample vector size.
mlpack::ann::Reparametrization::OutputParameter
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: reparametrization.hpp:100
mlpack::ann::Reparametrization::OutputSize
size_t const & OutputSize() const
Get the output size.
Definition: reparametrization.hpp:110
mlpack
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: add_to_cli11.hpp:21
mlpack::ann::Reparametrization::Delta
OutputDataType const & Delta() const
Get the delta.
Definition: reparametrization.hpp:105
mlpack::ann::Reparametrization::Forward
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
mlpack::ann::Reparametrization::Stochastic
bool Stochastic() const
Get the value of the stochastic parameter.
Definition: reparametrization.hpp:125
mlpack::ann::Reparametrization::Reparametrization
Reparametrization()
Create the Reparametrization object.
mlpack::ann::Reparametrization
Implementation of the Reparametrization layer class.
Definition: reparametrization.hpp:57
mlpack::ann::Reparametrization::Loss
double Loss()
Get the KL divergence with standard normal.
Definition: reparametrization.hpp:115
mlpack::ann::Reparametrization::IncludeKL
bool IncludeKL() const
Get the value of the includeKl parameter.
Definition: reparametrization.hpp:128
mlpack::ann::Reparametrization::Delta
OutputDataType & Delta()
Modify the delta.
Definition: reparametrization.hpp:107
mlpack::ann::Reparametrization::Beta
double Beta() const
Get the value of the beta hyperparameter.
Definition: reparametrization.hpp:131
mlpack::ann::Reparametrization::Backward
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
layer_types.hpp