12 #ifndef MLPACK_METHODS_RL_SIMPLE_DQN_HPP
13 #define MLPACK_METHODS_RL_SIMPLE_DQN_HPP
60 const bool isNoisy =
false,
61 InitType init = InitType(),
62 OutputLayerType outputLayer = OutputLayerType()):
63 network(outputLayer, init),
66 network.Add(
new Linear<>(inputDim, h1));
70 noisyLayerIndex.push_back(network.Model().size());
73 noisyLayerIndex.push_back(network.Model().size());
80 network.Add(
new Linear<>(h2, outputDim));
90 SimpleDQN(NetworkType& network,
const bool isNoisy =
false):
106 void Predict(
const arma::mat state, arma::mat& actionValue)
108 network.Predict(state, actionValue);
117 void Forward(
const arma::mat state, arma::mat& target)
119 network.Forward(state, target);
127 network.ResetParameters();
135 for (
size_t i = 0; i < noisyLayerIndex.size(); i++)
137 boost::get<NoisyLinear<>*>
138 (network.Model()[noisyLayerIndex[i]])->ResetNoise();
143 const arma::mat&
Parameters()
const {
return network.Parameters(); }
154 void Backward(
const arma::mat state, arma::mat& target, arma::mat& gradient)
156 network.Backward(state, target, gradient);
167 std::vector<size_t> noisyLayerIndex;