提交 c74107bf 编写于 作者: C caoying03

fix backward computation.

上级 6a630f27
...@@ -101,8 +101,10 @@ void CRFLayer::backward(const UpdateCallback& callback) { ...@@ -101,8 +101,10 @@ void CRFLayer::backward(const UpdateCallback& callback) {
: real(1.0f); : real(1.0f);
instanceWeight *= coeff_; instanceWeight *= coeff_;
if (output.grad) {
MatrixPtr grad = output.grad->subRowMatrix(starts[i], starts[i + 1]); MatrixPtr grad = output.grad->subRowMatrix(starts[i], starts[i + 1]);
grad->add(*crfs_[i].getXGrad(), real(1.0f), instanceWeight); grad->add(*crfs_[i].getXGrad(), real(1.0f), instanceWeight);
}
if (needWGrad) { if (needWGrad) {
weight_->getWGrad()->add( weight_->getWGrad()->add(
*crfs_[i].getWGrad(), real(1.0f), instanceWeight); *crfs_[i].getWGrad(), real(1.0f), instanceWeight);
......
...@@ -102,7 +102,6 @@ real LinearChainCRF::forward(real* x, int* s, int length) { ...@@ -102,7 +102,6 @@ real LinearChainCRF::forward(real* x, int* s, int length) {
} }
void LinearChainCRF::backward(real* x, int* s, int length, bool needWGrad) { void LinearChainCRF::backward(real* x, int* s, int length, bool needWGrad) {
MatrixPtr matX = Matrix::create(x, length, numClasses_);
Matrix::resizeOrCreate(matGrad_, length, numClasses_); Matrix::resizeOrCreate(matGrad_, length, numClasses_);
Matrix::resizeOrCreate(beta_, length, numClasses_); Matrix::resizeOrCreate(beta_, length, numClasses_);
real* b = b_->getData(); real* b = b_->getData();
......
...@@ -272,7 +272,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T> ...@@ -272,7 +272,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
int end_pos = static_cast<int>(in_lod[level][i + 1]); int end_pos = static_cast<int>(in_lod[level][i + 1]);
if (end_pos == start_pos) { if (end_pos == start_pos) {
// If an empty input sequence is given, pad 0 for its cost. // If an empty input sequence is given, pad 0 for its cost.
log_likelihood[i] = static_cast<T>(0.); log_likelihood[i] = 0.;
continue; continue;
} }
...@@ -305,7 +305,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T> ...@@ -305,7 +305,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
const size_t tag_num = x_dims[1]; const size_t tag_num = x_dims[1];
// The 1st row of w are transition weights for start mask. // The 1st row of w are transition weights for start mask.
// The 2nd row of w are transition weights for end mask. // The 2nd row of w are transition weights for end mask.
// Transition weights among other tags begins from the 3rd row of w. // Transition weights among other tags begin from the 3rd row of w.
const size_t state_trans_base_idx = 2; const size_t state_trans_base_idx = 2;
for (size_t i = 0; i < tag_num; ++i) { for (size_t i = 0; i < tag_num; ++i) {
...@@ -315,7 +315,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T> ...@@ -315,7 +315,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
for (size_t k = 1; k < seq_length; ++k) { for (size_t k = 1; k < seq_length; ++k) {
for (size_t i = 0; i < tag_num; ++i) { for (size_t i = 0; i < tag_num; ++i) {
T sum = static_cast<T>(0.); T sum = 0.;
for (size_t j = 0; j < tag_num; ++j) { for (size_t j = 0; j < tag_num; ++j) {
sum += alpha_value[(k - 1) * tag_num + j] * sum += alpha_value[(k - 1) * tag_num + j] *
w_exps[(j + state_trans_base_idx) * tag_num + i]; w_exps[(j + state_trans_base_idx) * tag_num + i];
...@@ -476,17 +476,17 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T> ...@@ -476,17 +476,17 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
const size_t tag_num = x_dims[1]; const size_t tag_num = x_dims[1];
const size_t state_trans_base_idx = 2; const size_t state_trans_base_idx = 2;
// Calculate the backwark vectors beta. // Calculate the backward vectors: beta.
// First, calculate the initialition state. // First, calculate the initialition state.
for (int i = 0; i < tag_num; ++i) { for (size_t i = 0; i < tag_num; ++i) {
beta_value[(seq_length - 1) * tag_num + i] = w_exps[tag_num + i]; beta_value[(seq_length - 1) * tag_num + i] = w_exps[tag_num + i];
} }
NormalizeL1<T>(beta_value + (seq_length - 1) * tag_num, tag_num); NormalizeL1<T>(beta_value + (seq_length - 1) * tag_num, tag_num);
for (int k = seq_length - 2; k >= 0; --k) { for (int k = static_cast<int>(seq_length) - 2; k >= 0; --k) {
for (int i = 0; i < tag_num; ++i) { for (size_t i = 0; i < tag_num; ++i) {
T sum = static_cast<T>(0.); T sum = 0.;
for (int j = 0; j < tag_num; ++j) { for (size_t j = 0; j < tag_num; ++j) {
sum += w_exps[(i + state_trans_base_idx) * tag_num + j] * sum += w_exps[(i + state_trans_base_idx) * tag_num + j] *
x_exps[(k + 1) * tag_num + j] * x_exps[(k + 1) * tag_num + j] *
beta_value[(k + 1) * tag_num + j]; beta_value[(k + 1) * tag_num + j];
...@@ -500,13 +500,14 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T> ...@@ -500,13 +500,14 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
auto beta_mat = EigenMatrix<T>::From(*beta); auto beta_mat = EigenMatrix<T>::From(*beta);
auto x_grad_mat = EigenMatrix<T>::From(*emission_grad); auto x_grad_mat = EigenMatrix<T>::From(*emission_grad);
auto* place = ctx.GetEigenDevice<platform::CPUPlace>(); auto* place = ctx.GetEigenDevice<platform::CPUPlace>();
x_grad_mat.device(*place) = alpha_mat * beta_mat; auto prob = alpha_mat * beta_mat;
x_grad_mat /= x_grad_mat.sum(Eigen::DSizes<int, 1>(1)) auto row_sum = prob.sum(Eigen::DSizes<int, 1>(1))
.reshape(Eigen::DSizes<int, 2>(seq_length, 1)) .reshape(Eigen::DSizes<int, 2>(seq_length, 1))
.broadcast(Eigen::DSizes<int, 2>(1, tag_num)); .broadcast(Eigen::DSizes<int, 2>(1, tag_num));
x_grad_mat.device(*place) = prob / row_sum;
for (int k = 0; k < seq_length; ++k) { for (size_t k = 0; k < seq_length; ++k) {
x_grad_mat(k, label_value[k]) -= static_cast<T>(1); x_grad_mat(k, label_value[k]) -= static_cast<T>(1.);
} }
if (transition_grad) { if (transition_grad) {
...@@ -518,29 +519,35 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T> ...@@ -518,29 +519,35 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
} }
auto x_exps_mat = EigenMatrix<T>::From(*emission_exps); auto x_exps_mat = EigenMatrix<T>::From(*emission_exps);
beta_mat = beta_mat * x_exps_mat;
beta_mat /= beta_mat.sum(Eigen::DSizes<int, 1>(1)) // TODO(caoying): Fix this to avoid using this local variable.
Tensor tmp;
tmp.mutable_data<T>(beta->dims(), platform::CPUPlace());
auto tmp_mat = EigenMatrix<T>::From(tmp);
auto prob = beta_mat * x_exps_mat;
auto row_sum = prob.sum(Eigen::DSizes<int, 1>(1))
.reshape(Eigen::DSizes<int, 2>(seq_length, 1)) .reshape(Eigen::DSizes<int, 2>(seq_length, 1))
.broadcast(Eigen::DSizes<int, 2>(1, tag_num)); .broadcast(Eigen::DSizes<int, 2>(1, tag_num));
tmp_mat.device(*place) = prob / row_sum;
for (int k = 1; k < seq_length; ++k) { for (size_t k = 1; k < seq_length; ++k) {
T sum = static_cast<T>(0.); T sum = 0.;
for (int i = 0; i < tag_num; ++i) { for (size_t i = 0; i < tag_num; ++i) {
for (int j = 0; j < tag_num; ++j) { for (size_t j = 0; j < tag_num; ++j) {
sum += w_exps[(i + state_trans_base_idx) * tag_num + j] * sum += w_exps[(i + state_trans_base_idx) * tag_num + j] *
alpha_mat(k - 1, i) * beta_mat(k, j); alpha_mat(k - 1, i) * tmp_mat(k, j);
} }
} }
sum = static_cast<T>(1.) / sum; sum = 1. / sum;
for (int i = 0; i < tag_num; ++i) { for (size_t i = 0; i < tag_num; ++i) {
for (int j = 0; j < tag_num; ++j) { for (size_t j = 0; j < tag_num; ++j) {
trans_grad[(i + state_trans_base_idx) * tag_num + j] += trans_grad[(i + state_trans_base_idx) * tag_num + j] +=
sum * w_exps[(i + state_trans_base_idx) * tag_num + j] * sum * w_exps[(i + state_trans_base_idx) * tag_num + j] *
alpha_mat(k - 1, i) * beta_mat(k, j); alpha_mat(k - 1, i) * tmp_mat(k, j);
} }
} }
trans_grad[label_value[k - 1] * tag_num + label_value[k]] -= trans_grad[(label_value[k - 1] + state_trans_base_idx) * tag_num +
static_cast<T>(1.); label_value[k]] -= static_cast<T>(1.);
} }
} }
} }
...@@ -554,9 +561,7 @@ REGISTER_OP(linear_chain_crf, ops::LinearChainCrfOp, ops::LinearChainCrfOpMaker, ...@@ -554,9 +561,7 @@ REGISTER_OP(linear_chain_crf, ops::LinearChainCrfOp, ops::LinearChainCrfOpMaker,
linear_chain_crf_grad, ops::LinearChainCrfGradOp); linear_chain_crf_grad, ops::LinearChainCrfGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
linear_chain_crf, linear_chain_crf,
ops::LinearChainCrfOpKernel<paddle::platform::CPUPlace, float>, ops::LinearChainCrfOpKernel<paddle::platform::CPUPlace, float>);
ops::LinearChainCrfOpKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
linear_chain_crf_grad, linear_chain_crf_grad,
ops::LinearChainCrfGradOpKernel<paddle::platform::CPUPlace, float>, ops::LinearChainCrfGradOpKernel<paddle::platform::CPUPlace, float>);
ops::LinearChainCrfGradOpKernel<paddle::platform::CPUPlace, double>);
...@@ -83,11 +83,10 @@ class LinearChainCrfForward(object): ...@@ -83,11 +83,10 @@ class LinearChainCrfForward(object):
class TestLinearChainCrfOp(OpTest): class TestLinearChainCrfOp(OpTest):
def set_test_data(self): def set_test_data(self):
SEQ_NUM = 2 SEQ_NUM = 3
TAG_NUM = 17 TAG_NUM = 17
MAX_SEQ_LEN = 5 MAX_SEQ_LEN = 5
random.seed(1)
# the linear_chain_crf operator only supports sequence (LoD level = 1) # the linear_chain_crf operator only supports sequence (LoD level = 1)
lod = [[0]] lod = [[0]]
for i in range(SEQ_NUM): for i in range(SEQ_NUM):
...@@ -109,7 +108,6 @@ class TestLinearChainCrfOp(OpTest): ...@@ -109,7 +108,6 @@ class TestLinearChainCrfOp(OpTest):
"Transition": transition, "Transition": transition,
"Label": (labels, lod) "Label": (labels, lod)
} }
crf = LinearChainCrfForward(lod[0], emission, emission_row_max, crf = LinearChainCrfForward(lod[0], emission, emission_row_max,
emission_exps, transition, transition_exps, emission_exps, transition, transition_exps,
labels) labels)
...@@ -130,11 +128,17 @@ class TestLinearChainCrfOp(OpTest): ...@@ -130,11 +128,17 @@ class TestLinearChainCrfOp(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["Emission", "Transition"], "LogLikelihood") self.check_grad(
["Emission", "Transition"],
"LogLikelihood",
max_relative_error=0.05)
def test_check_grad_ignore_transition(self): def test_check_grad_ignore_transition(self):
self.check_grad( self.check_grad(
["Emission"], "LogLikelihood", no_grad_set=set("Transition")) ["Emission"],
"LogLikelihood",
max_relative_error=0.05,
no_grad_set=set("Transition"))
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册