提交 80a5ee00 编写于 作者: C caoying03

fix forward and add backward.

上级 3123e3cf
...@@ -30,20 +30,24 @@ class LinearChainCrfOpKernel : public framework::OpKernel<T> { ...@@ -30,20 +30,24 @@ class LinearChainCrfOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override; void Compute(const framework::ExecutionContext& ctx) const override;
protected: protected:
T ForwardOneSequence(const platform::DeviceContext& ctx, T ForwardOneSequence(const Tensor* emission, const Tensor* emission_row_max,
const Tensor& emission, Tensor& emission_row_max, const Tensor* emission_exps, const Tensor* trans_weights,
Tensor& emission_exps, const Tensor& trans_weights, const Tensor* trans_weight_exps, const Tensor* label,
Tensor& trans_weight_exps, const Tensor& label, Tensor* alpha) const;
Tensor& a) const;
private:
T NormalizeL1(T* x, size_t len) const;
}; };
template <typename Place, typename T> template <typename Place, typename T>
class LinearChainCrfGradOpKernel : public framework::OpKernel<T> { class LinearChainCrfGradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override; void Compute(const framework::ExecutionContext& ctx) const override;
protected:
void BackwardOneSequence(const platform::DeviceContext& ctx,
const Tensor* emission_exps,
const Tensor* transition_exps, const Tensor* alpha,
const Tensor* label, Tensor* beta,
Tensor* transition_grad,
Tensor* emission_grad) const;
}; };
} // namespace operators } // namespace operators
......
...@@ -4,10 +4,12 @@ import numpy as np ...@@ -4,10 +4,12 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
import pdb
class LinearChainCrfForward(object): class LinearChainCrfForward(object):
def __init__(self, seq_start_positions, emission_weights, def __init__(self, seq_start_positions, emission_weights, emission_row_max,
transition_weights, labels): emission_exps, transition_weights, transition_exps, labels):
self.tag_num = emission_weights.shape[1] self.tag_num = emission_weights.shape[1]
self.seq_num = len(seq_start_positions) - 1 self.seq_num = len(seq_start_positions) - 1
...@@ -15,25 +17,25 @@ class LinearChainCrfForward(object): ...@@ -15,25 +17,25 @@ class LinearChainCrfForward(object):
self.labels = labels self.labels = labels
self.x = emission_weights self.x = emission_weights
self.x_row_max = np.amax(self.x, axis=1, keepdims=True) self.x_row_max = emission_row_max
self.x_exps = np.exp(self.x - self.x_row_max) self.x_exps = emission_exps
# unnormalized logits of the transition weights for the start mark. # unnormalized logits of the transition weights for the start mark.
self.a = transition_weights[0, :] self.a = transition_weights[0, :]
self.a_exps = np.exp(self.a) self.a_exps = transition_exps[0, :]
# unnormalized logits of the transition weights for the end mark. # unnormalized logits of the transition weights for the end mark.
self.b = transition_weights[1, :] self.b = transition_weights[1, :]
self.b_exps = np.exp(self.b) self.b_exps = transition_exps[1, :]
# unnormalized logits of the transition weights for all the other tags. # unnormalized logits of the transition weights for all the other tags.
self.w = transition_weights[2:, :] self.w = transition_weights[2:, :]
self.w_exps = np.exp(self.w) self.w_exps = transition_exps[2:, :]
# The output of linear chain crf operator. # The output of linear chain crf operator.
# alpha is a memo table in dynamic programming to caculate # alpha is a memo table in dynamic programming to caculate
# nomalization factor. # nomalization factor.
self.alpha = np.zeros( self.alpha = np.zeros(
(seq_start_positions[-1], self.tag_num), dtype="float32") (seq_start_positions[-1], self.tag_num), dtype="float32")
self.log_likelihood = np.zeros((self.tag_num, 1)) self.log_likelihood = np.zeros((self.seq_num, 1))
def _l1_norm(self, x): def _l1_norm(self, x):
s = np.sum(x) s = np.sum(x)
...@@ -91,11 +93,15 @@ class TestLinearChainCrfOp(OpTest): ...@@ -91,11 +93,15 @@ class TestLinearChainCrfOp(OpTest):
lod = [[0]] lod = [[0]]
for i in range(SEQ_NUM): for i in range(SEQ_NUM):
lod[-1].append(lod[-1][-1] + random.randint(1, MAX_SEQ_LEN)) lod[-1].append(lod[-1][-1] + random.randint(1, MAX_SEQ_LEN))
emission = np.random.uniform(-1, 1, emission = np.random.uniform(-1, 1,
[lod[-1][-1], TAG_NUM]).astype("float32") [lod[-1][-1], TAG_NUM]).astype("float32")
emission_row_max = np.amax(emission, axis=1, keepdims=True)
emission_exps = np.exp(emission - emission_row_max)
transition = np.random.uniform(-0.5, 0.5, transition = np.random.uniform(-0.5, 0.5,
[TAG_NUM + 2, TAG_NUM]).astype("float32") [TAG_NUM + 2, TAG_NUM]).astype("float32")
transition_exps = np.exp(transition)
labels = np.random.randint( labels = np.random.randint(
low=0, high=TAG_NUM, size=(lod[-1][-1], 1), dtype="int32") low=0, high=TAG_NUM, size=(lod[-1][-1], 1), dtype="int32")
...@@ -105,10 +111,17 @@ class TestLinearChainCrfOp(OpTest): ...@@ -105,10 +111,17 @@ class TestLinearChainCrfOp(OpTest):
"Label": (labels, lod) "Label": (labels, lod)
} }
crf = LinearChainCrfForward(lod[0], emission, transition, labels) crf = LinearChainCrfForward(lod[0], emission, emission_row_max,
emission_exps, transition, transition_exps,
labels)
alpha, log_likelihood = crf.crf_forward_compute() alpha, log_likelihood = crf.crf_forward_compute()
self.outputs = {"Alpha": alpha, "LogLikelihood": log_likelihood} self.outputs = {
"Alpha": alpha,
"EmissionExps": emission_exps,
"TransitionExps": transition_exps,
"LogLikelihood": log_likelihood
}
def setUp(self): def setUp(self):
self.op_type = "linear_chain_crf" self.op_type = "linear_chain_crf"
...@@ -117,6 +130,13 @@ class TestLinearChainCrfOp(OpTest): ...@@ -117,6 +130,13 @@ class TestLinearChainCrfOp(OpTest):
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self):
self.check_grad(["Emission", "Transition"], "LogLikelihood")
def test_check_grad_ignore_transition(self):
self.check_grad(
["Emission"], "LogLikelihood", no_grad_set=set("Transition"))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册