12 #ifndef MLPACK_METHODS_RL_PRIORITIZED_REPLAY_HPP
13 #define MLPACK_METHODS_RL_PRIORITIZED_REPLAY_HPP
38 template <
typename EnvironmentType>
83 const size_t capacity,
85 const size_t nSteps = 1,
86 const size_t dimension = StateType::dimension) :
94 replayBetaIters(10000),
96 states(dimension, capacity),
99 nextStates(dimension, capacity),
103 while (size < capacity)
127 const double& discount)
129 nStepBuffer.push_back({state, action, reward, nextState, isEnd});
132 if (nStepBuffer.size() < nSteps)
136 if (nStepBuffer.size() > nSteps)
137 nStepBuffer.pop_front();
140 assert(nStepBuffer.size() == nSteps);
145 state = nStepBuffer.front().state;
146 action = nStepBuffer.front().action;
147 states.col(position) = state.Encode();
148 actions[position] = action;
149 rewards(position) = reward;
150 nextStates.col(position) = nextState.Encode();
151 isTerminal(position) = isEnd;
153 idxSum.
Set(position, maxPriority * alpha);
156 if (position == capacity)
174 const double& discount)
176 reward = nStepBuffer.back().reward;
177 nextState = nStepBuffer.back().nextState;
178 isEnd = nStepBuffer.back().isEnd;
181 for (
int i = nStepBuffer.size() - 2; i >= 0; i--)
183 bool iE = nStepBuffer[i].isEnd;
184 reward = nStepBuffer[i].reward + discount * reward * (1 - iE);
187 nextState = nStepBuffer[i].nextState;
200 arma::ucolvec idxes(batchSize);
201 double totalSum = idxSum.
Sum(0, (full ? capacity : position));
202 double sumPerRange = totalSum / batchSize;
203 for (
size_t bt = 0; bt < batchSize; bt++)
205 const double mass = arma::randu() * sumPerRange + bt * sumPerRange;
222 std::vector<ActionType>& sampledActions,
223 arma::rowvec& sampledRewards,
224 arma::mat& sampledNextStates,
225 arma::irowvec& isTerminal)
230 sampledStates = states.cols(sampledIndices);
231 for (
size_t t = 0; t < sampledIndices.n_rows; t ++)
232 sampledActions.push_back(actions[sampledIndices[t]]);
233 sampledRewards = rewards.elem(sampledIndices).t();
234 sampledNextStates = nextStates.cols(sampledIndices);
235 isTerminal = this->isTerminal.elem(sampledIndices).t();
239 size_t numSample = full ? capacity : position;
240 weights = arma::rowvec(sampledIndices.n_rows);
242 for (
size_t i = 0; i < sampledIndices.n_rows; ++i)
244 double p_sample = idxSum.
Get(sampledIndices(i)) / idxSum.
Sum();
245 weights(i) = pow(numSample * p_sample, -beta);
247 weights /= weights.max();
258 arma::colvec alphaPri = alpha * priorities;
259 maxPriority = std::max(maxPriority, arma::max(priorities));
270 return full ? capacity : position;
278 beta = beta + (1 - initialBeta) * 1.0 / replayBetaIters;
290 std::vector<ActionType> sampledActions,
291 arma::mat nextActionValues,
292 arma::mat& gradients)
294 arma::colvec tdError(target.n_cols);
295 for (
size_t i = 0; i < target.n_cols; i ++)
297 tdError(i) = nextActionValues(sampledActions[i].action, i) -
298 target(sampledActions[i].action, i);
300 tdError = arma::abs(tdError);
304 gradients = arma::mean(weights) * gradients;
308 const size_t&
NSteps()
const {
return nSteps; }
337 size_t replayBetaIters;
343 arma::ucolvec sampledIndices;
346 arma::rowvec weights;
352 std::deque<Transition> nStepBuffer;
358 std::vector<ActionType> actions;
361 arma::rowvec rewards;
364 arma::mat nextStates;
367 arma::irowvec isTerminal;