1 #ifndef __STAN__GM__PARSER__STATEMENT_GRAMMAR_DEF__HPP__
2 #define __STAN__GM__PARSER__STATEMENT_GRAMMAR_DEF__HPP__
16 #include <boost/spirit/include/qi.hpp>
18 #include <boost/spirit/include/phoenix_core.hpp>
19 #include <boost/spirit/include/phoenix_function.hpp>
20 #include <boost/spirit/include/phoenix_fusion.hpp>
21 #include <boost/spirit/include/phoenix_object.hpp>
22 #include <boost/spirit/include/phoenix_operator.hpp>
23 #include <boost/spirit/include/phoenix_stl.hpp>
25 #include <boost/lexical_cast.hpp>
26 #include <boost/fusion/include/adapt_struct.hpp>
27 #include <boost/fusion/include/std_pair.hpp>
28 #include <boost/config/warning_disable.hpp>
29 #include <boost/spirit/include/qi.hpp>
30 #include <boost/spirit/include/qi_numeric.hpp>
31 #include <boost/spirit/include/classic_position_iterator.hpp>
32 #include <boost/spirit/include/phoenix_core.hpp>
33 #include <boost/spirit/include/phoenix_function.hpp>
34 #include <boost/spirit/include/phoenix_fusion.hpp>
35 #include <boost/spirit/include/phoenix_object.hpp>
36 #include <boost/spirit/include/phoenix_operator.hpp>
37 #include <boost/spirit/include/phoenix_stl.hpp>
38 #include <boost/spirit/include/support_multi_pass.hpp>
39 #include <boost/tuple/tuple.hpp>
40 #include <boost/variant/apply_visitor.hpp>
41 #include <boost/variant/recursive_variant.hpp>
56 (std::vector<stan::gm::expression>, dims_) )
59 (std::string, family_)
60 (std::vector<stan::gm::expression>, args_) )
63 (std::string, variable_)
68 (std::vector<stan::gm::printable>, printables_) )
76 (std::vector<stan::gm::var_decl>, local_decl_)
77 (std::vector<stan::gm::statement>, statements_) )
83 struct validate_assignment {
84 template <
typename T1,
typename T2,
typename T3,
typename T4>
85 struct result {
typedef bool type; };
87 bool operator()(assignment& a,
90 std::stringstream& error_msgs)
const {
93 std::string name = a.var_dims_.name_;
94 if (!vm.exists(name)) {
95 error_msgs <<
"unknown variable in assignment"
96 <<
"; lhs variable=" << a.var_dims_.name_
104 && lhs_origin != origin_allowed) {
105 error_msgs <<
"attempt to assign variable in wrong block."
106 <<
" left-hand-side variable origin=";
108 error_msgs << std::endl;
113 a.var_type_ = vm.get(name);
114 size_t lhs_var_num_dims = a.var_type_.dims_.size();
115 size_t num_index_dims = a.var_dims_.dims_.size();
121 if (lhs_type.is_ill_formed()) {
122 error_msgs <<
"too many indexes for variable "
123 <<
"; variable name = " << name
124 <<
"; num dimensions given = " << num_index_dims
125 <<
"; variable array dimensions = " << lhs_var_num_dims;
128 if (lhs_type.num_dims_ != a.expr_.expression_type().num_dims_) {
129 error_msgs <<
"mismatched dimensions on left- and right-hand side of assignment"
130 <<
"; left dims=" << lhs_type.num_dims_
131 <<
"; right dims=" << a.expr_.expression_type().num_dims_
137 base_expr_type rhs_base_type = a.expr_.expression_type().base_type_;
139 bool types_compatible
140 = lhs_base_type == rhs_base_type
142 if (!types_compatible) {
143 error_msgs <<
"base type mismatch in assignment"
144 <<
"; left variable=" << a.var_dims_.name_
145 <<
"; left base type=";
147 error_msgs <<
"; right base type=";
149 error_msgs << std::endl;
155 boost::phoenix::function<validate_assignment> validate_assignment_f;
157 struct validate_sample {
158 template <
typename T1,
typename T2>
159 struct result {
typedef bool type; };
161 bool is_double_return(
const std::string& function_name,
162 const std::vector<expr_type>& arg_types,
163 std::ostream& error_msgs)
const {
168 bool operator()(
const sample& s,
169 std::ostream& error_msgs)
const {
170 std::vector<expr_type> arg_types;
171 arg_types.push_back(s.expr_.expression_type());
172 for (
size_t i = 0; i < s.dist_.args_.size(); ++i)
173 arg_types.push_back(s.dist_.args_[i].expression_type());
174 std::string function_name(s.dist_.family_);
175 function_name +=
"_log";
180 if (!is_double_return(function_name,arg_types,error_msgs)) {
181 error_msgs <<
"unknown distribution=" << s.dist_.family_ << std::endl;
184 if (s.truncation_.has_low()) {
185 std::vector<expr_type> arg_types_trunc(arg_types);
186 arg_types_trunc[0] = s.truncation_.low_.expression_type();
187 std::string function_name_cdf(s.dist_.family_);
188 function_name_cdf +=
"_cdf";
189 if (!is_double_return(function_name_cdf,arg_types_trunc,error_msgs)) {
190 error_msgs <<
"lower truncation not defined for specified arguments to "
191 << s.dist_.family_ << std::endl;
194 if (!is_double_return(function_name_cdf,arg_types,error_msgs)) {
195 error_msgs <<
"lower bound in truncation type does not match"
196 <<
" sampled variate in distribution's type"
201 if (s.truncation_.has_high()) {
202 std::vector<expr_type> arg_types_trunc(arg_types);
203 arg_types_trunc[0] = s.truncation_.high_.expression_type();
204 std::string function_name_cdf(s.dist_.family_);
205 function_name_cdf +=
"_cdf";
206 if (!is_double_return(function_name_cdf,arg_types_trunc,error_msgs)) {
207 error_msgs <<
"upper truncation not defined for specified arguments to "
208 << s.dist_.family_ << std::endl;
211 if (!is_double_return(function_name_cdf,arg_types,error_msgs)) {
212 error_msgs <<
"upper bound in truncation type does not match"
213 <<
" sampled variate in distribution's type"
222 boost::phoenix::function<validate_sample> validate_sample_f;
224 struct unscope_locals {
225 template <
typename T1,
typename T2>
226 struct result {
typedef void type; };
227 void operator()(
const std::vector<var_decl>& var_decls,
228 variable_map& vm)
const {
229 for (
size_t i = 0; i < var_decls.size(); ++i)
230 vm.remove(var_decls[i].name());
233 boost::phoenix::function<unscope_locals> unscope_locals_f;
262 struct add_while_condition {
263 template <
typename T1,
typename T2,
typename T3>
264 struct result {
typedef bool type; };
265 bool operator()(while_statement& ws,
267 std::stringstream& error_msgs)
const {
268 if (!
e.expression_type().is_primitive()) {
269 error_msgs <<
"conditions in while statement must be primitive int or real;"
270 <<
" found type=" <<
e.expression_type() << std::endl;
277 boost::phoenix::function<add_while_condition> add_while_condition_f;
279 struct add_while_body {
280 template <
typename T1,
typename T2>
281 struct result {
typedef void type; };
282 void operator()(while_statement& ws,
283 const statement& s)
const {
287 boost::phoenix::function<add_while_body> add_while_body_f;
289 struct add_loop_identifier {
290 template <
typename T1,
typename T2,
typename T3,
typename T4>
291 struct result {
typedef bool type; };
292 bool operator()(
const std::string& name,
293 std::string& name_local,
295 std::stringstream& error_msgs)
const {
297 if (vm.exists(name)) {
298 error_msgs <<
"ERROR: loop variable already declared."
299 <<
" variable name=\"" << name <<
"\"" << std::endl;
303 base_var_decl(name,std::vector<expression>(),
309 boost::phoenix::function<add_loop_identifier> add_loop_identifier_f;
311 struct remove_loop_identifier {
312 template <
typename T1,
typename T2>
313 struct result {
typedef void type; };
314 void operator()(
const std::string& name,
315 variable_map& vm)
const {
319 boost::phoenix::function<remove_loop_identifier> remove_loop_identifier_f;
321 struct validate_int_expr2 {
322 template <
typename T1,
typename T2>
323 struct result {
typedef bool type; };
325 bool operator()(
const expression& expr,
326 std::stringstream& error_msgs)
const {
327 if (!expr.expression_type().is_primitive_int()) {
328 error_msgs <<
"expression denoting integer required; found type="
329 << expr.expression_type() << std::endl;
335 boost::phoenix::function<validate_int_expr2> validate_int_expr2_f;
337 struct validate_allow_sample {
338 template <
typename T1,
typename T2>
339 struct result {
typedef bool type; };
341 bool operator()(
const bool& allow_sample,
342 std::stringstream& error_msgs)
const {
344 error_msgs <<
"ERROR: sampling only allowed in model."
351 boost::phoenix::function<validate_allow_sample> validate_allow_sample_f;
354 template <
typename Iterator>
356 std::stringstream& error_msgs)
357 : statement_grammar::base_type(statement_r),
359 error_msgs_(error_msgs),
360 expression_g(var_map,error_msgs),
361 var_decls_g(var_map,error_msgs),
362 statement_2_g(var_map,error_msgs,*this)
364 using boost::spirit::qi::_1;
365 using boost::spirit::qi::char_;
366 using boost::spirit::qi::eps;
367 using boost::spirit::qi::lexeme;
368 using boost::spirit::qi::lit;
369 using boost::spirit::qi::_pass;
370 using boost::spirit::qi::_val;
372 using boost::spirit::qi::labels::_a;
373 using boost::spirit::qi::labels::_r1;
374 using boost::spirit::qi::labels::_r2;
379 statement_r.name(
"statement");
381 %= statement_seq_r(_r1,_r2)
382 | for_statement_r(_r1,_r2)
383 | while_statement_r(_r1,_r2)
384 | statement_2_g(_r1,_r2)
388 = validate_assignment_f(_1,_r2,boost::phoenix::ref(var_map_),
389 boost::phoenix::ref(error_msgs_))]
390 | sample_r(_r1) [_pass = validate_sample_f(_1,
391 boost::phoenix::ref(error_msgs_))]
396 statement_seq_r.name(
"sequence of statements");
399 > local_var_decls_r[_a = _1]
400 > *statement_r(_r1,_r2)
402 > eps[unscope_locals_f(_a,boost::phoenix::ref(var_map_))]
408 while_statement_r.name(
"while statement");
413 [_pass = add_while_condition_f(_val,_1,
414 boost::phoenix::ref(error_msgs_))]
416 > statement_r(_r1,_r2)
417 [add_while_body_f(_val,_1)]
447 for_statement_r.name(
"for statement");
451 > identifier_r [_pass
452 = add_loop_identifier_f(_1,_a,
453 boost::phoenix::ref(var_map_),
454 boost::phoenix::ref(error_msgs_))]
458 > statement_r(_r1,_r2)
460 [remove_loop_identifier_f(_a,boost::phoenix::ref(var_map_))];
463 print_statement_r.name(
"print statement");
467 > (printable_r %
',')
471 printable_r.name(
"printable");
473 %= printable_string_r
476 printable_string_r.name(
"printable quoted string");
479 > lexeme[*char_(
"a-zA-Z0-9/~!@#$%^&*()_+`-={}|[]:;'<>?,./ ")]
482 identifier_r.name(
"identifier");
484 %= (lexeme[char_(
"a-zA-Z")
485 >> *char_(
"a-zA-Z0-9_.")]);
487 range_r.name(
"range expression pair, colon");
490 [_pass = validate_int_expr2_f(_1,boost::phoenix::ref(error_msgs_))]
493 [_pass = validate_int_expr2_f(_1,boost::phoenix::ref(error_msgs_))];
495 assignment_r.name(
"variable assignment by expression");
503 var_lhs_r.name(
"variable and array dimensions");
508 opt_dims_r.name(
"array dimensions (optional)");
512 dims_r.name(
"array dimensions");
516 [_pass = validate_int_expr2_f(_1,boost::phoenix::ref(error_msgs_))]
522 sample_r.name(
"distribution of expression");
528 = validate_allow_sample_f(_r1,boost::phoenix::ref(error_msgs_))]
530 > -truncation_range_r
533 distribution_r.name(
"distribution and parameters");
537 >> -(expression_g %
',')
540 truncation_range_r.name(
"range pair");
549 no_op_statement_r.name(
"no op statement");
551 %= lit(
';') [_val = no_op_statement()];
static function_signatures & instance()
expr_type get_result_type(const std::string &name, const std::vector< expr_type > &args, std::ostream &error_msgs)
expr_type infer_type_indexing(const base_expr_type &expr_base_type, size_t num_expr_dims, size_t num_index_dims)
std::ostream & write_base_expr_type(std::ostream &o, base_expr_type type)
void print_var_origin(std::ostream &o, const var_origin &vo)
double e()
Return the base of the natural logarithm.
Probability, optimization and sampling library.
BOOST_FUSION_ADAPT_STRUCT(stan::gm::program,(std::vector< stan::gm::var_decl >, data_decl_)(DUMMY_STRUCT::type, derived_data_decl_)(std::vector< stan::gm::var_decl >, parameter_decl_)(DUMMY_STRUCT::type, derived_decl_)(stan::gm::statement, statement_)(DUMMY_STRUCT::type, generated_decl_)) namespace stan
bool is_primitive_double() const
statement_grammar(variable_map &var_map, std::stringstream &error_msgs)