Go to the documentation of this file.
12 #ifndef MLPACK_METHODS_ANN_LAYER_POSITIONAL_ENCODING_HPP
13 #define MLPACK_METHODS_ANN_LAYER_POSITIONAL_ENCODING_HPP
34 typename InputDataType = arma::mat,
35 typename OutputDataType = arma::mat
52 const size_t maxSequenceLength);
62 void Forward(
const arma::Mat<eT>& input, arma::Mat<eT>& output);
75 const arma::Mat<eT>& gy,
89 OutputDataType
const&
Delta()
const {
return delta; }
91 OutputDataType&
Delta() {
return delta; }
94 InputDataType
const&
Encoding()
const {
return positionalEncoding; }
99 template<
typename Archive>
106 void InitPositionalEncoding();
112 size_t maxSequenceLength;
115 InputDataType positionalEncoding;
118 OutputDataType delta;
121 InputDataType inputParameter;
124 OutputDataType outputParameter;
131 #include "positional_encoding_impl.hpp"
The core includes that mlpack expects; standard C++ includes and Armadillo.
OutputDataType & OutputParameter()
Modify the output parameter.
InputDataType & InputParameter()
Modify the input parameter.
InputDataType const & Encoding() const
Get the positional encoding vector.
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Linear algebra utility functions, generally performed on matrices or vectors.
Positional Encoding injects some information about the relative or absolute position of the tokens in...
PositionalEncoding()
Create PositionalEncoding object.
OutputDataType const & Delta() const
Get the delta.
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
InputDataType const & InputParameter() const
Get the input parameter.
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
PositionalEncoding(const size_t embedDim, const size_t maxSequenceLength)
Create the PositionalEncoding layer object using the specified parameters.
OutputDataType & Delta()
Modify the delta.
OutputDataType const & OutputParameter() const
Get the output parameter.