12 #ifndef MLPACK_METHODS_HMM_HMM_REGRESSION_HPP
13 #define MLPACK_METHODS_HMM_HMM_REGRESSION_HPP
91 const double tolerance = 1e-5) :
92 HMM<distribution::RegressionDistribution>(states, emissions, tolerance)
121 const std::vector<distribution::RegressionDistribution>&
emission,
122 const double tolerance = 1e-5) :
123 HMM<distribution::RegressionDistribution>(initial, transition, emission,
158 void Train(
const std::vector<arma::mat>& predictors,
159 const std::vector<arma::vec>& responses);
180 void Train(
const std::vector<arma::mat>& predictors,
181 const std::vector<arma::vec>& responses,
182 const std::vector<arma::Row<size_t> >& stateSeq);
203 double Estimate(
const arma::mat& predictors,
204 const arma::vec& responses,
205 arma::mat& stateProb,
206 arma::mat& forwardProb,
207 arma::mat& backwardProb,
208 arma::vec& scales)
const;
222 double Estimate(
const arma::mat& predictors,
223 const arma::vec& responses,
224 arma::mat& stateProb)
const;
237 double Predict(
const arma::mat& predictors,
238 const arma::vec& responses,
239 arma::Row<size_t>& stateSeq)
const;
249 const arma::vec& responses)
const;
264 void Filter(
const arma::mat& predictors,
265 const arma::vec& responses,
266 arma::vec& filterSeq,
267 size_t ahead = 0)
const;
282 void Smooth(
const arma::mat& predictors,
283 const arma::vec& responses,
284 arma::vec& smoothSeq)
const;
290 void StackData(
const std::vector<arma::mat>& predictors,
291 const std::vector<arma::vec>& responses,
292 std::vector<arma::mat>& dataSeq)
const;
294 void StackData(
const arma::mat& predictors,
295 const arma::vec& responses,
296 arma::mat& dataSeq)
const;
309 void Forward(
const arma::mat& predictors,
310 const arma::vec& responses,
312 arma::mat& forwardProb)
const;
326 void Backward(
const arma::mat& predictors,
327 const arma::vec& responses,
328 const arma::vec& scales,
329 arma::mat& backwardProb)
const;
338 #include "hmm_regression_impl.hpp"
std::vector< distribution::RegressionDistribution > emission
Set of emission probability distributions; one for each state.
double Predict(const arma::mat &predictors, const arma::vec &responses, arma::Row< size_t > &stateSeq) const
Compute the most probable hidden state sequence for the given predictors and responses, using the Viterbi algorithm, returning the log-likelihood of the most likely state sequence.
The core includes that mlpack expects; standard C++ includes and Armadillo.
A class that represents a Hidden Markov Model Regression (HMMR).
arma::mat transition
Transition probability matrix.
void Smooth(const arma::mat &predictors, const arma::vec &responses, arma::vec &smoothSeq) const
HMM smoothing.
A class that represents a univariate conditionally Gaussian distribution.
A class that represents a Hidden Markov Model with an arbitrary type of emission distribution.
double Estimate(const arma::mat &predictors, const arma::vec &responses, arma::mat &stateProb, arma::mat &forwardProb, arma::mat &backwardProb, arma::vec &scales) const
Estimate the probabilities of each hidden state at each time step for each given data observation...
HMMRegression(const arma::vec &initial, const arma::mat &transition, const std::vector< distribution::RegressionDistribution > &emission, const double tolerance=1e-5)
Create the Hidden Markov Model Regression with the given initial probability vector, the given transition matrix, and the given regression emission distributions.
void Filter(const arma::mat &predictors, const arma::vec &responses, arma::vec &filterSeq, size_t ahead=0) const
HMMR filtering.
void Train(const std::vector< arma::mat > &predictors, const std::vector< arma::vec > &responses)
Train the model using the Baum-Welch algorithm, with only the given predictors and responses...
double LogLikelihood(const arma::mat &predictors, const arma::vec &responses) const
Compute the log-likelihood of the given predictors and responses.
HMMRegression(const size_t states, const distribution::RegressionDistribution emissions, const double tolerance=1e-5)
Create the Hidden Markov Model Regression with the given number of hidden states and the given defaul...