12 #ifndef _MLPACK_METHODS_AMF_VALIDATIONRMSETERMINATION_HPP_INCLUDED
13 #define _MLPACK_METHODS_AMF_VALIDATIONRMSETERMINATION_HPP_INCLUDED
36 template <
class MatType>
50 size_t num_test_points,
51 double tolerance = 1e-5,
52 size_t maxIterations = 10000,
53 size_t reverseStepTolerance = 3)
54 : tolerance(tolerance),
55 maxIterations(maxIterations),
56 num_test_points(num_test_points),
57 reverseStepTolerance(reverseStepTolerance)
63 test_points.zeros(num_test_points, 3);
66 for(
size_t i = 0; i < num_test_points; i++)
77 }
while((t_val = V(t_row, t_col)) == 0);
80 test_points(i, 0) = t_row;
81 test_points(i, 1) = t_col;
82 test_points(i, 2) = t_val;
104 reverseStepCount = 0;
125 for(
size_t i = 0; i < num_test_points; i++)
127 size_t t_row = test_points(i, 0);
128 size_t t_col = test_points(i, 1);
129 double t_val = test_points(i, 2);
130 double temp = (t_val - WH(t_row, t_col));
134 rmse /= num_test_points;
142 if ((rmseOld - rmse) / rmseOld < tolerance && iteration > 4)
145 if (reverseStepCount == 0 && isCopy ==
false)
152 c_indexOld = rmseOld;
162 reverseStepCount = 0;
164 if (rmse <= c_indexOld && isCopy ==
true)
171 if (reverseStepCount == reverseStepTolerance || iteration > maxIterations)
187 const double&
Index()
const {
return rmse; }
207 size_t maxIterations;
209 size_t num_test_points;
215 arma::mat test_points;
222 size_t reverseStepTolerance;
224 size_t reverseStepCount;
241 #endif // _MLPACK_METHODS_AMF_VALIDATIONRMSETERMINATION_HPP_INCLUDED
const size_t & NumTestPoints() const
Get number of validation points.
void Initialize(const MatType &)
Initializes the termination policy before stating the factorization.
const double & Index() const
Get current value of residue.
const size_t & Iteration() const
Get current iteration count.
The core includes that mlpack expects; standard C++ includes and Armadillo.
This class implements validation termination policy based on RMSE index.
const double & Tolerance() const
Access tolerance value.
const size_t & MaxIterations() const
Access upper limit of iteration count.
ValidationRMSETermination(MatType &V, size_t num_test_points, double tolerance=1e-5, size_t maxIterations=10000, size_t reverseStepTolerance=3)
Create a validation set according to given parameters and nullifies this set in data matrix(training ...
bool IsConverged(arma::mat &W, arma::mat &H)
Check if termination criterio is met.