13 #ifndef MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_GELU_FUNCTION_HPP
14 #define MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_GELU_FUNCTION_HPP
40 static double Fn(
const double x)
42 return 0.5 * x * (1 + std::tanh(std::sqrt(2 /
M_PI) *
43 (x + 0.044715 * std::pow(x, 3))));
52 template<
typename InputVecType,
typename OutputVecType>
53 static void Fn(
const InputVecType& x, OutputVecType& y)
55 y = 0.5 * x % (1 + arma::tanh(std::sqrt(2 /
M_PI) *
56 (x + 0.044715 * arma::pow(x, 3))));
65 static double Deriv(
const double y)
67 return 0.5 * std::tanh(0.0356774 * std::pow(y, 3) + 0.797885 * y) +
68 (0.0535161 * std::pow(y, 3) + 0.398942 * y) *
69 std::pow(1 / std::cosh(0.0356774 * std::pow(y, 3) +
70 0.797885 * y), 2) + 0.5;
79 template<
typename InputVecType,
typename OutputVecType>
80 static void Deriv(
const InputVecType& y, OutputVecType& x)
82 x = 0.5 * arma::tanh(0.0356774 * arma::pow(y, 3) + 0.797885 * y) +
83 (0.0535161 * arma::pow(y, 3) + 0.398942 * y) %
84 arma::pow(1 / arma::cosh(0.0356774 * arma::pow(y, 3) +
85 0.797885 * y), 2) + 0.5;