mlpack  2.2.5
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
load_csv.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_DATA_LOAD_CSV_HPP
13 #define MLPACK_CORE_DATA_LOAD_CSV_HPP
14 
15 #include <boost/spirit/include/qi.hpp>
16 #include <boost/algorithm/string/trim.hpp>
17 
18 #include <mlpack/core.hpp>
19 #include <mlpack/core/util/log.hpp>
20 
21 #include <set>
22 #include <string>
23 
24 #include "extension.hpp"
25 #include "format.hpp"
26 #include "dataset_mapper.hpp"
27 
28 namespace mlpack {
29 namespace data {
30 
36 class LoadCSV
37 {
38 public:
39  explicit LoadCSV(std::string file, bool fatal = false);
40 
41  template<typename T, typename PolicyType>
42  void Load(arma::Mat<T> &inout, DatasetMapper<PolicyType> &infoSet, bool transpose = true)
43  {
44  if(!CanOpen())
45  {
46  return;
47  }
48 
49  if(transpose)
50  {
51  TranposeParse(inout, infoSet);
52  }
53  else
54  {
55  NonTranposeParse(inout, infoSet);
56  }
57  }
58 
59  size_t ColSize();
60  size_t RowSize();
61 
72  template<typename T, typename MapPolicy>
73  void GetMatrixSize(size_t& rows, size_t& cols, DatasetMapper<MapPolicy>& info)
74  {
75  // Take a pass through the file. If the DatasetMapper policy requires it,
76  // we will pass everything string through MapString(). This might be useful
77  // if, e.g., the MapPolicy needs to find which dimensions are numeric or
78  // categorical.
79 
80  // Reset to the start of the file.
81  inFile.clear();
82  inFile.seekg(0, std::ios::beg);
83  rows = 0;
84  cols = 0;
85 
86  // First, count the number of rows in the file (this is the dimensionality).
87  std::string line;
88  while (std::getline(inFile, line))
89  {
90  ++rows;
91  }
92  info = DatasetMapper<MapPolicy>(rows);
93 
94  // Now, jump back to the beginning of the file.
95  inFile.clear();
96  inFile.seekg(0, std::ios::beg);
97  rows = 0;
98  while (std::getline(inFile, line))
99  {
100  ++rows;
101 
102  if (rows == 1)
103  {
104  // Extract the number of columns.
105  auto findColSize = [&cols](iter_type) { ++cols; };
106  boost::spirit::qi::phrase_parse(line.begin(), line.end(),
107  CreateCharRule()[findColSize] % ",", boost::spirit::ascii::space);
108  }
109 
110  // I guess this is technically a second pass, but that's ok... still the
111  // same idea...
112  if (MapPolicy::NeedsFirstPass)
113  {
114  // In this case we must pass everything we parse to the MapPolicy.
115  auto firstPassMap = [&](const iter_type& iter)
116  {
117  std::string str(iter.begin(), iter.end());
118  if (str == "\t")
119  str.clear();
120  boost::trim(str);
121 
122  info.template MapFirstPass<T>(std::move(str), rows - 1);
123  };
124 
125  // Now parse the line.
126  boost::spirit::qi::phrase_parse(line.begin(), line.end(),
127  CreateCharRule()[firstPassMap] % ",", boost::spirit::ascii::space);
128  }
129  }
130  }
131 
132  template<typename T, typename MapPolicy>
133  void GetTransposeMatrixSize(size_t& rows, size_t& cols, DatasetMapper<MapPolicy>& info)
134  {
135  // Take a pass through the file. If the DatasetMapper policy requires it,
136  // we will pass everything string through MapString(). This might be useful
137  // if, e.g., the MapPolicy needs to find which dimensions are numeric or
138  // categorical.
139 
140  // Reset to the start of the file.
141  inFile.clear();
142  inFile.seekg(0, std::ios::beg);
143  rows = 0;
144  cols = 0;
145 
146  std::string line;
147  while (std::getline(inFile, line))
148  {
149  ++cols;
150 
151  if (cols == 1)
152  {
153  // Extract the number of dimensions.
154  auto findRowSize = [&rows](iter_type) { ++rows; };
155  boost::spirit::qi::phrase_parse(line.begin(), line.end(),
156  CreateCharRule()[findRowSize] % ",", boost::spirit::ascii::space);
157 
158  // Now that we know the dimensionality, initialize the DatasetMapper.
159  info = DatasetMapper<MapPolicy>(rows);
160  }
161 
162  // If we need to do a first pass for the DatasetMapper, do it.
163  if (MapPolicy::NeedsFirstPass)
164  {
165  size_t dim = 0;
166 
167  // In this case we must pass everything we parse to the MapPolicy.
168  auto firstPassMap = [&](const iter_type& iter)
169  {
170  std::string str(iter.begin(), iter.end());
171  if (str == "\t")
172  str.clear();
173  boost::trim(str);
174 
175  info.template MapFirstPass<T>(std::move(str), dim++);
176  };
177 
178  // Now parse the line.
179  boost::spirit::qi::phrase_parse(line.begin(), line.end(),
180  CreateCharRule()[firstPassMap] % ",", boost::spirit::ascii::space);
181  }
182  }
183  }
184 
185 private:
186  using iter_type = boost::iterator_range<std::string::iterator>;
187 
188  struct ElemParser
189  {
190  //return int_parser if the type of T is_integral
191  template<typename T>
192  static typename std::enable_if<std::is_integral<T>::value,
193  boost::spirit::qi::int_parser<T>>::type
194  Parser()
195  {
196  return boost::spirit::qi::int_parser<T>();
197  }
198 
199  //return real_parser if T is floating_point
200  template<typename T>
201  static typename std::enable_if<std::is_floating_point<T>::value,
202  boost::spirit::qi::real_parser<T>>::type
203  Parser()
204  {
205  return boost::spirit::qi::real_parser<T>();
206  }
207  };
208 
209  bool CanOpen();
210 
211  template<typename T, typename PolicyType>
212  void NonTranposeParse(arma::Mat<T> &inout, DatasetMapper<PolicyType> &infoSet)
213  {
214  using namespace boost::spirit;
215 
216  // Get the size of the matrix.
217  size_t rows, cols;
218  GetMatrixSize<T>(rows, cols, infoSet);
219 
220  // Set up output matrix.
221  inout.set_size(rows, cols);
222  size_t row = 0;
223  size_t col = 0;
224 
225  // Reset file position.
226  std::string line;
227  inFile.clear();
228  inFile.seekg(0, std::ios::beg);
229 
230  auto setCharClass = [&](iter_type const &iter)
231  {
232  std::string str(iter.begin(), iter.end());
233  if (str == "\t")
234  {
235  str.clear();
236  }
237  boost::trim(str);
238 
239  inout(row, col++) = infoSet.template MapString<T>(std::move(str), row);
240  };
241 
242  auto charRule = CreateCharRule();
243  while (std::getline(inFile, line))
244  {
245  //parse the numbers from a line(ex : 1,2,3,4), if the parser find the number
246  //it will execute the setNum function
247  const bool canParse = qi::phrase_parse(line.begin(), line.end(),
248  charRule[setCharClass] % ",", ascii::space);
249 
250  if (!canParse)
251  {
252  throw std::runtime_error("LoadCSV cannot parse categories");
253  }
254 
255  ++row; col = 0;
256  }
257  }
258 
259  template<typename T, typename PolicyType>
260  void TranposeParse(arma::Mat<T> &inout, DatasetMapper<PolicyType> &infoSet)
261  {
262  // Get matrix size. This also initializes infoSet correctly.
263  size_t rows, cols;
264  GetTransposeMatrixSize<T>(rows, cols, infoSet);
265 
266  // Set the matrix size.
267  inout.set_size(rows, cols);
268  TranposeParseImpl(inout, infoSet);
269  }
270 
271  template<typename T, typename PolicyType>
272  bool TranposeParseImpl(arma::Mat<T>& inout,
273  DatasetMapper<PolicyType>& infoSet)
274  {
275  using namespace boost::spirit;
276 
277  size_t row = 0;
278  size_t col = 0;
279  std::string line;
280  inFile.clear();
281  inFile.seekg(0, std::ios::beg);
282 
283  auto setCharClass = [&](iter_type const &iter)
284  {
285  // All parsed values must be mapped.
286  std::string str(iter.begin(), iter.end());
287  if (str == "\t")
288  str.clear();
289  boost::trim(str);
290 
291  inout(row, col) = infoSet.template MapString<T>(std::move(str), row);
292  ++row;
293  };
294 
295  auto charRule = CreateCharRule();
296  while (std::getline(inFile, line))
297  {
298  row = 0;
299  //parse number of characters from a line, it will execute setNum if it is number,
300  //else execute setCharClass, "|" means "if not a, then b"
301  // Assemble the rule
302 
303  const bool canParse = qi::phrase_parse(line.begin(), line.end(),
304  charRule[setCharClass] % ",",
305  ascii::space);
306  if(!canParse)
307  {
308  throw std::runtime_error("LoadCSV cannot parse categories");
309  }
310  ++col;
311  }
312 
313  return true;
314  }
315 
316  template<typename T>
317  boost::spirit::qi::rule<std::string::iterator, T(), boost::spirit::ascii::space_type>
318  CreateNumRule() const
319  {
320  using namespace boost::spirit;
321 
322  //elemParser will generate integer or real parser based on T
323  auto elemParser = ElemParser::Parser<T>();
324  //qi::skip can specify which characters you want to skip,
325  //in this example, elemParser will parse int or double value,
326  //we use qi::skip to skip space
327 
328  //qi::omit can omit the attributes of spirit, every parser of spirit
329  //has attribute(the type will pass into actions(functor))
330  //if you do not omit it, the attribute combine with attribute may
331  //change the attribute
332 
333  //input like 2-200 or 2DM will make the parser fail,
334  //so we use "look ahead parser--&" to make sure next
335  //character is "," or end of line(eof) or end of file(eoi)
336  //looks ahead parser will not consume any input or generate
337  //any attribute
338  if(extension == "csv" || extension == "txt")
339  {
340  return elemParser >> &(qi::lit(",") | qi::eol | qi::eoi);
341  }
342  else
343  {
344  return elemParser >> &(qi::lit("\t") | qi::eol | qi::eoi);
345  }
346  }
347 
348  boost::spirit::qi::rule<std::string::iterator, iter_type(), boost::spirit::ascii::space_type>
349  CreateCharRule() const;
350 
351  std::string extension;
352  bool fatalIfOpenFail;
353  std::string fileName;
354  std::ifstream inFile;
355 };
356 
357 } // namespace data
358 } // namespace mlpack
359 
360 #endif
Auxiliary information for a dataset, including mappings to/from strings and the datatype of each dime...
Load the csv file.This class use boost::spirit to implement the parser, please refer to following lin...
Definition: load_csv.hpp:36
void Load(arma::Mat< T > &inout, DatasetMapper< PolicyType > &infoSet, bool transpose=true)
Definition: load_csv.hpp:42
void GetTransposeMatrixSize(size_t &rows, size_t &cols, DatasetMapper< MapPolicy > &info)
Definition: load_csv.hpp:133
void GetMatrixSize(size_t &rows, size_t &cols, DatasetMapper< MapPolicy > &info)
Peek at the file to determine the number of rows and columns in the matrix, assuming a non-transposed...
Definition: load_csv.hpp:73
Include all of the base components required to write MLPACK methods, and the main MLPACK Doxygen docu...
LoadCSV(std::string file, bool fatal=false)