mlpack  2.0.1
serialization.hpp
Go to the documentation of this file.
1 
14 #ifndef __MLPACK_TESTS_SERIALIZATION_HPP
15 #define __MLPACK_TESTS_SERIALIZATION_HPP
16 
17 #include <boost/serialization/serialization.hpp>
18 #include <boost/archive/xml_iarchive.hpp>
19 #include <boost/archive/xml_oarchive.hpp>
20 #include <boost/archive/text_iarchive.hpp>
21 #include <boost/archive/text_oarchive.hpp>
22 #include <boost/archive/binary_iarchive.hpp>
23 #include <boost/archive/binary_oarchive.hpp>
24 #include <mlpack/core.hpp>
25 
26 #include <boost/test/unit_test.hpp>
28 
29 namespace mlpack {
30 
31 // Test function for loading and saving Armadillo objects.
32 template<typename MatType,
33  typename IArchiveType,
34  typename OArchiveType>
36 {
37  // First save it.
38  std::ofstream ofs("test", std::ios::binary);
39  OArchiveType o(ofs);
40 
41  bool success = true;
42  try
43  {
44  o << BOOST_SERIALIZATION_NVP(x);
45  }
46  catch (boost::archive::archive_exception& e)
47  {
48  success = false;
49  }
50 
51  BOOST_REQUIRE_EQUAL(success, true);
52  ofs.close();
53 
54  // Now load it.
55  MatType orig(x);
56  success = true;
57  std::ifstream ifs("test", std::ios::binary);
58  IArchiveType i(ifs);
59 
60  try
61  {
62  i >> BOOST_SERIALIZATION_NVP(x);
63  }
64  catch (boost::archive::archive_exception& e)
65  {
66  success = false;
67  }
68 
69  BOOST_REQUIRE_EQUAL(success, true);
70 
71  BOOST_REQUIRE_EQUAL(x.n_rows, orig.n_rows);
72  BOOST_REQUIRE_EQUAL(x.n_cols, orig.n_cols);
73  BOOST_REQUIRE_EQUAL(x.n_elem, orig.n_elem);
74 
75  for (size_t i = 0; i < x.n_cols; ++i)
76  for (size_t j = 0; j < x.n_rows; ++j)
77  if (double(orig(j, i)) == 0.0)
78  BOOST_REQUIRE_SMALL(double(x(j, i)), 1e-8);
79  else
80  BOOST_REQUIRE_CLOSE(double(orig(j, i)), double(x(j, i)), 1e-8);
81 
82  remove("test");
83 }
84 
85 // Test all serialization strategies.
86 template<typename MatType>
88 {
89  TestArmadilloSerialization<MatType, boost::archive::xml_iarchive,
90  boost::archive::xml_oarchive>(x);
91  TestArmadilloSerialization<MatType, boost::archive::text_iarchive,
92  boost::archive::text_oarchive>(x);
93  TestArmadilloSerialization<MatType, boost::archive::binary_iarchive,
94  boost::archive::binary_oarchive>(x);
95 }
96 
97 // Save and load an mlpack object.
98 // The re-loaded copy is placed in 'newT'.
99 template<typename T, typename IArchiveType, typename OArchiveType>
100 void SerializeObject(T& t, T& newT)
101 {
102  std::ofstream ofs("test", std::ios::binary);
103  OArchiveType o(ofs);
104 
105  bool success = true;
106  try
107  {
108  o << data::CreateNVP(t, "t");
109  }
110  catch (boost::archive::archive_exception& e)
111  {
112  success = false;
113  }
114  ofs.close();
115 
116  BOOST_REQUIRE_EQUAL(success, true);
117 
118  std::ifstream ifs("test", std::ios::binary);
119  IArchiveType i(ifs);
120 
121  try
122  {
123  i >> data::CreateNVP(newT, "t");
124  }
125  catch (boost::archive::archive_exception& e)
126  {
127  success = false;
128  }
129  ifs.close();
130 
131  BOOST_REQUIRE_EQUAL(success, true);
132 }
133 
134 // Test mlpack serialization with all three archive types.
135 template<typename T>
136 void SerializeObjectAll(T& t, T& xmlT, T& textT, T& binaryT)
137 {
138  SerializeObject<T, boost::archive::text_iarchive,
139  boost::archive::text_oarchive>(t, textT);
140  SerializeObject<T, boost::archive::binary_iarchive,
141  boost::archive::binary_oarchive>(t, binaryT);
142  SerializeObject<T, boost::archive::xml_iarchive,
143  boost::archive::xml_oarchive>(t, xmlT);
144 }
145 
146 // Save and load a non-default-constructible mlpack object.
147 template<typename T, typename IArchiveType, typename OArchiveType>
148 void SerializePointerObject(T* t, T*& newT)
149 {
150  std::ofstream ofs("test", std::ios::binary);
151  OArchiveType o(ofs);
152 
153  bool success = true;
154  try
155  {
156  o << data::CreateNVP(*t, "t");
157  }
158  catch (boost::archive::archive_exception& e)
159  {
160  success = false;
161  }
162  ofs.close();
163 
164  BOOST_REQUIRE_EQUAL(success, true);
165 
166  std::ifstream ifs("test", std::ios::binary);
167  IArchiveType i(ifs);
168 
169  try
170  {
171  newT = new T(i);
172  }
173  catch (std::exception& e)
174  {
175  success = false;
176  }
177  ifs.close();
178 
179  BOOST_REQUIRE_EQUAL(success, true);
180 }
181 
182 template<typename T>
183 void SerializePointerObjectAll(T* t, T*& xmlT, T*& textT, T*& binaryT)
184 {
185  SerializePointerObject<T, boost::archive::text_iarchive,
186  boost::archive::text_oarchive>(t, textT);
187  SerializePointerObject<T, boost::archive::binary_iarchive,
188  boost::archive::binary_oarchive>(t, binaryT);
189  SerializePointerObject<T, boost::archive::xml_iarchive,
190  boost::archive::xml_oarchive>(t, xmlT);
191 }
192 
193 // Utility function to check the equality of two Armadillo matrices.
194 void CheckMatrices(const arma::mat& x,
195  const arma::mat& xmlX,
196  const arma::mat& textX,
197  const arma::mat& binaryX);
198 
199 void CheckMatrices(const arma::Mat<size_t>& x,
200  const arma::Mat<size_t>& xmlX,
201  const arma::Mat<size_t>& textX,
202  const arma::Mat<size_t>& binaryX);
203 
204 } // namespace mlpack
205 
206 #endif
void SerializePointerObject(T *t, T *&newT)
Linear algebra utility functions, generally performed on matrices or vectors.
FirstShim< T > CreateNVP(T &t, const std::string &name, typename boost::enable_if< HasSerialize< T >>::type *=0)
Call this function to produce a name-value pair; this is similar to BOOST_SERIALIZATION_NVP(), but should be used for types that have a Serialize() function (or contain a type that has a Serialize() function) instead of a serialize() function.
void CheckMatrices(const arma::mat &x, const arma::mat &xmlX, const arma::mat &textX, const arma::mat &binaryX)
void TestAllArmadilloSerialization(MatType &x)
void SerializePointerObjectAll(T *t, T *&xmlT, T *&textT, T *&binaryT)
void SerializeObjectAll(T &t, T &xmlT, T &textT, T &binaryT)
Include all of the base components required to write MLPACK methods, and the main MLPACK Doxygen docu...
void SerializeObject(T &t, T &newT)
void TestArmadilloSerialization(MatType &x)