1 #ifndef __STAN__GM__PARSER__STATEMENT_GRAMMAR_DEF__HPP__ 2 #define __STAN__GM__PARSER__STATEMENT_GRAMMAR_DEF__HPP__ 16 #include <boost/spirit/include/qi.hpp> 18 #include <boost/spirit/include/phoenix_core.hpp> 19 #include <boost/spirit/include/phoenix_function.hpp> 20 #include <boost/spirit/include/phoenix_fusion.hpp> 21 #include <boost/spirit/include/phoenix_object.hpp> 22 #include <boost/spirit/include/phoenix_operator.hpp> 23 #include <boost/spirit/include/phoenix_stl.hpp> 25 #include <boost/lexical_cast.hpp> 26 #include <boost/fusion/include/adapt_struct.hpp> 27 #include <boost/fusion/include/std_pair.hpp> 28 #include <boost/config/warning_disable.hpp> 29 #include <boost/spirit/include/qi.hpp> 30 #include <boost/spirit/include/qi_numeric.hpp> 31 #include <boost/spirit/include/classic_position_iterator.hpp> 32 #include <boost/spirit/include/phoenix_core.hpp> 33 #include <boost/spirit/include/phoenix_function.hpp> 34 #include <boost/spirit/include/phoenix_fusion.hpp> 35 #include <boost/spirit/include/phoenix_object.hpp> 36 #include <boost/spirit/include/phoenix_operator.hpp> 37 #include <boost/spirit/include/phoenix_stl.hpp> 38 #include <boost/spirit/include/support_multi_pass.hpp> 39 #include <boost/tuple/tuple.hpp> 40 #include <boost/variant/apply_visitor.hpp> 41 #include <boost/variant/recursive_variant.hpp> 56 (std::vector<stan::gm::expression>, dims_) )
59 (std::string, family_)
60 (std::vector<stan::gm::expression>, args_) )
63 (std::string, variable_)
68 (std::vector<stan::gm::printable>, printables_) )
76 (std::vector<stan::gm::var_decl>, local_decl_)
77 (std::vector<stan::gm::statement>, statements_) )
83 struct validate_assignment {
84 template <
typename T1,
typename T2,
typename T3,
typename T4>
85 struct result {
typedef bool type; };
87 bool operator()(assignment& a,
90 std::stringstream& error_msgs)
const {
93 std::string name = a.var_dims_.name_;
94 if (!vm.exists(name)) {
95 error_msgs <<
"unknown variable in assignment" 96 <<
"; lhs variable=" << a.var_dims_.name_
104 && lhs_origin != origin_allowed) {
105 error_msgs <<
"attempt to assign variable in wrong block." 106 <<
" left-hand-side variable origin=";
108 error_msgs << std::endl;
113 a.var_type_ = vm.get(name);
114 size_t lhs_var_num_dims = a.var_type_.dims_.size();
115 size_t num_index_dims = a.var_dims_.dims_.size();
121 if (lhs_type.is_ill_formed()) {
122 error_msgs <<
"too many indexes for variable " 123 <<
"; variable name = " << name
124 <<
"; num dimensions given = " << num_index_dims
125 <<
"; variable array dimensions = " << lhs_var_num_dims;
128 if (lhs_type.num_dims_ != a.expr_.expression_type().num_dims_) {
129 error_msgs <<
"mismatched dimensions on left- and right-hand side of assignment" 130 <<
"; left dims=" << lhs_type.num_dims_
131 <<
"; right dims=" << a.expr_.expression_type().num_dims_
137 base_expr_type rhs_base_type = a.expr_.expression_type().base_type_;
139 bool types_compatible
140 = lhs_base_type == rhs_base_type
142 if (!types_compatible) {
143 error_msgs <<
"base type mismatch in assignment" 144 <<
"; left variable=" << a.var_dims_.name_
145 <<
"; left base type=";
147 error_msgs <<
"; right base type=";
149 error_msgs << std::endl;
155 boost::phoenix::function<validate_assignment> validate_assignment_f;
157 struct validate_sample {
158 template <
typename T1,
typename T2>
159 struct result {
typedef bool type; };
161 bool is_double_return(
const std::string& function_name,
162 const std::vector<expr_type>& arg_types,
163 std::ostream& error_msgs)
const {
168 bool operator()(
const sample& s,
169 std::ostream& error_msgs)
const {
170 std::vector<expr_type> arg_types;
171 arg_types.push_back(s.expr_.expression_type());
172 for (
size_t i = 0; i < s.dist_.args_.size(); ++i)
173 arg_types.push_back(s.dist_.args_[i].expression_type());
174 std::string function_name(s.dist_.family_);
175 function_name +=
"_log";
180 if (!is_double_return(function_name,arg_types,error_msgs)) {
181 error_msgs <<
"unknown distribution=" << s.dist_.family_ << std::endl;
184 if (s.truncation_.has_low()) {
185 std::vector<expr_type> arg_types_trunc(arg_types);
186 arg_types_trunc[0] = s.truncation_.low_.expression_type();
187 std::string function_name_cdf(s.dist_.family_);
188 function_name_cdf +=
"_cdf";
189 if (!is_double_return(function_name_cdf,arg_types_trunc,error_msgs)) {
190 error_msgs <<
"lower truncation not defined for specified arguments to " 191 << s.dist_.family_ << std::endl;
194 if (!is_double_return(function_name_cdf,arg_types,error_msgs)) {
195 error_msgs <<
"lower bound in truncation type does not match" 196 <<
" sampled variate in distribution's type" 201 if (s.truncation_.has_high()) {
202 std::vector<expr_type> arg_types_trunc(arg_types);
203 arg_types_trunc[0] = s.truncation_.high_.expression_type();
204 std::string function_name_cdf(s.dist_.family_);
205 function_name_cdf +=
"_cdf";
206 if (!is_double_return(function_name_cdf,arg_types_trunc,error_msgs)) {
207 error_msgs <<
"upper truncation not defined for specified arguments to " 208 << s.dist_.family_ << std::endl;
211 if (!is_double_return(function_name_cdf,arg_types,error_msgs)) {
212 error_msgs <<
"upper bound in truncation type does not match" 213 <<
" sampled variate in distribution's type" 222 boost::phoenix::function<validate_sample> validate_sample_f;
224 struct unscope_locals {
225 template <
typename T1,
typename T2>
226 struct result {
typedef void type; };
227 void operator()(
const std::vector<var_decl>& var_decls,
228 variable_map& vm)
const {
229 for (
size_t i = 0; i < var_decls.size(); ++i)
230 vm.remove(var_decls[i].name());
233 boost::phoenix::function<unscope_locals> unscope_locals_f;
262 struct add_while_condition {
263 template <
typename T1,
typename T2,
typename T3>
264 struct result {
typedef bool type; };
265 bool operator()(while_statement& ws,
267 std::stringstream& error_msgs)
const {
268 if (!
e.expression_type().is_primitive()) {
269 error_msgs <<
"conditions in while statement must be primitive int or real;" 270 <<
" found type=" <<
e.expression_type() << std::endl;
277 boost::phoenix::function<add_while_condition> add_while_condition_f;
279 struct add_while_body {
280 template <
typename T1,
typename T2>
281 struct result {
typedef void type; };
282 void operator()(while_statement& ws,
283 const statement& s)
const {
287 boost::phoenix::function<add_while_body> add_while_body_f;
289 struct add_loop_identifier {
290 template <
typename T1,
typename T2,
typename T3,
typename T4>
291 struct result {
typedef bool type; };
292 bool operator()(
const std::string& name,
293 std::string& name_local,
295 std::stringstream& error_msgs)
const {
297 if (vm.exists(name)) {
298 error_msgs <<
"ERROR: loop variable already declared." 299 <<
" variable name=\"" << name <<
"\"" << std::endl;
303 base_var_decl(name,std::vector<expression>(),
309 boost::phoenix::function<add_loop_identifier> add_loop_identifier_f;
311 struct remove_loop_identifier {
312 template <
typename T1,
typename T2>
313 struct result {
typedef void type; };
314 void operator()(
const std::string& name,
315 variable_map& vm)
const {
319 boost::phoenix::function<remove_loop_identifier> remove_loop_identifier_f;
321 struct validate_int_expr2 {
322 template <
typename T1,
typename T2>
323 struct result {
typedef bool type; };
325 bool operator()(
const expression& expr,
326 std::stringstream& error_msgs)
const {
327 if (!expr.expression_type().is_primitive_int()) {
328 error_msgs <<
"expression denoting integer required; found type=" 329 << expr.expression_type() << std::endl;
335 boost::phoenix::function<validate_int_expr2> validate_int_expr2_f;
337 struct validate_allow_sample {
338 template <
typename T1,
typename T2>
339 struct result {
typedef bool type; };
341 bool operator()(
const bool& allow_sample,
342 std::stringstream& error_msgs)
const {
344 error_msgs <<
"ERROR: sampling only allowed in model." 351 boost::phoenix::function<validate_allow_sample> validate_allow_sample_f;
354 template <
typename Iterator>
356 std::stringstream& error_msgs)
357 : statement_grammar::base_type(statement_r),
359 error_msgs_(error_msgs),
360 expression_g(var_map,error_msgs),
361 var_decls_g(var_map,error_msgs),
362 statement_2_g(var_map,error_msgs,*this)
364 using boost::spirit::qi::_1;
365 using boost::spirit::qi::char_;
366 using boost::spirit::qi::eps;
367 using boost::spirit::qi::lexeme;
368 using boost::spirit::qi::lit;
369 using boost::spirit::qi::_pass;
370 using boost::spirit::qi::_val;
372 using boost::spirit::qi::labels::_a;
373 using boost::spirit::qi::labels::_r1;
374 using boost::spirit::qi::labels::_r2;
388 = validate_assignment_f(_1,_r2,boost::phoenix::ref(
var_map_),
390 |
sample_r(_r1) [_pass = validate_sample_f(_1,
402 > eps[unscope_locals_f(_a,boost::phoenix::ref(
var_map_))]
413 [_pass = add_while_condition_f(_val,_1,
417 [add_while_body_f(_val,_1)]
452 = add_loop_identifier_f(_1,_a,
460 [remove_loop_identifier_f(_a,boost::phoenix::ref(
var_map_))];
479 > lexeme[*char_(
"a-zA-Z0-9/~!@#$%^&*()_+`-={}|[]:;'<>?,./ ")]
484 %= (lexeme[char_(
"a-zA-Z")
485 >> *char_(
"a-zA-Z0-9_.")]);
487 range_r.name(
"range expression pair, colon");
490 [_pass = validate_int_expr2_f(_1,boost::phoenix::ref(
error_msgs_))]
493 [_pass = validate_int_expr2_f(_1,boost::phoenix::ref(
error_msgs_))];
503 var_lhs_r.name(
"variable and array dimensions");
508 opt_dims_r.name(
"array dimensions (optional)");
512 dims_r.name(
"array dimensions");
516 [_pass = validate_int_expr2_f(_1,boost::phoenix::ref(
error_msgs_))]
522 sample_r.name(
"distribution of expression");
528 = validate_allow_sample_f(_r1,boost::phoenix::ref(
error_msgs_))]
551 %= lit(
';') [_val = no_op_statement()];
boost::spirit::qi::rule< Iterator, distribution(), whitespace_grammar< Iterator > > distribution_r
boost::spirit::qi::rule< Iterator, sample(bool), whitespace_grammar< Iterator > > sample_r
static function_signatures & instance()
boost::spirit::qi::rule< Iterator, printable(), whitespace_grammar< Iterator > > printable_r
boost::spirit::qi::rule< Iterator, variable_dims(), whitespace_grammar< Iterator > > var_lhs_r
Probability, optimization and sampling library.
boost::spirit::qi::rule< Iterator, print_statement(), whitespace_grammar< Iterator > > print_statement_r
boost::spirit::qi::rule< Iterator, std::vector< var_decl >), whitespace_grammar< Iterator > > local_var_decls_r
boost::spirit::qi::rule< Iterator, std::string(), whitespace_grammar< Iterator > > identifier_r
boost::spirit::qi::rule< Iterator, std::vector< expression >), whitespace_grammar< Iterator > > opt_dims_r
void print_var_origin(std::ostream &o, const var_origin &vo)
boost::spirit::qi::rule< Iterator, boost::spirit::qi::locals< std::vector< var_decl > >, statements(bool, var_origin), whitespace_grammar< Iterator > > statement_seq_r
boost::spirit::qi::rule< Iterator, std::string(), whitespace_grammar< Iterator > > printable_string_r
statement_2_grammar< Iterator > statement_2_g
boost::spirit::qi::rule< Iterator, range(), whitespace_grammar< Iterator > > range_r
boost::spirit::qi::rule< Iterator, std::vector< expression >), whitespace_grammar< Iterator > > dims_r
boost::spirit::qi::rule< Iterator, boost::spirit::qi::locals< std::string >, for_statement(bool, var_origin), whitespace_grammar< Iterator > > for_statement_r
std::stringstream & error_msgs_
BOOST_FUSION_ADAPT_STRUCT(stan::gm::program,(std::vector< stan::gm::var_decl >, data_decl_)(DUMMY_STRUCT::type, derived_data_decl_)(std::vector< stan::gm::var_decl >, parameter_decl_)(DUMMY_STRUCT::type, derived_decl_)(stan::gm::statement, statement_)(DUMMY_STRUCT::type, generated_decl_)) namespace stan
boost::spirit::qi::rule< Iterator, assignment(), whitespace_grammar< Iterator > > assignment_r
var_decls_grammar< Iterator > var_decls_g
expression_grammar< Iterator > expression_g
boost::spirit::qi::rule< Iterator, range(), whitespace_grammar< Iterator > > truncation_range_r
boost::spirit::qi::rule< Iterator, no_op_statement(), whitespace_grammar< Iterator > > no_op_statement_r
double e()
Return the base of the natural logarithm.
expr_type infer_type_indexing(const base_expr_type &expr_base_type, size_t num_expr_dims, size_t num_index_dims)
boost::spirit::qi::rule< Iterator, while_statement(bool, var_origin), whitespace_grammar< Iterator > > while_statement_r
boost::spirit::qi::rule< Iterator, statement(bool, var_origin), whitespace_grammar< Iterator > > statement_r
bool is_primitive_double() const
expr_type get_result_type(const std::string &name, const std::vector< expr_type > &args, std::ostream &error_msgs)
std::ostream & write_base_expr_type(std::ostream &o, base_expr_type type)
statement_grammar(variable_map &var_map, std::stringstream &error_msgs)