mlpack  3.4.2
nca.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_NCA_NCA_HPP
13 #define MLPACK_METHODS_NCA_NCA_HPP
14 
15 #include <mlpack/prereqs.hpp>
17 #include <ensmallen.hpp>
18 
20 
21 namespace mlpack {
22 namespace nca {
23 
47 template<typename MetricType = metric::SquaredEuclideanDistance,
48  typename OptimizerType = ens::StandardSGD>
49 class NCA
50 {
51  public:
61  NCA(const arma::mat& dataset,
62  const arma::Row<size_t>& labels,
63  MetricType metric = MetricType());
64 
77  template<typename... CallbackTypes>
78  void LearnDistance(arma::mat& outputMatrix, CallbackTypes&&... callbacks);
79 
81  const arma::mat& Dataset() const { return dataset; }
83  const arma::Row<size_t>& Labels() const { return labels; }
84 
86  const OptimizerType& Optimizer() const { return optimizer; }
87  OptimizerType& Optimizer() { return optimizer; }
88 
89  private:
91  const arma::mat& dataset;
93  const arma::Row<size_t>& labels;
94 
96  MetricType metric;
97 
100 
102  OptimizerType optimizer;
103 };
104 
105 } // namespace nca
106 } // namespace mlpack
107 
108 // Include the implementation.
109 #include "nca_impl.hpp"
110 
111 #endif
mlpack::nca::SoftmaxErrorFunction
The "softmax" stochastic neighbor assignment probability function.
Definition: nca_softmax_error_function.hpp:46
prereqs.hpp
The core includes that mlpack expects; standard C++ includes and Armadillo.
mlpack::nca::NCA::Labels
const arma::Row< size_t > & Labels() const
Get the labels reference.
Definition: nca.hpp:83
lmetric.hpp
mlpack::metric::SquaredEuclideanDistance
LMetric< 2, false > SquaredEuclideanDistance
The squared Euclidean (L2) distance.
Definition: lmetric.hpp:107
nca_softmax_error_function.hpp
mlpack
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: add_to_cli11.hpp:21
mlpack::nca::NCA
An implementation of Neighborhood Components Analysis, both a linear dimensionality reduction techniq...
Definition: nca.hpp:50
mlpack::nca::NCA::LearnDistance
void LearnDistance(arma::mat &outputMatrix, CallbackTypes &&... callbacks)
Perform Neighborhood Components Analysis.
mlpack::nca::NCA::Optimizer
const OptimizerType & Optimizer() const
Get the optimizer.
Definition: nca.hpp:86
mlpack::nca::NCA::Dataset
const arma::mat & Dataset() const
Get the dataset reference.
Definition: nca.hpp:81
mlpack::nca::NCA::Optimizer
OptimizerType & Optimizer()
Definition: nca.hpp:87
mlpack::nca::NCA::NCA
NCA(const arma::mat &dataset, const arma::Row< size_t > &labels, MetricType metric=MetricType())
Construct the Neighborhood Components Analysis object.