12 #ifndef MLPACK_METHODS_RL_CATEGORICAL_DQN_HPP
13 #define MLPACK_METHODS_RL_CATEGORICAL_DQN_HPP
21 #include "../training_config.hpp"
76 const bool isNoisy =
false,
77 InitType init = InitType(),
78 OutputLayerType outputLayer = OutputLayerType()):
79 network(outputLayer, init),
80 atomSize(config.AtomSize()),
85 network.Add(
new Linear<>(inputDim, h1));
89 noisyLayerIndex.push_back(network.Model().size());
92 noisyLayerIndex.push_back(network.Model().size());
99 network.Add(
new Linear<>(h2, outputDim * atomSize));
112 const bool isNoisy =
false):
113 network(
std::move(network)),
114 atomSize(config.AtomSize()),
131 void Predict(
const arma::mat state, arma::mat& actionValue)
134 network.Predict(state, q_atoms);
135 activations.copy_size(q_atoms);
136 actionValue.set_size(q_atoms.n_rows / atomSize, q_atoms.n_cols);
137 arma::rowvec support = arma::linspace<arma::rowvec>(vMin, vMax, atomSize);
138 for (
size_t i = 0; i < q_atoms.n_rows; i += atomSize)
140 arma::mat activation = activations.rows(i, i + atomSize - 1);
141 arma::mat input = q_atoms.rows(i, i + atomSize - 1);
142 softMax.Forward(input, activation);
143 activations.rows(i, i + atomSize - 1) = activation;
144 actionValue.row(i/atomSize) = support * activation;
154 void Forward(
const arma::mat state, arma::mat& dist)
157 network.Forward(state, q_atoms);
158 activations.copy_size(q_atoms);
159 for (
size_t i = 0; i < q_atoms.n_rows; i += atomSize)
161 arma::mat activation = activations.rows(i, i + atomSize - 1);
162 arma::mat input = q_atoms.rows(i, i + atomSize - 1);
163 softMax.Forward(input, activation);
164 activations.rows(i, i + atomSize - 1) = activation;
174 network.ResetParameters();
182 for (
size_t i = 0; i < noisyLayerIndex.size(); i++)
184 boost::get<NoisyLinear<>*>
185 (network.Model()[noisyLayerIndex[i]])->ResetNoise();
190 const arma::mat&
Parameters()
const {
return network.Parameters(); }
202 arma::mat& lossGradients,
205 arma::mat activationGradients(arma::size(activations));
206 for (
size_t i = 0; i < activations.n_rows; i += atomSize)
208 arma::mat activationGrad;
209 arma::mat lossGrad = lossGradients.rows(i, i + atomSize - 1);
210 arma::mat activation = activations.rows(i, i + atomSize - 1);
211 softMax.Backward(activation, lossGrad, activationGrad);
212 activationGradients.rows(i, i + atomSize - 1) = activationGrad;
214 network.Backward(state, activationGradients, gradient);
234 std::vector<size_t> noisyLayerIndex;
240 arma::mat activations;