Stan  1.0
probability, sampling & optimization
 All Classes Namespaces Files Functions Variables Typedefs Enumerator Friends Macros Pages
nuts.hpp
Go to the documentation of this file.
1 #ifndef __STAN__MCMC__NUTS_H__
2 #define __STAN__MCMC__NUTS_H__
3 
4 #include <ctime>
5 #include <cstddef>
6 #include <iostream>
7 #include <vector>
8 
9 #include <boost/random/normal_distribution.hpp>
10 #include <boost/random/mersenne_twister.hpp>
11 #include <boost/random/variate_generator.hpp>
12 #include <boost/random/uniform_01.hpp>
13 
14 #include <stan/math/util.hpp>
17 #include <stan/mcmc/hmc_base.hpp>
18 #include <stan/mcmc/util.hpp>
19 
20 namespace stan {
21 
22  namespace mcmc {
23 
31  template <class BaseRNG = boost::mt19937>
32  class nuts : public hmc_base<BaseRNG> {
33  private:
34 
35  // Stop immediately if H < u - _maxchange
36  const double _maxchange;
37 
38  // Limit tree depth
39  const int _maxdepth;
40 
41  // depth of last sample taken (-1 before any samples)
42  int _lastdepth;
43 
51  inline static bool compute_criterion(std::vector<double>& xplus,
52  std::vector<double>& xminus,
53  std::vector<double>& mplus,
54  std::vector<double>& mminus) {
55  std::vector<double> total_direction;
56  stan::math::sub(xplus, xminus, total_direction);
57  return stan::math::dot(total_direction, mminus) > 0
58  && stan::math::dot(total_direction, mplus) > 0;
59  }
60 
61  public:
62 
90  int maxdepth = 10,
91  double epsilon = -1,
92  double epsilon_pm = 0.0,
93  bool epsilon_adapt = true,
94  double delta = 0.6,
95  double gamma = 0.05,
96  BaseRNG base_rng = BaseRNG(std::time(0)),
97  const std::vector<double>* params_r = 0,
98  const std::vector<int>* params_i = 0)
99  : hmc_base<BaseRNG>(model,
100  epsilon,
101  epsilon_pm,
102  epsilon_adapt,
103  delta,
104  gamma,
105  base_rng,
106  params_r,
107  params_i),
108  _maxchange(-1000),
109  _maxdepth(maxdepth),
110  _lastdepth(-1)
111  {
112  // start at 10 * epsilon because NUTS cheaper for larger epsilon
113  this->adaptation_init(10.0);
114  }
115 
121  ~nuts() { }
122 
128  virtual sample next_impl() {
129  // Initialize the algorithm
130  std::vector<double> mminus(this->_model.num_params_r());
131  for (size_t i = 0; i < mminus.size(); ++i)
132  mminus[i] = this->_rand_unit_norm();
133  std::vector<double> mplus(mminus);
134  // The log-joint probability of the momentum and position terms, i.e.
135  // -(kinetic energy + potential energy)
136  double H0 = -0.5 * stan::math::dot_self(mminus) + this->_logp;
137 
138  std::vector<double> gradminus(this->_g);
139  std::vector<double> gradplus(this->_g);
140  std::vector<double> xminus(this->_x);
141  std::vector<double> xplus(this->_x);
142 
143  // Sample the slice variable
144  double u = log(this->_rand_uniform_01()) + H0;
145  int nvalid = 1;
146  int direction = 2 * (this->_rand_uniform_01() > 0.5) - 1;
147  bool criterion = true;
148 
149  // Repeatedly double the set of points we've visited
150  std::vector<double> newx, newgrad, dummy1, dummy2, dummy3;
151  double newlogp = -1;
152  double prob_sum = -1;
153  int newnvalid = -1;
154  int n_considered = 0;
155  // for-loop with depth outside to set lastdepth
156  int depth = 0;
157 
158  double epsilon = this->_epsilon;
159  // only vary epsilon after done adapting
160  if (!this->adapting() && this->varying_epsilon()) {
161  double low = epsilon * (1.0 - this->_epsilon_pm);
162  double high = epsilon * (1.0 + this->_epsilon_pm);
163  double range = high - low;
164  epsilon = low + (range * this->_rand_uniform_01());
165  }
166  this->_epsilon_last = epsilon; // use epsilon_last in tree build
167 
168  while (criterion && (_maxdepth < 0 || depth <= _maxdepth)) {
169  direction = 2 * (this->_rand_uniform_01() > 0.5) - 1;
170  if (direction == -1)
171  build_tree(xminus, mminus, gradminus, u, direction, depth,
172  H0, xminus, mminus, gradminus, dummy1, dummy2, dummy3,
173  newx, newgrad, newlogp, newnvalid, criterion, prob_sum,
174  n_considered);
175  else
176  build_tree(xplus, mplus, gradplus, u, direction, depth,
177  H0, dummy1, dummy2, dummy3, xplus, mplus, gradplus,
178  newx, newgrad, newlogp, newnvalid, criterion, prob_sum,
179  n_considered);
180  // We can't look at the results of this last doubling if criterion==false
181  if (!criterion)
182  break;
183  criterion = compute_criterion(xplus, xminus, mplus, mminus);
184  // Metropolis-Hastings to determine if we can jump to a point in
185  // the new half-tree
186  if (this->_rand_uniform_01() < float(newnvalid) / (1e-100+float(nvalid))) {
187  this->_x = newx;
188  this->_g = newgrad;
189  this->_logp = newlogp;
190  }
191  nvalid += newnvalid;
192 // fprintf(stderr, "depth = %d, this->_logp = %g\n", depth, this->_logp);
193  ++depth;
194  }
195  _lastdepth = depth;
196 
197  // Now we just have to update epsilon, if adaptation is on.
198  double adapt_stat = prob_sum / float(n_considered);
199  if (this->adapting()) {
200  double adapt_g = adapt_stat - this->_delta;
201  std::vector<double> gvec(1, -adapt_g);
202  std::vector<double> result;
203  this->_da.update(gvec, result);
204  this->_epsilon = exp(result[0]);
205  }
206  std::vector<double> result;
207  this->_da.xbar(result);
208 // fprintf(stderr, "xbar = %f\n", exp(result[0]));
209  double avg_eta = 1.0 / this->n_steps();
210  this->update_mean_stat(avg_eta,adapt_stat);
211 
212  mcmc::sample s(this->_x, this->_z, this->_logp);
213  return s;
214  }
215 
216  virtual void write_sampler_param_names(std::ostream& o) {
217  o << "treedepth__,";
218  if (this->_epsilon_adapt || this->varying_epsilon())
219  o << "stepsize__,";
220  }
221 
222  virtual void write_sampler_params(std::ostream& o) {
223  o << _lastdepth << ',';
224  if (this->_epsilon_adapt || this->varying_epsilon())
225  o << this->_epsilon_last << ',';
226  }
227 
228  virtual void get_sampler_param_names(std::vector<std::string>& names) {
229  names.clear();
230  names.push_back("treedepth__");
231  if (this->_epsilon_adapt || this->varying_epsilon())
232  names.push_back("stepsize__");
233  }
234  virtual void get_sampler_params(std::vector<double>& values) {
235  values.clear();
236  values.push_back(_lastdepth);
237  if (this->_epsilon_adapt || this->varying_epsilon())
238  values.push_back(this->_epsilon_last);
239  }
240 
275  void build_tree(const std::vector<double>& x,
276  const std::vector<double>& m,
277  const std::vector<double>& grad,
278  double u,
279  int direction,
280  int depth,
281  double H0,
282  std::vector<double>& xminus,
283  std::vector<double>& mminus,
284  std::vector<double>& gradminus,
285  std::vector<double>& xplus,
286  std::vector<double>& mplus,
287  std::vector<double>& gradplus,
288  std::vector<double>& newx,
289  std::vector<double>& newgrad,
290  double& newlogp,
291  int& nvalid,
292  bool& criterion,
293  double& prob_sum,
294  int& n_considered) {
295  if (depth == 0) { // base case
296  xminus = x;
297  gradminus = grad;
298  mminus = m;
299  // FIXME: lepfrog needs +/- this->_epsilon_pm
300  newlogp = leapfrog(this->_model, this->_z, xminus, mminus, gradminus,
301  direction * this->_epsilon_last,
302  this->_error_msgs, this->_output_msgs);
303  newx = xminus;
304  newgrad = gradminus;
305  xplus = xminus;
306  mplus = mminus;
307  gradplus = gradminus;
308  double newH = newlogp - 0.5 * stan::math::dot_self(mminus);
309  if (newH != newH) // treat nan as -inf
310  newH = -std::numeric_limits<double>::infinity();
311  nvalid = newH > u;
312  criterion = newH - u > _maxchange;
313  prob_sum = stan::math::min(1, exp(newH - H0));
314  n_considered = 1;
315  this->nfevals_plus_eq(1);
316  } else { // depth >= 1
317  build_tree(x, m, grad, u, direction, depth-1, H0, xminus, mminus,
318  gradminus, xplus, mplus, gradplus, newx, newgrad, newlogp,
319  nvalid, criterion, prob_sum, n_considered);
320  if (criterion) {
321  std::vector<double> dummy1, dummy2, dummy3;
322  std::vector<double> newx2;
323  std::vector<double> newgrad2;
324  double newlogp2;
325  int nvalid2;
326  bool criterion2;
327  double prob_sum2;
328  int n_considered2;
329  if (direction == -1)
330  build_tree(xminus, mminus, gradminus, u, direction, depth-1, H0,
331  xminus, mminus, gradminus, dummy1, dummy2, dummy3,
332  newx2, newgrad2, newlogp2, nvalid2, criterion2,
333  prob_sum2, n_considered2);
334  else
335  build_tree(xplus, mplus, gradplus, u, direction, depth-1, H0,
336  dummy1, dummy2, dummy3, xplus, mplus, gradplus,
337  newx2, newgrad2, newlogp2, nvalid2, criterion2,
338  prob_sum2, n_considered2);
339  if (criterion &&
340  (this->_rand_uniform_01()
341  < float(nvalid2) / float(nvalid+nvalid2))) {
342  newx = newx2;
343  newgrad = newgrad2;
344  newlogp = newlogp2;
345  }
346  n_considered += n_considered2;
347  prob_sum += prob_sum2;
348  criterion &= criterion2;
349  nvalid += nvalid2;
350  }
351  criterion &= compute_criterion(xplus, xminus, mplus, mminus);
352  }
353  }
354 
355 
356 
357  };
358 
359 
360  }
361 
362 }
363 
364 #endif

     [ Stan Home Page ] © 2011–2012, Stan Development Team.