提交 80a5ee00 编写于 作者: C caoying03

fix forward and add backward.

上级 3123e3cf
......@@ -30,20 +30,24 @@ class LinearChainCrfOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override;
protected:
T ForwardOneSequence(const platform::DeviceContext& ctx,
const Tensor& emission, Tensor& emission_row_max,
Tensor& emission_exps, const Tensor& trans_weights,
Tensor& trans_weight_exps, const Tensor& label,
Tensor& a) const;
private:
T NormalizeL1(T* x, size_t len) const;
T ForwardOneSequence(const Tensor* emission, const Tensor* emission_row_max,
const Tensor* emission_exps, const Tensor* trans_weights,
const Tensor* trans_weight_exps, const Tensor* label,
Tensor* alpha) const;
};
template <typename Place, typename T>
class LinearChainCrfGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override;
protected:
void BackwardOneSequence(const platform::DeviceContext& ctx,
const Tensor* emission_exps,
const Tensor* transition_exps, const Tensor* alpha,
const Tensor* label, Tensor* beta,
Tensor* transition_grad,
Tensor* emission_grad) const;
};
} // namespace operators
......
......@@ -4,10 +4,12 @@ import numpy as np
from op_test import OpTest
import pdb
class LinearChainCrfForward(object):
def __init__(self, seq_start_positions, emission_weights,
transition_weights, labels):
def __init__(self, seq_start_positions, emission_weights, emission_row_max,
emission_exps, transition_weights, transition_exps, labels):
self.tag_num = emission_weights.shape[1]
self.seq_num = len(seq_start_positions) - 1
......@@ -15,25 +17,25 @@ class LinearChainCrfForward(object):
self.labels = labels
self.x = emission_weights
self.x_row_max = np.amax(self.x, axis=1, keepdims=True)
self.x_exps = np.exp(self.x - self.x_row_max)
self.x_row_max = emission_row_max
self.x_exps = emission_exps
# unnormalized logits of the transition weights for the start mark.
self.a = transition_weights[0, :]
self.a_exps = np.exp(self.a)
self.a_exps = transition_exps[0, :]
# unnormalized logits of the transition weights for the end mark.
self.b = transition_weights[1, :]
self.b_exps = np.exp(self.b)
self.b_exps = transition_exps[1, :]
# unnormalized logits of the transition weights for all the other tags.
self.w = transition_weights[2:, :]
self.w_exps = np.exp(self.w)
self.w_exps = transition_exps[2:, :]
# The output of linear chain crf operator.
# alpha is a memo table in dynamic programming to caculate
# nomalization factor.
self.alpha = np.zeros(
(seq_start_positions[-1], self.tag_num), dtype="float32")
self.log_likelihood = np.zeros((self.tag_num, 1))
self.log_likelihood = np.zeros((self.seq_num, 1))
def _l1_norm(self, x):
s = np.sum(x)
......@@ -91,11 +93,15 @@ class TestLinearChainCrfOp(OpTest):
lod = [[0]]
for i in range(SEQ_NUM):
lod[-1].append(lod[-1][-1] + random.randint(1, MAX_SEQ_LEN))
emission = np.random.uniform(-1, 1,
[lod[-1][-1], TAG_NUM]).astype("float32")
emission_row_max = np.amax(emission, axis=1, keepdims=True)
emission_exps = np.exp(emission - emission_row_max)
transition = np.random.uniform(-0.5, 0.5,
[TAG_NUM + 2, TAG_NUM]).astype("float32")
transition_exps = np.exp(transition)
labels = np.random.randint(
low=0, high=TAG_NUM, size=(lod[-1][-1], 1), dtype="int32")
......@@ -105,10 +111,17 @@ class TestLinearChainCrfOp(OpTest):
"Label": (labels, lod)
}
crf = LinearChainCrfForward(lod[0], emission, transition, labels)
crf = LinearChainCrfForward(lod[0], emission, emission_row_max,
emission_exps, transition, transition_exps,
labels)
alpha, log_likelihood = crf.crf_forward_compute()
self.outputs = {"Alpha": alpha, "LogLikelihood": log_likelihood}
self.outputs = {
"Alpha": alpha,
"EmissionExps": emission_exps,
"TransitionExps": transition_exps,
"LogLikelihood": log_likelihood
}
def setUp(self):
self.op_type = "linear_chain_crf"
......@@ -117,6 +130,13 @@ class TestLinearChainCrfOp(OpTest):
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["Emission", "Transition"], "LogLikelihood")
def test_check_grad_ignore_transition(self):
self.check_grad(
["Emission"], "LogLikelihood", no_grad_set=set("Transition"))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册