12 #ifndef MLPACK_METHODS_RL_REPLAY_RANDOM_REPLAY_HPP
13 #define MLPACK_METHODS_RL_REPLAY_RANDOM_REPLAY_HPP
43 template <
typename EnvironmentType>
79 const size_t capacity,
80 const size_t nSteps = 1,
81 const size_t dimension = StateType::dimension) :
87 states(dimension, capacity),
90 nextStates(dimension, capacity),
109 const double& discount)
111 nStepBuffer.push_back({state, action, reward, nextState, isEnd});
114 if (nStepBuffer.size() < nSteps)
118 if (nStepBuffer.size() > nSteps)
119 nStepBuffer.pop_front();
122 assert(nStepBuffer.size() == nSteps);
127 state = nStepBuffer.front().state;
128 action = nStepBuffer.front().action;
130 states.col(position) = state.Encode();
131 actions[position] = action;
132 rewards(position) = reward;
133 nextStates.col(position) = nextState.Encode();
134 isTerminal(position) = isEnd;
136 if (position == capacity)
154 const double& discount)
156 reward = nStepBuffer.back().reward;
157 nextState = nStepBuffer.back().nextState;
158 isEnd = nStepBuffer.back().isEnd;
161 for (
int i = nStepBuffer.size() - 2; i >= 0; i--)
163 bool iE = nStepBuffer[i].isEnd;
164 reward = nStepBuffer[i].reward + discount * reward * (1 - iE);
167 nextState = nStepBuffer[i].nextState;
184 std::vector<ActionType>& sampledActions,
185 arma::rowvec& sampledRewards,
186 arma::mat& sampledNextStates,
187 arma::irowvec& isTerminal)
189 size_t upperBound = full ? capacity : position;
190 arma::uvec sampledIndices = arma::randi<arma::uvec>(
191 batchSize, arma::distr_param(0, upperBound - 1));
193 sampledStates = states.cols(sampledIndices);
194 for (
size_t t = 0; t < sampledIndices.n_rows; t ++)
195 sampledActions.push_back(actions[sampledIndices[t]]);
196 sampledRewards = rewards.elem(sampledIndices).t();
197 sampledNextStates = nextStates.cols(sampledIndices);
198 isTerminal = this->isTerminal.elem(sampledIndices).t();
208 return full ? capacity : position;
220 std::vector<ActionType> ,
228 const size_t&
NSteps()
const {
return nSteps; }
247 std::deque<Transition> nStepBuffer;
253 std::vector<ActionType> actions;
256 arma::rowvec rewards;
259 arma::mat nextStates;
262 arma::irowvec isTerminal;