From ebd992ec7923d7230bb33efa02e2d3544d514947 Mon Sep 17 00:00:00 2001 From: caoying03 Date: Tue, 31 Oct 2017 23:13:37 +0800 Subject: [PATCH] backpropagate gradients the CRF operator receives. --- paddle/operators/linear_chain_crf_op.h | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/paddle/operators/linear_chain_crf_op.h b/paddle/operators/linear_chain_crf_op.h index 24c8b4052d5..56fb0c9102b 100644 --- a/paddle/operators/linear_chain_crf_op.h +++ b/paddle/operators/linear_chain_crf_op.h @@ -35,6 +35,14 @@ static inline T NormalizeL1(T* x, size_t len) { return sum; } +template +struct ScalarMul { + explicit ScalarMul(const T& scalar) : scalar(scalar) {} + T operator()(const T& val) const { return val * scalar; } + + T scalar; +}; + using framework::LoDTensor; using framework::LoD; using framework::Tensor; @@ -349,8 +357,6 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel { // data reader operator, it can have no gradients. PADDLE_ENFORCE(emission_grad, "Output(Emission@Grad) should not be null."); emission_grad->mutable_data(platform::CPUPlace()); - math::SetConstant()(ctx.device_context(), - emission_grad, 0.); if (transition_grad) { transition_grad->mutable_data(platform::CPUPlace()); math::SetConstant()(ctx.device_context(), @@ -480,15 +486,18 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel { auto row_sum = prob.sum(Eigen::DSizes(1)) .reshape(Eigen::DSizes(seq_length, 1)) .broadcast(Eigen::DSizes(1, tag_num)); - x_grad_mat.device(*place) = prob / row_sum; + x_grad_mat.device(*place) = + (prob / row_sum).unaryExpr(ScalarMul(ll_grad)); for (size_t k = 0; k < seq_length; ++k) { - x_grad_mat(k, label_value[k]) -= static_cast(1.); + x_grad_mat(k, label_value[k]) -= static_cast(ll_grad); } if (transition_grad) { T* trans_grad = transition_grad->data(); for (size_t k = 0; k < tag_num; ++k) { + // Do not multiply by the output gradient here, because x_grad_mat has + // alrealy done this. trans_grad[k] += x_grad_mat(/*from start state*/ 0, k); trans_grad[tag_num + k] += x_grad_mat(/*to end state*/ seq_length - 1, k); @@ -496,8 +505,8 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel { auto x_exps_mat = EigenMatrix::From(emission_exps); - // TODO(caoying): Fix this to avoid using this local variable if when can - // profiling the training process. + // TODO(caoying): Fix this to avoid using this local variable if we can + // profile the training process. Tensor tmp; tmp.mutable_data(beta->dims(), platform::CPUPlace()); auto tmp_mat = EigenMatrix::From(tmp); @@ -520,11 +529,11 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel { for (size_t j = 0; j < tag_num; ++j) { trans_grad[(i + state_trans_base_idx) * tag_num + j] += sum * w_exps[(i + state_trans_base_idx) * tag_num + j] * - alpha_mat(k - 1, i) * tmp_mat(k, j); + alpha_mat(k - 1, i) * tmp_mat(k, j) * ll_grad; } } trans_grad[(label_value[k - 1] + state_trans_base_idx) * tag_num + - label_value[k]] -= static_cast(1.); + label_value[k]] -= static_cast(ll_grad); } } } -- GitLab