MLPACK
1.0.11
Main Page
Related Pages
Namespaces
Classes
Files
File List
File Members
src
mlpack
methods
amf
termination_policies
simple_tolerance_termination.hpp
Go to the documentation of this file.
1
22
#ifndef _MLPACK_METHODS_AMF_SIMPLE_TOLERANCE_TERMINATION_HPP_INCLUDED
23
#define _MLPACK_METHODS_AMF_SIMPLE_TOLERANCE_TERMINATION_HPP_INCLUDED
24
25
#include <
mlpack/core.hpp
>
26
27
namespace
mlpack {
28
namespace
amf {
29
40
template
<
class
MatType>
41
class
SimpleToleranceTermination
42
{
43
public
:
45
SimpleToleranceTermination
(
const
double
tolerance
= 1e-5,
46
const
size_t
maxIterations
= 10000,
47
const
size_t
reverseStepTolerance
= 3)
48
:
tolerance
(
tolerance
),
49
maxIterations
(
maxIterations
),
50
reverseStepTolerance
(
reverseStepTolerance
) {}
51
57
void
Initialize
(
const
MatType&
V
)
58
{
59
residueOld
= DBL_MAX;
60
iteration
= 1;
61
residue
= DBL_MIN;
62
reverseStepCount
= 0;
63
isCopy
=
false
;
64
65
this->V = &
V
;
66
67
c_index
= 0;
68
c_indexOld
= 0;
69
70
reverseStepCount
= 0;
71
}
72
79
bool
IsConverged
(arma::mat&
W
, arma::mat&
H
)
80
{
81
arma::mat WH;
82
83
WH = W *
H
;
84
85
// compute residue
86
residueOld
=
residue
;
87
size_t
n =
V
->n_rows;
88
size_t
m =
V
->n_cols;
89
double
sum = 0;
90
size_t
count = 0;
91
for
(
size_t
i = 0;i < n;i++)
92
{
93
for
(
size_t
j = 0;j < m;j++)
94
{
95
double
temp = 0;
96
if
((temp = (*
V
)(i,j)) != 0)
97
{
98
temp = (temp - WH(i, j));
99
temp = temp * temp;
100
sum += temp;
101
count++;
102
}
103
}
104
}
105
residue
= sum / count;
106
residue
= sqrt(
residue
);
107
108
// increment iteration count
109
iteration
++;
110
111
// if residue tolerance is not satisfied
112
if
((
residueOld
-
residue
) / residueOld < tolerance && iteration > 4)
113
{
114
// check if this is a first of successive drops
115
if
(
reverseStepCount
== 0 &&
isCopy
==
false
)
116
{
117
// store a copy of W and H matrix
118
isCopy
=
true
;
119
this->W =
W
;
120
this->H =
H
;
121
// store residue values
122
c_index
=
residue
;
123
c_indexOld
=
residueOld
;
124
}
125
// increase successive drop count
126
reverseStepCount
++;
127
}
128
// if tolerance is satisfied
129
else
130
{
131
// initialize successive drop count
132
reverseStepCount
= 0;
133
// if residue is droped below minimum scrap stored values
134
if
(
residue
<=
c_indexOld
&&
isCopy
==
true
)
135
{
136
isCopy
=
false
;
137
}
138
}
139
140
// check if termination criterion is met
141
if
(
reverseStepCount
==
reverseStepTolerance
||
iteration
>
maxIterations
)
142
{
143
// if stored values are present replace them with current value as they
144
// represent the minimum residue point
145
if
(
isCopy
)
146
{
147
W = this->
W
;
148
H = this->
H
;
149
residue
=
c_index
;
150
}
151
return
true
;
152
}
153
else
return
false
;
154
}
155
157
const
double
&
Index
()
const
{
return
residue
; }
158
160
const
size_t
&
Iteration
()
const
{
return
iteration
; }
161
163
const
size_t
&
MaxIterations
()
const
{
return
maxIterations
; }
164
size_t
&
MaxIterations
() {
return
maxIterations
; }
165
167
const
double
&
Tolerance
()
const
{
return
tolerance
; }
168
double
&
Tolerance
() {
return
tolerance
; }
169
170
private
:
172
double
tolerance
;
174
size_t
maxIterations
;
175
177
const
MatType*
V
;
178
180
size_t
iteration
;
181
183
double
residueOld
;
184
double
residue
;
185
double
normOld
;
186
188
size_t
reverseStepTolerance
;
190
size_t
reverseStepCount
;
191
194
bool
isCopy
;
195
197
arma::mat
W
;
198
arma::mat
H
;
199
double
c_indexOld
;
200
double
c_index
;
201
};
// class SimpleToleranceTermination
202
203
};
// namespace amf
204
};
// namespace mlpack
205
206
#endif // _MLPACK_METHODS_AMF_SIMPLE_TOLERANCE_TERMINATION_HPP_INCLUDED
207
Generated by
1.8.3.1