Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
80a5ee00
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
80a5ee00
编写于
10月 17, 2017
作者:
C
caoying03
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix forward and add backward.
上级
3123e3cf
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
302 addition
and
94 deletion
+302
-94
paddle/operators/linear_chain_crf_op.cc
paddle/operators/linear_chain_crf_op.cc
+259
-75
paddle/operators/linear_chain_crf_op.h
paddle/operators/linear_chain_crf_op.h
+12
-8
python/paddle/v2/framework/tests/test_linear_chain_crf_op.py
python/paddle/v2/framework/tests/test_linear_chain_crf_op.py
+31
-11
未找到文件。
paddle/operators/linear_chain_crf_op.cc
浏览文件 @
80a5ee00
...
@@ -17,6 +17,22 @@ limitations under the License. */
...
@@ -17,6 +17,22 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
{
template
<
typename
T
>
T
NormalizeL1
(
T
*
x
,
size_t
len
)
{
T
sum
=
0.
;
for
(
size_t
i
=
0
;
i
<
len
;
++
i
)
sum
+=
x
[
i
];
// (This comment is from the old LinearChainCRFLayer.)
// Right now, we just bet that sum won't be zero. If this really happens, we
// will figure out what should be done then.
PADDLE_ENFORCE
(
sum
,
"The unnormalized probabilites of all possible unfinished "
"sequences must be greater than 0."
);
for
(
size_t
i
=
0
;
i
<
len
;
++
i
)
x
[
i
]
/=
sum
;
return
sum
;
}
}
// namespace
using
framework
::
LoDTensor
;
using
framework
::
LoDTensor
;
using
framework
::
LoD
;
using
framework
::
LoD
;
...
@@ -54,13 +70,25 @@ class LinearChainCrfOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -54,13 +70,25 @@ class LinearChainCrfOpMaker : public framework::OpProtoAndCheckerMaker {
"each tag value
\f
$v$
\f
. This vector is called a forward vecotr and "
"each tag value
\f
$v$
\f
. This vector is called a forward vecotr and "
"will also be used in backward computations."
)
"will also be used in backward computations."
)
.
AsIntermediate
();
.
AsIntermediate
();
AddOutput
(
"EmissionExps"
,
"The exponentials of Input(Emission). This is an intermediate "
"computational result in forward computation, and will be reused "
"in backward computation."
)
.
AsIntermediate
();
AddOutput
(
"TransitionExps"
,
"The exponentials of Input(Transition). This is an intermediate "
"computational result in forward computation, and will be reused "
"in backward computation."
)
.
AsIntermediate
();
AddOutput
(
AddOutput
(
"LogLikelihood"
,
"LogLikelihood"
,
"(Tensor, default: Tensor<float>). The logarithm of the conditional "
"(Tensor, default: Tensor<float>). The logarithm of the "
"conditional "
"likelihood of each training sample in a mini-batch. This is a 2-D "
"likelihood of each training sample in a mini-batch. This is a 2-D "
"tensor with shape [S x 1], where S is the sequence number in a "
"tensor with shape [S x 1], where S is the sequence number in a "
"mini-batch. "
"mini-batch. "
"Note: S is equal to the sequence number in a mini-batch. The output "
"Note: S is equal to the sequence number in a mini-batch. The "
"output "
"is no longer a LoDTensor."
);
"is no longer a LoDTensor."
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Conditional Random Field defines an undirected probabilistic graph with nodes
Conditional Random Field defines an undirected probabilistic graph with nodes
...
@@ -129,6 +157,10 @@ class LinearChainCrfOp : public framework::OperatorWithKernel {
...
@@ -129,6 +157,10 @@ class LinearChainCrfOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Alpha"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Alpha"
),
"Output(Alpha) should be not null."
);
"Output(Alpha) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"EmissionExps"
),
"Output(EmissionExps) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"TransitionExps"
),
"Output(TransitionExps) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"LogLikelihood"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"LogLikelihood"
),
"Output(LogLikelihood) should be not null."
);
"Output(LogLikelihood) should be not null."
);
...
@@ -143,7 +175,7 @@ class LinearChainCrfOp : public framework::OperatorWithKernel {
...
@@ -143,7 +175,7 @@ class LinearChainCrfOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
transition_dims
[
0
]
-
2
,
transition_dims
[
1
],
transition_dims
[
0
]
-
2
,
transition_dims
[
1
],
"An invalid dimension for the Input(Transition), which should "
"An invalid dimension for the Input(Transition), which should "
"be a 2-D tensor with shape [
D + 2
x D]."
);
"be a 2-D tensor with shape [
(D + 2)
x D]."
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
emission_dims
[
1
],
transition_dims
[
1
],
emission_dims
[
1
],
transition_dims
[
1
],
"The 2nd dimension of the Input(Emission) and the Input(Transition) "
"The 2nd dimension of the Input(Emission) and the Input(Transition) "
...
@@ -157,11 +189,14 @@ class LinearChainCrfOp : public framework::OperatorWithKernel {
...
@@ -157,11 +189,14 @@ class LinearChainCrfOp : public framework::OperatorWithKernel {
"should be the same."
);
"should be the same."
);
ctx
->
SetOutputDim
(
"Alpha"
,
emission_dims
);
ctx
->
SetOutputDim
(
"Alpha"
,
emission_dims
);
ctx
->
SetOutputDim
(
"EmissionExps"
,
emission_dims
);
ctx
->
SetOutputDim
(
"TransitionExps"
,
transition_dims
);
// (TODO caoying) This is tricky. The 1st dimension of Output(LogLikelihood)
// (TODO caoying) This is tricky. The 1st dimension of Output(LogLikelihood)
// is the sequence number in a mini-batch. The dimension set here should be
// is the sequence number in a mini-batch. The dimension set here should be
// resized to its correct size in the function Compute.
// resized to its correct size in the function Compute.
ctx
->
SetOutputDim
(
"LogLikelihood"
,
{
emission_dims
[
0
],
1
});
ctx
->
SetOutputDim
(
"LogLikelihood"
,
{
emission_dims
[
0
],
1
});
ctx
->
ShareLoD
(
"Emission"
,
/*->*/
"EmissionExps"
);
}
}
protected:
protected:
...
@@ -180,9 +215,12 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
...
@@ -180,9 +215,12 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
"This kernel only runs on CPU."
);
"This kernel only runs on CPU."
);
auto
*
emission_weights
=
ctx
.
Input
<
LoDTensor
>
(
"Emission"
);
auto
*
emission_weights
=
ctx
.
Input
<
LoDTensor
>
(
"Emission"
);
auto
*
transition_weights
=
ctx
.
Input
<
Tensor
>
(
"Transition"
);
auto
*
transition_weights
=
ctx
.
Input
<
Tensor
>
(
"Transition"
);
auto
*
emission_exps
=
ctx
.
Output
<
LoDTensor
>
(
"EmissionExps"
);
emission_exps
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
auto
*
transition_exps
=
ctx
.
Output
<
Tensor
>
(
"TransitionExps"
);
transition_exps
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
auto
*
label
=
ctx
.
Input
<
LoDTensor
>
(
"Label"
);
auto
*
label
=
ctx
.
Input
<
LoDTensor
>
(
"Label"
);
auto
in_lod
=
emission_weights
->
lod
();
auto
in_lod
=
emission_weights
->
lod
();
...
@@ -195,18 +233,29 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
...
@@ -195,18 +233,29 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
const
size_t
level
=
0
;
const
size_t
level
=
0
;
auto
emission_dims
=
emission_weights
->
dims
();
auto
emission_dims
=
emission_weights
->
dims
();
const
size_t
batch_size
=
emission_dims
[
0
];
const
size_t
tag_num
=
emission_dims
[
1
];
const
size_t
seq_num
=
in_lod
[
level
].
size
()
-
1
;
const
size_t
seq_num
=
in_lod
[
level
].
size
()
-
1
;
// TODO(caoying) These local variables seems to be created and destroied
// every time this function is called. Will this bring additional overhead?
Tensor
emission_exps
;
Tensor
emission_row_max
;
Tensor
emission_row_max
;
Tensor
transition_exps
;
emission_exps
.
mutable_data
<
T
>
(
emission_dims
,
platform
::
CPUPlace
());
emission_row_max
.
mutable_data
<
T
>
(
emission_row_max
.
mutable_data
<
T
>
(
framework
::
make_ddim
({
emission_dims
[
0
],
1
}),
platform
::
CPUPlace
());
framework
::
make_ddim
({
static_cast
<
int
>
(
batch_size
),
1
}),
transition_exps
.
mutable_data
<
T
>
(
transition_weights
->
dims
(),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
auto
place
=
ctx
.
GetEigenDevice
<
platform
::
CPUPlace
>
();
auto
x
=
EigenMatrix
<
T
>::
From
(
*
emission_weights
);
auto
x_row_max
=
EigenMatrix
<
T
>::
From
(
emission_row_max
);
x_row_max
.
device
(
place
)
=
x
.
maximum
(
Eigen
::
DSizes
<
int
,
1
>
(
1
))
.
reshape
(
Eigen
::
DSizes
<
int
,
2
>
(
int
(
batch_size
),
1
));
auto
x_exps
=
EigenMatrix
<
T
>::
From
(
*
emission_exps
);
x_exps
.
device
(
place
)
=
(
x
-
x_row_max
.
broadcast
(
Eigen
::
DSizes
<
int
,
2
>
(
1
,
tag_num
))).
exp
();
auto
w
=
EigenMatrix
<
T
>::
From
(
*
transition_weights
);
auto
w_exps
=
EigenMatrix
<
T
>::
From
(
*
transition_exps
);
w_exps
.
device
(
place
)
=
w
.
exp
();
auto
*
alpha
=
ctx
.
Output
<
LoDTensor
>
(
"Alpha"
);
auto
*
alpha
=
ctx
.
Output
<
LoDTensor
>
(
"Alpha"
);
alpha
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
alpha
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
...
@@ -214,117 +263,124 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
...
@@ -214,117 +263,124 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
// resize the output tensor to the correct dimension.
// resize the output tensor to the correct dimension.
ll
->
Resize
({
static_cast
<
int
>
(
seq_num
),
1
});
ll
->
Resize
({
static_cast
<
int
>
(
seq_num
),
1
});
T
*
log_likelihood
=
ll
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
log_likelihood
=
ll
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
for
(
size_t
i
=
0
;
i
<
seq_num
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
seq_num
;
++
i
)
{
int
start_pos
=
static_cast
<
int
>
(
in_lod
[
level
][
i
]);
int
start_pos
=
static_cast
<
int
>
(
in_lod
[
level
][
i
]);
int
end_pos
=
static_cast
<
int
>
(
in_lod
[
level
][
i
+
1
]);
int
end_pos
=
static_cast
<
int
>
(
in_lod
[
level
][
i
+
1
]);
const
Tensor
one_seq
=
emission_weights
->
Slice
<
T
>
(
start_pos
,
end_pos
);
const
Tensor
one_seq
=
emission_weights
->
Slice
<
T
>
(
start_pos
,
end_pos
);
Tensor
one_seq_row_max
=
emission_row_max
.
Slice
<
T
>
(
start_pos
,
end_pos
);
Tensor
one_seq_row_max
=
emission_row_max
.
Slice
<
T
>
(
start_pos
,
end_pos
);
Tensor
one_seq_exps
=
emission_exps
.
Slice
<
T
>
(
start_pos
,
end_pos
);
Tensor
one_seq_exps
=
emission_exps
->
Slice
<
T
>
(
start_pos
,
end_pos
);
const
Tensor
one_seq_label
=
label
->
Slice
<
T
>
(
start_pos
,
end_pos
);
const
Tensor
one_seq_label
=
label
->
Slice
<
T
>
(
start_pos
,
end_pos
);
Tensor
one_seq_alpha
=
alpha
->
Slice
<
T
>
(
start_pos
,
end_pos
);
Tensor
one_seq_alpha
=
alpha
->
Slice
<
T
>
(
start_pos
,
end_pos
);
log_likelihood
[
i
]
=
ForwardOneSequence
(
log_likelihood
[
i
]
=
ForwardOneSequence
(
ctx
.
device_context
(),
one_seq
,
one_seq_row_max
,
one_seq_exp
s
,
&
one_seq
,
&
one_seq_row_max
,
&
one_seq_exps
,
transition_weight
s
,
(
*
transition_weights
),
transition_exps
,
one_seq_label
,
one_seq_alpha
);
transition_exps
,
&
one_seq_label
,
&
one_seq_alpha
);
}
}
}
}
protected:
protected:
T
ForwardOneSequence
(
const
platform
::
DeviceContext
&
ctx
,
T
ForwardOneSequence
(
const
Tensor
*
emission
,
const
Tensor
*
emission_row_max
,
const
Tensor
&
emission
,
Tensor
&
emission_row_max
,
const
Tensor
*
emission_exps
,
const
Tensor
*
trans_weights
,
Tensor
&
emission_exps
,
const
Tensor
&
trans_weights
,
const
Tensor
*
trans_weight_exps
,
const
Tensor
*
label
,
Tensor
&
trans_weight_exps
,
const
Tensor
&
label
,
Tensor
*
alpha
)
const
{
Tensor
&
alpha
)
const
{
const
T
*
x
=
emission
->
data
<
T
>
();
// (TODO caoying) Evaluate and optimize this.
const
T
*
x_row_max
=
emission_row_max
->
data
<
T
>
();
// The Eigen compution kernel will be invoked for multiple times.
const
T
*
x_exps
=
emission_exps
->
data
<
T
>
();
// Some computations regardless of sequence inforamtion could be performed
const
T
*
w
=
trans_weights
->
data
<
T
>
();
// only one time for the entire batch. This potentially could be optimized.
const
T
*
w_exps
=
trans_weight_exps
->
data
<
T
>
();
T
*
alpha_value
=
alpha
->
data
<
T
>
();
auto
x_dims
=
emission
.
dims
();
auto
x_dims
=
emission
->
dims
();
const
size_t
seq_length
=
x_dims
[
0
];
const
size_t
seq_length
=
x_dims
[
0
];
const
size_t
tag_num
=
x_dims
[
1
];
const
size_t
tag_num
=
x_dims
[
1
];
T
*
alpha_value
=
alpha
.
data
<
T
>
();
auto
x
=
EigenMatrix
<
T
>::
From
(
emission
);
auto
x_row_max
=
EigenMatrix
<
T
>::
From
(
emission_row_max
);
const
int
class_dim
=
1
;
x_row_max
.
device
(
*
ctx
.
GetEigenDevice
<
platform
::
CPUPlace
>
())
=
x
.
maximum
(
Eigen
::
DSizes
<
int
,
1
>
(
class_dim
))
.
reshape
(
Eigen
::
DSizes
<
int
,
2
>
(
int
(
seq_length
),
1
));
auto
x_exps
=
EigenMatrix
<
T
>::
From
(
emission_exps
);
x_exps
.
device
(
*
ctx
.
GetEigenDevice
<
platform
::
CPUPlace
>
())
=
(
x
-
x_row_max
.
broadcast
(
Eigen
::
DSizes
<
int
,
2
>
(
1
,
tag_num
))).
exp
();
auto
w
=
EigenMatrix
<
T
>::
From
(
trans_weights
);
auto
w_exps
=
EigenMatrix
<
T
>::
From
(
trans_weight_exps
);
w_exps
.
device
(
*
ctx
.
GetEigenDevice
<
platform
::
CPUPlace
>
())
=
w
.
exp
();
// The 1st row of w are transition weights for start mask.
// The 1st row of w are transition weights for start mask.
const
size_t
start_ridx
=
0
;
// The 2nd row of w are transition weights for end mask.
// The 2nd row of w are transition weights for end mask.
const
size_t
end_ridx
=
1
;
// Transition weights among other tags begins from the 3rd row of w.
// Transition weights among other tags begins from the 3rd row of w.
const
size_t
state_
base_r
idx
=
2
;
const
size_t
state_
trans_base_
idx
=
2
;
for
(
size_t
i
=
0
;
i
<
tag_num
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
tag_num
;
++
i
)
{
alpha_value
[
i
]
=
w_exps
(
start_ridx
,
i
)
*
x_exps
(
0
,
i
)
;
alpha_value
[
i
]
=
w_exps
[
i
]
*
x_exps
[
i
]
;
}
}
T
ll
=
-
x_row_max
(
0
,
1
)
-
std
::
log
(
NormalizeL1
(
alpha_value
,
tag_num
));
T
ll
=
-
x_row_max
[
0
]
-
std
::
log
(
NormalizeL1
<
T
>
(
alpha_value
,
tag_num
));
for
(
size_t
k
=
1
;
k
<
seq_length
;
++
k
)
{
for
(
size_t
k
=
1
;
k
<
seq_length
;
++
k
)
{
for
(
size_t
i
=
0
;
i
<
tag_num
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
tag_num
;
++
i
)
{
T
sum
=
0.
;
T
sum
=
0.
;
for
(
size_t
j
=
0
;
j
<
tag_num
;
++
j
)
{
for
(
size_t
j
=
0
;
j
<
tag_num
;
++
j
)
{
sum
+=
alpha_value
[(
k
-
1
)
*
tag_num
+
j
]
*
sum
+=
alpha_value
[(
k
-
1
)
*
tag_num
+
j
]
*
w_exps
(
j
+
state_base_ridx
,
i
)
;
w_exps
[(
j
+
state_trans_base_idx
)
*
tag_num
+
i
]
;
}
}
alpha_value
[
k
*
tag_num
+
i
]
=
x_exps
(
k
,
i
)
*
sum
;
alpha_value
[
k
*
tag_num
+
i
]
=
x_exps
[
k
*
tag_num
+
i
]
*
sum
;
}
}
ll
-=
x_row_max
(
k
,
1
)
+
ll
-=
x_row_max
[
k
]
+
std
::
log
(
NormalizeL1
(
alpha_value
+
k
*
tag_num
,
tag_num
));
std
::
log
(
NormalizeL1
<
T
>
(
alpha_value
+
k
*
tag_num
,
tag_num
));
}
}
T
sum
=
0.
;
T
sum
=
0.
;
for
(
size_t
i
=
0
;
i
<
tag_num
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
tag_num
;
++
i
)
{
sum
+=
alpha_value
[(
seq_length
-
1
)
*
tag_num
+
i
]
*
w_exps
(
end_ridx
,
i
)
;
sum
+=
alpha_value
[(
seq_length
-
1
)
*
tag_num
+
i
]
*
w_exps
[
tag_num
+
i
]
;
}
}
ll
-=
std
::
log
(
sum
);
ll
-=
std
::
log
(
sum
);
const
int
*
lbl
=
label
.
data
<
int
>
();
const
int
*
lbl
=
label
->
data
<
int
>
();
PADDLE_ENFORCE_LT
(
PADDLE_ENFORCE_LT
(
*
std
::
max_element
(
lbl
,
lbl
+
seq_length
),
tag_num
,
*
std
::
max_element
(
lbl
,
lbl
+
seq_length
),
tag_num
,
"An invalid tag label that execesses the largest tag number."
);
"An invalid tag label that execesses the largest tag number."
);
// Calculate the nominator part, which depends on the label sequence.
// Calculate the nominator part, which depends on the label sequence.
ll
+=
w
(
start_ridx
,
lbl
[
0
])
+
x
(
start_ridx
,
lbl
[
0
])
+
ll
+=
w
[
lbl
[
0
]]
/*start transition*/
+
x
[
lbl
[
0
]]
+
w
(
end_ridx
,
lbl
[
seq_length
-
1
])
;
w
[
tag_num
+
lbl
[
seq_length
-
1
]]
/*end transition*/
;
for
(
size_t
k
=
1
;
k
<
seq_length
;
++
k
)
for
(
size_t
k
=
1
;
k
<
seq_length
;
++
k
)
ll
+=
x
(
k
,
lbl
[
k
])
+
w
(
lbl
[
k
-
1
],
lbl
[
k
])
;
ll
+=
x
[
k
*
tag_num
+
lbl
[
k
]]
+
w
[
lbl
[
k
-
1
]
*
tag_num
+
lbl
[
k
]]
;
return
-
ll
;
return
-
ll
;
}
}
private:
T
NormalizeL1
(
T
*
x
,
size_t
len
)
const
{
T
sum
=
0.
;
for
(
size_t
i
=
0
;
i
<
len
;
++
i
)
sum
+=
x
[
i
];
// (This comment is from the old LinearChainCRFLayer.)
// Right now, we just bet that sum won't be zero. If this really happens, we
// will figure out what should be done then.
PADDLE_ENFORCE
(
sum
,
"The unnormalized probabilites of all possible unfinished "
"sequences must be greater than 0."
);
for
(
size_t
i
=
0
;
i
<
len
;
++
i
)
x
[
i
]
/=
sum
;
return
sum
;
}
};
};
class
LinearChainCrfGradOp
:
public
framework
::
OperatorWithKernel
{
class
LinearChainCrfGradOp
:
public
framework
::
OperatorWithKernel
{
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"EmissionExps"
),
"Input(EmissionExps) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"TransitionExps"
),
"Input(TransitionExps) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"LogLikelihood"
)),
"Input(LogLikelihood@GRAD) shoudl be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Emission"
)),
"Output(Emission@GRAD) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Transition"
)),
"Output(Transition@GRAD) should be not null."
);
auto
emission_exps_dims
=
ctx
->
GetInputDim
(
"EmissionExps"
);
auto
transition_exps_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"TransitionExps"
));
auto
label_dims
=
ctx
->
GetInputDim
(
"Label"
);
PADDLE_ENFORCE_EQ
(
emission_exps_dims
.
size
(),
2UL
,
"The Input(EmissionExps) should be a 2-D tensor."
);
PADDLE_ENFORCE_EQ
(
transition_exps_dims
.
size
(),
2UL
,
"The Input(TransitionExps) should be a 2-D tensor."
);
PADDLE_ENFORCE_EQ
(
transition_exps_dims
[
0
]
-
2
,
transition_exps_dims
[
1
],
"An invalid dimension for the Input(TransitionExps), which should "
"be a 2-D tensor with shape [(D + 2) x D]."
);
PADDLE_ENFORCE_EQ
(
emission_exps_dims
[
1
],
transition_exps_dims
[
1
],
"The 2nd dimension of the Input(EmissionExps) and the "
"Input(TransitionExps) should be equal to the tag number."
);
PADDLE_ENFORCE
(
label_dims
.
size
()
==
2UL
&&
label_dims
[
1
]
==
1UL
,
"The Input(Label) should be a 2-D tensor with the 2nd "
"dimensions fixed to 1."
);
PADDLE_ENFORCE_EQ
(
emission_exps_dims
[
0
],
label_dims
[
0
],
"The height of Input(EmissionExps) and the height of Input(Label) "
"should be the same."
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Emission"
),
emission_exps_dims
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Transition"
),
transition_exps_dims
);
}
};
};
template
<
typename
T
>
template
<
typename
T
>
...
@@ -334,6 +390,134 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
...
@@ -334,6 +390,134 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
"This kernel only runs on CPU."
);
"This kernel only runs on CPU."
);
auto
*
ll_grad
=
ctx
.
Input
<
LoDTensor
>
(
framework
::
GradVarName
(
"LogLikelihood"
));
auto
*
label
=
ctx
.
Input
<
LoDTensor
>
(
"Label"
);
auto
*
emission_exps
=
ctx
.
Input
<
LoDTensor
>
(
"EmissionExps"
);
auto
*
transition_exps
=
ctx
.
Input
<
Tensor
>
(
"TransitionExps"
);
auto
*
alpha
=
ctx
.
Input
<
Tensor
>
(
"Alpha"
);
auto
*
emission_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Emission"
));
emission_grad
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
auto
*
trans_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Transition"
));
if
(
trans_grad
)
trans_grad
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
auto
emission_dims
=
emission_exps
->
dims
();
// Beta is the memo table used in dynamic programming to calculate the
// backwark vectors. For a backward vector i (the i-th row of beta), it
// captures the unnormalized probabilities of partial sequences starting at
// position i.
Tensor
beta
;
beta
.
mutable_data
<
T
>
(
emission_dims
,
platform
::
CPUPlace
());
auto
place
=
ctx
.
GetEigenDevice
<
platform
::
CPUPlace
>
();
auto
x_grad
=
EigenMatrix
<
T
>::
From
(
*
emission_grad
);
auto
out_grad
=
EigenMatrix
<
T
>::
From
(
*
ll_grad
);
x_grad
.
device
(
place
)
=
x_grad
*
out_grad
.
broadcast
(
Eigen
::
DSizes
<
int
,
2
>
(
1
,
emission_dims
[
1
]));
const
size_t
level
=
0
;
// currently, only support sequence.
auto
lod
=
emission_exps
->
lod
();
for
(
size_t
i
=
0
;
i
<
lod
[
level
].
size
()
-
1
;
++
i
)
{
int
start_pos
=
static_cast
<
int
>
(
lod
[
level
][
i
]);
int
end_pos
=
static_cast
<
int
>
(
lod
[
level
][
i
+
1
]);
const
Tensor
one_seq_emission_exps
=
emission_exps
->
Slice
<
T
>
(
start_pos
,
end_pos
);
const
Tensor
one_seq_label
=
label
->
Slice
<
T
>
(
start_pos
,
end_pos
);
const
Tensor
one_seq_alpha
=
alpha
->
Slice
<
T
>
(
start_pos
,
end_pos
);
Tensor
one_seq_beta
=
beta
.
Slice
<
T
>
(
start_pos
,
end_pos
);
Tensor
one_seq_emission_grad
=
emission_grad
->
Slice
<
T
>
(
start_pos
,
end_pos
);
BackwardOneSequence
(
ctx
.
device_context
(),
&
one_seq_emission_exps
,
transition_exps
,
&
one_seq_alpha
,
&
one_seq_label
,
&
one_seq_beta
,
trans_grad
,
&
one_seq_emission_grad
);
}
}
protected:
void
BackwardOneSequence
(
const
platform
::
DeviceContext
&
ctx
,
const
Tensor
*
emission_exps
,
const
Tensor
*
transition_exps
,
const
Tensor
*
alpha
,
const
Tensor
*
label
,
Tensor
*
beta
,
Tensor
*
transition_grad
,
Tensor
*
emission_grad
)
const
{
const
T
*
w_exps
=
transition_exps
->
data
<
T
>
();
const
T
*
x_exps
=
emission_exps
->
data
<
T
>
();
const
int
*
label_value
=
label
->
data
<
int
>
();
T
*
beta_value
=
beta
->
data
<
T
>
();
auto
x_dims
=
emission_exps
->
dims
();
const
size_t
seq_length
=
x_dims
[
0
];
const
size_t
tag_num
=
x_dims
[
1
];
const
size_t
state_trans_base_idx
=
2
;
// Calculate the backwark vectors beta.
for
(
int
i
=
0
;
i
<
tag_num
;
++
i
)
beta_value
[(
seq_length
-
1
)
*
tag_num
+
i
]
=
w_exps
[
tag_num
+
i
];
NormalizeL1
<
T
>
(
beta_value
+
(
seq_length
-
1
)
*
tag_num
,
tag_num
);
for
(
int
k
=
seq_length
-
2
;
k
>=
0
;
--
k
)
{
for
(
int
i
=
0
;
i
<
tag_num
;
++
i
)
{
T
sum
=
0.
;
for
(
int
j
=
0
;
j
<
tag_num
;
++
j
)
{
sum
+=
x_exps
[(
i
+
state_trans_base_idx
)
*
tag_num
+
j
]
*
beta_value
[(
k
+
1
)
*
tag_num
+
j
]
*
x_exps
[(
k
+
1
)
*
tag_num
+
j
];
}
beta_value
[
k
*
tag_num
+
i
]
=
sum
;
}
NormalizeL1
<
T
>
(
beta_value
+
k
*
tag_num
,
tag_num
);
}
auto
alpha_mat
=
EigenMatrix
<
T
>::
From
(
*
alpha
);
auto
beta_mat
=
EigenMatrix
<
T
>::
From
(
*
beta
);
auto
x_grad_mat
=
EigenMatrix
<
T
>::
From
(
*
emission_grad
);
auto
*
place
=
ctx
.
GetEigenDevice
<
platform
::
CPUPlace
>
();
x_grad_mat
.
device
(
*
place
)
=
alpha_mat
*
beta_mat
;
x_grad_mat
/=
x_grad_mat
.
sum
(
Eigen
::
DSizes
<
int
,
1
>
(
1
))
.
reshape
(
Eigen
::
DSizes
<
int
,
2
>
(
seq_length
,
1
))
.
broadcast
(
Eigen
::
DSizes
<
int
,
2
>
(
1
,
tag_num
));
for
(
int
k
=
0
;
k
<
seq_length
;
++
k
)
x_grad_mat
(
k
,
label_value
[
k
])
-=
static_cast
<
T
>
(
1
);
if
(
transition_grad
)
{
T
*
trans_grad
=
transition_grad
->
data
<
T
>
();
for
(
size_t
k
=
0
;
k
<
tag_num
;
++
k
)
{
trans_grad
[
k
]
+=
x_grad_mat
(
/*from start state*/
0
,
k
);
trans_grad
[
tag_num
+
k
]
+=
x_grad_mat
(
/*to end state*/
seq_length
-
1
,
k
);
}
auto
x_exps_mat
=
EigenMatrix
<
T
>::
From
(
*
emission_exps
);
beta_mat
=
beta_mat
*
x_exps_mat
;
beta_mat
/=
beta_mat
.
sum
(
Eigen
::
DSizes
<
int
,
1
>
(
1
))
.
reshape
(
Eigen
::
DSizes
<
int
,
2
>
(
seq_length
,
1
))
.
broadcast
(
Eigen
::
DSizes
<
int
,
2
>
(
1
,
tag_num
));
for
(
int
k
=
1
;
k
<
seq_length
;
++
k
)
{
T
sum
=
0.
;
for
(
int
i
=
0
;
i
<
tag_num
;
++
i
)
{
for
(
int
j
=
0
;
j
<
tag_num
;
++
j
)
sum
+=
x_exps_mat
(
i
,
j
)
*
alpha_mat
(
k
-
1
,
i
)
*
beta_mat
(
k
,
j
);
}
sum
=
static_cast
<
T
>
(
1
)
/
sum
;
for
(
int
i
=
0
;
i
<
tag_num
;
++
i
)
{
for
(
int
j
=
0
;
j
<
tag_num
;
++
j
)
{
trans_grad
[(
i
+
2
)
*
tag_num
+
j
]
+=
sum
*
x_exps_mat
(
i
,
j
)
*
alpha_mat
(
k
-
1
,
i
)
*
beta_mat
(
k
,
j
);
}
}
trans_grad
[
label_value
[
k
-
1
]
*
tag_num
+
label_value
[
k
]]
-=
static_cast
<
T
>
(
1
);
}
}
}
}
};
};
...
...
paddle/operators/linear_chain_crf_op.h
浏览文件 @
80a5ee00
...
@@ -30,20 +30,24 @@ class LinearChainCrfOpKernel : public framework::OpKernel<T> {
...
@@ -30,20 +30,24 @@ class LinearChainCrfOpKernel : public framework::OpKernel<T> {
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
protected:
protected:
T
ForwardOneSequence
(
const
platform
::
DeviceContext
&
ctx
,
T
ForwardOneSequence
(
const
Tensor
*
emission
,
const
Tensor
*
emission_row_max
,
const
Tensor
&
emission
,
Tensor
&
emission_row_max
,
const
Tensor
*
emission_exps
,
const
Tensor
*
trans_weights
,
Tensor
&
emission_exps
,
const
Tensor
&
trans_weights
,
const
Tensor
*
trans_weight_exps
,
const
Tensor
*
label
,
Tensor
&
trans_weight_exps
,
const
Tensor
&
label
,
Tensor
*
alpha
)
const
;
Tensor
&
a
)
const
;
private:
T
NormalizeL1
(
T
*
x
,
size_t
len
)
const
;
};
};
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
LinearChainCrfGradOpKernel
:
public
framework
::
OpKernel
<
T
>
{
class
LinearChainCrfGradOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
protected:
void
BackwardOneSequence
(
const
platform
::
DeviceContext
&
ctx
,
const
Tensor
*
emission_exps
,
const
Tensor
*
transition_exps
,
const
Tensor
*
alpha
,
const
Tensor
*
label
,
Tensor
*
beta
,
Tensor
*
transition_grad
,
Tensor
*
emission_grad
)
const
;
};
};
}
// namespace operators
}
// namespace operators
...
...
python/paddle/v2/framework/tests/test_linear_chain_crf_op.py
浏览文件 @
80a5ee00
...
@@ -4,10 +4,12 @@ import numpy as np
...
@@ -4,10 +4,12 @@ import numpy as np
from
op_test
import
OpTest
from
op_test
import
OpTest
import
pdb
class
LinearChainCrfForward
(
object
):
class
LinearChainCrfForward
(
object
):
def
__init__
(
self
,
seq_start_positions
,
emission_weights
,
def
__init__
(
self
,
seq_start_positions
,
emission_weights
,
emission_row_max
,
transition_weight
s
,
labels
):
emission_exps
,
transition_weights
,
transition_exp
s
,
labels
):
self
.
tag_num
=
emission_weights
.
shape
[
1
]
self
.
tag_num
=
emission_weights
.
shape
[
1
]
self
.
seq_num
=
len
(
seq_start_positions
)
-
1
self
.
seq_num
=
len
(
seq_start_positions
)
-
1
...
@@ -15,25 +17,25 @@ class LinearChainCrfForward(object):
...
@@ -15,25 +17,25 @@ class LinearChainCrfForward(object):
self
.
labels
=
labels
self
.
labels
=
labels
self
.
x
=
emission_weights
self
.
x
=
emission_weights
self
.
x_row_max
=
np
.
amax
(
self
.
x
,
axis
=
1
,
keepdims
=
True
)
self
.
x_row_max
=
emission_row_max
self
.
x_exps
=
np
.
exp
(
self
.
x
-
self
.
x_row_max
)
self
.
x_exps
=
emission_exps
# unnormalized logits of the transition weights for the start mark.
# unnormalized logits of the transition weights for the start mark.
self
.
a
=
transition_weights
[
0
,
:]
self
.
a
=
transition_weights
[
0
,
:]
self
.
a_exps
=
np
.
exp
(
self
.
a
)
self
.
a_exps
=
transition_exps
[
0
,
:]
# unnormalized logits of the transition weights for the end mark.
# unnormalized logits of the transition weights for the end mark.
self
.
b
=
transition_weights
[
1
,
:]
self
.
b
=
transition_weights
[
1
,
:]
self
.
b_exps
=
np
.
exp
(
self
.
b
)
self
.
b_exps
=
transition_exps
[
1
,
:]
# unnormalized logits of the transition weights for all the other tags.
# unnormalized logits of the transition weights for all the other tags.
self
.
w
=
transition_weights
[
2
:,
:]
self
.
w
=
transition_weights
[
2
:,
:]
self
.
w_exps
=
np
.
exp
(
self
.
w
)
self
.
w_exps
=
transition_exps
[
2
:,
:]
# The output of linear chain crf operator.
# The output of linear chain crf operator.
# alpha is a memo table in dynamic programming to caculate
# alpha is a memo table in dynamic programming to caculate
# nomalization factor.
# nomalization factor.
self
.
alpha
=
np
.
zeros
(
self
.
alpha
=
np
.
zeros
(
(
seq_start_positions
[
-
1
],
self
.
tag_num
),
dtype
=
"float32"
)
(
seq_start_positions
[
-
1
],
self
.
tag_num
),
dtype
=
"float32"
)
self
.
log_likelihood
=
np
.
zeros
((
self
.
tag
_num
,
1
))
self
.
log_likelihood
=
np
.
zeros
((
self
.
seq
_num
,
1
))
def
_l1_norm
(
self
,
x
):
def
_l1_norm
(
self
,
x
):
s
=
np
.
sum
(
x
)
s
=
np
.
sum
(
x
)
...
@@ -91,11 +93,15 @@ class TestLinearChainCrfOp(OpTest):
...
@@ -91,11 +93,15 @@ class TestLinearChainCrfOp(OpTest):
lod
=
[[
0
]]
lod
=
[[
0
]]
for
i
in
range
(
SEQ_NUM
):
for
i
in
range
(
SEQ_NUM
):
lod
[
-
1
].
append
(
lod
[
-
1
][
-
1
]
+
random
.
randint
(
1
,
MAX_SEQ_LEN
))
lod
[
-
1
].
append
(
lod
[
-
1
][
-
1
]
+
random
.
randint
(
1
,
MAX_SEQ_LEN
))
emission
=
np
.
random
.
uniform
(
-
1
,
1
,
emission
=
np
.
random
.
uniform
(
-
1
,
1
,
[
lod
[
-
1
][
-
1
],
TAG_NUM
]).
astype
(
"float32"
)
[
lod
[
-
1
][
-
1
],
TAG_NUM
]).
astype
(
"float32"
)
emission_row_max
=
np
.
amax
(
emission
,
axis
=
1
,
keepdims
=
True
)
emission_exps
=
np
.
exp
(
emission
-
emission_row_max
)
transition
=
np
.
random
.
uniform
(
-
0.5
,
0.5
,
transition
=
np
.
random
.
uniform
(
-
0.5
,
0.5
,
[
TAG_NUM
+
2
,
TAG_NUM
]).
astype
(
"float32"
)
[
TAG_NUM
+
2
,
TAG_NUM
]).
astype
(
"float32"
)
transition_exps
=
np
.
exp
(
transition
)
labels
=
np
.
random
.
randint
(
labels
=
np
.
random
.
randint
(
low
=
0
,
high
=
TAG_NUM
,
size
=
(
lod
[
-
1
][
-
1
],
1
),
dtype
=
"int32"
)
low
=
0
,
high
=
TAG_NUM
,
size
=
(
lod
[
-
1
][
-
1
],
1
),
dtype
=
"int32"
)
...
@@ -105,10 +111,17 @@ class TestLinearChainCrfOp(OpTest):
...
@@ -105,10 +111,17 @@ class TestLinearChainCrfOp(OpTest):
"Label"
:
(
labels
,
lod
)
"Label"
:
(
labels
,
lod
)
}
}
crf
=
LinearChainCrfForward
(
lod
[
0
],
emission
,
transition
,
labels
)
crf
=
LinearChainCrfForward
(
lod
[
0
],
emission
,
emission_row_max
,
emission_exps
,
transition
,
transition_exps
,
labels
)
alpha
,
log_likelihood
=
crf
.
crf_forward_compute
()
alpha
,
log_likelihood
=
crf
.
crf_forward_compute
()
self
.
outputs
=
{
"Alpha"
:
alpha
,
"LogLikelihood"
:
log_likelihood
}
self
.
outputs
=
{
"Alpha"
:
alpha
,
"EmissionExps"
:
emission_exps
,
"TransitionExps"
:
transition_exps
,
"LogLikelihood"
:
log_likelihood
}
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"linear_chain_crf"
self
.
op_type
=
"linear_chain_crf"
...
@@ -117,6 +130,13 @@ class TestLinearChainCrfOp(OpTest):
...
@@ -117,6 +130,13 @@ class TestLinearChainCrfOp(OpTest):
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
"Emission"
,
"Transition"
],
"LogLikelihood"
)
def
test_check_grad_ignore_transition
(
self
):
self
.
check_grad
(
[
"Emission"
],
"LogLikelihood"
,
no_grad_set
=
set
(
"Transition"
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录