提交 c74107bf 编写于 作者: C caoying03

fix backward computation.

上级 6a630f27
......@@ -101,8 +101,10 @@ void CRFLayer::backward(const UpdateCallback& callback) {
: real(1.0f);
instanceWeight *= coeff_;
if (output.grad) {
MatrixPtr grad = output.grad->subRowMatrix(starts[i], starts[i + 1]);
grad->add(*crfs_[i].getXGrad(), real(1.0f), instanceWeight);
}
if (needWGrad) {
weight_->getWGrad()->add(
*crfs_[i].getWGrad(), real(1.0f), instanceWeight);
......
......@@ -102,7 +102,6 @@ real LinearChainCRF::forward(real* x, int* s, int length) {
}
void LinearChainCRF::backward(real* x, int* s, int length, bool needWGrad) {
MatrixPtr matX = Matrix::create(x, length, numClasses_);
Matrix::resizeOrCreate(matGrad_, length, numClasses_);
Matrix::resizeOrCreate(beta_, length, numClasses_);
real* b = b_->getData();
......
......@@ -272,7 +272,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
int end_pos = static_cast<int>(in_lod[level][i + 1]);
if (end_pos == start_pos) {
// If an empty input sequence is given, pad 0 for its cost.
log_likelihood[i] = static_cast<T>(0.);
log_likelihood[i] = 0.;
continue;
}
......@@ -305,7 +305,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
const size_t tag_num = x_dims[1];
// The 1st row of w are transition weights for start mask.
// The 2nd row of w are transition weights for end mask.
// Transition weights among other tags begins from the 3rd row of w.
// Transition weights among other tags begin from the 3rd row of w.
const size_t state_trans_base_idx = 2;
for (size_t i = 0; i < tag_num; ++i) {
......@@ -315,7 +315,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
for (size_t k = 1; k < seq_length; ++k) {
for (size_t i = 0; i < tag_num; ++i) {
T sum = static_cast<T>(0.);
T sum = 0.;
for (size_t j = 0; j < tag_num; ++j) {
sum += alpha_value[(k - 1) * tag_num + j] *
w_exps[(j + state_trans_base_idx) * tag_num + i];
......@@ -476,17 +476,17 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
const size_t tag_num = x_dims[1];
const size_t state_trans_base_idx = 2;
// Calculate the backwark vectors beta.
// Calculate the backward vectors: beta.
// First, calculate the initialition state.
for (int i = 0; i < tag_num; ++i) {
for (size_t i = 0; i < tag_num; ++i) {
beta_value[(seq_length - 1) * tag_num + i] = w_exps[tag_num + i];
}
NormalizeL1<T>(beta_value + (seq_length - 1) * tag_num, tag_num);
for (int k = seq_length - 2; k >= 0; --k) {
for (int i = 0; i < tag_num; ++i) {
T sum = static_cast<T>(0.);
for (int j = 0; j < tag_num; ++j) {
for (int k = static_cast<int>(seq_length) - 2; k >= 0; --k) {
for (size_t i = 0; i < tag_num; ++i) {
T sum = 0.;
for (size_t j = 0; j < tag_num; ++j) {
sum += w_exps[(i + state_trans_base_idx) * tag_num + j] *
x_exps[(k + 1) * tag_num + j] *
beta_value[(k + 1) * tag_num + j];
......@@ -500,13 +500,14 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
auto beta_mat = EigenMatrix<T>::From(*beta);
auto x_grad_mat = EigenMatrix<T>::From(*emission_grad);
auto* place = ctx.GetEigenDevice<platform::CPUPlace>();
x_grad_mat.device(*place) = alpha_mat * beta_mat;
x_grad_mat /= x_grad_mat.sum(Eigen::DSizes<int, 1>(1))
auto prob = alpha_mat * beta_mat;
auto row_sum = prob.sum(Eigen::DSizes<int, 1>(1))
.reshape(Eigen::DSizes<int, 2>(seq_length, 1))
.broadcast(Eigen::DSizes<int, 2>(1, tag_num));
x_grad_mat.device(*place) = prob / row_sum;
for (int k = 0; k < seq_length; ++k) {
x_grad_mat(k, label_value[k]) -= static_cast<T>(1);
for (size_t k = 0; k < seq_length; ++k) {
x_grad_mat(k, label_value[k]) -= static_cast<T>(1.);
}
if (transition_grad) {
......@@ -518,29 +519,35 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
}
auto x_exps_mat = EigenMatrix<T>::From(*emission_exps);
beta_mat = beta_mat * x_exps_mat;
beta_mat /= beta_mat.sum(Eigen::DSizes<int, 1>(1))
// TODO(caoying): Fix this to avoid using this local variable.
Tensor tmp;
tmp.mutable_data<T>(beta->dims(), platform::CPUPlace());
auto tmp_mat = EigenMatrix<T>::From(tmp);
auto prob = beta_mat * x_exps_mat;
auto row_sum = prob.sum(Eigen::DSizes<int, 1>(1))
.reshape(Eigen::DSizes<int, 2>(seq_length, 1))
.broadcast(Eigen::DSizes<int, 2>(1, tag_num));
tmp_mat.device(*place) = prob / row_sum;
for (int k = 1; k < seq_length; ++k) {
T sum = static_cast<T>(0.);
for (int i = 0; i < tag_num; ++i) {
for (int j = 0; j < tag_num; ++j) {
for (size_t k = 1; k < seq_length; ++k) {
T sum = 0.;
for (size_t i = 0; i < tag_num; ++i) {
for (size_t j = 0; j < tag_num; ++j) {
sum += w_exps[(i + state_trans_base_idx) * tag_num + j] *
alpha_mat(k - 1, i) * beta_mat(k, j);
alpha_mat(k - 1, i) * tmp_mat(k, j);
}
}
sum = static_cast<T>(1.) / sum;
for (int i = 0; i < tag_num; ++i) {
for (int j = 0; j < tag_num; ++j) {
sum = 1. / sum;
for (size_t i = 0; i < tag_num; ++i) {
for (size_t j = 0; j < tag_num; ++j) {
trans_grad[(i + state_trans_base_idx) * tag_num + j] +=
sum * w_exps[(i + state_trans_base_idx) * tag_num + j] *
alpha_mat(k - 1, i) * beta_mat(k, j);
alpha_mat(k - 1, i) * tmp_mat(k, j);
}
}
trans_grad[label_value[k - 1] * tag_num + label_value[k]] -=
static_cast<T>(1.);
trans_grad[(label_value[k - 1] + state_trans_base_idx) * tag_num +
label_value[k]] -= static_cast<T>(1.);
}
}
}
......@@ -554,9 +561,7 @@ REGISTER_OP(linear_chain_crf, ops::LinearChainCrfOp, ops::LinearChainCrfOpMaker,
linear_chain_crf_grad, ops::LinearChainCrfGradOp);
REGISTER_OP_CPU_KERNEL(
linear_chain_crf,
ops::LinearChainCrfOpKernel<paddle::platform::CPUPlace, float>,
ops::LinearChainCrfOpKernel<paddle::platform::CPUPlace, double>);
ops::LinearChainCrfOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
linear_chain_crf_grad,
ops::LinearChainCrfGradOpKernel<paddle::platform::CPUPlace, float>,
ops::LinearChainCrfGradOpKernel<paddle::platform::CPUPlace, double>);
ops::LinearChainCrfGradOpKernel<paddle::platform::CPUPlace, float>);
......@@ -83,11 +83,10 @@ class LinearChainCrfForward(object):
class TestLinearChainCrfOp(OpTest):
def set_test_data(self):
SEQ_NUM = 2
SEQ_NUM = 3
TAG_NUM = 17
MAX_SEQ_LEN = 5
random.seed(1)
# the linear_chain_crf operator only supports sequence (LoD level = 1)
lod = [[0]]
for i in range(SEQ_NUM):
......@@ -109,7 +108,6 @@ class TestLinearChainCrfOp(OpTest):
"Transition": transition,
"Label": (labels, lod)
}
crf = LinearChainCrfForward(lod[0], emission, emission_row_max,
emission_exps, transition, transition_exps,
labels)
......@@ -130,11 +128,17 @@ class TestLinearChainCrfOp(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(["Emission", "Transition"], "LogLikelihood")
self.check_grad(
["Emission", "Transition"],
"LogLikelihood",
max_relative_error=0.05)
def test_check_grad_ignore_transition(self):
self.check_grad(
["Emission"], "LogLikelihood", no_grad_set=set("Transition"))
["Emission"],
"LogLikelihood",
max_relative_error=0.05,
no_grad_set=set("Transition"))
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册