mlpack  3.4.2
gan.hpp
Go to the documentation of this file.
1 
11 #ifndef MLPACK_METHODS_ANN_GAN_GAN_HPP
12 #define MLPACK_METHODS_ANN_GAN_GAN_HPP
13 
14 #include <mlpack/core.hpp>
15 
23 
24 
25 namespace mlpack {
26 namespace ann {
27 
57 template<
58  typename Model,
59  typename InitializationRuleType,
60  typename Noise,
61  typename PolicyType = StandardGAN
62 >
63 class GAN
64 {
65  public:
83  GAN(Model generator,
84  Model discriminator,
85  InitializationRuleType& initializeRule,
86  Noise& noiseFunction,
87  const size_t noiseDim,
88  const size_t batchSize,
89  const size_t generatorUpdateStep,
90  const size_t preTrainSize,
91  const double multiplier,
92  const double clippingParameter = 0.01,
93  const double lambda = 10.0);
94 
96  GAN(const GAN&);
97 
99  GAN(GAN&&);
100 
107  void ResetData(arma::mat trainData);
108 
109  // Reset function.
110  void Reset();
111 
123  template<typename OptimizerType, typename... CallbackTypes>
124  double Train(arma::mat trainData,
125  OptimizerType& Optimizer,
126  CallbackTypes&&... callbacks);
127 
137  template<typename Policy = PolicyType>
138  typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
139  std::is_same<Policy, DCGAN>::value, double>::type
140  Evaluate(const arma::mat& parameters,
141  const size_t i,
142  const size_t batchSize);
143 
152  template<typename Policy = PolicyType>
153  typename std::enable_if<std::is_same<Policy, WGAN>::value,
154  double>::type
155  Evaluate(const arma::mat& parameters,
156  const size_t i,
157  const size_t batchSize);
158 
167  template<typename Policy = PolicyType>
168  typename std::enable_if<std::is_same<Policy, WGANGP>::value,
169  double>::type
170  Evaluate(const arma::mat& parameters,
171  const size_t i,
172  const size_t batchSize);
173 
184  template<typename GradType, typename Policy = PolicyType>
185  typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
186  std::is_same<Policy, DCGAN>::value, double>::type
187  EvaluateWithGradient(const arma::mat& parameters,
188  const size_t i,
189  GradType& gradient,
190  const size_t batchSize);
191 
202  template<typename GradType, typename Policy = PolicyType>
203  typename std::enable_if<std::is_same<Policy, WGAN>::value,
204  double>::type
205  EvaluateWithGradient(const arma::mat& parameters,
206  const size_t i,
207  GradType& gradient,
208  const size_t batchSize);
209 
220  template<typename GradType, typename Policy = PolicyType>
221  typename std::enable_if<std::is_same<Policy, WGANGP>::value,
222  double>::type
223  EvaluateWithGradient(const arma::mat& parameters,
224  const size_t i,
225  GradType& gradient,
226  const size_t batchSize);
227 
238  template<typename Policy = PolicyType>
239  typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
240  std::is_same<Policy, DCGAN>::value, void>::type
241  Gradient(const arma::mat& parameters,
242  const size_t i,
243  arma::mat& gradient,
244  const size_t batchSize);
245 
256  template<typename Policy = PolicyType>
257  typename std::enable_if<std::is_same<Policy, WGAN>::value, void>::type
258  Gradient(const arma::mat& parameters,
259  const size_t i,
260  arma::mat& gradient,
261  const size_t batchSize);
262 
273  template<typename Policy = PolicyType>
274  typename std::enable_if<std::is_same<Policy, WGANGP>::value,
275  void>::type
276  Gradient(const arma::mat& parameters,
277  const size_t i,
278  arma::mat& gradient,
279  const size_t batchSize);
280 
285  void Shuffle();
286 
292  void Forward(const arma::mat& input);
293 
300  void Predict(arma::mat input, arma::mat& output);
301 
303  const arma::mat& Parameters() const { return parameter; }
305  arma::mat& Parameters() { return parameter; }
306 
308  const Model& Generator() const { return generator; }
310  Model& Generator() { return generator; }
312  const Model& Discriminator() const { return discriminator; }
314  Model& Discriminator() { return discriminator; }
315 
317  size_t NumFunctions() const { return numFunctions; }
318 
320  const arma::mat& Responses() const { return responses; }
322  arma::mat& Responses() { return responses; }
323 
325  const arma::mat& Predictors() const { return predictors; }
327  arma::mat& Predictors() { return predictors; }
328 
330  template<typename Archive>
331  void serialize(Archive& ar, const unsigned int /* version */);
332 
333  private:
338  void ResetDeterministic();
339 
341  arma::mat predictors;
343  arma::mat parameter;
345  Model generator;
347  Model discriminator;
349  InitializationRuleType initializeRule;
351  Noise noiseFunction;
353  size_t noiseDim;
355  size_t numFunctions;
357  size_t batchSize;
359  size_t currentBatch;
361  size_t generatorUpdateStep;
363  size_t preTrainSize;
365  double multiplier;
367  double clippingParameter;
369  double lambda;
371  bool reset;
373  DeltaVisitor deltaVisitor;
375  arma::mat responses;
377  arma::mat currentInput;
379  arma::mat currentTarget;
381  OutputParameterVisitor outputParameterVisitor;
383  WeightSizeVisitor weightSizeVisitor;
385  ResetVisitor resetVisitor;
387  arma::mat gradient;
389  arma::mat gradientDiscriminator;
391  arma::mat noiseGradientDiscriminator;
393  arma::mat normGradientDiscriminator;
395  arma::mat noise;
397  arma::mat gradientGenerator;
399  bool deterministic;
401  size_t genWeights;
403  size_t discWeights;
404 };
405 
406 } // namespace ann
407 } // namespace mlpack
408 
409 // Include implementation.
410 #include "gan_impl.hpp"
411 #include "wgan_impl.hpp"
412 #include "wgangp_impl.hpp"
413 
414 
415 #endif
reset_visitor.hpp
mlpack::ann::GAN::GAN
GAN(const GAN &)
Copy constructor.
weight_set_visitor.hpp
ffn.hpp
mlpack::ann::GAN::Gradient
std::enable_if< std::is_same< Policy, WGANGP >::value, void >::type Gradient(const arma::mat &parameters, const size_t i, arma::mat &gradient, const size_t batchSize)
Gradient function for WGAN-GP.
mlpack::ann::GAN::Responses
const arma::mat & Responses() const
Get the matrix of responses to the input data points.
Definition: gan.hpp:320
mlpack::ann::GAN::serialize
void serialize(Archive &ar, const unsigned int)
Serialize the model.
mlpack::ann::GAN::Gradient
std::enable_if< std::is_same< Policy, WGAN >::value, void >::type Gradient(const arma::mat &parameters, const size_t i, arma::mat &gradient, const size_t batchSize)
Gradient function for WGAN.
mlpack::ann::GAN::Generator
const Model & Generator() const
Return the generator of the GAN.
Definition: gan.hpp:308
inception_score.hpp
mlpack::ann::GAN::Predictors
arma::mat & Predictors()
Modify the matrix of data points (predictors).
Definition: gan.hpp:327
mlpack::ann::GAN::Generator
Model & Generator()
Modify the generator of the GAN.
Definition: gan.hpp:310
mlpack::ann::GAN::Reset
void Reset()
mlpack::ann::GAN::Predict
void Predict(arma::mat input, arma::mat &output)
This function predicts the output of the network on the given input.
mlpack::ann::GAN::EvaluateWithGradient
std::enable_if< std::is_same< Policy, WGAN >::value, double >::type EvaluateWithGradient(const arma::mat &parameters, const size_t i, GradType &gradient, const size_t batchSize)
EvaluateWithGradient function for the WGAN.
mlpack::ann::GAN::Discriminator
const Model & Discriminator() const
Return the discriminator of the GAN.
Definition: gan.hpp:312
mlpack::ann::GAN::Forward
void Forward(const arma::mat &input)
This function does a forward pass through the GAN network.
mlpack
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: add_to_cli11.hpp:21
mlpack::ann::GAN::Evaluate
std::enable_if< std::is_same< Policy, WGAN >::value, double >::type Evaluate(const arma::mat &parameters, const size_t i, const size_t batchSize)
Evaluate function for the WGAN.
mlpack::ann::GAN::Shuffle
void Shuffle()
Shuffle the order of function visitation.
mlpack::ann::GAN::Predictors
const arma::mat & Predictors() const
Get the matrix of data points (predictors).
Definition: gan.hpp:325
mlpack::ann::GAN::GAN
GAN(Model generator, Model discriminator, InitializationRuleType &initializeRule, Noise &noiseFunction, const size_t noiseDim, const size_t batchSize, const size_t generatorUpdateStep, const size_t preTrainSize, const double multiplier, const double clippingParameter=0.01, const double lambda=10.0)
Constructor for GAN class.
mlpack::ann::GAN::NumFunctions
size_t NumFunctions() const
Return the number of separable functions (the number of predictor points).
Definition: gan.hpp:317
mlpack::ann::GAN::Gradient
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, void >::type Gradient(const arma::mat &parameters, const size_t i, arma::mat &gradient, const size_t batchSize)
Gradient function for Standard GAN and DCGAN.
mlpack::ann::ResetVisitor
ResetVisitor executes the Reset() function.
Definition: reset_visitor.hpp:27
mlpack::ann::GAN::GAN
GAN(GAN &&)
Move constructor.
mlpack::ann::GAN::EvaluateWithGradient
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type EvaluateWithGradient(const arma::mat &parameters, const size_t i, GradType &gradient, const size_t batchSize)
EvaluateWithGradient function for the Standard GAN and DCGAN.
mlpack::ann::GAN
The implementation of the standard GAN module.
Definition: gan.hpp:64
mlpack::ann::GAN::Parameters
const arma::mat & Parameters() const
Return the parameters of the network.
Definition: gan.hpp:303
mlpack::ann::OutputParameterVisitor
OutputParameterVisitor exposes the output parameter of the given module.
Definition: output_parameter_visitor.hpp:28
mlpack::ann::DeltaVisitor
DeltaVisitor exposes the delta parameter of the given module.
Definition: delta_visitor.hpp:28
mlpack::ann::GAN::Train
double Train(arma::mat trainData, OptimizerType &Optimizer, CallbackTypes &&... callbacks)
Train function.
mlpack::ann::GAN::Evaluate
std::enable_if< std::is_same< Policy, WGANGP >::value, double >::type Evaluate(const arma::mat &parameters, const size_t i, const size_t batchSize)
Evaluate function for the WGAN-GP.
mlpack::ann::GAN::ResetData
void ResetData(arma::mat trainData)
Initialize the generator, discriminator and weights of the model for training.
output_parameter_visitor.hpp
mlpack::ann::GAN::Responses
arma::mat & Responses()
Modify the matrix of responses to the input data points.
Definition: gan.hpp:322
mlpack::ann::GAN::Evaluate
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type Evaluate(const arma::mat &parameters, const size_t i, const size_t batchSize)
Evaluate function for the Standard GAN and DCGAN.
mlpack::ann::GAN::Discriminator
Model & Discriminator()
Modify the discriminator of the GAN.
Definition: gan.hpp:314
core.hpp
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
mlpack::ann::GAN::EvaluateWithGradient
std::enable_if< std::is_same< Policy, WGANGP >::value, double >::type EvaluateWithGradient(const arma::mat &parameters, const size_t i, GradType &gradient, const size_t batchSize)
EvaluateWithGradient function for the WGAN-GP.
mlpack::ann::GAN::Parameters
arma::mat & Parameters()
Modify the parameters of the network.
Definition: gan.hpp:305
gan_policies.hpp
weight_size_visitor.hpp
mlpack::ann::WeightSizeVisitor
WeightSizeVisitor returns the number of weights of the given module.
Definition: weight_size_visitor.hpp:28