Stan  1.0
probability, sampling & optimization
nuts_diag.hpp
Go to the documentation of this file.
1 #ifndef __STAN__MCMC__NUTS_DIAG_H__
2 #define __STAN__MCMC__NUTS_DIAG_H__
3 
4 #include <ctime>
5 #include <cstddef>
6 #include <iostream>
7 #include <vector>
8 
9 #include <boost/random/normal_distribution.hpp>
10 #include <boost/random/mersenne_twister.hpp>
11 #include <boost/random/variate_generator.hpp>
12 #include <boost/random/uniform_01.hpp>
13 
14 #include <stan/math/util.hpp>
17 #include <stan/mcmc/hmc_base.hpp>
18 #include <stan/mcmc/util.hpp>
19 
20 namespace stan {
21 
22  namespace mcmc {
23 
32  template <class BaseRNG = boost::mt19937>
33  class nuts_diag : public hmc_base<BaseRNG> {
34  private:
35 
36  // Stop immediately if H < u - _maxchange
37  const double _maxchange;
38 
39  // Limit tree depth
40  const int _maxdepth;
41 
42  // Depth of last sample taken (-1 before any samples)
43  int _lastdepth;
44 
45  // Vector of per-parameter step sizes.
46  std::vector<double> _step_sizes;
47  // Running statistics to estimate per-coordinate std. deviations.
48  std::vector<double> _x_sum;
49  std::vector<double> _xsq_sum;
50  int _x_sum_n;
51  // Next time we should adapt the per-parameter step sizes.
52  int _next_diag_adapt;
53 
61  inline static bool compute_criterion(std::vector<double>& xplus,
62  std::vector<double>& xminus,
63  std::vector<double>& mplus,
64  std::vector<double>& mminus) {
65  std::vector<double> total_direction;
66  stan::math::sub(xplus, xminus, total_direction);
67  return stan::math::dot(total_direction, mminus) > 0
68  && stan::math::dot(total_direction, mplus) > 0;
69  }
70 
71  public:
72 
100  int maxdepth = 10,
101  double epsilon = -1,
102  double epsilon_pm = 0.0,
103  bool epsilon_adapt = true,
104  double delta = 0.6,
105  double gamma = 0.05,
106  BaseRNG base_rng = BaseRNG(std::time(0)),
107  const std::vector<double>* params_r = 0,
108  const std::vector<int>* params_i = 0)
109  : hmc_base<BaseRNG>(model,
110  epsilon,
111  epsilon_pm,
112  epsilon_adapt,
113  delta,
114  gamma,
115  base_rng,
116  params_r,
117  params_i),
118 
119  _maxchange(-1000),
120  _maxdepth(maxdepth),
121  _lastdepth(-1),
122 
123  _step_sizes(model.num_params_r(), 1.0),
124  _x_sum(model.num_params_r(), 0),
125  _xsq_sum(model.num_params_r(), 0),
126  _x_sum_n(0),
127  _next_diag_adapt(10)
128  {
129  // start at 10 * epsilon because NUTS cheaper for larger epsilon
130  this->adaptation_init(10.0);
131  }
132 
139 
145  virtual sample next_impl() {
146  // Initialize the algorithm
147  std::vector<double> mminus(this->_model.num_params_r());
148  for (size_t i = 0; i < mminus.size(); ++i)
149  mminus[i] = this->_rand_unit_norm();
150  std::vector<double> mplus(mminus);
151  // The log-joint probability of the momentum and position terms, i.e.
152  // -(kinetic energy + potential energy)
153  double H0 = -0.5 * stan::math::dot_self(mminus) + this->_logp;
154 
155  std::vector<double> gradminus(this->_g);
156  std::vector<double> gradplus(this->_g);
157  std::vector<double> xminus(this->_x);
158  std::vector<double> xplus(this->_x);
159 
160  // Sample the slice variable
161  double u = log(this->_rand_uniform_01()) + H0;
162  int nvalid = 1;
163  int direction = 2 * (this->_rand_uniform_01() > 0.5) - 1;
164  bool criterion = true;
165 
166  // Repeatedly double the set of points we've visited
167  std::vector<double> newx, newgrad, dummy1, dummy2, dummy3;
168  double newlogp = -1;
169  double prob_sum = -1;
170  int newnvalid = -1;
171  int n_considered = 0;
172  // for-loop with depth outside to set lastdepth
173  int depth = 0;
174 
175  double epsilon = this->_epsilon;
176  // only vary epsilon after done adapting
177  if (!this->adapting() && this->varying_epsilon()) {
178  double low = epsilon * (1.0 - this->_epsilon_pm);
179  double high = epsilon * (1.0 + this->_epsilon_pm);
180  double range = high - low;
181  epsilon = low + (range * this->_rand_uniform_01());
182  }
183  this->_epsilon_last = epsilon; // use epsilon_last in tree build
184 
185  while (criterion && (_maxdepth < 0 || depth <= _maxdepth)) {
186  direction = 2 * (this->_rand_uniform_01() > 0.5) - 1;
187  if (direction == -1)
188  build_tree(xminus, mminus, gradminus, u, direction, depth,
189  H0, xminus, mminus, gradminus, dummy1, dummy2, dummy3,
190  newx, newgrad, newlogp, newnvalid, criterion, prob_sum,
191  n_considered);
192  else
193  build_tree(xplus, mplus, gradplus, u, direction, depth,
194  H0, dummy1, dummy2, dummy3, xplus, mplus, gradplus,
195  newx, newgrad, newlogp, newnvalid, criterion, prob_sum,
196  n_considered);
197  // We can't look at the results of this last doubling if criterion==false
198  if (!criterion)
199  break;
200  criterion = compute_criterion(xplus, xminus, mplus, mminus);
201  // Metropolis-Hastings to determine if we can jump to a point in
202  // the new half-tree
203  if (this->_rand_uniform_01() < float(newnvalid) / (1e-100+float(nvalid))) {
204  this->_x = newx;
205  this->_g = newgrad;
206  this->_logp = newlogp;
207  }
208  nvalid += newnvalid;
209  ++depth;
210  }
211  _lastdepth = depth;
212 
213  // Now we just have to update global (epsilon) and local
214  // (step_sizes) step sizes, if adaptation is on.
215  double adapt_stat = prob_sum / float(n_considered);
216  if (this->adapting()) {
217  // epsilon.
218  double adapt_g = adapt_stat - this->_delta;
219  std::vector<double> gvec(1, -adapt_g);
220  std::vector<double> result;
221  this->_da.update(gvec, result);
222  this->_epsilon = exp(result[0]);
223  // step_sizes. Doesn't happen every step.
224  if (this->_n_adapt_steps == _next_diag_adapt) {
225  _next_diag_adapt *= 2;
226  double step_size_sq_sum = 0;
227  for (size_t i = 0; i < _step_sizes.size(); i++) {
228  double Ex = _x_sum[i] / _x_sum_n;
229  double Exsq = _xsq_sum[i] / _x_sum_n;
230  _x_sum[i] = 0;
231  _xsq_sum[i] = 0;
232  _step_sizes[i] = sqrt(Exsq - Ex*Ex);
233  step_size_sq_sum += _step_sizes[i] * _step_sizes[i];
234  }
235  if (step_size_sq_sum > 0.0) {
236  _x_sum_n = 0;
237  double normalizer = sqrt((double)_step_sizes.size())
238  / sqrt(step_size_sq_sum);
239  for (size_t i = 0; i < _step_sizes.size(); i++)
240  _step_sizes[i] *= normalizer;
241  } else {
242  for (size_t i = 0; i < _step_sizes.size(); i++)
243  _step_sizes[i] = 1.0;
244  }
245  }
246  }
247  std::vector<double> result;
248  this->_da.xbar(result);
249  double avg_eta = 1.0 / this->n_steps();
250  this->update_mean_stat(avg_eta,adapt_stat);
251 
252  return mcmc::sample(this->_x, this->_z, this->_logp);
253  }
254 
255  virtual void write_sampler_param_names(std::ostream& o) {
256  o << "treedepth__,";
257  if (this->_epsilon_adapt || this->varying_epsilon())
258  o << "stepsize__,";
259  }
260 
261  virtual void write_sampler_params(std::ostream& o) {
262  o << _lastdepth << ',';
263  if (this->_epsilon_adapt || this->varying_epsilon())
264  o << this->_epsilon_last << ',';
265  }
266 
267  virtual void write_adaptation_params(std::ostream& o) {
268  o << "# (mcmc::nuts_diag) adaptation finished" << '\n';
269  o << "# step size=" << this->_epsilon << '\n';
270  o << "# parameter step size multipliers:\n"; // FIXME: names/delineation requires access to model
271  o << "# ";
272  for (size_t k = 0; k < _step_sizes.size(); ++k) {
273  if (k > 0) o << ',';
274  o << _step_sizes[k];
275  }
276  o << '\n';
277  }
278 
279  virtual void get_sampler_param_names(std::vector<std::string>& names) {
280  names.clear();
281  names.push_back("treedepth__");
282  if (this->_epsilon_adapt || this->varying_epsilon())
283  names.push_back("stepsize__");
284  }
285 
286  virtual void get_sampler_params(std::vector<double>& values) {
287  values.clear();
288  values.push_back(_lastdepth);
289  if (this->_epsilon_adapt || this->varying_epsilon())
290  values.push_back(this->_epsilon_last);
291  }
292 
327  void build_tree(const std::vector<double>& x,
328  const std::vector<double>& m,
329  const std::vector<double>& grad,
330  double u,
331  int direction,
332  int depth,
333  double H0,
334  std::vector<double>& xminus,
335  std::vector<double>& mminus,
336  std::vector<double>& gradminus,
337  std::vector<double>& xplus,
338  std::vector<double>& mplus,
339  std::vector<double>& gradplus,
340  std::vector<double>& newx,
341  std::vector<double>& newgrad,
342  double& newlogp,
343  int& nvalid,
344  bool& criterion,
345  double& prob_sum,
346  int& n_considered) {
347  if (depth == 0) { // base case
348  xminus = x;
349  gradminus = grad;
350  mminus = m;
351  newlogp = rescaled_leapfrog(this->_model, this->_z, _step_sizes,
352  xminus, mminus, gradminus,
353  direction * this->_epsilon_last,
354  this->_error_msgs, this->_output_msgs);
355  newx = xminus;
356  newgrad = gradminus;
357  xplus = xminus;
358  mplus = mminus;
359  gradplus = gradminus;
360  double newH = newlogp - 0.5 * stan::math::dot_self(mminus);
361  if (newH != newH) // treat nan as -inf
362  newH = -std::numeric_limits<double>::infinity();
363  nvalid = newH > u;
364  criterion = newH - u > _maxchange;
365  prob_sum = stan::math::min(1, exp(newH - H0));
366  n_considered = 1;
367  this->nfevals_plus_eq(1);
368  // Update running statistics if point is in slice
369  if (nvalid) {
370  _x_sum_n++;
371  for (size_t i = 0; i < newx.size(); i++) {
372  _x_sum[i] += newx[i];
373  _xsq_sum[i] += newx[i] * newx[i];
374  }
375  }
376  } else { // depth >= 1
377  build_tree(x, m, grad, u, direction, depth-1, H0, xminus, mminus,
378  gradminus, xplus, mplus, gradplus, newx, newgrad, newlogp,
379  nvalid, criterion, prob_sum, n_considered);
380  if (criterion) {
381  std::vector<double> dummy1, dummy2, dummy3;
382  std::vector<double> newx2;
383  std::vector<double> newgrad2;
384  double newlogp2;
385  int nvalid2;
386  bool criterion2;
387  double prob_sum2;
388  int n_considered2;
389  if (direction == -1)
390  build_tree(xminus, mminus, gradminus, u, direction, depth-1, H0,
391  xminus, mminus, gradminus, dummy1, dummy2, dummy3,
392  newx2, newgrad2, newlogp2, nvalid2, criterion2,
393  prob_sum2, n_considered2);
394  else
395  build_tree(xplus, mplus, gradplus, u, direction, depth-1, H0,
396  dummy1, dummy2, dummy3, xplus, mplus, gradplus,
397  newx2, newgrad2, newlogp2, nvalid2, criterion2,
398  prob_sum2, n_considered2);
399  if (criterion &&
400  (this->_rand_uniform_01()
401  < float(nvalid2) / float(nvalid+nvalid2))){
402  newx = newx2;
403  newgrad = newgrad2;
404  newlogp = newlogp2;
405  }
406  n_considered += n_considered2;
407  prob_sum += prob_sum2;
408  criterion &= criterion2;
409  nvalid += nvalid2;
410  }
411  criterion &= compute_criterion(xplus, xminus, mplus, mminus);
412  }
413  }
414 
415 
416  };
417 
418 
419  }
420 
421 }
422 
423 #endif
void update(const std::vector< double > &g, std::vector< double > &xk)
Produces the next iterate xk given the current gradient g.
Definition: dualaverage.hpp:53
void xbar(std::vector< double > &xbar)
Get the exponentially weighted moving average of all previous iterates.
Definition: dualaverage.hpp:86
bool adapting()
Return whether or not parameter adaptation is on.
int n_steps()
Return the number of iterations for this sampler.
void update_mean_stat(double avg_eta, double adapt_stat)
Updates the mean statistic given the specified adaptation statistic and weighting.
void nfevals_plus_eq(int n)
Add the specified number of evaluations to the number of function evaluations.
boost::uniform_01< boost::mt19937 & > _rand_uniform_01
Definition: hmc_base.hpp:44
stan::model::prob_grad & _model
Definition: hmc_base.hpp:27
void adaptation_init(double epsilon_scale)
Definition: hmc_base.hpp:90
boost::variate_generator< boost::mt19937 &, boost::normal_distribution<> > _rand_unit_norm
Definition: hmc_base.hpp:41
No-U-Turn Sampler (NUTS) with varying step sizes.
Definition: nuts_diag.hpp:33
virtual void write_sampler_param_names(std::ostream &o)
Write out any sampler-specific parameter names for output.
Definition: nuts_diag.hpp:255
void build_tree(const std::vector< double > &x, const std::vector< double > &m, const std::vector< double > &grad, double u, int direction, int depth, double H0, std::vector< double > &xminus, std::vector< double > &mminus, std::vector< double > &gradminus, std::vector< double > &xplus, std::vector< double > &mplus, std::vector< double > &gradplus, std::vector< double > &newx, std::vector< double > &newgrad, double &newlogp, int &nvalid, bool &criterion, double &prob_sum, int &n_considered)
The core recursion in NUTS.
Definition: nuts_diag.hpp:327
virtual void get_sampler_param_names(std::vector< std::string > &names)
Get any sampler-specific parameter namess.
Definition: nuts_diag.hpp:279
virtual void write_sampler_params(std::ostream &o)
Write out any sampler-specific parameters for output.
Definition: nuts_diag.hpp:261
~nuts_diag()
Destroy this sampler.
Definition: nuts_diag.hpp:138
virtual void get_sampler_params(std::vector< double > &values)
Get any sampler-specific parameters.
Definition: nuts_diag.hpp:286
virtual void write_adaptation_params(std::ostream &o)
Use this method to write the adaptation parameters into the output.
Definition: nuts_diag.hpp:267
virtual sample next_impl()
Return the next sample.
Definition: nuts_diag.hpp:145
nuts_diag(stan::model::prob_grad &model, int maxdepth=10, double epsilon=-1, double epsilon_pm=0.0, bool epsilon_adapt=true, double delta=0.6, double gamma=0.05, BaseRNG base_rng=BaseRNG(std::time(0)), const std::vector< double > *params_r=0, const std::vector< int > *params_i=0)
Construct a No-U-Turn Sampler (NUTS) for the specified model, using the specified step size and numbe...
Definition: nuts_diag.hpp:99
Representation of a MCMC sample.
Definition: sampler.hpp:16
The prob_grad class represents densities with fixed numbers of discrete and scalar parameters and the...
Definition: prob_grad.hpp:24
virtual size_t num_params_r()
Definition: prob_grad.hpp:52
var sqrt(const var &a)
Return the square root of the specified variable (cmath).
Definition: agrad.hpp:1761
static void grad(chainable *vi)
Compute the gradient for all variables starting from the specified root variable implementation.
Definition: agrad.hpp:2187
var log(const var &a)
Return the natural log of the specified variable (cmath).
Definition: agrad.hpp:1730
var exp(const var &a)
Return the exponentiation of the specified variable (cmath).
Definition: agrad.hpp:1716
double epsilon()
Return minimum positive number representable.
double e()
Return the base of the natural logarithm.
int min(const std::vector< int > &x)
Returns the minimum coefficient in the specified column vector.
Definition: matrix.hpp:917
double dot(std::vector< double > &x, std::vector< double > &y)
Definition: util.hpp:30
void sub(std::vector< double > &x, std::vector< double > &y, std::vector< double > &result)
Definition: util.hpp:54
double dot_self(const Eigen::Matrix< double, R, C > &v)
Returns the dot product of the specified vector with itself.
Definition: matrix.hpp:846
double rescaled_leapfrog(stan::model::prob_grad &model, std::vector< int > z, const std::vector< double > &step_sizes, std::vector< double > &x, std::vector< double > &m, std::vector< double > &g, double epsilon, std::ostream *error_msgs=0, std::ostream *output_msgs=0)
Definition: util.hpp:79
Probability, optimization and sampling library.
Definition: agrad.cpp:6

     [ Stan Home Page ] © 2011–2012, Stan Development Team.