Stan  1.0
probability, sampling & optimization
nuts.hpp
Go to the documentation of this file.
1 #ifndef __STAN__MCMC__NUTS_H__
2 #define __STAN__MCMC__NUTS_H__
3 
4 #include <ctime>
5 #include <cstddef>
6 #include <iostream>
7 #include <vector>
8 
9 #include <boost/random/normal_distribution.hpp>
10 #include <boost/random/mersenne_twister.hpp>
11 #include <boost/random/variate_generator.hpp>
12 #include <boost/random/uniform_01.hpp>
13 
14 #include <stan/math/util.hpp>
17 #include <stan/mcmc/hmc_base.hpp>
18 #include <stan/mcmc/util.hpp>
19 
20 namespace stan {
21 
22  namespace mcmc {
23 
31  template <class BaseRNG = boost::mt19937>
32  class nuts : public hmc_base<BaseRNG> {
33  private:
34 
35  // Stop immediately if H < u - _maxchange
36  const double _maxchange;
37 
38  // Limit tree depth
39  const int _maxdepth;
40 
41  // depth of last sample taken (-1 before any samples)
42  int _lastdepth;
43 
51  inline static bool compute_criterion(std::vector<double>& xplus,
52  std::vector<double>& xminus,
53  std::vector<double>& mplus,
54  std::vector<double>& mminus) {
55  std::vector<double> total_direction;
56  stan::math::sub(xplus, xminus, total_direction);
57  return stan::math::dot(total_direction, mminus) > 0
58  && stan::math::dot(total_direction, mplus) > 0;
59  }
60 
61  public:
62 
90  int maxdepth = 10,
91  double epsilon = -1,
92  double epsilon_pm = 0.0,
93  bool epsilon_adapt = true,
94  double delta = 0.6,
95  double gamma = 0.05,
96  BaseRNG base_rng = BaseRNG(std::time(0)),
97  const std::vector<double>* params_r = 0,
98  const std::vector<int>* params_i = 0)
99  : hmc_base<BaseRNG>(model,
100  epsilon,
101  epsilon_pm,
102  epsilon_adapt,
103  delta,
104  gamma,
105  base_rng,
106  params_r,
107  params_i),
108  _maxchange(-1000),
109  _maxdepth(maxdepth),
110  _lastdepth(-1)
111  {
112  // start at 10 * epsilon because NUTS cheaper for larger epsilon
113  this->adaptation_init(10.0);
114  }
115 
121  ~nuts() { }
122 
128  virtual sample next_impl() {
129  // Initialize the algorithm
130  std::vector<double> mminus(this->_model.num_params_r());
131  for (size_t i = 0; i < mminus.size(); ++i)
132  mminus[i] = this->_rand_unit_norm();
133  std::vector<double> mplus(mminus);
134  // The log-joint probability of the momentum and position terms, i.e.
135  // -(kinetic energy + potential energy)
136  double H0 = -0.5 * stan::math::dot_self(mminus) + this->_logp;
137 
138  std::vector<double> gradminus(this->_g);
139  std::vector<double> gradplus(this->_g);
140  std::vector<double> xminus(this->_x);
141  std::vector<double> xplus(this->_x);
142 
143  // Sample the slice variable
144  double u = log(this->_rand_uniform_01()) + H0;
145  int nvalid = 1;
146  int direction = 2 * (this->_rand_uniform_01() > 0.5) - 1;
147  bool criterion = true;
148 
149  // Repeatedly double the set of points we've visited
150  std::vector<double> newx, newgrad, dummy1, dummy2, dummy3;
151  double newlogp = -1;
152  double prob_sum = -1;
153  int newnvalid = -1;
154  int n_considered = 0;
155  // for-loop with depth outside to set lastdepth
156  int depth = 0;
157 
158  double epsilon = this->_epsilon;
159  // only vary epsilon after done adapting
160  if (!this->adapting() && this->varying_epsilon()) {
161  double low = epsilon * (1.0 - this->_epsilon_pm);
162  double high = epsilon * (1.0 + this->_epsilon_pm);
163  double range = high - low;
164  epsilon = low + (range * this->_rand_uniform_01());
165  }
166  this->_epsilon_last = epsilon; // use epsilon_last in tree build
167 
168  while (criterion && (_maxdepth < 0 || depth <= _maxdepth)) {
169  direction = 2 * (this->_rand_uniform_01() > 0.5) - 1;
170  if (direction == -1)
171  build_tree(xminus, mminus, gradminus, u, direction, depth,
172  H0, xminus, mminus, gradminus, dummy1, dummy2, dummy3,
173  newx, newgrad, newlogp, newnvalid, criterion, prob_sum,
174  n_considered);
175  else
176  build_tree(xplus, mplus, gradplus, u, direction, depth,
177  H0, dummy1, dummy2, dummy3, xplus, mplus, gradplus,
178  newx, newgrad, newlogp, newnvalid, criterion, prob_sum,
179  n_considered);
180  // We can't look at the results of this last doubling if criterion==false
181  if (!criterion)
182  break;
183  criterion = compute_criterion(xplus, xminus, mplus, mminus);
184  // Metropolis-Hastings to determine if we can jump to a point in
185  // the new half-tree
186  if (this->_rand_uniform_01() < float(newnvalid) / (1e-100+float(nvalid))) {
187  this->_x = newx;
188  this->_g = newgrad;
189  this->_logp = newlogp;
190  }
191  nvalid += newnvalid;
192 // fprintf(stderr, "depth = %d, this->_logp = %g\n", depth, this->_logp);
193  ++depth;
194  }
195  _lastdepth = depth;
196 
197  // Now we just have to update epsilon, if adaptation is on.
198  double adapt_stat = prob_sum / float(n_considered);
199  if (this->adapting()) {
200  double adapt_g = adapt_stat - this->_delta;
201  std::vector<double> gvec(1, -adapt_g);
202  std::vector<double> result;
203  this->_da.update(gvec, result);
204  this->_epsilon = exp(result[0]);
205  }
206  std::vector<double> result;
207  this->_da.xbar(result);
208 // fprintf(stderr, "xbar = %f\n", exp(result[0]));
209  double avg_eta = 1.0 / this->n_steps();
210  this->update_mean_stat(avg_eta,adapt_stat);
211 
212  mcmc::sample s(this->_x, this->_z, this->_logp);
213  return s;
214  }
215 
216  virtual void write_sampler_param_names(std::ostream& o) {
217  o << "treedepth__,";
218  if (this->_epsilon_adapt || this->varying_epsilon())
219  o << "stepsize__,";
220  }
221 
222  virtual void write_sampler_params(std::ostream& o) {
223  o << _lastdepth << ',';
224  if (this->_epsilon_adapt || this->varying_epsilon())
225  o << this->_epsilon_last << ',';
226  }
227 
228  virtual void get_sampler_param_names(std::vector<std::string>& names) {
229  names.clear();
230  names.push_back("treedepth__");
231  if (this->_epsilon_adapt || this->varying_epsilon())
232  names.push_back("stepsize__");
233  }
234  virtual void get_sampler_params(std::vector<double>& values) {
235  values.clear();
236  values.push_back(_lastdepth);
237  if (this->_epsilon_adapt || this->varying_epsilon())
238  values.push_back(this->_epsilon_last);
239  }
240 
275  void build_tree(const std::vector<double>& x,
276  const std::vector<double>& m,
277  const std::vector<double>& grad,
278  double u,
279  int direction,
280  int depth,
281  double H0,
282  std::vector<double>& xminus,
283  std::vector<double>& mminus,
284  std::vector<double>& gradminus,
285  std::vector<double>& xplus,
286  std::vector<double>& mplus,
287  std::vector<double>& gradplus,
288  std::vector<double>& newx,
289  std::vector<double>& newgrad,
290  double& newlogp,
291  int& nvalid,
292  bool& criterion,
293  double& prob_sum,
294  int& n_considered) {
295  if (depth == 0) { // base case
296  xminus = x;
297  gradminus = grad;
298  mminus = m;
299  // FIXME: lepfrog needs +/- this->_epsilon_pm
300  newlogp = leapfrog(this->_model, this->_z, xminus, mminus, gradminus,
301  direction * this->_epsilon_last,
302  this->_error_msgs, this->_output_msgs);
303  newx = xminus;
304  newgrad = gradminus;
305  xplus = xminus;
306  mplus = mminus;
307  gradplus = gradminus;
308  double newH = newlogp - 0.5 * stan::math::dot_self(mminus);
309  if (newH != newH) // treat nan as -inf
310  newH = -std::numeric_limits<double>::infinity();
311  nvalid = newH > u;
312  criterion = newH - u > _maxchange;
313  prob_sum = stan::math::min(1, exp(newH - H0));
314  n_considered = 1;
315  this->nfevals_plus_eq(1);
316  } else { // depth >= 1
317  build_tree(x, m, grad, u, direction, depth-1, H0, xminus, mminus,
318  gradminus, xplus, mplus, gradplus, newx, newgrad, newlogp,
319  nvalid, criterion, prob_sum, n_considered);
320  if (criterion) {
321  std::vector<double> dummy1, dummy2, dummy3;
322  std::vector<double> newx2;
323  std::vector<double> newgrad2;
324  double newlogp2;
325  int nvalid2;
326  bool criterion2;
327  double prob_sum2;
328  int n_considered2;
329  if (direction == -1)
330  build_tree(xminus, mminus, gradminus, u, direction, depth-1, H0,
331  xminus, mminus, gradminus, dummy1, dummy2, dummy3,
332  newx2, newgrad2, newlogp2, nvalid2, criterion2,
333  prob_sum2, n_considered2);
334  else
335  build_tree(xplus, mplus, gradplus, u, direction, depth-1, H0,
336  dummy1, dummy2, dummy3, xplus, mplus, gradplus,
337  newx2, newgrad2, newlogp2, nvalid2, criterion2,
338  prob_sum2, n_considered2);
339  if (criterion &&
340  (this->_rand_uniform_01()
341  < float(nvalid2) / float(nvalid+nvalid2))) {
342  newx = newx2;
343  newgrad = newgrad2;
344  newlogp = newlogp2;
345  }
346  n_considered += n_considered2;
347  prob_sum += prob_sum2;
348  criterion &= criterion2;
349  nvalid += nvalid2;
350  }
351  criterion &= compute_criterion(xplus, xminus, mplus, mminus);
352  }
353  }
354 
355 
356 
357  };
358 
359 
360  }
361 
362 }
363 
364 #endif
void update(const std::vector< double > &g, std::vector< double > &xk)
Produces the next iterate xk given the current gradient g.
Definition: dualaverage.hpp:53
void xbar(std::vector< double > &xbar)
Get the exponentially weighted moving average of all previous iterates.
Definition: dualaverage.hpp:86
bool adapting()
Return whether or not parameter adaptation is on.
int n_steps()
Return the number of iterations for this sampler.
void update_mean_stat(double avg_eta, double adapt_stat)
Updates the mean statistic given the specified adaptation statistic and weighting.
void nfevals_plus_eq(int n)
Add the specified number of evaluations to the number of function evaluations.
boost::uniform_01< boost::mt19937 & > _rand_uniform_01
Definition: hmc_base.hpp:44
stan::model::prob_grad & _model
Definition: hmc_base.hpp:27
void adaptation_init(double epsilon_scale)
Definition: hmc_base.hpp:90
boost::variate_generator< boost::mt19937 &, boost::normal_distribution<> > _rand_unit_norm
Definition: hmc_base.hpp:41
No-U-Turn Sampler (NUTS).
Definition: nuts.hpp:32
virtual sample next_impl()
Return the next sample.
Definition: nuts.hpp:128
virtual void write_sampler_params(std::ostream &o)
Write out any sampler-specific parameters for output.
Definition: nuts.hpp:222
virtual void get_sampler_params(std::vector< double > &values)
Get any sampler-specific parameters.
Definition: nuts.hpp:234
virtual void get_sampler_param_names(std::vector< std::string > &names)
Get any sampler-specific parameter namess.
Definition: nuts.hpp:228
~nuts()
Destructor.
Definition: nuts.hpp:121
virtual void write_sampler_param_names(std::ostream &o)
Write out any sampler-specific parameter names for output.
Definition: nuts.hpp:216
void build_tree(const std::vector< double > &x, const std::vector< double > &m, const std::vector< double > &grad, double u, int direction, int depth, double H0, std::vector< double > &xminus, std::vector< double > &mminus, std::vector< double > &gradminus, std::vector< double > &xplus, std::vector< double > &mplus, std::vector< double > &gradplus, std::vector< double > &newx, std::vector< double > &newgrad, double &newlogp, int &nvalid, bool &criterion, double &prob_sum, int &n_considered)
The core recursion in NUTS.
Definition: nuts.hpp:275
nuts(stan::model::prob_grad &model, int maxdepth=10, double epsilon=-1, double epsilon_pm=0.0, bool epsilon_adapt=true, double delta=0.6, double gamma=0.05, BaseRNG base_rng=BaseRNG(std::time(0)), const std::vector< double > *params_r=0, const std::vector< int > *params_i=0)
Construct a No-U-Turn Sampler (NUTS) for the specified model, using the specified step size and numbe...
Definition: nuts.hpp:89
Representation of a MCMC sample.
Definition: sampler.hpp:16
The prob_grad class represents densities with fixed numbers of discrete and scalar parameters and the...
Definition: prob_grad.hpp:24
virtual size_t num_params_r()
Definition: prob_grad.hpp:52
static void grad(chainable *vi)
Compute the gradient for all variables starting from the specified root variable implementation.
Definition: agrad.hpp:2187
var log(const var &a)
Return the natural log of the specified variable (cmath).
Definition: agrad.hpp:1730
var exp(const var &a)
Return the exponentiation of the specified variable (cmath).
Definition: agrad.hpp:1716
double epsilon()
Return minimum positive number representable.
double e()
Return the base of the natural logarithm.
int min(const std::vector< int > &x)
Returns the minimum coefficient in the specified column vector.
Definition: matrix.hpp:917
double dot(std::vector< double > &x, std::vector< double > &y)
Definition: util.hpp:30
void sub(std::vector< double > &x, std::vector< double > &y, std::vector< double > &result)
Definition: util.hpp:54
double dot_self(const Eigen::Matrix< double, R, C > &v)
Returns the dot product of the specified vector with itself.
Definition: matrix.hpp:846
double leapfrog(stan::model::prob_grad &model, std::vector< int > z, std::vector< double > &x, std::vector< double > &m, std::vector< double > &g, double epsilon, std::ostream *error_msgs=0, std::ostream *output_msgs=0)
Computes the log probability for a single leapfrog step in Hamiltonian Monte Carlo.
Definition: util.hpp:57
Probability, optimization and sampling library.
Definition: agrad.cpp:6

     [ Stan Home Page ] © 2011–2012, Stan Development Team.