mlpack  2.2.5
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
ra_search_rules.hpp
Go to the documentation of this file.
1 
14 #ifndef MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
15 #define MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
16 
18 
19 namespace mlpack {
20 namespace neighbor {
21 
30 template<typename SortPolicy, typename MetricType, typename TreeType>
32 {
33  public:
55  RASearchRules(const arma::mat& referenceSet,
56  const arma::mat& querySet,
57  const size_t k,
58  MetricType& metric,
59  const double tau = 5,
60  const double alpha = 0.95,
61  const bool naive = false,
62  const bool sampleAtLeaves = false,
63  const bool firstLeafExact = false,
64  const size_t singleSampleLimit = 20,
65  const bool sameSet = false);
66 
74  void GetResults(arma::Mat<size_t>& neighbors, arma::mat& distances);
75 
83  double BaseCase(const size_t queryIndex, const size_t referenceIndex);
84 
107  double Score(const size_t queryIndex, TreeType& referenceNode);
108 
132  double Score(const size_t queryIndex,
133  TreeType& referenceNode,
134  const double baseCaseResult);
135 
153  double Rescore(const size_t queryIndex,
154  TreeType& referenceNode,
155  const double oldScore);
156 
175  double Score(TreeType& queryNode, TreeType& referenceNode);
176 
197  double Score(TreeType& queryNode,
198  TreeType& referenceNode,
199  const double baseCaseResult);
200 
223  double Rescore(TreeType& queryNode,
224  TreeType& referenceNode,
225  const double oldScore);
226 
227 
228  size_t NumDistComputations() { return numDistComputations; }
230  {
231  if (numSamplesMade.n_elem == 0)
232  return 0;
233  else
234  return arma::sum(numSamplesMade);
235  }
236 
238 
239  const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
240  TraversalInfoType& TraversalInfo() { return traversalInfo; }
241 
242  private:
244  const arma::mat& referenceSet;
245 
247  const arma::mat& querySet;
248 
250  typedef std::pair<double, size_t> Candidate;
251 
253  struct CandidateCmp {
254  bool operator()(const Candidate& c1, const Candidate& c2)
255  {
256  return !SortPolicy::IsBetter(c2.first, c1.first);
257  };
258  };
259 
261  typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
262  CandidateList;
263 
265  std::vector<CandidateList> candidates;
266 
268  const size_t k;
269 
271  MetricType& metric;
272 
274  bool sampleAtLeaves;
275 
277  bool firstLeafExact;
278 
280  size_t singleSampleLimit;
281 
283  size_t numSamplesReqd;
284 
286  arma::Col<size_t> numSamplesMade;
287 
289  double samplingRatio;
290 
291  // TO REMOVE: just for testing
292  size_t numDistComputations;
293 
295  bool sameSet;
296 
297  TraversalInfoType traversalInfo;
298 
306  void InsertNeighbor(const size_t queryIndex,
307  const size_t neighbor,
308  const double distance);
309 
313  double Score(const size_t queryIndex,
314  TreeType& referenceNode,
315  const double distance,
316  const double bestDistance);
317 
321  double Score(TreeType& queryNode,
322  TreeType& referenceNode,
323  const double distance,
324  const double bestDistance);
325 
326  static_assert(tree::TreeTraits<TreeType>::UniqueNumDescendants, "TreeType "
327  "must provide a unique number of descendants points.");
328 }; // class RASearchRules
329 
330 } // namespace neighbor
331 } // namespace mlpack
332 
333 // Include implementation.
334 #include "ra_search_rules_impl.hpp"
335 
336 #endif // MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
RASearchRules(const arma::mat &referenceSet, const arma::mat &querySet, const size_t k, MetricType &metric, const double tau=5, const double alpha=0.95, const bool naive=false, const bool sampleAtLeaves=false, const bool firstLeafExact=false, const size_t singleSampleLimit=20, const bool sameSet=false)
Construct the RASearchRules object.
The TraversalInfo class holds traversal information which is used in dual-tree (and single-tree) trav...
tree::TraversalInfo< TreeType > TraversalInfoType
const TraversalInfoType & TraversalInfo() const
double Rescore(const size_t queryIndex, TreeType &referenceNode, const double oldScore)
Re-evaluate the score for recursion order.
double Score(const size_t queryIndex, TreeType &referenceNode)
Get the score for recursion order.
see subsection cli_alt_reg_tut Alternate DET regularization The usual regularized error f $R_ alpha(t)\f $of a node\f $t\f $is given by
Definition: det.txt:340
TraversalInfoType & TraversalInfo()
double BaseCase(const size_t queryIndex, const size_t referenceIndex)
Get the distance from the query point to the reference point.
void GetResults(arma::Mat< size_t > &neighbors, arma::mat &distances)
Store the list of candidates for each query point in the given matrices.
The RASearchRules class is a template helper class used by RASearch class when performing rank-approx...