Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
a7660331a
tesseract
提交
4e9665de
T
tesseract
项目概览
a7660331a
/
tesseract
与 Fork 源项目一致
从无法访问的项目Fork
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tesseract
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
4e9665de
编写于
8月 02, 2017
作者:
R
Ray Smith
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Added ADAM optimizer, unless git screwed it up, cos there is no diff
上级
2633fef0
变更
21
隐藏空白更改
内联
并排
Showing
21 changed file
with
386 addition
and
130 deletion
+386
-130
arch/Makefile.am
arch/Makefile.am
+1
-1
arch/simddetect.cpp
arch/simddetect.cpp
+14
-0
arch/simddetect.h
arch/simddetect.h
+13
-0
ccstruct/matrix.h
ccstruct/matrix.h
+9
-6
lstm/convolve.cpp
lstm/convolve.cpp
+1
-1
lstm/fullyconnected.cpp
lstm/fullyconnected.cpp
+18
-7
lstm/fullyconnected.h
lstm/fullyconnected.h
+10
-4
lstm/lstm.cpp
lstm/lstm.cpp
+21
-13
lstm/lstm.h
lstm/lstm.h
+10
-4
lstm/lstmrecognizer.cpp
lstm/lstmrecognizer.cpp
+15
-8
lstm/lstmrecognizer.h
lstm/lstmrecognizer.h
+10
-20
lstm/lstmtrainer.cpp
lstm/lstmtrainer.cpp
+90
-13
lstm/lstmtrainer.h
lstm/lstmtrainer.h
+15
-1
lstm/network.h
lstm/network.h
+13
-5
lstm/plumbing.cpp
lstm/plumbing.cpp
+18
-5
lstm/plumbing.h
lstm/plumbing.h
+10
-4
lstm/series.cpp
lstm/series.cpp
+20
-1
lstm/series.h
lstm/series.h
+6
-0
lstm/weightmatrix.cpp
lstm/weightmatrix.cpp
+59
-20
lstm/weightmatrix.h
lstm/weightmatrix.h
+16
-10
training/lstmtraining.cpp
training/lstmtraining.cpp
+17
-7
未找到文件。
arch/Makefile.am
浏览文件 @
4e9665de
AM_CPPFLAGS
+=
-I
$(top_srcdir)
/ccutil
-I
$(top_srcdir)
/viewer
AM_CPPFLAGS
+=
-I
$(top_srcdir)
/ccutil
-I
$(top_srcdir)
/viewer
-DUSE_STD_NAMESPACE
AUTOMAKE_OPTIONS
=
subdir-objects
AUTOMAKE_OPTIONS
=
subdir-objects
SUBDIRS
=
SUBDIRS
=
AM_CXXFLAGS
=
AM_CXXFLAGS
=
...
...
arch/simddetect.cpp
浏览文件 @
4e9665de
...
@@ -37,6 +37,9 @@ SIMDDetect SIMDDetect::detector;
...
@@ -37,6 +37,9 @@ SIMDDetect SIMDDetect::detector;
// If true, then AVX has been detected.
// If true, then AVX has been detected.
bool
SIMDDetect
::
avx_available_
;
bool
SIMDDetect
::
avx_available_
;
bool
SIMDDetect
::
avx2_available_
;
bool
SIMDDetect
::
avx512F_available_
;
bool
SIMDDetect
::
avx512BW_available_
;
// If true, then SSe4.1 has been detected.
// If true, then SSe4.1 has been detected.
bool
SIMDDetect
::
sse_available_
;
bool
SIMDDetect
::
sse_available_
;
...
@@ -50,8 +53,19 @@ SIMDDetect::SIMDDetect() {
...
@@ -50,8 +53,19 @@ SIMDDetect::SIMDDetect() {
#if defined(__GNUC__)
#if defined(__GNUC__)
unsigned
int
eax
,
ebx
,
ecx
,
edx
;
unsigned
int
eax
,
ebx
,
ecx
,
edx
;
if
(
__get_cpuid
(
1
,
&
eax
,
&
ebx
,
&
ecx
,
&
edx
)
!=
0
)
{
if
(
__get_cpuid
(
1
,
&
eax
,
&
ebx
,
&
ecx
,
&
edx
)
!=
0
)
{
// Note that these tests all use hex because the older compilers don't have
// the newer flags.
sse_available_
=
(
ecx
&
0x00080000
)
!=
0
;
sse_available_
=
(
ecx
&
0x00080000
)
!=
0
;
avx_available_
=
(
ecx
&
0x10000000
)
!=
0
;
avx_available_
=
(
ecx
&
0x10000000
)
!=
0
;
if
(
avx_available_
)
{
// There is supposed to be a __get_cpuid_count function, but this is all
// there is in my cpuid.h. It is a macro for an asm statement and cannot
// be used inside an if.
__cpuid_count
(
7
,
0
,
eax
,
ebx
,
ecx
,
edx
);
avx2_available_
=
(
ebx
&
0x00000020
)
!=
0
;
avx512F_available_
=
(
ebx
&
0x00010000
)
!=
0
;
avx512BW_available_
=
(
ebx
&
0x40000000
)
!=
0
;
}
}
}
#elif defined(_WIN32)
#elif defined(_WIN32)
int
cpuInfo
[
4
];
int
cpuInfo
[
4
];
...
...
arch/simddetect.h
浏览文件 @
4e9665de
...
@@ -24,6 +24,16 @@ class SIMDDetect {
...
@@ -24,6 +24,16 @@ class SIMDDetect {
public:
public:
// Returns true if AVX is available on this system.
// Returns true if AVX is available on this system.
static
inline
bool
IsAVXAvailable
()
{
return
detector
.
avx_available_
;
}
static
inline
bool
IsAVXAvailable
()
{
return
detector
.
avx_available_
;
}
// Returns true if AVX2 (integer support) is available on this system.
static
inline
bool
IsAVX2Available
()
{
return
detector
.
avx2_available_
;
}
// Returns true if AVX512 Foundation (float) is available on this system.
static
inline
bool
IsAVX512FAvailable
()
{
return
detector
.
avx512F_available_
;
}
// Returns true if AVX512 integer is available on this system.
static
inline
bool
IsAVX512BWAvailable
()
{
return
detector
.
avx512BW_available_
;
}
// Returns true if SSE4.1 is available on this system.
// Returns true if SSE4.1 is available on this system.
static
inline
bool
IsSSEAvailable
()
{
return
detector
.
sse_available_
;
}
static
inline
bool
IsSSEAvailable
()
{
return
detector
.
sse_available_
;
}
...
@@ -36,6 +46,9 @@ class SIMDDetect {
...
@@ -36,6 +46,9 @@ class SIMDDetect {
static
SIMDDetect
detector
;
static
SIMDDetect
detector
;
// If true, then AVX has been detected.
// If true, then AVX has been detected.
static
TESS_API
bool
avx_available_
;
static
TESS_API
bool
avx_available_
;
static
TESS_API
bool
avx2_available_
;
static
TESS_API
bool
avx512F_available_
;
static
TESS_API
bool
avx512BW_available_
;
// If true, then SSe4.1 has been detected.
// If true, then SSe4.1 has been detected.
static
TESS_API
bool
sse_available_
;
static
TESS_API
bool
sse_available_
;
};
};
ccstruct/matrix.h
浏览文件 @
4e9665de
...
@@ -360,19 +360,22 @@ class GENERIC_2D_ARRAY {
...
@@ -360,19 +360,22 @@ class GENERIC_2D_ARRAY {
}
}
// Accumulates the element-wise sums of squares of src into *this.
// Accumulates the element-wise sums of squares of src into *this.
void
SumSquares
(
const
GENERIC_2D_ARRAY
<
T
>&
src
)
{
void
SumSquares
(
const
GENERIC_2D_ARRAY
<
T
>&
src
,
T
decay_factor
)
{
T
update_factor
=
1.0
-
decay_factor
;
int
size
=
num_elements
();
int
size
=
num_elements
();
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
array_
[
i
]
+=
src
.
array_
[
i
]
*
src
.
array_
[
i
];
array_
[
i
]
=
array_
[
i
]
*
decay_factor
+
update_factor
*
src
.
array_
[
i
]
*
src
.
array_
[
i
];
}
}
}
}
// Scales each element using the ada-grad algorithm, ie array_[i] by
// Scales each element using the adam algorithm, ie array_[i] by
// sqrt(num_samples/max(1,sqsum[i])).
// sqrt(sqsum[i] + epsilon)).
void
AdaGradScaling
(
const
GENERIC_2D_ARRAY
<
T
>&
sqsum
,
int
num_samples
)
{
void
AdamUpdate
(
const
GENERIC_2D_ARRAY
<
T
>&
sum
,
const
GENERIC_2D_ARRAY
<
T
>&
sqsum
,
T
epsilon
)
{
int
size
=
num_elements
();
int
size
=
num_elements
();
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
array_
[
i
]
*=
sqrt
(
num_samples
/
MAX
(
1.0
,
sqsum
.
array_
[
i
])
);
array_
[
i
]
+=
sum
.
array_
[
i
]
/
(
sqrt
(
sqsum
.
array_
[
i
])
+
epsilon
);
}
}
}
}
...
...
lstm/convolve.cpp
浏览文件 @
4e9665de
...
@@ -112,7 +112,7 @@ bool Convolve::Backward(bool debug, const NetworkIO& fwd_deltas,
...
@@ -112,7 +112,7 @@ bool Convolve::Backward(bool debug, const NetworkIO& fwd_deltas,
}
}
}
}
}
while
(
src_index
.
Increment
());
}
while
(
src_index
.
Increment
());
back_deltas
->
Copy
WithNormalization
(
*
delta_sum
,
fwd_deltas
);
back_deltas
->
Copy
All
(
*
delta_sum
);
return
true
;
return
true
;
}
}
...
...
lstm/fullyconnected.cpp
浏览文件 @
4e9665de
...
@@ -79,11 +79,24 @@ void FullyConnected::SetEnableTraining(TrainingState state) {
...
@@ -79,11 +79,24 @@ void FullyConnected::SetEnableTraining(TrainingState state) {
// scale `range` picked according to the random number generator `randomizer`.
// scale `range` picked according to the random number generator `randomizer`.
int
FullyConnected
::
InitWeights
(
float
range
,
TRand
*
randomizer
)
{
int
FullyConnected
::
InitWeights
(
float
range
,
TRand
*
randomizer
)
{
Network
::
SetRandomizer
(
randomizer
);
Network
::
SetRandomizer
(
randomizer
);
num_weights_
=
weights_
.
InitWeightsFloat
(
no_
,
ni_
+
1
,
TestFlag
(
NF_ADA
_GRAD
),
num_weights_
=
weights_
.
InitWeightsFloat
(
no_
,
ni_
+
1
,
TestFlag
(
NF_ADA
M
),
range
,
randomizer
);
range
,
randomizer
);
return
num_weights_
;
return
num_weights_
;
}
}
// Changes the number of outputs to the size of the given code_map, copying
// the old weight matrix entries for each output from code_map[output] where
// non-negative, and uses the mean (over all outputs) of the existing weights
// for all outputs with negative code_map entries. Returns the new number of
// weights. Only operates on Softmax layers with old_no outputs.
int
FullyConnected
::
RemapOutputs
(
int
old_no
,
const
std
::
vector
<
int
>&
code_map
)
{
if
(
type_
==
NT_SOFTMAX
&&
no_
==
old_no
)
{
num_weights_
=
weights_
.
RemapOutputs
(
code_map
);
no_
=
code_map
.
size
();
}
return
num_weights_
;
}
// Converts a float network to an int network.
// Converts a float network to an int network.
void
FullyConnected
::
ConvertToInt
()
{
void
FullyConnected
::
ConvertToInt
()
{
weights_
.
ConvertToInt
();
weights_
.
ConvertToInt
();
...
@@ -240,7 +253,6 @@ bool FullyConnected::Backward(bool debug, const NetworkIO& fwd_deltas,
...
@@ -240,7 +253,6 @@ bool FullyConnected::Backward(bool debug, const NetworkIO& fwd_deltas,
FinishBackward
(
*
errors_t
.
get
());
FinishBackward
(
*
errors_t
.
get
());
if
(
needs_to_backprop_
)
{
if
(
needs_to_backprop_
)
{
back_deltas
->
ZeroInvalidElements
();
back_deltas
->
ZeroInvalidElements
();
back_deltas
->
CopyWithNormalization
(
*
back_deltas
,
fwd_deltas
);
#if DEBUG_DETAIL > 0
#if DEBUG_DETAIL > 0
tprintf
(
"F Backprop:%s
\n
"
,
name_
.
string
());
tprintf
(
"F Backprop:%s
\n
"
,
name_
.
string
());
back_deltas
->
Print
(
10
);
back_deltas
->
Print
(
10
);
...
@@ -281,12 +293,11 @@ void FullyConnected::FinishBackward(const TransposedArray& errors_t) {
...
@@ -281,12 +293,11 @@ void FullyConnected::FinishBackward(const TransposedArray& errors_t) {
weights_
.
SumOuterTransposed
(
errors_t
,
*
external_source_
,
true
);
weights_
.
SumOuterTransposed
(
errors_t
,
*
external_source_
,
true
);
}
}
// Updates the weights using the given learning rate and momentum.
// Updates the weights using the given learning rate, momentum and adam_beta.
// num_samples is the quotient to be used in the adagrad computation iff
// num_samples is used in the adam computation iff use_adam_ is true.
// use_ada_grad_ is true.
void
FullyConnected
::
Update
(
float
learning_rate
,
float
momentum
,
void
FullyConnected
::
Update
(
float
learning_rate
,
float
momentum
,
int
num_samples
)
{
float
adam_beta
,
int
num_samples
)
{
weights_
.
Update
(
learning_rate
,
momentum
,
num_samples
);
weights_
.
Update
(
learning_rate
,
momentum
,
adam_beta
,
num_samples
);
}
}
// Sums the products of weight updates in *this and other, splitting into
// Sums the products of weight updates in *this and other, splitting into
...
...
lstm/fullyconnected.h
浏览文件 @
4e9665de
...
@@ -68,6 +68,12 @@ class FullyConnected : public Network {
...
@@ -68,6 +68,12 @@ class FullyConnected : public Network {
// Sets up the network for training. Initializes weights using weights of
// Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`.
// scale `range` picked according to the random number generator `randomizer`.
virtual
int
InitWeights
(
float
range
,
TRand
*
randomizer
);
virtual
int
InitWeights
(
float
range
,
TRand
*
randomizer
);
// Changes the number of outputs to the size of the given code_map, copying
// the old weight matrix entries for each output from code_map[output] where
// non-negative, and uses the mean (over all outputs) of the existing weights
// for all outputs with negative code_map entries. Returns the new number of
// weights. Only operates on Softmax layers with old_no outputs.
int
RemapOutputs
(
int
old_no
,
const
std
::
vector
<
int
>&
code_map
)
override
;
// Converts a float network to an int network.
// Converts a float network to an int network.
virtual
void
ConvertToInt
();
virtual
void
ConvertToInt
();
...
@@ -101,10 +107,10 @@ class FullyConnected : public Network {
...
@@ -101,10 +107,10 @@ class FullyConnected : public Network {
TransposedArray
*
errors_t
,
double
*
backprop
);
TransposedArray
*
errors_t
,
double
*
backprop
);
void
FinishBackward
(
const
TransposedArray
&
errors_t
);
void
FinishBackward
(
const
TransposedArray
&
errors_t
);
// Updates the weights using the given learning rate
and momentum
.
// Updates the weights using the given learning rate
, momentum and adam_beta
.
// num_samples is
the quotient to be used in the adagrad computation iff
// num_samples is
used in the adam computation iff use_adam_ is true.
// use_ada_grad_ is true.
void
Update
(
float
learning_rate
,
float
momentum
,
float
adam_beta
,
virtual
void
Update
(
float
learning_rate
,
float
momentum
,
int
num_samples
)
;
int
num_samples
)
override
;
// Sums the products of weight updates in *this and other, splitting into
// Sums the products of weight updates in *this and other, splitting into
// positive (same direction) in *same and negative (different direction) in
// positive (same direction) in *same and negative (different direction) in
// *changed.
// *changed.
...
...
lstm/lstm.cpp
浏览文件 @
4e9665de
...
@@ -132,7 +132,7 @@ int LSTM::InitWeights(float range, TRand* randomizer) {
...
@@ -132,7 +132,7 @@ int LSTM::InitWeights(float range, TRand* randomizer) {
for
(
int
w
=
0
;
w
<
WT_COUNT
;
++
w
)
{
for
(
int
w
=
0
;
w
<
WT_COUNT
;
++
w
)
{
if
(
w
==
GFS
&&
!
Is2D
())
continue
;
if
(
w
==
GFS
&&
!
Is2D
())
continue
;
num_weights_
+=
gate_weights_
[
w
].
InitWeightsFloat
(
num_weights_
+=
gate_weights_
[
w
].
InitWeightsFloat
(
ns_
,
na_
+
1
,
TestFlag
(
NF_ADA
_GRAD
),
range
,
randomizer
);
ns_
,
na_
+
1
,
TestFlag
(
NF_ADA
M
),
range
,
randomizer
);
}
}
if
(
softmax_
!=
NULL
)
{
if
(
softmax_
!=
NULL
)
{
num_weights_
+=
softmax_
->
InitWeights
(
range
,
randomizer
);
num_weights_
+=
softmax_
->
InitWeights
(
range
,
randomizer
);
...
@@ -140,6 +140,19 @@ int LSTM::InitWeights(float range, TRand* randomizer) {
...
@@ -140,6 +140,19 @@ int LSTM::InitWeights(float range, TRand* randomizer) {
return
num_weights_
;
return
num_weights_
;
}
}
// Changes the number of outputs to the size of the given code_map, copying
// the old weight matrix entries for each output from code_map[output] where
// non-negative, and uses the mean (over all outputs) of the existing weights
// for all outputs with negative code_map entries. Returns the new number of
// weights. Only operates on Softmax layers with old_no outputs.
int
LSTM
::
RemapOutputs
(
int
old_no
,
const
std
::
vector
<
int
>&
code_map
)
{
if
(
softmax_
!=
NULL
)
{
num_weights_
-=
softmax_
->
num_weights
();
num_weights_
+=
softmax_
->
RemapOutputs
(
old_no
,
code_map
);
}
return
num_weights_
;
}
// Converts a float network to an int network.
// Converts a float network to an int network.
void
LSTM
::
ConvertToInt
()
{
void
LSTM
::
ConvertToInt
()
{
for
(
int
w
=
0
;
w
<
WT_COUNT
;
++
w
)
{
for
(
int
w
=
0
;
w
<
WT_COUNT
;
++
w
)
{
...
@@ -618,27 +631,22 @@ bool LSTM::Backward(bool debug, const NetworkIO& fwd_deltas,
...
@@ -618,27 +631,22 @@ bool LSTM::Backward(bool debug, const NetworkIO& fwd_deltas,
if
(
softmax_
!=
NULL
)
{
if
(
softmax_
!=
NULL
)
{
softmax_
->
FinishBackward
(
*
softmax_errors_t
);
softmax_
->
FinishBackward
(
*
softmax_errors_t
);
}
}
if
(
needs_to_backprop_
)
{
return
needs_to_backprop_
;
// Normalize the inputerr in back_deltas.
back_deltas
->
CopyWithNormalization
(
*
back_deltas
,
fwd_deltas
);
return
true
;
}
return
false
;
}
}
// Updates the weights using the given learning rate
and momentum
.
// Updates the weights using the given learning rate
, momentum and adam_beta
.
// num_samples is
the quotient to be used in the adagrad computation iff
// num_samples is
used in the adam computation iff use_adam_ is true.
// use_ada_grad_ is true.
void
LSTM
::
Update
(
float
learning_rate
,
float
momentum
,
float
adam_beta
,
void
LSTM
::
Update
(
float
learning_rate
,
float
momentum
,
int
num_samples
)
{
int
num_samples
)
{
#if DEBUG_DETAIL > 3
#if DEBUG_DETAIL > 3
PrintW
();
PrintW
();
#endif
#endif
for
(
int
w
=
0
;
w
<
WT_COUNT
;
++
w
)
{
for
(
int
w
=
0
;
w
<
WT_COUNT
;
++
w
)
{
if
(
w
==
GFS
&&
!
Is2D
())
continue
;
if
(
w
==
GFS
&&
!
Is2D
())
continue
;
gate_weights_
[
w
].
Update
(
learning_rate
,
momentum
,
num_samples
);
gate_weights_
[
w
].
Update
(
learning_rate
,
momentum
,
adam_beta
,
num_samples
);
}
}
if
(
softmax_
!=
NULL
)
{
if
(
softmax_
!=
NULL
)
{
softmax_
->
Update
(
learning_rate
,
momentum
,
num_samples
);
softmax_
->
Update
(
learning_rate
,
momentum
,
adam_beta
,
num_samples
);
}
}
#if DEBUG_DETAIL > 3
#if DEBUG_DETAIL > 3
PrintDW
();
PrintDW
();
...
...
lstm/lstm.h
浏览文件 @
4e9665de
...
@@ -76,6 +76,12 @@ class LSTM : public Network {
...
@@ -76,6 +76,12 @@ class LSTM : public Network {
// Sets up the network for training. Initializes weights using weights of
// Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`.
// scale `range` picked according to the random number generator `randomizer`.
virtual
int
InitWeights
(
float
range
,
TRand
*
randomizer
);
virtual
int
InitWeights
(
float
range
,
TRand
*
randomizer
);
// Changes the number of outputs to the size of the given code_map, copying
// the old weight matrix entries for each output from code_map[output] where
// non-negative, and uses the mean (over all outputs) of the existing weights
// for all outputs with negative code_map entries. Returns the new number of
// weights. Only operates on Softmax layers with old_no outputs.
int
RemapOutputs
(
int
old_no
,
const
std
::
vector
<
int
>&
code_map
)
override
;
// Converts a float network to an int network.
// Converts a float network to an int network.
virtual
void
ConvertToInt
();
virtual
void
ConvertToInt
();
...
@@ -99,10 +105,10 @@ class LSTM : public Network {
...
@@ -99,10 +105,10 @@ class LSTM : public Network {
virtual
bool
Backward
(
bool
debug
,
const
NetworkIO
&
fwd_deltas
,
virtual
bool
Backward
(
bool
debug
,
const
NetworkIO
&
fwd_deltas
,
NetworkScratch
*
scratch
,
NetworkScratch
*
scratch
,
NetworkIO
*
back_deltas
);
NetworkIO
*
back_deltas
);
// Updates the weights using the given learning rate
and momentum
.
// Updates the weights using the given learning rate
, momentum and adam_beta
.
// num_samples is
the quotient to be used in the adagrad computation iff
// num_samples is
used in the adam computation iff use_adam_ is true.
// use_ada_grad_ is true.
void
Update
(
float
learning_rate
,
float
momentum
,
float
adam_beta
,
virtual
void
Update
(
float
learning_rate
,
float
momentum
,
int
num_samples
)
;
int
num_samples
)
override
;
// Sums the products of weight updates in *this and other, splitting into
// Sums the products of weight updates in *this and other, splitting into
// positive (same direction) in *same and negative (different direction) in
// positive (same direction) in *same and negative (different direction) in
// *changed.
// *changed.
...
...
lstm/lstmrecognizer.cpp
浏览文件 @
4e9665de
...
@@ -55,9 +55,9 @@ LSTMRecognizer::LSTMRecognizer()
...
@@ -55,9 +55,9 @@ LSTMRecognizer::LSTMRecognizer()
training_iteration_
(
0
),
training_iteration_
(
0
),
sample_iteration_
(
0
),
sample_iteration_
(
0
),
null_char_
(
UNICHAR_BROKEN
),
null_char_
(
UNICHAR_BROKEN
),
weight_range_
(
0.0
f
),
learning_rate_
(
0.0
f
),
learning_rate_
(
0.0
f
),
momentum_
(
0.0
f
),
momentum_
(
0.0
f
),
adam_beta_
(
0.0
f
),
dict_
(
NULL
),
dict_
(
NULL
),
search_
(
NULL
),
search_
(
NULL
),
debug_win_
(
NULL
)
{}
debug_win_
(
NULL
)
{}
...
@@ -94,7 +94,7 @@ bool LSTMRecognizer::Serialize(const TessdataManager* mgr, TFile* fp) const {
...
@@ -94,7 +94,7 @@ bool LSTMRecognizer::Serialize(const TessdataManager* mgr, TFile* fp) const {
if
(
fp
->
FWrite
(
&
sample_iteration_
,
sizeof
(
sample_iteration_
),
1
)
!=
1
)
if
(
fp
->
FWrite
(
&
sample_iteration_
,
sizeof
(
sample_iteration_
),
1
)
!=
1
)
return
false
;
return
false
;
if
(
fp
->
FWrite
(
&
null_char_
,
sizeof
(
null_char_
),
1
)
!=
1
)
return
false
;
if
(
fp
->
FWrite
(
&
null_char_
,
sizeof
(
null_char_
),
1
)
!=
1
)
return
false
;
if
(
fp
->
FWrite
(
&
weight_range_
,
sizeof
(
weight_range
_
),
1
)
!=
1
)
return
false
;
if
(
fp
->
FWrite
(
&
adam_beta_
,
sizeof
(
adam_beta
_
),
1
)
!=
1
)
return
false
;
if
(
fp
->
FWrite
(
&
learning_rate_
,
sizeof
(
learning_rate_
),
1
)
!=
1
)
return
false
;
if
(
fp
->
FWrite
(
&
learning_rate_
,
sizeof
(
learning_rate_
),
1
)
!=
1
)
return
false
;
if
(
fp
->
FWrite
(
&
momentum_
,
sizeof
(
momentum_
),
1
)
!=
1
)
return
false
;
if
(
fp
->
FWrite
(
&
momentum_
,
sizeof
(
momentum_
),
1
)
!=
1
)
return
false
;
if
(
include_charsets
&&
IsRecoding
()
&&
!
recoder_
.
Serialize
(
fp
))
return
false
;
if
(
include_charsets
&&
IsRecoding
()
&&
!
recoder_
.
Serialize
(
fp
))
return
false
;
...
@@ -120,8 +120,7 @@ bool LSTMRecognizer::DeSerialize(const TessdataManager* mgr, TFile* fp) {
...
@@ -120,8 +120,7 @@ bool LSTMRecognizer::DeSerialize(const TessdataManager* mgr, TFile* fp) {
if
(
fp
->
FReadEndian
(
&
sample_iteration_
,
sizeof
(
sample_iteration_
),
1
)
!=
1
)
if
(
fp
->
FReadEndian
(
&
sample_iteration_
,
sizeof
(
sample_iteration_
),
1
)
!=
1
)
return
false
;
return
false
;
if
(
fp
->
FReadEndian
(
&
null_char_
,
sizeof
(
null_char_
),
1
)
!=
1
)
return
false
;
if
(
fp
->
FReadEndian
(
&
null_char_
,
sizeof
(
null_char_
),
1
)
!=
1
)
return
false
;
if
(
fp
->
FReadEndian
(
&
weight_range_
,
sizeof
(
weight_range_
),
1
)
!=
1
)
if
(
fp
->
FReadEndian
(
&
adam_beta_
,
sizeof
(
adam_beta_
),
1
)
!=
1
)
return
false
;
return
false
;
if
(
fp
->
FReadEndian
(
&
learning_rate_
,
sizeof
(
learning_rate_
),
1
)
!=
1
)
if
(
fp
->
FReadEndian
(
&
learning_rate_
,
sizeof
(
learning_rate_
),
1
)
!=
1
)
return
false
;
return
false
;
if
(
fp
->
FReadEndian
(
&
momentum_
,
sizeof
(
momentum_
),
1
)
!=
1
)
return
false
;
if
(
fp
->
FReadEndian
(
&
momentum_
,
sizeof
(
momentum_
),
1
)
!=
1
)
return
false
;
...
@@ -207,14 +206,22 @@ void LSTMRecognizer::OutputStats(const NetworkIO& outputs, float* min_output,
...
@@ -207,14 +206,22 @@ void LSTMRecognizer::OutputStats(const NetworkIO& outputs, float* min_output,
STATS
stats
(
0
,
kOutputScale
+
1
);
STATS
stats
(
0
,
kOutputScale
+
1
);
for
(
int
t
=
0
;
t
<
outputs
.
Width
();
++
t
)
{
for
(
int
t
=
0
;
t
<
outputs
.
Width
();
++
t
)
{
int
best_label
=
outputs
.
BestLabel
(
t
,
NULL
);
int
best_label
=
outputs
.
BestLabel
(
t
,
NULL
);
if
(
best_label
!=
null_char_
||
t
==
0
)
{
if
(
best_label
!=
null_char_
)
{
float
best_output
=
outputs
.
f
(
t
)[
best_label
];
float
best_output
=
outputs
.
f
(
t
)[
best_label
];
stats
.
add
(
static_cast
<
int
>
(
kOutputScale
*
best_output
),
1
);
stats
.
add
(
static_cast
<
int
>
(
kOutputScale
*
best_output
),
1
);
}
}
}
}
*
min_output
=
static_cast
<
float
>
(
stats
.
min_bucket
())
/
kOutputScale
;
// If the output is all nulls it could be that the photometric interpretation
*
mean_output
=
stats
.
mean
()
/
kOutputScale
;
// is wrong, so make it look bad, so the other way can win, even if not great.
*
sd
=
stats
.
sd
()
/
kOutputScale
;
if
(
stats
.
get_total
()
==
0
)
{
*
min_output
=
0.0
f
;
*
mean_output
=
0.0
f
;
*
sd
=
1.0
f
;
}
else
{
*
min_output
=
static_cast
<
float
>
(
stats
.
min_bucket
())
/
kOutputScale
;
*
mean_output
=
stats
.
mean
()
/
kOutputScale
;
*
sd
=
stats
.
sd
()
/
kOutputScale
;
}
}
}
// Recognizes the image_data, returning the labels,
// Recognizes the image_data, returning the labels,
...
...
lstm/lstmrecognizer.h
浏览文件 @
4e9665de
...
@@ -45,8 +45,6 @@ class ImageData;
...
@@ -45,8 +45,6 @@ class ImageData;
// Enum indicating training mode control flags.
// Enum indicating training mode control flags.
enum
TrainingFlags
{
enum
TrainingFlags
{
TF_INT_MODE
=
1
,
TF_INT_MODE
=
1
,
TF_AUTO_HARDEN
=
2
,
TF_ROUND_ROBIN_TRAINING
=
16
,
TF_COMPRESS_UNICHARSET
=
64
,
TF_COMPRESS_UNICHARSET
=
64
,
};
};
...
@@ -69,9 +67,6 @@ class LSTMRecognizer {
...
@@ -69,9 +67,6 @@ class LSTMRecognizer {
double
learning_rate
()
const
{
double
learning_rate
()
const
{
return
learning_rate_
;
return
learning_rate_
;
}
}
bool
IsHardening
()
const
{
return
(
training_flags_
&
TF_AUTO_HARDEN
)
!=
0
;
}
LossType
OutputLossType
()
const
{
LossType
OutputLossType
()
const
{
if
(
network_
==
nullptr
)
return
LT_NONE
;
if
(
network_
==
nullptr
)
return
LT_NONE
;
StaticShape
shape
;
StaticShape
shape
;
...
@@ -84,11 +79,6 @@ class LSTMRecognizer {
...
@@ -84,11 +79,6 @@ class LSTMRecognizer {
bool
IsRecoding
()
const
{
bool
IsRecoding
()
const
{
return
(
training_flags_
&
TF_COMPRESS_UNICHARSET
)
!=
0
;
return
(
training_flags_
&
TF_COMPRESS_UNICHARSET
)
!=
0
;
}
}
// Returns the cache strategy for the DocumentCache.
CachingStrategy
CacheStrategy
()
const
{
return
training_flags_
&
TF_ROUND_ROBIN_TRAINING
?
CS_ROUND_ROBIN
:
CS_SEQUENTIAL
;
}
// Returns true if the network is a TensorFlow network.
// Returns true if the network is a TensorFlow network.
bool
IsTensorFlow
()
const
{
return
network_
->
type
()
==
NT_TENSORFLOW
;
}
bool
IsTensorFlow
()
const
{
return
network_
->
type
()
==
NT_TENSORFLOW
;
}
// Returns a vector of layer ids that can be passed to other layer functions
// Returns a vector of layer ids that can be passed to other layer functions
...
@@ -137,10 +127,10 @@ class LSTMRecognizer {
...
@@ -137,10 +127,10 @@ class LSTMRecognizer {
series
->
ScaleLayerLearningRate
(
&
id
[
1
],
factor
);
series
->
ScaleLayerLearningRate
(
&
id
[
1
],
factor
);
}
}
// True if the network is using adagrad to train.
bool
IsUsingAdaGrad
()
const
{
return
network_
->
TestFlag
(
NF_ADA_GRAD
);
}
// Provides access to the UNICHARSET that this classifier works with.
// Provides access to the UNICHARSET that this classifier works with.
const
UNICHARSET
&
GetUnicharset
()
const
{
return
ccutil_
.
unicharset
;
}
const
UNICHARSET
&
GetUnicharset
()
const
{
return
ccutil_
.
unicharset
;
}
// Provides access to the UnicharCompress that this classifier works with.
const
UnicharCompress
&
GetRecoder
()
const
{
return
recoder_
;
}
// Provides access to the Dict that this classifier works with.
// Provides access to the Dict that this classifier works with.
const
Dict
*
GetDict
()
const
{
return
dict_
;
}
const
Dict
*
GetDict
()
const
{
return
dict_
;
}
// Sets the sample iteration to the given value. The sample_iteration_
// Sets the sample iteration to the given value. The sample_iteration_
...
@@ -215,6 +205,12 @@ class LSTMRecognizer {
...
@@ -215,6 +205,12 @@ class LSTMRecognizer {
const
GenericVector
<
int
>&
label_coords
,
const
GenericVector
<
int
>&
label_coords
,
const
char
*
window_name
,
const
char
*
window_name
,
ScrollView
**
window
);
ScrollView
**
window
);
// Converts the network output to a sequence of labels. Outputs labels, scores
// and start xcoords of each char, and each null_char_, with an additional
// final xcoord for the end of the output.
// The conversion method is determined by internal state.
void
LabelsFromOutputs
(
const
NetworkIO
&
outputs
,
GenericVector
<
int
>*
labels
,
GenericVector
<
int
>*
xcoords
);
protected:
protected:
// Sets the random seed from the sample_iteration_;
// Sets the random seed from the sample_iteration_;
...
@@ -241,12 +237,6 @@ class LSTMRecognizer {
...
@@ -241,12 +237,6 @@ class LSTMRecognizer {
void
DebugActivationRange
(
const
NetworkIO
&
outputs
,
const
char
*
label
,
void
DebugActivationRange
(
const
NetworkIO
&
outputs
,
const
char
*
label
,
int
best_choice
,
int
x_start
,
int
x_end
);
int
best_choice
,
int
x_start
,
int
x_end
);
// Converts the network output to a sequence of labels. Outputs labels, scores
// and start xcoords of each char, and each null_char_, with an additional
// final xcoord for the end of the output.
// The conversion method is determined by internal state.
void
LabelsFromOutputs
(
const
NetworkIO
&
outputs
,
GenericVector
<
int
>*
labels
,
GenericVector
<
int
>*
xcoords
);
// As LabelsViaCTC except that this function constructs the best path that
// As LabelsViaCTC except that this function constructs the best path that
// contains only legal sequences of subcodes for recoder_.
// contains only legal sequences of subcodes for recoder_.
void
LabelsViaReEncode
(
const
NetworkIO
&
output
,
GenericVector
<
int
>*
labels
,
void
LabelsViaReEncode
(
const
NetworkIO
&
output
,
GenericVector
<
int
>*
labels
,
...
@@ -290,11 +280,11 @@ class LSTMRecognizer {
...
@@ -290,11 +280,11 @@ class LSTMRecognizer {
// Index in softmax of null character. May take the value UNICHAR_BROKEN or
// Index in softmax of null character. May take the value UNICHAR_BROKEN or
// ccutil_.unicharset.size().
// ccutil_.unicharset.size().
inT32
null_char_
;
inT32
null_char_
;
// Range used for the initial random numbers in the weights.
float
weight_range_
;
// Learning rate and momentum multipliers of deltas in backprop.
// Learning rate and momentum multipliers of deltas in backprop.
float
learning_rate_
;
float
learning_rate_
;
float
momentum_
;
float
momentum_
;
// Smoothing factor for 2nd moment of gradients.
float
adam_beta_
;
// === NOT SERIALIZED.
// === NOT SERIALIZED.
TRand
randomizer_
;
TRand
randomizer_
;
...
...
lstm/lstmtrainer.cpp
浏览文件 @
4e9665de
...
@@ -123,11 +123,45 @@ LSTMTrainer::~LSTMTrainer() {
...
@@ -123,11 +123,45 @@ LSTMTrainer::~LSTMTrainer() {
// Tries to deserialize a trainer from the given file and silently returns
// Tries to deserialize a trainer from the given file and silently returns
// false in case of failure.
// false in case of failure.
bool
LSTMTrainer
::
TryLoadingCheckpoint
(
const
char
*
filename
)
{
bool
LSTMTrainer
::
TryLoadingCheckpoint
(
const
char
*
filename
,
const
char
*
old_traineddata
)
{
GenericVector
<
char
>
data
;
GenericVector
<
char
>
data
;
if
(
!
(
*
file_reader_
)(
filename
,
&
data
))
return
false
;
if
(
!
(
*
file_reader_
)(
filename
,
&
data
))
return
false
;
tprintf
(
"Loaded file %s, unpacking...
\n
"
,
filename
);
tprintf
(
"Loaded file %s, unpacking...
\n
"
,
filename
);
<<<<<<<
Updated
upstream
return
checkpoint_reader_
->
Run
(
data
,
this
);
return
checkpoint_reader_
->
Run
(
data
,
this
);
=======
if
(
!
checkpoint_reader_
->
Run
(
data
,
this
))
return
false
;
StaticShape
shape
=
network_
->
OutputShape
(
network_
->
InputShape
());
if
(((
old_traineddata
==
nullptr
||
*
old_traineddata
==
'\0'
)
&&
network_
->
NumOutputs
()
==
recoder_
.
code_range
())
||
filename
==
old_traineddata
)
{
return
true
;
// Normal checkpoint load complete.
}
tprintf
(
"Code range changed from %d to %d!!
\n
"
,
network_
->
NumOutputs
(),
recoder_
.
code_range
());
if
(
old_traineddata
==
nullptr
||
*
old_traineddata
==
'\0'
)
{
tprintf
(
"Must supply the old traineddata for code conversion!
\n
"
);
return
false
;
}
TessdataManager
old_mgr
;
ASSERT_HOST
(
old_mgr
.
Init
(
old_traineddata
));
TFile
fp
;
if
(
!
old_mgr
.
GetComponent
(
TESSDATA_LSTM_UNICHARSET
,
&
fp
))
return
false
;
UNICHARSET
old_chset
;
if
(
!
old_chset
.
load_from_file
(
&
fp
,
false
))
return
false
;
if
(
!
old_mgr
.
GetComponent
(
TESSDATA_LSTM_RECODER
,
&
fp
))
return
false
;
UnicharCompress
old_recoder
;
if
(
!
old_recoder
.
DeSerialize
(
&
fp
))
return
false
;
std
::
vector
<
int
>
code_map
=
MapRecoder
(
old_chset
,
old_recoder
);
// Set the null_char_ to the new value.
int
old_null_char
=
null_char_
;
SetNullChar
();
// Map the softmax(s) in the network.
network_
->
RemapOutputs
(
old_recoder
.
code_range
(),
code_map
);
tprintf
(
"Previous null char=%d mapped to %d
\n
"
,
old_null_char
,
null_char_
);
return
true
;
>>>>>>>
Stashed
changes
}
}
// Initializes the trainer with a network_spec in the network description
// Initializes the trainer with a network_spec in the network description
...
@@ -138,11 +172,13 @@ bool LSTMTrainer::TryLoadingCheckpoint(const char* filename) {
...
@@ -138,11 +172,13 @@ bool LSTMTrainer::TryLoadingCheckpoint(const char* filename) {
// Note: Be sure to call InitCharSet before InitNetwork!
// Note: Be sure to call InitCharSet before InitNetwork!
bool
LSTMTrainer
::
InitNetwork
(
const
STRING
&
network_spec
,
int
append_index
,
bool
LSTMTrainer
::
InitNetwork
(
const
STRING
&
network_spec
,
int
append_index
,
int
net_flags
,
float
weight_range
,
int
net_flags
,
float
weight_range
,
float
learning_rate
,
float
momentum
)
{
float
learning_rate
,
float
momentum
,
float
adam_beta
)
{
mgr_
.
SetVersionString
(
mgr_
.
VersionString
()
+
":"
+
network_spec
.
string
());
mgr_
.
SetVersionString
(
mgr_
.
VersionString
()
+
":"
+
network_spec
.
string
());
weight_range_
=
weight_range
;
adam_beta_
=
adam_beta
;
learning_rate_
=
learning_rate
;
learning_rate_
=
learning_rate
;
momentum_
=
momentum
;
momentum_
=
momentum
;
SetNullChar
();
if
(
!
NetworkBuilder
::
InitNetwork
(
recoder_
.
code_range
(),
network_spec
,
if
(
!
NetworkBuilder
::
InitNetwork
(
recoder_
.
code_range
(),
network_spec
,
append_index
,
net_flags
,
weight_range
,
append_index
,
net_flags
,
weight_range
,
&
randomizer_
,
&
network_
))
{
&
randomizer_
,
&
network_
))
{
...
@@ -151,9 +187,10 @@ bool LSTMTrainer::InitNetwork(const STRING& network_spec, int append_index,
...
@@ -151,9 +187,10 @@ bool LSTMTrainer::InitNetwork(const STRING& network_spec, int append_index,
network_str_
+=
network_spec
;
network_str_
+=
network_spec
;
tprintf
(
"Built network:%s from request %s
\n
"
,
tprintf
(
"Built network:%s from request %s
\n
"
,
network_
->
spec
().
string
(),
network_spec
.
string
());
network_
->
spec
().
string
(),
network_spec
.
string
());
tprintf
(
"Training parameters:
\n
Debug interval = %d,"
tprintf
(
" weights = %g, learning rate = %g, momentum=%g
\n
"
,
"Training parameters:
\n
Debug interval = %d,"
debug_interval_
,
weight_range_
,
learning_rate_
,
momentum_
);
" weights = %g, learning rate = %g, momentum=%g
\n
"
,
debug_interval_
,
weight_range
,
learning_rate_
,
momentum_
);
tprintf
(
"null char=%d
\n
"
,
null_char_
);
tprintf
(
"null char=%d
\n
"
,
null_char_
);
return
true
;
return
true
;
}
}
...
@@ -606,8 +643,6 @@ int LSTMTrainer::ReduceLayerLearningRates(double factor, int num_samples,
...
@@ -606,8 +643,6 @@ int LSTMTrainer::ReduceLayerLearningRates(double factor, int num_samples,
LR_SAME
,
// Learning rate will stay the same.
LR_SAME
,
// Learning rate will stay the same.
LR_COUNT
// Size of arrays.
LR_COUNT
// Size of arrays.
};
};
// Epsilon is so small that it may as well be zero, but still positive.
const
double
kEpsilon
=
1.0e-30
;
GenericVector
<
STRING
>
layers
=
EnumerateLayers
();
GenericVector
<
STRING
>
layers
=
EnumerateLayers
();
int
num_layers
=
layers
.
size
();
int
num_layers
=
layers
.
size
();
GenericVector
<
int
>
num_weights
;
GenericVector
<
int
>
num_weights
;
...
@@ -636,7 +671,7 @@ int LSTMTrainer::ReduceLayerLearningRates(double factor, int num_samples,
...
@@ -636,7 +671,7 @@ int LSTMTrainer::ReduceLayerLearningRates(double factor, int num_samples,
LSTMTrainer
copy_trainer
;
LSTMTrainer
copy_trainer
;
samples_trainer
->
ReadTrainingDump
(
orig_trainer
,
&
copy_trainer
);
samples_trainer
->
ReadTrainingDump
(
orig_trainer
,
&
copy_trainer
);
// Clear the updates, doing nothing else.
// Clear the updates, doing nothing else.
copy_trainer
.
network_
->
Update
(
0.0
,
0.0
,
0
);
copy_trainer
.
network_
->
Update
(
0.0
,
0.0
,
0
.0
,
0
);
// Adjust the learning rate in each layer.
// Adjust the learning rate in each layer.
for
(
int
i
=
0
;
i
<
num_layers
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_layers
;
++
i
)
{
if
(
num_weights
[
i
]
==
0
)
continue
;
if
(
num_weights
[
i
]
==
0
)
continue
;
...
@@ -656,9 +691,11 @@ int LSTMTrainer::ReduceLayerLearningRates(double factor, int num_samples,
...
@@ -656,9 +691,11 @@ int LSTMTrainer::ReduceLayerLearningRates(double factor, int num_samples,
LSTMTrainer
layer_trainer
;
LSTMTrainer
layer_trainer
;
samples_trainer
->
ReadTrainingDump
(
updated_trainer
,
&
layer_trainer
);
samples_trainer
->
ReadTrainingDump
(
updated_trainer
,
&
layer_trainer
);
Network
*
layer
=
layer_trainer
.
GetLayer
(
layers
[
i
]);
Network
*
layer
=
layer_trainer
.
GetLayer
(
layers
[
i
]);
// Update the weights in just the layer, and also zero the updates
// Update the weights in just the layer, using Adam if enabled.
// matrix (to epsilon).
layer
->
Update
(
0.0
,
momentum_
,
adam_beta_
,
layer
->
Update
(
0.0
,
kEpsilon
,
0
);
layer_trainer
.
training_iteration_
+
1
);
// Zero the updates matrix again.
layer
->
Update
(
0.0
,
0.0
,
0.0
,
0
);
// Train again on the same sample, again holding back the updates.
// Train again on the same sample, again holding back the updates.
layer_trainer
.
TrainOnLine
(
trainingdata
,
true
);
layer_trainer
.
TrainOnLine
(
trainingdata
,
true
);
// Count the sign changes in the updates in layer vs in copy_trainer.
// Count the sign changes in the updates in layer vs in copy_trainer.
...
@@ -773,7 +810,7 @@ Trainability LSTMTrainer::TrainOnLine(const ImageData* trainingdata,
...
@@ -773,7 +810,7 @@ Trainability LSTMTrainer::TrainOnLine(const ImageData* trainingdata,
training_iteration
()
>
training_iteration
()
>
last_perfect_training_iteration_
+
perfect_delay_
))
{
last_perfect_training_iteration_
+
perfect_delay_
))
{
network_
->
Backward
(
debug
,
targets
,
&
scratch_space_
,
&
bp_deltas
);
network_
->
Backward
(
debug
,
targets
,
&
scratch_space_
,
&
bp_deltas
);
network_
->
Update
(
learning_rate_
,
batch
?
-
1.0
f
:
momentum_
,
network_
->
Update
(
learning_rate_
,
batch
?
-
1.0
f
:
momentum_
,
adam_beta_
,
training_iteration_
+
1
);
training_iteration_
+
1
);
}
}
#ifndef GRAPHICS_DISABLED
#ifndef GRAPHICS_DISABLED
...
@@ -928,6 +965,41 @@ void LSTMTrainer::FillErrorBuffer(double new_error, ErrorTypes type) {
...
@@ -928,6 +965,41 @@ void LSTMTrainer::FillErrorBuffer(double new_error, ErrorTypes type) {
error_rates_
[
type
]
=
100.0
*
new_error
;
error_rates_
[
type
]
=
100.0
*
new_error
;
}
}
// Helper generates a map from each current recoder_ code (ie softmax index)
// to the corresponding old_recoder code, or -1 if there isn't one.
std
::
vector
<
int
>
LSTMTrainer
::
MapRecoder
(
const
UNICHARSET
&
old_chset
,
const
UnicharCompress
&
old_recoder
)
const
{
int
num_new_codes
=
recoder_
.
code_range
();
int
num_new_unichars
=
GetUnicharset
().
size
();
std
::
vector
<
int
>
code_map
(
num_new_codes
,
-
1
);
for
(
int
c
=
0
;
c
<
num_new_codes
;
++
c
)
{
int
old_code
=
-
1
;
// Find all new unichar_ids that recode to something that includes c.
// The <= is to include the null char, which may be beyond the unicharset.
for
(
int
uid
=
0
;
uid
<=
num_new_unichars
;
++
uid
)
{
RecodedCharID
codes
;
int
length
=
recoder_
.
EncodeUnichar
(
uid
,
&
codes
);
int
code_index
=
0
;
while
(
code_index
<
length
&&
codes
(
code_index
)
!=
c
)
++
code_index
;
if
(
code_index
==
length
)
continue
;
// The old unicharset must have the same unichar.
int
old_uid
=
uid
<
num_new_unichars
?
old_chset
.
unichar_to_id
(
GetUnicharset
().
id_to_unichar
(
uid
))
:
old_chset
.
size
()
-
1
;
if
(
old_uid
==
INVALID_UNICHAR_ID
)
continue
;
// The encoding of old_uid at the same code_index is the old code.
RecodedCharID
old_codes
;
if
(
code_index
<
old_recoder
.
EncodeUnichar
(
old_uid
,
&
old_codes
))
{
old_code
=
old_codes
(
code_index
);
break
;
}
}
code_map
[
c
]
=
old_code
;
}
return
code_map
;
}
// Private version of InitCharSet above finishes the job after initializing
// Private version of InitCharSet above finishes the job after initializing
// the mgr_ data member.
// the mgr_ data member.
void
LSTMTrainer
::
InitCharSet
()
{
void
LSTMTrainer
::
InitCharSet
()
{
...
@@ -939,6 +1011,11 @@ void LSTMTrainer::InitCharSet() {
...
@@ -939,6 +1011,11 @@ void LSTMTrainer::InitCharSet() {
"Must provide a traineddata containing lstm_unicharset and"
"Must provide a traineddata containing lstm_unicharset and"
" lstm_recoder!
\n
"
!=
nullptr
);
" lstm_recoder!
\n
"
!=
nullptr
);
}
}
SetNullChar
();
}
// Helper computes and sets the null_char_.
void
LSTMTrainer
::
SetNullChar
()
{
null_char_
=
GetUnicharset
().
has_special_codes
()
?
UNICHAR_BROKEN
null_char_
=
GetUnicharset
().
has_special_codes
()
?
UNICHAR_BROKEN
:
GetUnicharset
().
size
();
:
GetUnicharset
().
size
();
RecodedCharID
code
;
RecodedCharID
code
;
...
...
lstm/lstmtrainer.h
浏览文件 @
4e9665de
...
@@ -98,8 +98,15 @@ class LSTMTrainer : public LSTMRecognizer {
...
@@ -98,8 +98,15 @@ class LSTMTrainer : public LSTMRecognizer {
virtual
~
LSTMTrainer
();
virtual
~
LSTMTrainer
();
// Tries to deserialize a trainer from the given file and silently returns
// Tries to deserialize a trainer from the given file and silently returns
<<<<<<<
Updated
upstream
// false in case of failure.
// false in case of failure.
bool
TryLoadingCheckpoint
(
const
char
*
filename
);
bool
TryLoadingCheckpoint
(
const
char
*
filename
);
=======
// false in case of failure. If old_traineddata is not null, then it is
// assumed that the character set is to be re-mapped from old_traininddata to
// the new, with consequent change in weight matrices etc.
bool
TryLoadingCheckpoint
(
const
char
*
filename
,
const
char
*
old_traineddata
);
>>>>>>>
Stashed
changes
// Initializes the character set encode/decode mechanism directly from a
// Initializes the character set encode/decode mechanism directly from a
// previously setup traineddata containing dawgs, UNICHARSET and
// previously setup traineddata containing dawgs, UNICHARSET and
...
@@ -120,7 +127,8 @@ class LSTMTrainer : public LSTMRecognizer {
...
@@ -120,7 +127,8 @@ class LSTMTrainer : public LSTMRecognizer {
// For other args see NetworkBuilder::InitNetwork.
// For other args see NetworkBuilder::InitNetwork.
// Note: Be sure to call InitCharSet before InitNetwork!
// Note: Be sure to call InitCharSet before InitNetwork!
bool
InitNetwork
(
const
STRING
&
network_spec
,
int
append_index
,
int
net_flags
,
bool
InitNetwork
(
const
STRING
&
network_spec
,
int
append_index
,
int
net_flags
,
float
weight_range
,
float
learning_rate
,
float
momentum
);
float
weight_range
,
float
learning_rate
,
float
momentum
,
float
adam_beta
);
// Initializes a trainer from a serialized TFNetworkModel proto.
// Initializes a trainer from a serialized TFNetworkModel proto.
// Returns the global step of TensorFlow graph or 0 if failed.
// Returns the global step of TensorFlow graph or 0 if failed.
// Building a compatible TF graph: See tfnetwork.proto.
// Building a compatible TF graph: See tfnetwork.proto.
...
@@ -320,11 +328,17 @@ class LSTMTrainer : public LSTMRecognizer {
...
@@ -320,11 +328,17 @@ class LSTMTrainer : public LSTMRecognizer {
// Fills the whole error buffer of the given type with the given value.
// Fills the whole error buffer of the given type with the given value.
void
FillErrorBuffer
(
double
new_error
,
ErrorTypes
type
);
void
FillErrorBuffer
(
double
new_error
,
ErrorTypes
type
);
// Helper generates a map from each current recoder_ code (ie softmax index)
// to the corresponding old_recoder code, or -1 if there isn't one.
std
::
vector
<
int
>
MapRecoder
(
const
UNICHARSET
&
old_chset
,
const
UnicharCompress
&
old_recoder
)
const
;
protected:
protected:
// Private version of InitCharSet above finishes the job after initializing
// Private version of InitCharSet above finishes the job after initializing
// the mgr_ data member.
// the mgr_ data member.
void
InitCharSet
();
void
InitCharSet
();
// Helper computes and sets the null_char_.
void
SetNullChar
();
// Factored sub-constructor sets up reasonable default values.
// Factored sub-constructor sets up reasonable default values.
void
EmptyConstructor
();
void
EmptyConstructor
();
...
...
lstm/network.h
浏览文件 @
4e9665de
...
@@ -85,7 +85,7 @@ enum NetworkType {
...
@@ -85,7 +85,7 @@ enum NetworkType {
enum
NetworkFlags
{
enum
NetworkFlags
{
// Network forward/backprop behavior.
// Network forward/backprop behavior.
NF_LAYER_SPECIFIC_LR
=
64
,
// Separate learning rate for each layer.
NF_LAYER_SPECIFIC_LR
=
64
,
// Separate learning rate for each layer.
NF_ADA
_GRAD
=
128
,
// Weight-specific learning rate.
NF_ADA
M
=
128
,
// Weight-specific learning rate.
};
};
// State of training and desired state used in SetEnableTraining.
// State of training and desired state used in SetEnableTraining.
...
@@ -172,6 +172,14 @@ class Network {
...
@@ -172,6 +172,14 @@ class Network {
// and should not be deleted by any of the networks.
// and should not be deleted by any of the networks.
// Returns the number of weights initialized.
// Returns the number of weights initialized.
virtual
int
InitWeights
(
float
range
,
TRand
*
randomizer
);
virtual
int
InitWeights
(
float
range
,
TRand
*
randomizer
);
// Changes the number of outputs to the size of the given code_map, copying
// the old weight matrix entries for each output from code_map[output] where
// non-negative, and uses the mean (over all outputs) of the existing weights
// for all outputs with negative code_map entries. Returns the new number of
// weights. Only operates on Softmax layers with old_no outputs.
virtual
int
RemapOutputs
(
int
old_no
,
const
std
::
vector
<
int
>&
code_map
)
{
return
0
;
}
// Converts a float network to an int network.
// Converts a float network to an int network.
virtual
void
ConvertToInt
()
{}
virtual
void
ConvertToInt
()
{}
...
@@ -212,10 +220,10 @@ class Network {
...
@@ -212,10 +220,10 @@ class Network {
// Should be overridden by subclasses, but NOT called by their DeSerialize.
// Should be overridden by subclasses, but NOT called by their DeSerialize.
virtual
bool
DeSerialize
(
TFile
*
fp
);
virtual
bool
DeSerialize
(
TFile
*
fp
);
// Updates the weights using the given learning rate
and momentum
.
// Updates the weights using the given learning rate
, momentum and adam_beta
.
// num_samples is
the quotient to be used in the adagrad computation iff
// num_samples is
used in the adam computation iff use_adam_ is true.
// use_ada_grad_ is true.
virtual
void
Update
(
float
learning_rate
,
float
momentum
,
float
adam_beta
,
virtual
void
Update
(
float
learning_rate
,
float
momentum
,
int
num_samples
)
{}
int
num_samples
)
{}
// Sums the products of weight updates in *this and other, splitting into
// Sums the products of weight updates in *this and other, splitting into
// positive (same direction) in *same and negative (different direction) in
// positive (same direction) in *same and negative (different direction) in
// *changed.
// *changed.
...
...
lstm/plumbing.cpp
浏览文件 @
4e9665de
...
@@ -57,6 +57,19 @@ int Plumbing::InitWeights(float range, TRand* randomizer) {
...
@@ -57,6 +57,19 @@ int Plumbing::InitWeights(float range, TRand* randomizer) {
return
num_weights_
;
return
num_weights_
;
}
}
// Changes the number of outputs to the size of the given code_map, copying
// the old weight matrix entries for each output from code_map[output] where
// non-negative, and uses the mean (over all outputs) of the existing weights
// for all outputs with negative code_map entries. Returns the new number of
// weights. Only operates on Softmax layers with old_no outputs.
int
Plumbing
::
RemapOutputs
(
int
old_no
,
const
std
::
vector
<
int
>&
code_map
)
{
num_weights_
=
0
;
for
(
int
i
=
0
;
i
<
stack_
.
size
();
++
i
)
{
num_weights_
+=
stack_
[
i
]
->
RemapOutputs
(
old_no
,
code_map
);
}
return
num_weights_
;
}
// Converts a float network to an int network.
// Converts a float network to an int network.
void
Plumbing
::
ConvertToInt
()
{
void
Plumbing
::
ConvertToInt
()
{
for
(
int
i
=
0
;
i
<
stack_
.
size
();
++
i
)
for
(
int
i
=
0
;
i
<
stack_
.
size
();
++
i
)
...
@@ -204,10 +217,10 @@ bool Plumbing::DeSerialize(TFile* fp) {
...
@@ -204,10 +217,10 @@ bool Plumbing::DeSerialize(TFile* fp) {
return
true
;
return
true
;
}
}
// Updates the weights using the given learning rate
and momentum
.
// Updates the weights using the given learning rate
, momentum and adam_beta
.
// num_samples is
the quotient to be used in the adagrad computation iff
// num_samples is
used in the adam computation iff use_adam_ is true.
// use_ada_grad_ is true.
void
Plumbing
::
Update
(
float
learning_rate
,
float
momentum
,
float
adam_beta
,
void
Plumbing
::
Update
(
float
learning_rate
,
float
momentum
,
int
num_samples
)
{
int
num_samples
)
{
for
(
int
i
=
0
;
i
<
stack_
.
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
stack_
.
size
();
++
i
)
{
if
(
network_flags_
&
NF_LAYER_SPECIFIC_LR
)
{
if
(
network_flags_
&
NF_LAYER_SPECIFIC_LR
)
{
if
(
i
<
learning_rates_
.
size
())
if
(
i
<
learning_rates_
.
size
())
...
@@ -216,7 +229,7 @@ void Plumbing::Update(float learning_rate, float momentum, int num_samples) {
...
@@ -216,7 +229,7 @@ void Plumbing::Update(float learning_rate, float momentum, int num_samples) {
learning_rates_
.
push_back
(
learning_rate
);
learning_rates_
.
push_back
(
learning_rate
);
}
}
if
(
stack_
[
i
]
->
IsTraining
())
{
if
(
stack_
[
i
]
->
IsTraining
())
{
stack_
[
i
]
->
Update
(
learning_rate
,
momentum
,
num_samples
);
stack_
[
i
]
->
Update
(
learning_rate
,
momentum
,
adam_beta
,
num_samples
);
}
}
}
}
}
}
...
...
lstm/plumbing.h
浏览文件 @
4e9665de
...
@@ -57,6 +57,12 @@ class Plumbing : public Network {
...
@@ -57,6 +57,12 @@ class Plumbing : public Network {
// and should not be deleted by any of the networks.
// and should not be deleted by any of the networks.
// Returns the number of weights initialized.
// Returns the number of weights initialized.
virtual
int
InitWeights
(
float
range
,
TRand
*
randomizer
);
virtual
int
InitWeights
(
float
range
,
TRand
*
randomizer
);
// Changes the number of outputs to the size of the given code_map, copying
// the old weight matrix entries for each output from code_map[output] where
// non-negative, and uses the mean (over all outputs) of the existing weights
// for all outputs with negative code_map entries. Returns the new number of
// weights. Only operates on Softmax layers with old_no outputs.
int
RemapOutputs
(
int
old_no
,
const
std
::
vector
<
int
>&
code_map
)
override
;
// Converts a float network to an int network.
// Converts a float network to an int network.
virtual
void
ConvertToInt
();
virtual
void
ConvertToInt
();
...
@@ -118,10 +124,10 @@ class Plumbing : public Network {
...
@@ -118,10 +124,10 @@ class Plumbing : public Network {
// Reads from the given file. Returns false in case of error.
// Reads from the given file. Returns false in case of error.
virtual
bool
DeSerialize
(
TFile
*
fp
);
virtual
bool
DeSerialize
(
TFile
*
fp
);
// Updates the weights using the given learning rate
and momentum
.
// Updates the weights using the given learning rate
, momentum and adam_beta
.
// num_samples is
the quotient to be used in the adagrad computation iff
// num_samples is
used in the adam computation iff use_adam_ is true.
// use_ada_grad_ is true.
void
Update
(
float
learning_rate
,
float
momentum
,
float
adam_beta
,
virtual
void
Update
(
float
learning_rate
,
float
momentum
,
int
num_samples
)
;
int
num_samples
)
override
;
// Sums the products of weight updates in *this and other, splitting into
// Sums the products of weight updates in *this and other, splitting into
// positive (same direction) in *same and negative (different direction) in
// positive (same direction) in *same and negative (different direction) in
// *changed.
// *changed.
...
...
lstm/series.cpp
浏览文件 @
4e9665de
...
@@ -49,7 +49,7 @@ StaticShape Series::OutputShape(const StaticShape& input_shape) const {
...
@@ -49,7 +49,7 @@ StaticShape Series::OutputShape(const StaticShape& input_shape) const {
// Note that series has its own implementation just for debug purposes.
// Note that series has its own implementation just for debug purposes.
int
Series
::
InitWeights
(
float
range
,
TRand
*
randomizer
)
{
int
Series
::
InitWeights
(
float
range
,
TRand
*
randomizer
)
{
num_weights_
=
0
;
num_weights_
=
0
;
tprintf
(
"Num outputs,weights in
serial
:
\n
"
);
tprintf
(
"Num outputs,weights in
Series
:
\n
"
);
for
(
int
i
=
0
;
i
<
stack_
.
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
stack_
.
size
();
++
i
)
{
int
weights
=
stack_
[
i
]
->
InitWeights
(
range
,
randomizer
);
int
weights
=
stack_
[
i
]
->
InitWeights
(
range
,
randomizer
);
tprintf
(
" %s:%d, %d
\n
"
,
tprintf
(
" %s:%d, %d
\n
"
,
...
@@ -60,6 +60,25 @@ int Series::InitWeights(float range, TRand* randomizer) {
...
@@ -60,6 +60,25 @@ int Series::InitWeights(float range, TRand* randomizer) {
return
num_weights_
;
return
num_weights_
;
}
}
// Changes the number of outputs to the size of the given code_map, copying
// the old weight matrix entries for each output from code_map[output] where
// non-negative, and uses the mean (over all outputs) of the existing weights
// for all outputs with negative code_map entries. Returns the new number of
// weights. Only operates on Softmax layers with old_no outputs.
int
Series
::
RemapOutputs
(
int
old_no
,
const
std
::
vector
<
int
>&
code_map
)
{
num_weights_
=
0
;
tprintf
(
"Num (Extended) outputs,weights in Series:
\n
"
);
for
(
int
i
=
0
;
i
<
stack_
.
size
();
++
i
)
{
int
weights
=
stack_
[
i
]
->
RemapOutputs
(
old_no
,
code_map
);
tprintf
(
" %s:%d, %d
\n
"
,
stack_
[
i
]
->
spec
().
string
(),
stack_
[
i
]
->
NumOutputs
(),
weights
);
num_weights_
+=
weights
;
}
tprintf
(
"Total weights = %d
\n
"
,
num_weights_
);
no_
=
stack_
.
back
()
->
NumOutputs
();
return
num_weights_
;
}
// Sets needs_to_backprop_ to needs_backprop and returns true if
// Sets needs_to_backprop_ to needs_backprop and returns true if
// needs_backprop || any weights in this network so the next layer forward
// needs_backprop || any weights in this network so the next layer forward
// can be told to produce backprop for this layer if needed.
// can be told to produce backprop for this layer if needed.
...
...
lstm/series.h
浏览文件 @
4e9665de
...
@@ -46,6 +46,12 @@ class Series : public Plumbing {
...
@@ -46,6 +46,12 @@ class Series : public Plumbing {
// scale `range` picked according to the random number generator `randomizer`.
// scale `range` picked according to the random number generator `randomizer`.
// Returns the number of weights initialized.
// Returns the number of weights initialized.
virtual
int
InitWeights
(
float
range
,
TRand
*
randomizer
);
virtual
int
InitWeights
(
float
range
,
TRand
*
randomizer
);
// Changes the number of outputs to the size of the given code_map, copying
// the old weight matrix entries for each output from code_map[output] where
// non-negative, and uses the mean (over all outputs) of the existing weights
// for all outputs with negative code_map entries. Returns the new number of
// weights. Only operates on Softmax layers with old_no outputs.
int
RemapOutputs
(
int
old_no
,
const
std
::
vector
<
int
>&
code_map
)
override
;
// Sets needs_to_backprop_ to needs_backprop and returns true if
// Sets needs_to_backprop_ to needs_backprop and returns true if
// needs_backprop || any weights in this network so the next layer forward
// needs_backprop || any weights in this network so the next layer forward
...
...
lstm/weightmatrix.cpp
浏览文件 @
4e9665de
...
@@ -26,6 +26,11 @@
...
@@ -26,6 +26,11 @@
namespace
tesseract
{
namespace
tesseract
{
// Number of iterations after which the correction effectively becomes unity.
const
int
kAdamCorrectionIterations
=
200000
;
// Epsilon in Adam to prevent division by zero.
const
double
kAdamEpsilon
=
1e-8
;
// Copies the whole input transposed, converted to double, into *this.
// Copies the whole input transposed, converted to double, into *this.
void
TransposedArray
::
Transpose
(
const
GENERIC_2D_ARRAY
<
double
>&
input
)
{
void
TransposedArray
::
Transpose
(
const
GENERIC_2D_ARRAY
<
double
>&
input
)
{
int
width
=
input
.
dim1
();
int
width
=
input
.
dim1
();
...
@@ -36,7 +41,7 @@ void TransposedArray::Transpose(const GENERIC_2D_ARRAY<double>& input) {
...
@@ -36,7 +41,7 @@ void TransposedArray::Transpose(const GENERIC_2D_ARRAY<double>& input) {
// Sets up the network for training. Initializes weights using weights of
// Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`.
// scale `range` picked according to the random number generator `randomizer`.
int
WeightMatrix
::
InitWeightsFloat
(
int
no
,
int
ni
,
bool
ada_grad
,
int
WeightMatrix
::
InitWeightsFloat
(
int
no
,
int
ni
,
bool
use_adam
,
float
weight_range
,
TRand
*
randomizer
)
{
float
weight_range
,
TRand
*
randomizer
)
{
int_mode_
=
false
;
int_mode_
=
false
;
wf_
.
Resize
(
no
,
ni
,
0.0
);
wf_
.
Resize
(
no
,
ni
,
0.0
);
...
@@ -47,11 +52,37 @@ int WeightMatrix::InitWeightsFloat(int no, int ni, bool ada_grad,
...
@@ -47,11 +52,37 @@ int WeightMatrix::InitWeightsFloat(int no, int ni, bool ada_grad,
}
}
}
}
}
}
use_ada
_grad_
=
ada_grad
;
use_ada
m_
=
use_adam
;
InitBackward
();
InitBackward
();
return
ni
*
no
;
return
ni
*
no
;
}
}
// Changes the number of outputs to the size of the given code_map, copying
// the old weight matrix entries for each output from code_map[output] where
// non-negative, and uses the mean (over all outputs) of the existing weights
// for all outputs with negative code_map entries. Returns the new number of
// weights.
int
WeightMatrix
::
RemapOutputs
(
const
std
::
vector
<
int
>&
code_map
)
{
GENERIC_2D_ARRAY
<
double
>
old_wf
(
wf_
);
int
old_no
=
wf_
.
dim1
();
int
new_no
=
code_map
.
size
();
int
ni
=
wf_
.
dim2
();
std
::
vector
<
double
>
means
(
ni
,
0.0
);
for
(
int
c
=
0
;
c
<
old_no
;
++
c
)
{
const
double
*
weights
=
wf_
[
c
];
for
(
int
i
=
0
;
i
<
ni
;
++
i
)
means
[
i
]
+=
weights
[
i
];
}
for
(
double
&
mean
:
means
)
mean
/=
old_no
;
wf_
.
ResizeNoInit
(
new_no
,
ni
);
InitBackward
();
for
(
int
dest
=
0
;
dest
<
new_no
;
++
dest
)
{
int
src
=
code_map
[
dest
];
const
double
*
src_data
=
src
>=
0
?
old_wf
[
src
]
:
means
.
data
();
memcpy
(
wf_
[
dest
],
src_data
,
ni
*
sizeof
(
*
src_data
));
}
return
ni
*
new_no
;
}
// Converts a float network to an int network. Each set of input weights that
// Converts a float network to an int network. Each set of input weights that
// corresponds to a single output weight is converted independently:
// corresponds to a single output weight is converted independently:
// Compute the max absolute value of the weight set.
// Compute the max absolute value of the weight set.
...
@@ -90,13 +121,13 @@ void WeightMatrix::InitBackward() {
...
@@ -90,13 +121,13 @@ void WeightMatrix::InitBackward() {
dw_
.
Resize
(
no
,
ni
,
0.0
);
dw_
.
Resize
(
no
,
ni
,
0.0
);
updates_
.
Resize
(
no
,
ni
,
0.0
);
updates_
.
Resize
(
no
,
ni
,
0.0
);
wf_t_
.
Transpose
(
wf_
);
wf_t_
.
Transpose
(
wf_
);
if
(
use_ada
_grad
_
)
dw_sq_sum_
.
Resize
(
no
,
ni
,
0.0
);
if
(
use_ada
m
_
)
dw_sq_sum_
.
Resize
(
no
,
ni
,
0.0
);
}
}
// Flag on mode to indicate that this weightmatrix uses inT8.
// Flag on mode to indicate that this weightmatrix uses inT8.
const
int
kInt8Flag
=
1
;
const
int
kInt8Flag
=
1
;
// Flag on mode to indicate that this weightmatrix uses ada
grad
.
// Flag on mode to indicate that this weightmatrix uses ada
m
.
const
int
kAda
Grad
Flag
=
4
;
const
int
kAda
m
Flag
=
4
;
// Flag on mode to indicate that this weightmatrix uses double. Set
// Flag on mode to indicate that this weightmatrix uses double. Set
// independently of kInt8Flag as even in int mode the scales can
// independently of kInt8Flag as even in int mode the scales can
// be float or double.
// be float or double.
...
@@ -106,8 +137,8 @@ const int kDoubleFlag = 128;
...
@@ -106,8 +137,8 @@ const int kDoubleFlag = 128;
bool
WeightMatrix
::
Serialize
(
bool
training
,
TFile
*
fp
)
const
{
bool
WeightMatrix
::
Serialize
(
bool
training
,
TFile
*
fp
)
const
{
// For backward compatibility, add kDoubleFlag to mode to indicate the doubles
// For backward compatibility, add kDoubleFlag to mode to indicate the doubles
// format, without errs, so we can detect and read old format weight matrices.
// format, without errs, so we can detect and read old format weight matrices.
uinT8
mode
=
(
int_mode_
?
kInt8Flag
:
0
)
|
uinT8
mode
=
(
use_ada_grad_
?
kAdaGrad
Flag
:
0
)
|
kDoubleFlag
;
(
int_mode_
?
kInt8Flag
:
0
)
|
(
use_adam_
?
kAdam
Flag
:
0
)
|
kDoubleFlag
;
if
(
fp
->
FWrite
(
&
mode
,
sizeof
(
mode
),
1
)
!=
1
)
return
false
;
if
(
fp
->
FWrite
(
&
mode
,
sizeof
(
mode
),
1
)
!=
1
)
return
false
;
if
(
int_mode_
)
{
if
(
int_mode_
)
{
if
(
!
wi_
.
Serialize
(
fp
))
return
false
;
if
(
!
wi_
.
Serialize
(
fp
))
return
false
;
...
@@ -115,7 +146,7 @@ bool WeightMatrix::Serialize(bool training, TFile* fp) const {
...
@@ -115,7 +146,7 @@ bool WeightMatrix::Serialize(bool training, TFile* fp) const {
}
else
{
}
else
{
if
(
!
wf_
.
Serialize
(
fp
))
return
false
;
if
(
!
wf_
.
Serialize
(
fp
))
return
false
;
if
(
training
&&
!
updates_
.
Serialize
(
fp
))
return
false
;
if
(
training
&&
!
updates_
.
Serialize
(
fp
))
return
false
;
if
(
training
&&
use_ada
_grad
_
&&
!
dw_sq_sum_
.
Serialize
(
fp
))
return
false
;
if
(
training
&&
use_ada
m
_
&&
!
dw_sq_sum_
.
Serialize
(
fp
))
return
false
;
}
}
return
true
;
return
true
;
}
}
...
@@ -126,7 +157,7 @@ bool WeightMatrix::DeSerialize(bool training, TFile* fp) {
...
@@ -126,7 +157,7 @@ bool WeightMatrix::DeSerialize(bool training, TFile* fp) {
uinT8
mode
=
0
;
uinT8
mode
=
0
;
if
(
fp
->
FRead
(
&
mode
,
sizeof
(
mode
),
1
)
!=
1
)
return
false
;
if
(
fp
->
FRead
(
&
mode
,
sizeof
(
mode
),
1
)
!=
1
)
return
false
;
int_mode_
=
(
mode
&
kInt8Flag
)
!=
0
;
int_mode_
=
(
mode
&
kInt8Flag
)
!=
0
;
use_ada
_grad_
=
(
mode
&
kAdaGrad
Flag
)
!=
0
;
use_ada
m_
=
(
mode
&
kAdam
Flag
)
!=
0
;
if
((
mode
&
kDoubleFlag
)
==
0
)
return
DeSerializeOld
(
training
,
fp
);
if
((
mode
&
kDoubleFlag
)
==
0
)
return
DeSerializeOld
(
training
,
fp
);
if
(
int_mode_
)
{
if
(
int_mode_
)
{
if
(
!
wi_
.
DeSerialize
(
fp
))
return
false
;
if
(
!
wi_
.
DeSerialize
(
fp
))
return
false
;
...
@@ -136,7 +167,7 @@ bool WeightMatrix::DeSerialize(bool training, TFile* fp) {
...
@@ -136,7 +167,7 @@ bool WeightMatrix::DeSerialize(bool training, TFile* fp) {
if
(
training
)
{
if
(
training
)
{
InitBackward
();
InitBackward
();
if
(
!
updates_
.
DeSerialize
(
fp
))
return
false
;
if
(
!
updates_
.
DeSerialize
(
fp
))
return
false
;
if
(
use_ada
_grad
_
&&
!
dw_sq_sum_
.
DeSerialize
(
fp
))
return
false
;
if
(
use_ada
m
_
&&
!
dw_sq_sum_
.
DeSerialize
(
fp
))
return
false
;
}
}
}
}
return
true
;
return
true
;
...
@@ -247,19 +278,27 @@ void WeightMatrix::SumOuterTransposed(const TransposedArray& u,
...
@@ -247,19 +278,27 @@ void WeightMatrix::SumOuterTransposed(const TransposedArray& u,
}
}
// Updates the weights using the given learning rate and momentum.
// Updates the weights using the given learning rate and momentum.
// num_samples is the quotient to be used in the ada
grad
computation iff
// num_samples is the quotient to be used in the ada
m
computation iff
// use_ada
_grad
_ is true.
// use_ada
m
_ is true.
void
WeightMatrix
::
Update
(
double
learning_rate
,
double
momentum
,
void
WeightMatrix
::
Update
(
double
learning_rate
,
double
momentum
,
int
num_samples
)
{
double
adam_beta
,
int
num_samples
)
{
ASSERT_HOST
(
!
int_mode_
);
ASSERT_HOST
(
!
int_mode_
);
if
(
use_ada_grad_
&&
num_samples
>
0
)
{
if
(
use_adam_
&&
num_samples
>
0
&&
num_samples
<
kAdamCorrectionIterations
)
{
dw_sq_sum_
.
SumSquares
(
dw_
);
learning_rate
*=
sqrt
(
1.0
-
pow
(
adam_beta
,
num_samples
));
dw_
.
AdaGradScaling
(
dw_sq_sum_
,
num_samples
);
learning_rate
/=
1.0
-
pow
(
momentum
,
num_samples
);
}
if
(
use_adam_
&&
num_samples
>
0
&&
momentum
>
0.0
)
{
dw_sq_sum_
.
SumSquares
(
dw_
,
adam_beta
);
dw_
*=
learning_rate
*
(
1.0
-
momentum
);
updates_
*=
momentum
;
updates_
+=
dw_
;
wf_
.
AdamUpdate
(
updates_
,
dw_sq_sum_
,
learning_rate
*
kAdamEpsilon
);
}
else
{
dw_
*=
learning_rate
;
updates_
+=
dw_
;
if
(
momentum
>
0.0
)
wf_
+=
updates_
;
if
(
momentum
>=
0.0
)
updates_
*=
momentum
;
}
}
dw_
*=
learning_rate
;
updates_
+=
dw_
;
if
(
momentum
>
0.0
)
wf_
+=
updates_
;
if
(
momentum
>=
0.0
)
updates_
*=
momentum
;
wf_t_
.
Transpose
(
wf_
);
wf_t_
.
Transpose
(
wf_
);
}
}
...
...
lstm/weightmatrix.h
浏览文件 @
4e9665de
...
@@ -62,14 +62,20 @@ class TransposedArray : public GENERIC_2D_ARRAY<double> {
...
@@ -62,14 +62,20 @@ class TransposedArray : public GENERIC_2D_ARRAY<double> {
// backward steps with the matrix and updates to the weights.
// backward steps with the matrix and updates to the weights.
class
WeightMatrix
{
class
WeightMatrix
{
public:
public:
WeightMatrix
()
:
int_mode_
(
false
),
use_ada
_grad
_
(
false
)
{}
WeightMatrix
()
:
int_mode_
(
false
),
use_ada
m
_
(
false
)
{}
// Sets up the network for training. Initializes weights using weights of
// Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`.
// scale `range` picked according to the random number generator `randomizer`.
// Note the order is outputs, inputs, as this is the order of indices to
// Note the order is outputs, inputs, as this is the order of indices to
// the matrix, so the adjacent elements are multiplied by the input during
// the matrix, so the adjacent elements are multiplied by the input during
// a forward operation.
// a forward operation.
int
InitWeightsFloat
(
int
no
,
int
ni
,
bool
ada_grad
,
float
weight_range
,
int
InitWeightsFloat
(
int
no
,
int
ni
,
bool
use_adam
,
float
weight_range
,
TRand
*
randomizer
);
TRand
*
randomizer
);
// Changes the number of outputs to the size of the given code_map, copying
// the old weight matrix entries for each output from code_map[output] where
// non-negative, and uses the mean (over all outputs) of the existing weights
// for all outputs with negative code_map entries. Returns the new number of
// weights.
int
RemapOutputs
(
const
std
::
vector
<
int
>&
code_map
);
// Converts a float network to an int network. Each set of input weights that
// Converts a float network to an int network. Each set of input weights that
// corresponds to a single output weight is converted independently:
// corresponds to a single output weight is converted independently:
...
@@ -123,10 +129,10 @@ class WeightMatrix {
...
@@ -123,10 +129,10 @@ class WeightMatrix {
// Runs parallel if requested. Note that inputs must be transposed.
// Runs parallel if requested. Note that inputs must be transposed.
void
SumOuterTransposed
(
const
TransposedArray
&
u
,
const
TransposedArray
&
v
,
void
SumOuterTransposed
(
const
TransposedArray
&
u
,
const
TransposedArray
&
v
,
bool
parallel
);
bool
parallel
);
// Updates the weights using the given learning rate
and momentum
.
// Updates the weights using the given learning rate
, momentum and adam_beta
.
// num_samples is
the quotient to be used in the adagrad computation iff
// num_samples is
used in the Adam correction factor.
// use_ada_grad_ is true.
void
Update
(
double
learning_rate
,
double
momentum
,
double
adam_beta
,
void
Update
(
double
learning_rate
,
double
momentum
,
int
num_samples
);
int
num_samples
);
// Adds the dw_ in other to the dw_ is *this.
// Adds the dw_ in other to the dw_ is *this.
void
AddDeltas
(
const
WeightMatrix
&
other
);
void
AddDeltas
(
const
WeightMatrix
&
other
);
// Sums the products of weight updates in *this and other, splitting into
// Sums the products of weight updates in *this and other, splitting into
...
@@ -163,8 +169,8 @@ class WeightMatrix {
...
@@ -163,8 +169,8 @@ class WeightMatrix {
TransposedArray
wf_t_
;
TransposedArray
wf_t_
;
// Which of wf_ and wi_ are we actually using.
// Which of wf_ and wi_ are we actually using.
bool
int_mode_
;
bool
int_mode_
;
// True if we are running ada
grad
in this weight matrix.
// True if we are running ada
m
in this weight matrix.
bool
use_ada
_grad
_
;
bool
use_ada
m
_
;
// If we are using wi_, then scales_ is a factor to restore the row product
// If we are using wi_, then scales_ is a factor to restore the row product
// with a vector to the correct range.
// with a vector to the correct range.
GenericVector
<
double
>
scales_
;
GenericVector
<
double
>
scales_
;
...
@@ -172,8 +178,8 @@ class WeightMatrix {
...
@@ -172,8 +178,8 @@ class WeightMatrix {
// amount to be added to wf_/wi_.
// amount to be added to wf_/wi_.
GENERIC_2D_ARRAY
<
double
>
dw_
;
GENERIC_2D_ARRAY
<
double
>
dw_
;
GENERIC_2D_ARRAY
<
double
>
updates_
;
GENERIC_2D_ARRAY
<
double
>
updates_
;
// Iff use_ada
_grad
_, the sum of squares of dw_. The number of samples is
// Iff use_ada
m
_, the sum of squares of dw_. The number of samples is
// given to Update(). Serialized iff use_ada
_grad
_.
// given to Update(). Serialized iff use_ada
m
_.
GENERIC_2D_ARRAY
<
double
>
dw_sq_sum_
;
GENERIC_2D_ARRAY
<
double
>
dw_sq_sum_
;
};
};
...
...
training/lstmtraining.cpp
浏览文件 @
4e9665de
...
@@ -34,8 +34,9 @@ INT_PARAM_FLAG(perfect_sample_delay, 0,
...
@@ -34,8 +34,9 @@ INT_PARAM_FLAG(perfect_sample_delay, 0,
"How many imperfect samples between perfect ones."
);
"How many imperfect samples between perfect ones."
);
DOUBLE_PARAM_FLAG
(
target_error_rate
,
0.01
,
"Final error rate in percent."
);
DOUBLE_PARAM_FLAG
(
target_error_rate
,
0.01
,
"Final error rate in percent."
);
DOUBLE_PARAM_FLAG
(
weight_range
,
0.1
,
"Range of initial random weights."
);
DOUBLE_PARAM_FLAG
(
weight_range
,
0.1
,
"Range of initial random weights."
);
DOUBLE_PARAM_FLAG
(
learning_rate
,
1.0e-4
,
"Weight factor for new deltas."
);
DOUBLE_PARAM_FLAG
(
learning_rate
,
10.0e-4
,
"Weight factor for new deltas."
);
DOUBLE_PARAM_FLAG
(
momentum
,
0.9
,
"Decay factor for repeating deltas."
);
DOUBLE_PARAM_FLAG
(
momentum
,
0.5
,
"Decay factor for repeating deltas."
);
DOUBLE_PARAM_FLAG
(
adam_beta
,
0.999
,
"Decay factor for repeating deltas."
);
INT_PARAM_FLAG
(
max_image_MB
,
6000
,
"Max memory to use for images."
);
INT_PARAM_FLAG
(
max_image_MB
,
6000
,
"Max memory to use for images."
);
STRING_PARAM_FLAG
(
continue_from
,
""
,
"Existing model to extend"
);
STRING_PARAM_FLAG
(
continue_from
,
""
,
"Existing model to extend"
);
STRING_PARAM_FLAG
(
model_output
,
"lstmtrain"
,
"Basename for output models"
);
STRING_PARAM_FLAG
(
model_output
,
"lstmtrain"
,
"Basename for output models"
);
...
@@ -56,6 +57,11 @@ BOOL_PARAM_FLAG(debug_network, false,
...
@@ -56,6 +57,11 @@ BOOL_PARAM_FLAG(debug_network, false,
INT_PARAM_FLAG
(
max_iterations
,
0
,
"If set, exit after this many iterations"
);
INT_PARAM_FLAG
(
max_iterations
,
0
,
"If set, exit after this many iterations"
);
STRING_PARAM_FLAG
(
traineddata
,
""
,
STRING_PARAM_FLAG
(
traineddata
,
""
,
"Combined Dawgs/Unicharset/Recoder for language model"
);
"Combined Dawgs/Unicharset/Recoder for language model"
);
<<<<<<<
Updated
upstream
=======
STRING_PARAM_FLAG
(
old_traineddata
,
""
,
"Previous traineddata arg when changing the character set"
);
>>>>>>>
Stashed
changes
// Number of training images to train between calls to MaintainCheckpoints.
// Number of training images to train between calls to MaintainCheckpoints.
const
int
kNumPagesPerBatch
=
100
;
const
int
kNumPagesPerBatch
=
100
;
...
@@ -91,7 +97,7 @@ int main(int argc, char **argv) {
...
@@ -91,7 +97,7 @@ int main(int argc, char **argv) {
// Reading something from an existing model doesn't require many flags,
// Reading something from an existing model doesn't require many flags,
// so do it now and exit.
// so do it now and exit.
if
(
FLAGS_stop_training
||
FLAGS_debug_network
)
{
if
(
FLAGS_stop_training
||
FLAGS_debug_network
)
{
if
(
!
trainer
.
TryLoadingCheckpoint
(
FLAGS_continue_from
.
c_str
()))
{
if
(
!
trainer
.
TryLoadingCheckpoint
(
FLAGS_continue_from
.
c_str
()
,
nullptr
))
{
tprintf
(
"Failed to read continue from: %s
\n
"
,
tprintf
(
"Failed to read continue from: %s
\n
"
,
FLAGS_continue_from
.
c_str
());
FLAGS_continue_from
.
c_str
());
return
1
;
return
1
;
...
@@ -122,14 +128,17 @@ int main(int argc, char **argv) {
...
@@ -122,14 +128,17 @@ int main(int argc, char **argv) {
}
}
// Checkpoints always take priority if they are available.
// Checkpoints always take priority if they are available.
if
(
trainer
.
TryLoadingCheckpoint
(
checkpoint_file
.
string
())
||
if
(
trainer
.
TryLoadingCheckpoint
(
checkpoint_file
.
string
()
,
nullptr
)
||
trainer
.
TryLoadingCheckpoint
(
checkpoint_bak
.
string
()))
{
trainer
.
TryLoadingCheckpoint
(
checkpoint_bak
.
string
()
,
nullptr
))
{
tprintf
(
"Successfully restored trainer from %s
\n
"
,
tprintf
(
"Successfully restored trainer from %s
\n
"
,
checkpoint_file
.
string
());
checkpoint_file
.
string
());
}
else
{
}
else
{
if
(
!
FLAGS_continue_from
.
empty
())
{
if
(
!
FLAGS_continue_from
.
empty
())
{
// Load a past model file to improve upon.
// Load a past model file to improve upon.
if
(
!
trainer
.
TryLoadingCheckpoint
(
FLAGS_continue_from
.
c_str
()))
{
if
(
!
trainer
.
TryLoadingCheckpoint
(
FLAGS_continue_from
.
c_str
(),
FLAGS_append_index
>=
0
?
FLAGS_continue_from
.
c_str
()
:
FLAGS_old_traineddata
.
c_str
()))
{
tprintf
(
"Failed to continue from: %s
\n
"
,
FLAGS_continue_from
.
c_str
());
tprintf
(
"Failed to continue from: %s
\n
"
,
FLAGS_continue_from
.
c_str
());
return
1
;
return
1
;
}
}
...
@@ -147,7 +156,8 @@ int main(int argc, char **argv) {
...
@@ -147,7 +156,8 @@ int main(int argc, char **argv) {
// We are initializing from scratch.
// We are initializing from scratch.
if
(
!
trainer
.
InitNetwork
(
FLAGS_net_spec
.
c_str
(),
FLAGS_append_index
,
if
(
!
trainer
.
InitNetwork
(
FLAGS_net_spec
.
c_str
(),
FLAGS_append_index
,
FLAGS_net_mode
,
FLAGS_weight_range
,
FLAGS_net_mode
,
FLAGS_weight_range
,
FLAGS_learning_rate
,
FLAGS_momentum
))
{
FLAGS_learning_rate
,
FLAGS_momentum
,
FLAGS_adam_beta
))
{
tprintf
(
"Failed to create network from spec: %s
\n
"
,
tprintf
(
"Failed to create network from spec: %s
\n
"
,
FLAGS_net_spec
.
c_str
());
FLAGS_net_spec
.
c_str
());
return
1
;
return
1
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录