Go to the documentation of this file.
13 #ifndef MLPACK_METHODS_RL_ENVIRONMENT_DOUBLE_POLE_CART_HPP
14 #define MLPACK_METHODS_RL_ENVIRONMENT_DOUBLE_POLE_CART_HPP
49 State(
const arma::colvec& data) : data(data)
53 arma::colvec
Data()
const {
return data; }
55 arma::colvec&
Data() {
return data; }
68 double Angle(
const size_t i)
const {
return data[2 * i]; }
70 double&
Angle(
const size_t i) {
return data[2 * i]; }
78 const arma::colvec&
Encode()
const {
return data; }
124 const double m1 = 0.1,
125 const double m2 = 0.01,
126 const double l1 = 0.5,
127 const double l2 = 0.05,
128 const double gravity = 9.8,
129 const double massCart = 1.0,
130 const double forceMag = 10.0,
131 const double tau = 0.02,
132 const double thetaThresholdRadians = 36 * 2 * 3.1416 / 360,
133 const double xThreshold = 2.4,
134 const double doneReward = 0.0) :
144 thetaThresholdRadians(thetaThresholdRadians),
145 xThreshold(xThreshold),
146 doneReward(doneReward),
166 arma::vec dydx(6, arma::fill::zeros);
170 Dsdt(state, action, dydx);
171 RK4(state, action, dydx, nextState);
177 if (done && maxSteps != 0 && stepsPerformed >= maxSteps)
201 double totalForce = action.
action ? forceMag : -forceMag;
202 double totalMass = massCart;
205 double sinTheta1 = std::sin(state.
Angle(1));
206 double sinTheta2 = std::sin(state.
Angle(2));
207 double cosTheta1 = std::cos(state.
Angle(1));
208 double cosTheta2 = std::cos(state.
Angle(2));
211 totalForce += m1 * l1 * omega1 * omega1 * sinTheta1 + 0.375 * m1 * gravity *
212 std::sin(2 * state.
Angle(1));
213 totalForce += m2 * l2 * omega2 * omega2 * sinTheta1 + 0.375 * m2 * gravity *
214 std::sin(2 * state.
Angle(2));
217 totalMass += m1 * (0.25 + 0.75 * sinTheta1 * sinTheta1);
218 totalMass += m2 * (0.25 + 0.75 * sinTheta2 * sinTheta2);
221 double xAcc = totalForce / totalMass;
225 dydx[3] = -0.75 * (xAcc * cosTheta1 + gravity * sinTheta1) / l1;
226 dydx[5] = -0.75 * (xAcc * cosTheta2 + gravity * sinTheta2) / l2;
243 const double hh = tau * 0.5;
244 const double h6 = tau / 6;
249 yt = state.
Data() + (hh * dydx);
254 yt = state.
Data() + (hh * dyt);
260 yt = state.
Data() + (tau * dym);
267 nextState.
Data() = state.
Data() + h6 * (dydx + dyt + 2 * dym);
281 return Sample(state, action, nextState);
292 return State((arma::randu<arma::vec>(6) - 0.5) / 10.0);
303 if (maxSteps != 0 && stepsPerformed >= maxSteps)
305 Log::Info <<
"Episode terminated due to the maximum number of steps"
309 if (std::abs(state.
Position()) > xThreshold)
311 Log::Info <<
"Episode terminated due to cart crossing threshold";
314 if (std::abs(state.
Angle(1)) > thetaThresholdRadians ||
315 std::abs(state.
Angle(2)) > thetaThresholdRadians)
317 Log::Info <<
"Episode terminated due to pole falling";
360 double thetaThresholdRadians;
369 size_t stepsPerformed;
Implementation of Double Pole Cart Balancing task.
State(const arma::colvec &data)
Construct a state instance from given data.
size_t MaxSteps() const
Get the maximum number of steps allowed.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Dsdt(const State &state, const Action &action, arma::vec &dydx)
This is the ordinary differential equations required for estimation of next state through RK4 method.
Implementation of action of Double Pole Cart.
double Angle(const size_t i) const
Get the angle of the $i^{th}$ pole.
double AngularVelocity(const size_t i) const
Get the angular velocity of the $i^{th}$ pole.
State()
Construct a state instance.
double & Position()
Modify the position of the cart.
double & Angle(const size_t i)
Modify the angle of the $i^{th}$ pole.
double & Velocity()
Modify the velocity of the cart.
static constexpr size_t dimension
Dimension of the encoded state.
double Velocity() const
Get the velocity of the cart.
void RK4(const State &state, const Action &action, arma::vec &dydx, State &nextState)
This function calls the RK4 iterative method to estimate the next state based on given ordinary diffe...
double Position() const
Get the position of the cart.
Linear algebra utility functions, generally performed on matrices or vectors.
arma::colvec Data() const
Get the internal representation of the state.
bool IsTerminal(const State &state) const
This function checks if the car has reached the terminal state.
Implementation of the state of Double Pole Cart.
size_t & MaxSteps()
Set the maximum number of steps allowed.
double & AngularVelocity(const size_t i)
Modify the angular velocity of the $i^{th}$ pole.
arma::colvec & Data()
Modify the internal representation of the state.
double Sample(const State &state, const Action &action)
Dynamics of Double Pole Cart.
State InitialSample()
Initial state representation is randomly generated within [-0.05, 0.05].
DoublePoleCart(const size_t maxSteps=0, const double m1=0.1, const double m2=0.01, const double l1=0.5, const double l2=0.05, const double gravity=9.8, const double massCart=1.0, const double forceMag=10.0, const double tau=0.02, const double thetaThresholdRadians=36 *2 *3.1416/360, const double xThreshold=2.4, const double doneReward=0.0)
Construct a Double Pole Cart instance using the given constants.
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
const arma::colvec & Encode() const
Encode the state to a vector..
size_t StepsPerformed() const
Get the number of steps performed.
double Sample(const State &state, const Action &action, State &nextState)
Dynamics of Double Pole Cart instance.