Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
45c8a88a
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
45c8a88a
编写于
12月 05, 2017
作者:
Q
Qiao Longfei
提交者:
GitHub
12月 05, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add crf_decoding layer (#6274)
* add crf_decoding layer * fix some typo * fix test_crf_decoding_op
上级
e760641a
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
61 addition
and
25 deletion
+61
-25
paddle/operators/crf_decoding_op.cc
paddle/operators/crf_decoding_op.cc
+9
-8
paddle/operators/crf_decoding_op.h
paddle/operators/crf_decoding_op.h
+5
-5
python/paddle/v2/fluid/framework.py
python/paddle/v2/fluid/framework.py
+1
-1
python/paddle/v2/fluid/layer_helper.py
python/paddle/v2/fluid/layer_helper.py
+7
-1
python/paddle/v2/fluid/layers.py
python/paddle/v2/fluid/layers.py
+18
-0
python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py
...n/paddle/v2/fluid/tests/book/test_label_semantic_roles.py
+9
-3
python/paddle/v2/fluid/tests/test_crf_decoding_op.py
python/paddle/v2/fluid/tests/test_crf_decoding_op.py
+6
-6
python/paddle/v2/fluid/tests/test_layers.py
python/paddle/v2/fluid/tests/test_layers.py
+6
-1
未找到文件。
paddle/operators/crf_decoding_op.cc
浏览文件 @
45c8a88a
...
@@ -36,17 +36,18 @@ class CRFDecodingOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -36,17 +36,18 @@ class CRFDecodingOpMaker : public framework::OpProtoAndCheckerMaker {
"w. See more details in comments of the linear_chain_crf operator."
);
"w. See more details in comments of the linear_chain_crf operator."
);
AddInput
(
AddInput
(
"Label"
,
"Label"
,
"(LoDTensor, LoDTensor<int>). The ground truth with shape "
"(LoDTensor, LoDTensor<int
64_t
>). The ground truth with shape "
"[N x 1]. This input is optional. See more details in the operator's "
"[N x 1]. This input is optional. See more details in the operator's "
"comments."
)
"comments."
)
.
AsDispensable
();
.
AsDispensable
();
AddOutput
(
"ViterbiPath"
,
AddOutput
(
"(LoDTensor, LoDTensor<int>). The decoding results. What to "
"ViterbiPath"
,
"return changes depending on whether the Input(Label) (the groud "
"(LoDTensor, LoDTensor<int64_t>). The decoding results. What to "
"truth) is given. See more details in the operator's comment."
);
"return changes depending on whether the Input(Label) (the ground "
"truth) is given. See more details in the operator's comment."
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
The crf_decoding operator reads the emission feature weights and the transition
The crf_decoding operator reads the emission feature weights and the transition
f
r
eature weights learned by the linear_chain_crf operator. It implements the
feature weights learned by the linear_chain_crf operator. It implements the
Viterbi algorithm which is a dynamic programming algorithm for finding the most
Viterbi algorithm which is a dynamic programming algorithm for finding the most
likely sequence of hidden states, called the Viterbi path, that results in a
likely sequence of hidden states, called the Viterbi path, that results in a
sequence of observed tags.
sequence of observed tags.
...
@@ -60,14 +61,14 @@ operator.
...
@@ -60,14 +61,14 @@ operator.
When Input(Label) is given, the crf_decoding operator returns a row vector
When Input(Label) is given, the crf_decoding operator returns a row vector
with shape [N x 1] whose values are fixed to be 0, indicating an incorrect
with shape [N x 1] whose values are fixed to be 0, indicating an incorrect
prediction, or 1 indicating a tag is correctly predicted. Such an ouput is the
prediction, or 1 indicating a tag is correctly predicted. Such an ou
t
put is the
input to chunk_eval operator.
input to chunk_eval operator.
2. Input(Label) is not given:
2. Input(Label) is not given:
This is the standard decoding process.
This is the standard decoding process.
The crf_decoding operator returns a row vec
ot
r with shape [N x 1] whose values
The crf_decoding operator returns a row vec
to
r with shape [N x 1] whose values
range from 0 to maximum tag number - 1. Each element indicates an index of a
range from 0 to maximum tag number - 1. Each element indicates an index of a
predicted tag.
predicted tag.
)DOC"
);
)DOC"
);
...
...
paddle/operators/crf_decoding_op.h
浏览文件 @
45c8a88a
...
@@ -43,9 +43,9 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
...
@@ -43,9 +43,9 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
const
size_t
level
=
0
;
const
size_t
level
=
0
;
const
size_t
seq_num
=
lod
[
level
].
size
()
-
1
;
const
size_t
seq_num
=
lod
[
level
].
size
()
-
1
;
int
*
path
=
decoded_path
->
mutable_data
<
in
t
>
(
platform
::
CPUPlace
());
int
64_t
*
path
=
decoded_path
->
mutable_data
<
int64_
t
>
(
platform
::
CPUPlace
());
math
::
SetConstant
<
platform
::
CPUPlace
,
int
>
()(
ctx
.
device_context
(),
math
::
SetConstant
<
platform
::
CPUPlace
,
int
64_t
>
()(
ctx
.
device_context
(),
decoded_path
,
0
);
decoded_path
,
0
);
for
(
size_t
i
=
0
;
i
<
seq_num
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
seq_num
;
++
i
)
{
int
start_pos
=
static_cast
<
int
>
(
lod
[
level
][
i
]);
int
start_pos
=
static_cast
<
int
>
(
lod
[
level
][
i
]);
int
end_pos
=
static_cast
<
int
>
(
lod
[
level
][
i
+
1
]);
int
end_pos
=
static_cast
<
int
>
(
lod
[
level
][
i
+
1
]);
...
@@ -57,7 +57,7 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
...
@@ -57,7 +57,7 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
if
(
label
)
{
if
(
label
)
{
PADDLE_ENFORCE_EQ
(
label
->
NumLevels
(),
1UL
,
PADDLE_ENFORCE_EQ
(
label
->
NumLevels
(),
1UL
,
"The Input(Label) should be a sequence."
);
"The Input(Label) should be a sequence."
);
const
int
*
label_value
=
label
->
data
<
in
t
>
();
const
int
64_t
*
label_value
=
label
->
data
<
int64_
t
>
();
size_t
batch_size
=
emission_weights
->
dims
()[
0
];
size_t
batch_size
=
emission_weights
->
dims
()[
0
];
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
path
[
i
]
=
label_value
[
i
]
==
path
[
i
]
?
1
:
0
;
path
[
i
]
=
label_value
[
i
]
==
path
[
i
]
?
1
:
0
;
...
@@ -76,7 +76,7 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
...
@@ -76,7 +76,7 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
const
T
*
x
=
emission_weights
.
data
<
T
>
();
const
T
*
x
=
emission_weights
.
data
<
T
>
();
const
T
*
w
=
transition_weights
.
data
<
T
>
();
const
T
*
w
=
transition_weights
.
data
<
T
>
();
int
*
path
=
decoded_path
->
data
<
in
t
>
();
int
64_t
*
path
=
decoded_path
->
data
<
int64_
t
>
();
// alpha is a memo table. An element alpha(k, v) records the score of the
// alpha is a memo table. An element alpha(k, v) records the score of the
// best sequence of tags from position 1 to position k with v being the end
// best sequence of tags from position 1 to position k with v being the end
...
...
python/paddle/v2/fluid/framework.py
浏览文件 @
45c8a88a
...
@@ -237,7 +237,7 @@ class Operator(object):
...
@@ -237,7 +237,7 @@ class Operator(object):
def
find_name
(
var_list
,
name
):
def
find_name
(
var_list
,
name
):
for
var_name
in
var_list
:
for
var_name
in
var_list
:
if
var_name
==
name
:
if
var_
list
[
var_name
]
is
not
None
and
var_
name
==
name
:
return
True
return
True
return
False
return
False
...
...
python/paddle/v2/fluid/layer_helper.py
浏览文件 @
45c8a88a
import
copy
import
copy
import
itertools
import
itertools
from
framework
import
Variable
,
default_main_program
,
default_startup_program
,
\
from
framework
import
Variable
,
Parameter
,
default_main_program
,
default_startup_program
,
\
unique_name
,
dtype_is_floating
unique_name
,
dtype_is_floating
from
paddle.v2.fluid.initializer
import
Constant
,
Xavier
from
paddle.v2.fluid.initializer
import
Constant
,
Xavier
from
param_attr
import
ParamAttr
from
param_attr
import
ParamAttr
...
@@ -122,6 +122,12 @@ class LayerHelper(object):
...
@@ -122,6 +122,12 @@ class LayerHelper(object):
return
self
.
main_program
.
global_block
().
create_parameter
(
return
self
.
main_program
.
global_block
().
create_parameter
(
dtype
=
dtype
,
shape
=
shape
,
**
attr
.
to_kwargs
())
dtype
=
dtype
,
shape
=
shape
,
**
attr
.
to_kwargs
())
def
get_parameter
(
self
,
name
):
param
=
self
.
main_program
.
global_block
().
var
(
name
)
if
not
isinstance
(
param
,
Parameter
):
raise
ValueError
(
"no Parameter name %s found"
%
name
)
return
param
def
create_tmp_variable
(
self
,
dtype
):
def
create_tmp_variable
(
self
,
dtype
):
return
self
.
main_program
.
current_block
().
create_var
(
return
self
.
main_program
.
current_block
().
create_var
(
name
=
unique_name
(
"."
.
join
([
self
.
name
,
'tmp'
])),
name
=
unique_name
(
"."
.
join
([
self
.
name
,
'tmp'
])),
...
...
python/paddle/v2/fluid/layers.py
浏览文件 @
45c8a88a
...
@@ -477,6 +477,24 @@ def linear_chain_crf(input,
...
@@ -477,6 +477,24 @@ def linear_chain_crf(input,
return
log_likelihood
return
log_likelihood
def
crf_decoding
(
input
,
param_attr
,
label
=
None
,
main_program
=
None
,
startup_program
=
None
):
helper
=
LayerHelper
(
'crf_decoding'
,
**
locals
())
transition
=
helper
.
get_parameter
(
param_attr
.
name
)
viterbi_path
=
helper
.
create_tmp_variable
(
dtype
=
helper
.
input_dtype
())
helper
.
append_op
(
type
=
'crf_decoding'
,
inputs
=
{
"Emission"
:
[
input
],
"Transition"
:
transition
,
"Label"
:
label
},
outputs
=
{
"ViterbiPath"
:
[
viterbi_path
]})
return
viterbi_path
def
assign
(
input
,
output
,
main_program
=
None
,
startup_program
=
None
):
def
assign
(
input
,
output
,
main_program
=
None
,
startup_program
=
None
):
helper
=
LayerHelper
(
'assign'
,
**
locals
())
helper
=
LayerHelper
(
'assign'
,
**
locals
())
helper
.
append_op
(
helper
.
append_op
(
...
...
python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py
浏览文件 @
45c8a88a
...
@@ -137,12 +137,19 @@ def main():
...
@@ -137,12 +137,19 @@ def main():
param_attr
=
fluid
.
ParamAttr
(
param_attr
=
fluid
.
ParamAttr
(
name
=
'crfw'
,
learning_rate
=
mix_hidden_lr
))
name
=
'crfw'
,
learning_rate
=
mix_hidden_lr
))
avg_cost
=
fluid
.
layers
.
mean
(
x
=
crf_cost
)
avg_cost
=
fluid
.
layers
.
mean
(
x
=
crf_cost
)
# TODO(qiao)
# TODO(qiao)
# 1. add crf_decode_layer and evaluator
# check other optimizers and check why out will be NAN
# 2. use other optimizer and check why out will be NAN
sgd_optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.0001
)
sgd_optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.0001
)
sgd_optimizer
.
minimize
(
avg_cost
)
sgd_optimizer
.
minimize
(
avg_cost
)
# TODO(qiao)
# add dependency track and move this config before optimizer
crf_decode
=
fluid
.
layers
.
crf_decoding
(
input
=
feature_out
,
label
=
target
,
param_attr
=
fluid
.
ParamAttr
(
name
=
'crfw'
))
train_data
=
paddle
.
batch
(
train_data
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
paddle
.
reader
.
shuffle
(
paddle
.
dataset
.
conll05
.
test
(),
buf_size
=
8192
),
paddle
.
dataset
.
conll05
.
test
(),
buf_size
=
8192
),
...
@@ -168,7 +175,6 @@ def main():
...
@@ -168,7 +175,6 @@ def main():
feed
=
feeder
.
feed
(
data
),
feed
=
feeder
.
feed
(
data
),
fetch_list
=
[
avg_cost
])
fetch_list
=
[
avg_cost
])
avg_cost_val
=
np
.
array
(
outs
[
0
])
avg_cost_val
=
np
.
array
(
outs
[
0
])
if
batch_id
%
10
==
0
:
if
batch_id
%
10
==
0
:
print
(
"avg_cost="
+
str
(
avg_cost_val
))
print
(
"avg_cost="
+
str
(
avg_cost_val
))
...
...
python/paddle/v2/fluid/tests/test_crf_decoding_op.py
浏览文件 @
45c8a88a
...
@@ -20,14 +20,14 @@ class CRFDecoding(object):
...
@@ -20,14 +20,14 @@ class CRFDecoding(object):
self
.
w
=
transition_weights
[
2
:,
:]
self
.
w
=
transition_weights
[
2
:,
:]
self
.
track
=
np
.
zeros
(
self
.
track
=
np
.
zeros
(
(
seq_start_positions
[
-
1
],
self
.
tag_num
),
dtype
=
"int
32
"
)
(
seq_start_positions
[
-
1
],
self
.
tag_num
),
dtype
=
"int
64
"
)
self
.
decoded_path
=
np
.
zeros
(
self
.
decoded_path
=
np
.
zeros
(
(
seq_start_positions
[
-
1
],
1
),
dtype
=
"int
32
"
)
(
seq_start_positions
[
-
1
],
1
),
dtype
=
"int
64
"
)
def
_decode_one_sequence
(
self
,
decoded_path
,
x
):
def
_decode_one_sequence
(
self
,
decoded_path
,
x
):
seq_len
,
tag_num
=
x
.
shape
seq_len
,
tag_num
=
x
.
shape
alpha
=
np
.
zeros
((
seq_len
,
tag_num
),
dtype
=
"float64"
)
alpha
=
np
.
zeros
((
seq_len
,
tag_num
),
dtype
=
"float64"
)
track
=
np
.
zeros
((
seq_len
,
tag_num
),
dtype
=
"int
32
"
)
track
=
np
.
zeros
((
seq_len
,
tag_num
),
dtype
=
"int
64
"
)
for
i
in
range
(
tag_num
):
for
i
in
range
(
tag_num
):
alpha
[
0
,
i
]
=
self
.
a
[
i
]
+
x
[
0
,
i
]
alpha
[
0
,
i
]
=
self
.
a
[
i
]
+
x
[
0
,
i
]
...
@@ -125,10 +125,10 @@ class TestCRFDecodingOp2(OpTest):
...
@@ -125,10 +125,10 @@ class TestCRFDecodingOp2(OpTest):
axis
=
0
)
axis
=
0
)
labels
=
np
.
random
.
randint
(
labels
=
np
.
random
.
randint
(
low
=
0
,
high
=
TAG_NUM
,
size
=
(
lod
[
-
1
][
-
1
],
1
),
dtype
=
"int
32
"
)
low
=
0
,
high
=
TAG_NUM
,
size
=
(
lod
[
-
1
][
-
1
],
1
),
dtype
=
"int
64
"
)
predicted_labels
=
np
.
ones
(
predicted_labels
=
np
.
ones
(
(
lod
[
-
1
][
-
1
],
1
),
dtype
=
"int
32
"
)
*
(
TAG_NUM
-
1
)
(
lod
[
-
1
][
-
1
],
1
),
dtype
=
"int
64
"
)
*
(
TAG_NUM
-
1
)
expected_output
=
(
labels
==
predicted_labels
).
astype
(
"int
32
"
)
expected_output
=
(
labels
==
predicted_labels
).
astype
(
"int
64
"
)
self
.
inputs
=
{
self
.
inputs
=
{
"Emission"
:
(
emission
,
lod
),
"Emission"
:
(
emission
,
lod
),
...
...
python/paddle/v2/fluid/tests/test_layers.py
浏览文件 @
45c8a88a
...
@@ -4,6 +4,7 @@ import unittest
...
@@ -4,6 +4,7 @@ import unittest
import
paddle.v2.fluid.layers
as
layers
import
paddle.v2.fluid.layers
as
layers
import
paddle.v2.fluid.nets
as
nets
import
paddle.v2.fluid.nets
as
nets
from
paddle.v2.fluid.framework
import
Program
,
program_guard
from
paddle.v2.fluid.framework
import
Program
,
program_guard
from
paddle.v2.fluid.param_attr
import
ParamAttr
class
TestBook
(
unittest
.
TestCase
):
class
TestBook
(
unittest
.
TestCase
):
...
@@ -132,8 +133,12 @@ class TestBook(unittest.TestCase):
...
@@ -132,8 +133,12 @@ class TestBook(unittest.TestCase):
images
=
layers
.
data
(
name
=
'pixel'
,
shape
=
[
784
],
dtype
=
'float32'
)
images
=
layers
.
data
(
name
=
'pixel'
,
shape
=
[
784
],
dtype
=
'float32'
)
label
=
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int32'
)
label
=
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int32'
)
hidden
=
layers
.
fc
(
input
=
images
,
size
=
128
)
hidden
=
layers
.
fc
(
input
=
images
,
size
=
128
)
crf
=
layers
.
linear_chain_crf
(
input
=
hidden
,
label
=
label
)
crf
=
layers
.
linear_chain_crf
(
input
=
hidden
,
label
=
label
,
param_attr
=
ParamAttr
(
name
=
"crfw"
))
crf_decode
=
layers
.
crf_decoding
(
input
=
hidden
,
param_attr
=
ParamAttr
(
name
=
"crfw"
))
self
.
assertNotEqual
(
crf
,
None
)
self
.
assertNotEqual
(
crf
,
None
)
self
.
assertNotEqual
(
crf_decode
,
None
)
print
(
str
(
program
))
print
(
str
(
program
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录