提交 ebd992ec 编写于 作者: C caoying03

backpropagate gradients the CRF operator receives.

上级 2ac9a3d8
...@@ -35,6 +35,14 @@ static inline T NormalizeL1(T* x, size_t len) { ...@@ -35,6 +35,14 @@ static inline T NormalizeL1(T* x, size_t len) {
return sum; return sum;
} }
template <typename T>
struct ScalarMul {
explicit ScalarMul(const T& scalar) : scalar(scalar) {}
T operator()(const T& val) const { return val * scalar; }
T scalar;
};
using framework::LoDTensor; using framework::LoDTensor;
using framework::LoD; using framework::LoD;
using framework::Tensor; using framework::Tensor;
...@@ -349,8 +357,6 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel<T> { ...@@ -349,8 +357,6 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel<T> {
// data reader operator, it can have no gradients. // data reader operator, it can have no gradients.
PADDLE_ENFORCE(emission_grad, "Output(Emission@Grad) should not be null."); PADDLE_ENFORCE(emission_grad, "Output(Emission@Grad) should not be null.");
emission_grad->mutable_data<T>(platform::CPUPlace()); emission_grad->mutable_data<T>(platform::CPUPlace());
math::SetConstant<platform::CPUPlace, T>()(ctx.device_context(),
emission_grad, 0.);
if (transition_grad) { if (transition_grad) {
transition_grad->mutable_data<T>(platform::CPUPlace()); transition_grad->mutable_data<T>(platform::CPUPlace());
math::SetConstant<platform::CPUPlace, T>()(ctx.device_context(), math::SetConstant<platform::CPUPlace, T>()(ctx.device_context(),
...@@ -480,15 +486,18 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel<T> { ...@@ -480,15 +486,18 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel<T> {
auto row_sum = prob.sum(Eigen::DSizes<int, 1>(1)) auto row_sum = prob.sum(Eigen::DSizes<int, 1>(1))
.reshape(Eigen::DSizes<int, 2>(seq_length, 1)) .reshape(Eigen::DSizes<int, 2>(seq_length, 1))
.broadcast(Eigen::DSizes<int, 2>(1, tag_num)); .broadcast(Eigen::DSizes<int, 2>(1, tag_num));
x_grad_mat.device(*place) = prob / row_sum; x_grad_mat.device(*place) =
(prob / row_sum).unaryExpr(ScalarMul<T>(ll_grad));
for (size_t k = 0; k < seq_length; ++k) { for (size_t k = 0; k < seq_length; ++k) {
x_grad_mat(k, label_value[k]) -= static_cast<T>(1.); x_grad_mat(k, label_value[k]) -= static_cast<T>(ll_grad);
} }
if (transition_grad) { if (transition_grad) {
T* trans_grad = transition_grad->data<T>(); T* trans_grad = transition_grad->data<T>();
for (size_t k = 0; k < tag_num; ++k) { for (size_t k = 0; k < tag_num; ++k) {
// Do not multiply by the output gradient here, because x_grad_mat has
// alrealy done this.
trans_grad[k] += x_grad_mat(/*from start state*/ 0, k); trans_grad[k] += x_grad_mat(/*from start state*/ 0, k);
trans_grad[tag_num + k] += trans_grad[tag_num + k] +=
x_grad_mat(/*to end state*/ seq_length - 1, k); x_grad_mat(/*to end state*/ seq_length - 1, k);
...@@ -496,8 +505,8 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel<T> { ...@@ -496,8 +505,8 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel<T> {
auto x_exps_mat = EigenMatrix<T>::From(emission_exps); auto x_exps_mat = EigenMatrix<T>::From(emission_exps);
// TODO(caoying): Fix this to avoid using this local variable if when can // TODO(caoying): Fix this to avoid using this local variable if we can
// profiling the training process. // profile the training process.
Tensor tmp; Tensor tmp;
tmp.mutable_data<T>(beta->dims(), platform::CPUPlace()); tmp.mutable_data<T>(beta->dims(), platform::CPUPlace());
auto tmp_mat = EigenMatrix<T>::From(tmp); auto tmp_mat = EigenMatrix<T>::From(tmp);
...@@ -520,11 +529,11 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel<T> { ...@@ -520,11 +529,11 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel<T> {
for (size_t j = 0; j < tag_num; ++j) { for (size_t j = 0; j < tag_num; ++j) {
trans_grad[(i + state_trans_base_idx) * tag_num + j] += trans_grad[(i + state_trans_base_idx) * tag_num + j] +=
sum * w_exps[(i + state_trans_base_idx) * tag_num + j] * sum * w_exps[(i + state_trans_base_idx) * tag_num + j] *
alpha_mat(k - 1, i) * tmp_mat(k, j); alpha_mat(k - 1, i) * tmp_mat(k, j) * ll_grad;
} }
} }
trans_grad[(label_value[k - 1] + state_trans_base_idx) * tag_num + trans_grad[(label_value[k - 1] + state_trans_base_idx) * tag_num +
label_value[k]] -= static_cast<T>(1.); label_value[k]] -= static_cast<T>(ll_grad);
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册