未验证 提交 4a7e6deb 编写于 作者: J JesseyXujin 提交者: GitHub

add padding in linear_chain_crf op (#19583)

* add padding in linear_chain_crf op

* modify API.spec

* add linear_chain_crf_op.cc and linear_chain_crf_op.h

* remove useless unit test , test=develop

* modify API.spec, test=develop

* remove some blanks in nn.py , test=develop

* fix some bugs on nn.py and API.spec ,test=develop

* fix nn.py, test=develop

* fix API.spec ,test=develop

* fix bug of CI test in test_linear_chain_crf_op.py

* fix bug of CI test in test_linear_chain_crf_op.py, test=develop

* remove paddle_enforce, test=develop

* remove paddle_enforce, test=develop

* remove paddle_enforce, test=develop

* remove paddle_enforce, test=develop

* remove paddle_enforce, test=develop

* remove paddle_enforce, test=develop

* modify nn.py, test=develop

* fix API.spec, test=develop

* fix unittest bug, test=develop
上级 19474019
......@@ -115,7 +115,7 @@ paddle.fluid.layers.dynamic_lstm (ArgSpec(args=['input', 'size', 'h_0', 'c_0', '
paddle.fluid.layers.dynamic_lstmp (ArgSpec(args=['input', 'size', 'proj_size', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'proj_activation', 'dtype', 'name', 'h_0', 'c_0', 'cell_clip', 'proj_clip'], varargs=None, keywords=None, defaults=(None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'tanh', 'float32', None, None, None, None, None)), ('document', 'c37d51aad655c8a9f9b045c64717320a'))
paddle.fluid.layers.dynamic_gru (ArgSpec(args=['input', 'size', 'param_attr', 'bias_attr', 'is_reverse', 'gate_activation', 'candidate_activation', 'h_0', 'origin_mode'], varargs=None, keywords=None, defaults=(None, None, False, 'sigmoid', 'tanh', None, False)), ('document', '83617c165827e030636c80486d5de6f3'))
paddle.fluid.layers.gru_unit (ArgSpec(args=['input', 'hidden', 'size', 'param_attr', 'bias_attr', 'activation', 'gate_activation', 'origin_mode'], varargs=None, keywords=None, defaults=(None, None, 'tanh', 'sigmoid', False)), ('document', '33974b9bfa69f2f1eb85e6f956dff04e'))
paddle.fluid.layers.linear_chain_crf (ArgSpec(args=['input', 'label', 'param_attr'], varargs=None, keywords=None, defaults=(None,)), ('document', '34f96be41684b0959897a9e735997e20'))
paddle.fluid.layers.linear_chain_crf (ArgSpec(args=['input', 'label', 'param_attr', 'length'], varargs=None, keywords=None, defaults=(None, None)), ('document', '715f8f12d68ae90504a7b768e82be6f4'))
paddle.fluid.layers.crf_decoding (ArgSpec(args=['input', 'param_attr', 'label'], varargs=None, keywords=None, defaults=(None,)), ('document', '5ce117258e243be1c81539e254178d90'))
paddle.fluid.layers.cos_sim (ArgSpec(args=['X', 'Y'], varargs=None, keywords=None, defaults=None), ('document', '8e6ce424cf9e261ef32ee229c06a6e66'))
paddle.fluid.layers.cross_entropy (ArgSpec(args=['input', 'label', 'soft_label', 'ignore_index'], varargs=None, keywords=None, defaults=(False, -100)), ('document', 'f43c659ca1749a3f0ff2231e6dfda07d'))
......
......@@ -23,21 +23,28 @@ class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Emission",
"(LoDTensor, default LoDTensor<float>) "
"A 2-D LoDTensor with shape [N x D], where N is the size of the "
"(LoDTensor/Tensor<float>). When a LoDTensor input,A 2-D LoDTensor"
" with shape [N x D], where N is the size of the "
"mini-batch and D is the total tag number. The unscaled emission "
"weight matrix for the linear chain CRF. ");
"weight matrix for the linear chain CRF. When a Tensor input,"
"A Tensor with shape [N x S x D], where N is batch number,"
"S is max length of sequences, D is the total tag number.");
AddInput("Transition",
"(Tensor, default Tensor<float>) A 2-D Tensor with shape "
"[(D + 2) x D]. The learnable parameter for the linear_chain_crf "
"operator. See more details in the operator's comments.");
AddInput("Label",
"(LoDTensor, default LoDTensor<int64_t>) A LoDTensor with shape "
"(LoDTensor/Tensor<int64_t>), when a LoDTensor input, "
"[N x 1], where N is the total element number in a mini-batch. "
"The ground truth.");
"when a Tensor input, [N x S], where N is batch number. "
"S is max length of sequences. The ground truth.");
AddInput("length",
"(Tensor, default Tensor<int64_t>) A Tensor with shape "
"[M x 1], where M is the sequence number in a mini-batch.")
.AsDispensable();
AddOutput(
"Alpha",
"(Tensor, default Tensor<float>) A 2-D Tensor with shape [N x D]. "
"(Tensor, default Tensor<float>), the same shape with Emission. "
"The forward vectors for the entire batch. Denote it as $\alpha$. "
"$\alpha$ is a memo table used to calculate the normalization "
"factor in CRF. $\alpha[k, v]$ stores the unnormalized "
......@@ -49,7 +56,7 @@ class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker {
.AsIntermediate();
AddOutput(
"EmissionExps",
"(Tensor, default Tensor<float>) A 2-D Tensor with shape [N x D]. "
"(Tensor, default Tensor<float>), the same shape with Emission. "
"The exponentials of Input(Emission). This is an intermediate "
"computational result in forward computation, and will be reused in "
"backward computation.")
......@@ -145,11 +152,6 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("LogLikelihood"),
"Output(LogLikelihood) should be not null.");
auto emission_dims = ctx->GetInputDim("Emission");
PADDLE_ENFORCE_EQ(emission_dims.size(), 2,
"The Input(Emission) should be a 2-D tensor.");
PADDLE_ENFORCE(emission_dims[0], "An empty mini-batch is not allowed.");
auto transition_dims = ctx->GetInputDim("Transition");
PADDLE_ENFORCE_EQ(transition_dims.size(), 2,
"The Input(Transition) should be a 2-D tensor.");
......@@ -164,20 +166,40 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
"An invalid dimension for the Input(Transition), which should "
"be a 2-D tensor with shape [(D + 2) x D].");
}
PADDLE_INFERSHAPE_ENFORCE_EQ(
ctx, emission_dims[1], transition_dims[1],
"The 2nd dimension of the Input(Emission) and the Input(Transition) "
"should be equal to the tag number.");
auto label_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL,
"The Input(Label) should be a 2-D tensor with the 2nd "
"dimensions fixed to 1.");
PADDLE_INFERSHAPE_ENFORCE_EQ(
ctx, emission_dims[0], label_dims[0],
"The height of Input(Emission) and the height of Input(Label) "
"should be the same.");
auto emission_dims = ctx->GetInputDim("Emission");
PADDLE_ENFORCE_NE(emission_dims[0], 0,
"An empty mini-batch is not allowed.");
if (ctx->HasInput("length")) {
PADDLE_ENFORCE_EQ(emission_dims.size(), 3,
"The Input(Emission) should be a 3-D tensor.");
auto label_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(label_dims.size(), 3,
"The Input(Label) should be a 3-D tensor");
PADDLE_INFERSHAPE_ENFORCE_EQ(
ctx, emission_dims[0], label_dims[0],
"The batch size of Input(Emission) and Input(Label) "
"should be the same.");
PADDLE_INFERSHAPE_ENFORCE_EQ(
ctx, emission_dims[1], label_dims[1],
"The max length of Input(Emission) and Input(Label) "
"should be the same.");
} else {
PADDLE_ENFORCE_EQ(emission_dims.size(), 2,
"The Input(Emission) should be a 2-D tensor.");
PADDLE_INFERSHAPE_ENFORCE_EQ(
ctx, emission_dims[1], transition_dims[1],
"The 2nd dimension of the Input(Emission) and the Input(Transition) "
"should be equal to the tag number.");
auto label_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(label_dims.size(), 2,
"The Input(Label) should be a 2-D tensor with the 2nd "
"dimensions fixed to 1.");
PADDLE_INFERSHAPE_ENFORCE_EQ(
ctx, emission_dims[0], label_dims[0],
"The height of Input(Emission) and the height of Input(Label) "
"should be the same.");
}
ctx->SetOutputDim("Alpha", emission_dims);
ctx->SetOutputDim("EmissionExps", emission_dims);
ctx->SetOutputDim("TransitionExps", transition_dims);
......@@ -210,12 +232,6 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("LogLikelihood")),
"Input(LogLikelihood@GRAD) shoudl be not null.");
auto emission_exps_dims = ctx->GetInputDim("EmissionExps");
PADDLE_ENFORCE_EQ(emission_exps_dims.size(), 2,
"The Input(EmissionExps) should be a 2-D tensor.");
PADDLE_ENFORCE(emission_exps_dims[0],
"An empty mini-batch is not allowed.");
auto transition_exps_dims = ctx->GetInputDim("TransitionExps");
PADDLE_ENFORCE_EQ(transition_exps_dims.size(), 2,
"The Input(TransitionExps) should be a 2-D tensor.");
......@@ -230,15 +246,34 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
"An invalid dimension for the Input(TransitionExps), which should "
"be a 2-D tensor with shape [(D + 2) x D].");
}
PADDLE_INFERSHAPE_ENFORCE_EQ(
ctx, emission_exps_dims[1], transition_exps_dims[1],
"The 2nd dimension of the Input(EmissionExps) and the "
"Input(TransitionExps) should be equal to the tag number.");
auto emission_exps_dims = ctx->GetInputDim("EmissionExps");
auto label_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL,
"The Input(Label) should be a 2-D tensor with the 2nd "
"dimensions fixed to 1.");
if (ctx->HasInput("length")) {
PADDLE_ENFORCE_EQ(emission_exps_dims.size(), 3,
"The Input(EmissionExps) should be a 3-D tensor.");
PADDLE_INFERSHAPE_ENFORCE_EQ(
ctx, emission_exps_dims[2], transition_exps_dims[1],
"The 3nd dimension of the Input(EmissionExps) and the "
"Input(TransitionExps) should be equal to the tag number.");
PADDLE_ENFORCE_EQ(label_dims.size(), 3,
"The Input(Label) should be a 3-D tensor with the 3nd "
"dimensions fixed to 1.");
} else {
PADDLE_ENFORCE_EQ(emission_exps_dims.size(), 2,
"The Input(EmissionExps) should be a 2-D tensor.");
PADDLE_INFERSHAPE_ENFORCE_EQ(
ctx, emission_exps_dims[1], transition_exps_dims[1],
"The 2nd dimension of the Input(EmissionExps) and the "
"Input(TransitionExps) should be equal to the tag number.");
PADDLE_ENFORCE_EQ(label_dims.size(), 2,
"The Input(Label) should be a 2-D tensor");
PADDLE_ENFORCE_EQ(label_dims[1], 1,
"The Input(Label) 2nd dimensions fixed to 1.");
}
PADDLE_ENFORCE_NE(emission_exps_dims[0], 0,
"An empty mini-batch is not allowed.");
PADDLE_INFERSHAPE_ENFORCE_EQ(
ctx, emission_exps_dims[0], label_dims[0],
"The height of Input(EmissionExps) and the height of Input(Label) "
......@@ -246,8 +281,12 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
if (ctx->HasOutput(framework::GradVarName("Emission"))) {
ctx->SetOutputDim(framework::GradVarName("Emission"), emission_exps_dims);
ctx->ShareLoD("Emission", framework::GradVarName("Emission"));
if (ctx->HasInput("length") == false) {
ctx->ShareLoD("Emission", framework::GradVarName("Emission"));
}
}
// ctx->SetOutputDim(framework::GradVarName("Emission"),
// emission_exps_dims);
if (ctx->HasOutput(framework::GradVarName("Transition"))) {
ctx->SetOutputDim(framework::GradVarName("Transition"),
transition_exps_dims);
......@@ -275,15 +314,15 @@ class LinearChainCRFGradDescMaker : public framework::SingleGradOpDescMaker {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("linear_chain_crf_grad");
op->SetAttrMap(Attrs());
op->SetInput("Emission", Input("Emission"));
op->SetInput("Transition", Input("Transition"));
op->SetInput("Label", Input("Label"));
op->SetInput("Alpha", Output("Alpha"));
op->SetInput("EmissionExps", Output("EmissionExps"));
op->SetInput("TransitionExps", Output("TransitionExps"));
if (ForwardOp().Inputs().count("length") > 0) {
op->SetInput("length", Input("length"));
}
op->SetInput(framework::GradVarName("LogLikelihood"),
OutputGrad("LogLikelihood"));
......
......@@ -54,20 +54,9 @@ template <typename DeviceContext, typename T>
class LinearChainCRFOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
// TODO(caoying) The checks related to LoD information should be
// moved into InferShape once after the InferShape is refactored.
PADDLE_ENFORCE_EQ(ctx.Input<LoDTensor>("Emission")->NumLevels(), 1UL,
"The Input(Emission) should be a sequence.");
PADDLE_ENFORCE_EQ(ctx.Input<LoDTensor>("Label")->NumLevels(), 1UL,
"The Input(Label) should be a sequence.");
auto in_lod = ctx.Input<LoDTensor>("Label")->lod();
PADDLE_ENFORCE(in_lod.size(), "Input(Label) must be a sequence.");
const size_t level = 0;
const size_t seq_num = in_lod[level].size() - 1;
const LoDTensor* emission_weights = ctx.Input<LoDTensor>("Emission");
const Tensor* transition_weights = ctx.Input<Tensor>("Transition");
const LoDTensor* label = ctx.Input<LoDTensor>("Label");
const Tensor* emission_weights = ctx.Input<framework::Tensor>("Emission");
const Tensor* transition_weights =
ctx.Input<framework::Tensor>("Transition");
Tensor* emission_exps = ctx.Output<Tensor>("EmissionExps");
Tensor* transition_exps = ctx.Output<Tensor>("TransitionExps");
......@@ -76,56 +65,103 @@ class LinearChainCRFOpKernel : public framework::OpKernel<T> {
// Because the computation codes only runs on CPU, here the memory for all
// the outputs is FIXED to be allocated on the CPU memory.
emission_exps->mutable_data<T>(platform::CPUPlace());
auto* emission_exps_data =
emission_exps->mutable_data<T>(platform::CPUPlace());
auto* alpha_data = alpha->mutable_data<T>(platform::CPUPlace());
transition_exps->mutable_data<T>(platform::CPUPlace());
alpha->mutable_data<T>(platform::CPUPlace());
// Resize the output tensor to its correct dimension.
memset(emission_exps_data, 0, emission_exps->numel() * sizeof(T));
memset(alpha_data, 0, alpha->numel() * sizeof(T));
auto emission_dims = emission_weights->dims();
const Tensor* label = ctx.Input<framework::Tensor>("Label");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
Tensor emission_weights_tmp = ctx.AllocateTmpTensor<T, DeviceContext>(
emission_weights->dims(), dev_ctx);
emission_weights_tmp.ShareDataWith(*emission_weights);
Tensor label_tmp =
ctx.AllocateTmpTensor<T, DeviceContext>(label->dims(), dev_ctx);
label_tmp.ShareDataWith(*label);
Tensor emission_exps_tmp =
ctx.AllocateTmpTensor<T, DeviceContext>(emission_exps->dims(), dev_ctx);
emission_exps_tmp.ShareDataWith(*emission_exps);
Tensor alpha_tmp =
ctx.AllocateTmpTensor<T, DeviceContext>(alpha->dims(), dev_ctx);
alpha_tmp.ShareDataWith(*alpha);
size_t seq_num = 0;
size_t batch_size;
size_t tag_num;
const int64_t* length_data;
framework::Vector<size_t> in_lod;
if (ctx.HasInput("length")) {
const Tensor* label_length = ctx.Input<framework::Tensor>("length");
length_data = label_length->data<int64_t>();
seq_num = label_length->numel();
batch_size = emission_dims[0] * emission_dims[1];
tag_num = emission_dims[2];
emission_weights_tmp.Resize(
{emission_dims[0] * emission_dims[1], emission_dims[2]});
auto label_dims = label->dims();
label_tmp.Resize({label_dims[0] * label_dims[1], label_dims[2]});
alpha_tmp.Resize({emission_dims[0] * emission_dims[1], emission_dims[2]});
emission_exps_tmp.Resize(
{emission_dims[0] * emission_dims[1], emission_dims[2]});
PADDLE_ENFORCE_EQ(seq_num, emission_dims[0],
"the size of Input(length) must be equal to "
"emission_dims[0].");
PADDLE_ENFORCE_EQ(seq_num, label_dims[0],
"the size of Input(length) must be equal to "
"label_dims[0].");
} else {
seq_num = ctx.Input<LoDTensor>("Label")->lod()[0].size() - 1;
batch_size = emission_dims[0];
tag_num = emission_dims[1];
in_lod = ctx.Input<LoDTensor>("Label")->lod()[0];
PADDLE_ENFORCE_NE(in_lod.size(), 0, "Input(Label) must be a sequence.");
}
ll->Resize({static_cast<int>(seq_num), 1});
ll->mutable_data<T>(platform::CPUPlace());
// Now, all the inputs and outputs should be on the CPU memory.
auto emission_dims = emission_weights->dims();
const size_t batch_size = emission_dims[0];
const size_t tag_num = emission_dims[1];
Tensor emission_row_max;
emission_row_max.mutable_data<T>(
framework::make_ddim({static_cast<int64_t>(batch_size), 1}),
platform::CPUPlace());
auto& place = *ctx.template device_context<platform::CPUDeviceContext>()
.eigen_device();
auto x = EigenMatrix<T>::From(*emission_weights);
auto x = EigenMatrix<T>::From(emission_weights_tmp);
auto x_row_max = EigenMatrix<T>::From(emission_row_max);
x_row_max.device(place) =
x.maximum(Eigen::DSizes<int, 1>(1))
.reshape(Eigen::DSizes<int, 2>(static_cast<int>(batch_size), 1));
auto x_exps = EigenMatrix<T>::From(*emission_exps);
auto x_exps = EigenMatrix<T>::From(emission_exps_tmp);
x_exps.device(place) =
(x - x_row_max.broadcast(Eigen::DSizes<int, 2>(1, tag_num))).exp();
auto w = EigenMatrix<T>::From(*transition_weights);
auto w_exps = EigenMatrix<T>::From(*transition_exps);
w_exps.device(place) = w.exp();
T* log_likelihood = ll->data<T>();
for (size_t i = 0; i < seq_num; ++i) {
int start_pos = static_cast<int>(in_lod[level][i]);
int end_pos = static_cast<int>(in_lod[level][i + 1]);
int start_pos = 0;
int end_pos = 0;
if (ctx.HasInput("length")) {
if (length_data[i] == 0) continue;
start_pos = i * emission_dims[1];
end_pos = start_pos + static_cast<int>(length_data[i]);
} else {
start_pos = static_cast<int>(in_lod[i]);
end_pos = static_cast<int>(in_lod[i + 1]);
}
if (end_pos == start_pos) {
// If an empty input sequence is given, pad 0 for its cost.
log_likelihood[i] = 0.;
continue;
}
const Tensor one_seq = emission_weights->Slice(start_pos, end_pos);
const Tensor one_seq = emission_weights_tmp.Slice(start_pos, end_pos);
Tensor one_seq_row_max = emission_row_max.Slice(start_pos, end_pos);
Tensor one_seq_exps = emission_exps->Slice(start_pos, end_pos);
const Tensor one_seq_label = label->Slice(start_pos, end_pos);
Tensor one_seq_alpha = alpha->Slice(start_pos, end_pos);
Tensor one_seq_exps = emission_exps_tmp.Slice(start_pos, end_pos);
const Tensor one_seq_label = label_tmp.Slice(start_pos, end_pos);
Tensor one_seq_alpha = alpha_tmp.Slice(start_pos, end_pos);
log_likelihood[i] = ForwardOneSequence(
one_seq, one_seq_row_max, one_seq_exps, *transition_weights,
*transition_exps, one_seq_label, &one_seq_alpha);
......@@ -197,52 +233,91 @@ template <typename DeviceContext, typename T>
class LinearChainCRFGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const size_t level = 0; // currently, only support sequence.
auto lod = ctx.Input<LoDTensor>("Label")->lod();
PADDLE_ENFORCE(lod.size(), "Input(Label) must be a sequence.");
const Tensor* label = ctx.Input<LoDTensor>("Label");
const Tensor* label = ctx.Input<Tensor>("Label");
const Tensor* emission_exps = ctx.Input<Tensor>("EmissionExps");
const Tensor* transition_exps = ctx.Input<Tensor>("TransitionExps");
const Tensor* alpha = ctx.Input<Tensor>("Alpha");
const T* ll_grad =
ctx.Input<Tensor>(framework::GradVarName("LogLikelihood"))->data<T>();
auto& dev_ctx = ctx.template device_context<DeviceContext>();
Tensor* emission_grad =
ctx.Output<Tensor>(framework::GradVarName("Emission"));
auto* emission_grad_data =
emission_grad->mutable_data<T>(platform::CPUPlace());
memset(emission_grad_data, 0, emission_grad->numel() * sizeof(T));
Tensor alpha_tmp =
ctx.AllocateTmpTensor<T, DeviceContext>(alpha->dims(), dev_ctx);
alpha_tmp.ShareDataWith(*alpha);
Tensor label_tmp =
ctx.AllocateTmpTensor<T, DeviceContext>(label->dims(), dev_ctx);
label_tmp.ShareDataWith(*label);
Tensor emission_exps_tmp =
ctx.AllocateTmpTensor<T, DeviceContext>(emission_exps->dims(), dev_ctx);
emission_exps_tmp.ShareDataWith(*emission_exps);
Tensor emission_grad_tmp =
ctx.AllocateTmpTensor<T, DeviceContext>(emission_grad->dims(), dev_ctx);
emission_grad_tmp.ShareDataWith(*emission_grad);
// getting seq_num using padding or not
size_t seq_num = 0;
framework::Vector<size_t> lod;
const int64_t* length_data;
if (ctx.HasInput("length")) {
const Tensor* label_length = ctx.Input<framework::Tensor>("length");
length_data = label_length->data<int64_t>();
seq_num = label_length->numel();
auto emission_dims = emission_grad->dims();
auto label_dims = label->dims();
emission_grad_tmp.Resize(
{emission_dims[0] * emission_dims[1], emission_dims[2]});
label_tmp.Resize({label_dims[0] * label_dims[1], label_dims[2]});
alpha_tmp.Resize({emission_dims[0] * emission_dims[1], emission_dims[2]});
emission_exps_tmp.Resize(
{emission_dims[0] * emission_dims[1], emission_dims[2]});
} else {
seq_num = ctx.Input<LoDTensor>("Label")->lod()[0].size() - 1;
lod = ctx.Input<LoDTensor>("Label")->lod()[0];
PADDLE_ENFORCE_NE(lod.size(), 0, "Input(Label) must be a sequence.");
}
Tensor* transition_grad =
ctx.Output<Tensor>(framework::GradVarName("Transition"));
// TODO(caoying) Fix this constraint. When the Input(Emission) is from the
// data reader operator, it can have no gradients.
PADDLE_ENFORCE(emission_grad, "Output(Emission@Grad) should not be null.");
emission_grad->mutable_data<T>(platform::CPUPlace());
if (transition_grad) {
transition_grad->mutable_data<T>(platform::CPUPlace());
math::set_constant(ctx.device_context(), transition_grad, 0.);
}
// Now, all the inputs and outputs should be on the CPU memory.
auto emission_dims = emission_exps->dims();
// Beta is the memo table used in dynamic programming to calculate the
// backwark vectors. For a backward vector i (the i-th row of beta), it
// captures the unnormalized probabilities of partial sequences starting
// at position i.
Tensor beta;
beta.mutable_data<T>(emission_dims, platform::CPUPlace());
for (size_t i = 0; i < lod[level].size() - 1; ++i) {
int start_pos = static_cast<int>(lod[level][i]);
int end_pos = static_cast<int>(lod[level][i + 1]);
if (end_pos == start_pos) continue;
auto* beta_data = beta.mutable_data<T>(emission_dims, platform::CPUPlace());
memset(beta_data, 0, beta.numel() * sizeof(T));
if (ctx.HasInput("length")) {
beta.Resize({emission_dims[0] * emission_dims[1], emission_dims[2]});
}
for (size_t i = 0; i < seq_num; ++i) {
int start_pos = 0;
int end_pos = 0;
if (ctx.HasInput("length")) {
if (length_data[i] == 0) continue;
start_pos = i * emission_dims[1];
end_pos = start_pos + static_cast<int>(length_data[i]);
} else {
start_pos = static_cast<int>(lod[i]);
end_pos = static_cast<int>(lod[i + 1]);
}
const Tensor one_seq_emission_exps =
emission_exps->Slice(start_pos, end_pos);
const Tensor one_seq_label = label->Slice(start_pos, end_pos);
const Tensor one_seq_alpha = alpha->Slice(start_pos, end_pos);
emission_exps_tmp.Slice(start_pos, end_pos);
const Tensor one_seq_label = label_tmp.Slice(start_pos, end_pos);
const Tensor one_seq_alpha = alpha_tmp.Slice(start_pos, end_pos);
Tensor one_seq_beta = beta.Slice(start_pos, end_pos);
Tensor one_seq_emission_grad = emission_grad->Slice(start_pos, end_pos);
Tensor one_seq_emission_grad =
emission_grad_tmp.Slice(start_pos, end_pos);
BackwardOneSequence(
ctx.template device_context<platform::CPUDeviceContext>(), ll_grad[i],
one_seq_emission_exps, *transition_exps, one_seq_alpha, one_seq_label,
......@@ -261,7 +336,6 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel<T> {
const T* x_exps = emission_exps.data<T>();
const int64_t* label_value = label.data<int64_t>();
T* beta_value = beta->data<T>();
auto x_dims = emission_exps.dims();
const size_t seq_length = x_dims[0];
const size_t tag_num = x_dims[1];
......
......@@ -1404,7 +1404,7 @@ def gru_unit(input,
@templatedoc()
def linear_chain_crf(input, label, param_attr=None):
def linear_chain_crf(input, label, param_attr=None, length=None):
"""
Linear Chain CRF.
......@@ -1414,6 +1414,7 @@ def linear_chain_crf(input, label, param_attr=None):
input(${emission_type}): ${emission_comment}
input(${transition_type}): ${transition_comment}
label(${label_type}): ${label_comment}
Length(${length_type}): ${length_comment}
param_attr(ParamAttr): The attribute of the learnable parameter.
Returns:
......@@ -1424,16 +1425,62 @@ def linear_chain_crf(input, label, param_attr=None):
Examples:
.. code-block:: python
import paddle.fluid as fluid
emission = fluid.layers.data(name='emission', shape=[1000], dtype='float32')
target = fluid.layers.data(name='target', shape=[1], dtype='int32')
crf_cost = fluid.layers.linear_chain_crf(
input=emission,
label=target,
param_attr=fluid.ParamAttr(
import paddle.fluid as fluid
import numpy as np
#define net structure, using LodTensor
train_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
input_data = fluid.layers.data(name='input_data', shape=[10], dtype='float32', lod_level=1)
label = fluid.layers.data(name='label', shape=[1], dtype='int', lod_level=1)
emission= fluid.layers.fc(input=input_data, size=10, act="tanh")
crf_cost = fluid.layers.linear_chain_crf(
input=emission,
label=label,
param_attr=fluid.ParamAttr(
name='crfw',
learning_rate=0.01))
use_cuda = False
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_program)
#define data, using LoDTensor
a = fluid.create_lod_tensor(np.random.rand(12,10).astype('float32'), [[3,3,4,2]], place)
b = fluid.create_lod_tensor(np.array([[1],[1],[2],[3],[1],[1],[1],[3],[1],[1],[1],[1]]),[[3,3,4,2]] , place)
feed1 = {'input_data':a,'label':b}
loss= exe.run(train_program,feed=feed1, fetch_list=[crf_cost])
print(loss)
#define net structure, using padding
train_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
input_data2 = fluid.layers.data(name='input_data2', shape=[10,10], dtype='float32')
label2 = fluid.layers.data(name='label2', shape=[10,1], dtype='int')
label_length = fluid.layers.data(name='length', shape=[1], dtype='int')
emission2= fluid.layers.fc(input=input_data2, size=10, act="tanh", num_flatten_dims=2)
crf_cost2 = fluid.layers.linear_chain_crf(
input=emission2,
label=label2,
length=label_length,
param_attr=fluid.ParamAttr(
name='crfw',
learning_rate=0.2))
learning_rate=0.01))
use_cuda = False
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_program)
#define data, using padding
cc=np.random.rand(4,10,10).astype('float32')
dd=np.random.rand(4,10,1).astype('int64')
ll=np.array([[3,3,4,2]])
feed2 = {'input_data2':cc,'label2':dd,'length':ll}
loss2= exe.run(train_program,feed=feed2, fetch_list=[crf_cost2])
print(loss2)
"""
helper = LayerHelper('linear_chain_crf', **locals())
size = input.shape[1]
......@@ -1449,11 +1496,16 @@ def linear_chain_crf(input, label, param_attr=None):
dtype=helper.input_dtype())
log_likelihood = helper.create_variable_for_type_inference(
dtype=helper.input_dtype())
this_inputs = {
"Emission": [input],
"Transition": transition,
"Label": [label]
}
if length:
this_inputs['length'] = [length]
helper.append_op(
type='linear_chain_crf',
inputs={"Emission": [input],
"Transition": transition,
"Label": label},
inputs=this_inputs,
outputs={
"Alpha": [alpha],
"EmissionExps": [emission_exps],
......
......@@ -111,7 +111,7 @@ class TestLinearChainCrfOp(OpTest):
lod = [[]]
seq_start_pos = [0]
for i in range(SEQ_NUM):
lod[-1].append(random.randint(0, MAX_SEQ_LEN))
lod[-1].append(random.randint(1, MAX_SEQ_LEN))
seq_start_pos.append(seq_start_pos[-1] + lod[-1][-1])
emission = np.random.uniform(
-1, 1, [seq_start_pos[-1], TAG_NUM]).astype("float64")
......@@ -157,5 +157,81 @@ class TestLinearChainCrfOp(OpTest):
["Emission"], "LogLikelihood", no_grad_set=set("Transition"))
class TestLinearChainCrfPaddingTensor(OpTest):
def seq_pad(self, data, length):
max_len = np.max(length)
shape = [len(length), max_len] + list(data.shape[1:])
padded = np.zeros(shape).astype(data.dtype)
offset = 0
for i, l in enumerate(length):
padded[i, 0:l] = data[offset:offset + l]
offset += l
return padded
def seq_pad_exps(self, data, length):
# Adding for transition_exps
max_len = np.max(length)
shape = [len(length), max_len] + list(data.shape[1:])
padded = np.ones(shape).astype(data.dtype)
offset = 0
for i, l in enumerate(length):
padded[i, 0:l] = data[offset:offset + l]
offset += l
return padded
def set_test_data_1(self):
# Fix the unittest by: add padding tensor in inputs
SEQ_NUM = 3
TAG_NUM = 17
MAX_SEQ_LEN = 5
# the linear_chain_crf operator only supports sequence (LoD level = 1)
lod = [[]]
seq_start_pos = [0]
for i in range(SEQ_NUM):
lod[-1].append(random.randint(1, MAX_SEQ_LEN))
seq_start_pos.append(seq_start_pos[-1] + lod[-1][-1])
emission = np.random.uniform(
-1, 1, [seq_start_pos[-1], TAG_NUM]).astype("float64")
emission_row_max = np.amax(emission, axis=1, keepdims=True)
emission_exps = np.exp(emission - emission_row_max)
transition = np.random.uniform(-0.5, 0.5,
[TAG_NUM + 2, TAG_NUM]).astype("float64")
transition_exps = np.exp(transition)
labels = np.random.randint(
low=0, high=TAG_NUM, size=(seq_start_pos[-1], 1), dtype="int64")
self.inputs = {
"Emission": self.seq_pad(emission, lod[0]),
"Transition": transition,
"Label": self.seq_pad(labels, lod[0]),
"length": np.array(lod).astype("int64")
}
crf = LinearChainCrfForward(seq_start_pos, emission, emission_row_max,
emission_exps, transition, transition_exps,
labels)
alpha, log_likelihood = crf.crf_forward_compute()
self.outputs = {
"Alpha": self.seq_pad(alpha, lod[0]),
"EmissionExps": self.seq_pad_exps(emission_exps, lod[0]),
"TransitionExps": transition_exps,
"LogLikelihood": log_likelihood
}
def setUp(self):
self.op_type = "linear_chain_crf"
self.set_test_data_1()
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["Emission", "Transition"], "LogLikelihood")
def test_check_grad_ignore_transition(self):
self.check_grad(
["Emission"], "LogLikelihood", no_grad_set=set("Transition"))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册