未验证 提交 50b1cab1 编写于 作者: Y Yibing Liu 提交者: GitHub

Add padding support for crf_decoding (#19057)

* Add padding support for crf_decoding

* Fixes in comupte kernel

test=develop

* Update API Spec

test=develop

* Update API.spec

test=develop

* Avoid using paddle_enforce

test=develop

* Fix enforce

test=develop
上级 45fb031f
......@@ -116,7 +116,7 @@ paddle.fluid.layers.dynamic_lstmp (ArgSpec(args=['input', 'size', 'proj_size', '
paddle.fluid.layers.dynamic_gru (ArgSpec(args=['input', 'size', 'param_attr', 'bias_attr', 'is_reverse', 'gate_activation', 'candidate_activation', 'h_0', 'origin_mode'], varargs=None, keywords=None, defaults=(None, None, False, 'sigmoid', 'tanh', None, False)), ('document', '83617c165827e030636c80486d5de6f3'))
paddle.fluid.layers.gru_unit (ArgSpec(args=['input', 'hidden', 'size', 'param_attr', 'bias_attr', 'activation', 'gate_activation', 'origin_mode'], varargs=None, keywords=None, defaults=(None, None, 'tanh', 'sigmoid', False)), ('document', '33974b9bfa69f2f1eb85e6f956dff04e'))
paddle.fluid.layers.linear_chain_crf (ArgSpec(args=['input', 'label', 'param_attr'], varargs=None, keywords=None, defaults=(None,)), ('document', '34f96be41684b0959897a9e735997e20'))
paddle.fluid.layers.crf_decoding (ArgSpec(args=['input', 'param_attr', 'label'], varargs=None, keywords=None, defaults=(None,)), ('document', 'c469f22029f7b5d41ecd44dfa1e81ffd'))
paddle.fluid.layers.crf_decoding (ArgSpec(args=['input', 'param_attr', 'label'], varargs=None, keywords=None, defaults=(None,)), ('document', '5ce117258e243be1c81539e254178d90'))
paddle.fluid.layers.cos_sim (ArgSpec(args=['X', 'Y'], varargs=None, keywords=None, defaults=None), ('document', '8e6ce424cf9e261ef32ee229c06a6e66'))
paddle.fluid.layers.cross_entropy (ArgSpec(args=['input', 'label', 'soft_label', 'ignore_index'], varargs=None, keywords=None, defaults=(False, -100)), ('document', 'f43c659ca1749a3f0ff2231e6dfda07d'))
paddle.fluid.layers.bpr_loss (ArgSpec(args=['input', 'label', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '6263dfdeb6c670fa0922c9cbc8fb1bf4'))
......
......@@ -19,14 +19,17 @@ namespace operators {
class CRFDecodingOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Emission",
"(LoDTensor, default: LoDTensor<float>). A LoDTensor with shape "
"[N x D] where N is the size of the mini-batch and D is the total "
"tag number. This input is the unscaled emission weight matrix of "
"the linear_chain_crf operator.");
AddInput(
"Emission",
"(Tensor<float>/LoDTensor<float>). For a LoDTensor input, its "
"shape is [N x D] where N is the total sequence length of the "
"mini-batch and D is the total tag number. While for a tensor "
"input, its shape is [B X S X D] with B the batch size and S the "
"sequence length of each sample after padding. This input is the "
"unscaled emission weight matrix of the linear_chain_crf operator.");
AddInput(
"Transition",
"(Tensor, default: Tensor<float>). A Tensor with shape [(D + 2) x D]. "
"(Tensor<float>). A Tensor with shape [(D + 2) x D]. "
"This input is the transition weights learned by the linear_chain_crf "
"operator, denoted as w. The 1st row of w are transition weights for "
"the start mask. The 2nd row of w are transition weights for the end "
......@@ -34,15 +37,24 @@ class CRFDecodingOpMaker : public framework::OpProtoAndCheckerMaker {
"w. See more details in comments of the linear_chain_crf operator.");
AddInput(
"Label",
"(LoDTensor, LoDTensor<int64_t>). The ground truth with shape "
"[N x 1]. This input is optional. See more details in the operator's "
"comments.")
"(Tensor<int64_t>/LoDTensor<int64_t>). The ground truth with shape "
"[N x 1] (for LoDTensor) or [B x S] (for Tensor). This input is "
"optional. "
"See more details in the operator's comments.")
.AsDispensable();
AddOutput(
"ViterbiPath",
"(LoDTensor, LoDTensor<int64_t>). The decoding results. What to "
"(Tensor<int64_t>/LoDTensor<int64_t>). The decoding results. What to "
"return changes depending on whether the Input(Label) (the ground "
"truth) is given. See more details in the operator's comment.");
AddInput("Length",
"(Tensor<int64_t>). The actual length of each sample before "
"padding with shape [B x 1]. It means the Input(Emission), "
"Input(Label) "
"and Output(ViterbiPath) are common tensors with padding when "
"this input "
"is given.")
.AsDispensable();
AddComment(R"DOC(
The crf_decoding operator reads the emission feature weights and the transition
feature weights learned by the linear_chain_crf operator. It implements the
......@@ -55,15 +67,16 @@ The output of this operator changes according to whether Input(Label) is given:
1. Input(Label) is given:
This happens in training. This operator is used to co-work with the chunk_eval
operator.
When Input(Label) is given, the crf_decoding operator returns a row vector
with shape [N x 1] whose values are fixed to be 0, indicating an incorrect
prediction, or 1 indicating a tag is correctly predicted. Such an output is the
input to chunk_eval operator.
When Input(Label) is given, the crf_decoding operator returns tensor with the
sampe shape as Input(Label) whose values are fixed to be 0, indicating an
incorrect prediction, or 1 indicating a tag is correctly predicted. Such an
output is the input to chunk_eval operator.
2. Input(Label) is not given:
This is the standard decoding process.
The crf_decoding operator returns a row vector with shape [N x 1] whose values
The crf_decoding operator returns a row vector with shape [N x 1]/[B x S], here
the shape depends on the inputs are LoDTensors or common tensors, whose values
range from 0 to maximum tag number - 1, Each element indicates an index of a
predicted tag.
)DOC");
......@@ -75,37 +88,46 @@ class CRFDecodingOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Emission"),
PADDLE_ENFORCE_EQ(ctx->HasInput("Emission"), true,
"Input(Emission) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Transition"),
PADDLE_ENFORCE_EQ(ctx->HasInput("Transition"), true,
"Input(Transition) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("ViterbiPath"),
PADDLE_ENFORCE_EQ(ctx->HasOutput("ViterbiPath"), true,
"Output(ViterbiPath) should be not null.");
auto emission_dims = ctx->GetInputDim("Emission");
bool has_length = ctx->HasInput("Length");
if (has_length) {
PADDLE_ENFORCE_EQ(emission_dims.size(), 3,
"The Input(Emission) should be a 3-D tensor.");
} else {
PADDLE_ENFORCE_EQ(emission_dims.size(), 2,
"The Input(Emission) should be a 2-D tensor.");
PADDLE_ENFORCE(emission_dims[0], "An empty mini-batch is not allowed.");
}
PADDLE_ENFORCE_NE(emission_dims[0], 0,
"An empty mini-batch is not allowed.");
auto transition_dims = ctx->GetInputDim("Transition");
PADDLE_ENFORCE_EQ(transition_dims.size(), 2,
PADDLE_ENFORCE_EQ(transition_dims.size(), 2UL,
"The Input(Transition) should be a 2-D tensor.");
PADDLE_ENFORCE_EQ(
transition_dims[0] - 2, transition_dims[1],
"An invalid dimension for the Input(Transition), which should "
"be a 2-D tensor with shape [(D + 2) x D].");
if (ctx->IsRuntime() || (emission_dims[1] > 0 && transition_dims[1] > 0)) {
if (ctx->IsRuntime() || (emission_dims[emission_dims.size() - 1] > 0 &&
transition_dims[transition_dims.size() - 1] > 0)) {
PADDLE_ENFORCE_EQ(
emission_dims[1], transition_dims[1],
"The 2nd dimension of the Input(Emission) and the Input(Transition) "
emission_dims[emission_dims.size() - 1],
transition_dims[transition_dims.size() - 1],
"The last dimension of the Input(Emission) and the Input(Transition) "
"should be equal to the tag number.");
}
if (ctx->HasInput("Label")) {
auto label_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL,
"The Input(Label) should be a 2-D tensor with the 2nd "
"dimensions fixed to 1.");
PADDLE_ENFORCE_EQ(label_dims.size(), 2UL,
"The Input(Label) should be a 2-D tensor");
if (ctx->IsRuntime() || (emission_dims[0] > 0 && label_dims[0] > 0)) {
PADDLE_ENFORCE_EQ(
emission_dims[0], label_dims[0],
......@@ -115,8 +137,12 @@ class CRFDecodingOp : public framework::OperatorWithKernel {
}
ctx->ShareLoD("Emission", /*->*/ "ViterbiPath");
if (has_length) {
ctx->SetOutputDim("ViterbiPath", {emission_dims[0], emission_dims[1]});
} else {
ctx->SetOutputDim("ViterbiPath", {emission_dims[0], 1});
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
......
......@@ -35,16 +35,42 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
auto* label = ctx.Input<LoDTensor>("Label");
auto* decoded_path = ctx.Output<Tensor>("ViterbiPath");
int64_t* path = decoded_path->mutable_data<int64_t>(platform::CPUPlace());
math::SetConstant<DeviceContext, int64_t>()(
ctx.template device_context<DeviceContext>(), decoded_path, 0);
bool has_length = ctx.HasInput("Length");
if (has_length) {
auto* length = ctx.Input<Tensor>("Length");
const size_t seq_num = length->numel();
const int64_t* length_data = length->data<int64_t>();
auto in_dims = emission_weights->dims();
auto& dev_ctx = ctx.template device_context<DeviceContext>();
framework::Tensor emission_weights_tmp =
ctx.AllocateTmpTensor<T, DeviceContext>(emission_weights->dims(),
dev_ctx);
emission_weights_tmp.ShareDataWith(*emission_weights);
emission_weights_tmp.Resize({in_dims[0] * in_dims[1], in_dims[2]});
decoded_path->Resize({in_dims[0] * in_dims[1], 1});
for (size_t i = 0; i < seq_num; ++i) {
if (length_data[i] == 0) continue;
int start_pos = i * in_dims[1];
int end_pos = start_pos + static_cast<int>(length_data[i]);
Tensor decoded_path_one_seq = decoded_path->Slice(start_pos, end_pos);
Decode(emission_weights_tmp.Slice(start_pos, end_pos),
*transition_weights, &decoded_path_one_seq);
}
decoded_path->Resize({in_dims[0], in_dims[1]});
} else {
PADDLE_ENFORCE_EQ(emission_weights->NumLevels(), 1UL,
"The Input(Emission) should be a sequence.");
auto lod = emission_weights->lod();
PADDLE_ENFORCE(lod.size(), "Input(Emission) must be a sequence.");
PADDLE_ENFORCE_GT(lod.size(), 0, "Input(Emission) must be a sequence.");
const size_t level = 0;
const size_t seq_num = lod[level].size() - 1;
int64_t* path = decoded_path->mutable_data<int64_t>(platform::CPUPlace());
math::SetConstant<DeviceContext, int64_t>()(
ctx.template device_context<DeviceContext>(), decoded_path, 0);
for (size_t i = 0; i < seq_num; ++i) {
if (lod[level][i] == lod[level][i + 1]) continue;
int start_pos = static_cast<int>(lod[level][i]);
......@@ -53,13 +79,15 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
Decode(emission_weights->Slice(start_pos, end_pos), *transition_weights,
&decoded_path_one_seq);
}
}
if (label) {
if (!has_length) {
PADDLE_ENFORCE_EQ(label->NumLevels(), 1UL,
"The Input(Label) should be a sequence.");
}
const int64_t* label_value = label->data<int64_t>();
size_t batch_size = emission_weights->dims()[0];
for (size_t i = 0; i < batch_size; ++i) {
size_t numel = label->numel();
for (size_t i = 0; i < numel; ++i) {
path[i] = label_value[i] == path[i] ? 1 : 0;
}
}
......
......@@ -176,5 +176,55 @@ class TestCRFDecodingOp4(TestCRFDecodingOp2):
self.lod = [[0, 2, 3, 0]]
class TestCRFDecodingOp5(OpTest):
"""
Compare the dynamic program with random generated parameters and inputs
with grouth truth not being given.
"""
def seq_pad(self, data, length):
max_len = np.max(length)
shape = [len(length), max_len] + list(data.shape[1:])
padded = np.zeros(shape).astype(data.dtype)
offset = 0
for i, l in enumerate(length):
padded[i, 0:l] = data[offset:offset + l]
offset += l
return np.squeeze(padded)
def set_test_data(self):
SEQ_NUM = 3
TAG_NUM = 17
MAX_SEQ_LEN = 10
lod = [[]]
total_len = 0
for i in range(SEQ_NUM):
lod[-1].append(random.randint(1, MAX_SEQ_LEN))
total_len += lod[-1][-1]
emission = np.random.uniform(-1, 1,
[total_len, TAG_NUM]).astype("float64")
transition = np.random.uniform(-0.5, 0.5,
[TAG_NUM + 2, TAG_NUM]).astype("float64")
self.inputs = {
"Emission": self.seq_pad(emission, lod[0]),
"Transition": transition,
"Length": np.array(lod).astype('int64'),
}
decoder = CRFDecoding(emission, transition, lod[0])
decoded_path = decoder.decode()
self.outputs = {"ViterbiPath": self.seq_pad(decoded_path, lod[0])}
def setUp(self):
self.op_type = "crf_decoding"
self.set_test_data()
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册