Go to the documentation of this file.
13 #ifndef MLPACK_METHODS_ANN_LAYER_REPARAMETRIZATION_HPP
14 #define MLPACK_METHODS_ANN_LAYER_REPARAMETRIZATION_HPP
19 #include "../activation_functions/softplus_function.hpp"
53 typename InputDataType = arma::mat,
54 typename OutputDataType = arma::mat
71 const bool stochastic =
true,
72 const bool includeKl =
true,
73 const double beta = 1);
83 void Forward(
const arma::Mat<eT>& input, arma::Mat<eT>& output);
96 const arma::Mat<eT>& gy,
105 OutputDataType
const&
Delta()
const {
return delta; }
107 OutputDataType&
Delta() {
return delta; }
120 return -0.5 * beta * arma::accu(2 * arma::log(stdDev) - arma::pow(stdDev, 2)
121 - arma::pow(mean, 2) + 1) / mean.n_cols;
131 double Beta()
const {
return beta; }
136 template<
typename Archive>
153 OutputDataType delta;
156 OutputDataType gaussianSample;
163 OutputDataType preStdDev;
166 OutputDataType stdDev;
169 OutputDataType outputParameter;
176 #include "reparametrization_impl.hpp"
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
size_t & OutputSize()
Modify the output size.
The core includes that mlpack expects; standard C++ includes and Armadillo.
OutputDataType & OutputParameter()
Modify the output parameter.
Reparametrization(const size_t latentSize, const bool stochastic=true, const bool includeKl=true, const double beta=1)
Create the Reparametrization layer object using the specified sample vector size.
OutputDataType const & OutputParameter() const
Get the output parameter.
size_t const & OutputSize() const
Get the output size.
Linear algebra utility functions, generally performed on matrices or vectors.
OutputDataType const & Delta() const
Get the delta.
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
bool Stochastic() const
Get the value of the stochastic parameter.
Reparametrization()
Create the Reparametrization object.
Implementation of the Reparametrization layer class.
double Loss()
Get the KL divergence with standard normal.
bool IncludeKL() const
Get the value of the includeKl parameter.
OutputDataType & Delta()
Modify the delta.
double Beta() const
Get the value of the beta hyperparameter.
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...