diff --git a/paddle/operators/linear_chain_crf_op.cc b/paddle/operators/linear_chain_crf_op.cc index e127811a101f133802fc9c038d42e843d45d4368..14ae74ab662748c1cc774293680e758bc3cb560e 100644 --- a/paddle/operators/linear_chain_crf_op.cc +++ b/paddle/operators/linear_chain_crf_op.cc @@ -17,6 +17,22 @@ limitations under the License. */ namespace paddle { namespace operators { +namespace { +template <typename T> +T NormalizeL1(T* x, size_t len) { + T sum = 0.; + for (size_t i = 0; i < len; ++i) sum += x[i]; + // (This comment is from the old LinearChainCRFLayer.) + // Right now, we just bet that sum won't be zero. If this really happens, we + // will figure out what should be done then. + PADDLE_ENFORCE(sum, + "The unnormalized probabilites of all possible unfinished " + "sequences must be greater than 0."); + for (size_t i = 0; i < len; ++i) x[i] /= sum; + return sum; +} +} // namespace + using framework::LoDTensor; using framework::LoD; @@ -54,13 +70,25 @@ class LinearChainCrfOpMaker : public framework::OpProtoAndCheckerMaker { "each tag value \f$v$\f. This vector is called a forward vecotr and " "will also be used in backward computations.") .AsIntermediate(); + AddOutput("EmissionExps", + "The exponentials of Input(Emission). This is an intermediate " + "computational result in forward computation, and will be reused " + "in backward computation.") + .AsIntermediate(); + AddOutput("TransitionExps", + "The exponentials of Input(Transition). This is an intermediate " + "computational result in forward computation, and will be reused " + "in backward computation.") + .AsIntermediate(); AddOutput( "LogLikelihood", - "(Tensor, default: Tensor<float>). The logarithm of the conditional " + "(Tensor, default: Tensor<float>). The logarithm of the " + "conditional " "likelihood of each training sample in a mini-batch. This is a 2-D " "tensor with shape [S x 1], where S is the sequence number in a " "mini-batch. " - "Note: S is equal to the sequence number in a mini-batch. The output " + "Note: S is equal to the sequence number in a mini-batch. The " + "output " "is no longer a LoDTensor."); AddComment(R"DOC( Conditional Random Field defines an undirected probabilistic graph with nodes @@ -129,6 +157,10 @@ class LinearChainCrfOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasOutput("Alpha"), "Output(Alpha) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("EmissionExps"), + "Output(EmissionExps) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("TransitionExps"), + "Output(TransitionExps) should be not null."); PADDLE_ENFORCE(ctx->HasOutput("LogLikelihood"), "Output(LogLikelihood) should be not null."); @@ -143,7 +175,7 @@ class LinearChainCrfOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ( transition_dims[0] - 2, transition_dims[1], "An invalid dimension for the Input(Transition), which should " - "be a 2-D tensor with shape [D + 2 x D]."); + "be a 2-D tensor with shape [(D + 2) x D]."); PADDLE_ENFORCE_EQ( emission_dims[1], transition_dims[1], "The 2nd dimension of the Input(Emission) and the Input(Transition) " @@ -157,11 +189,14 @@ class LinearChainCrfOp : public framework::OperatorWithKernel { "should be the same."); ctx->SetOutputDim("Alpha", emission_dims); - + ctx->SetOutputDim("EmissionExps", emission_dims); + ctx->SetOutputDim("TransitionExps", transition_dims); // (TODO caoying) This is tricky. The 1st dimension of Output(LogLikelihood) // is the sequence number in a mini-batch. The dimension set here should be // resized to its correct size in the function Compute. ctx->SetOutputDim("LogLikelihood", {emission_dims[0], 1}); + + ctx->ShareLoD("Emission", /*->*/ "EmissionExps"); } protected: @@ -180,9 +215,12 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T> void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), "This kernel only runs on CPU."); - auto* emission_weights = ctx.Input<LoDTensor>("Emission"); auto* transition_weights = ctx.Input<Tensor>("Transition"); + auto* emission_exps = ctx.Output<LoDTensor>("EmissionExps"); + emission_exps->mutable_data<T>(platform::CPUPlace()); + auto* transition_exps = ctx.Output<Tensor>("TransitionExps"); + transition_exps->mutable_data<T>(platform::CPUPlace()); auto* label = ctx.Input<LoDTensor>("Label"); auto in_lod = emission_weights->lod(); @@ -195,18 +233,29 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T> const size_t level = 0; auto emission_dims = emission_weights->dims(); + const size_t batch_size = emission_dims[0]; + const size_t tag_num = emission_dims[1]; const size_t seq_num = in_lod[level].size() - 1; - // TODO(caoying) These local variables seems to be created and destroied - // every time this function is called. Will this bring additional overhead? - Tensor emission_exps; Tensor emission_row_max; - Tensor transition_exps; - emission_exps.mutable_data<T>(emission_dims, platform::CPUPlace()); emission_row_max.mutable_data<T>( - framework::make_ddim({emission_dims[0], 1}), platform::CPUPlace()); - transition_exps.mutable_data<T>(transition_weights->dims(), - platform::CPUPlace()); + framework::make_ddim({static_cast<int>(batch_size), 1}), + platform::CPUPlace()); + + auto place = ctx.GetEigenDevice<platform::CPUPlace>(); + auto x = EigenMatrix<T>::From(*emission_weights); + auto x_row_max = EigenMatrix<T>::From(emission_row_max); + x_row_max.device(place) = + x.maximum(Eigen::DSizes<int, 1>(1)) + .reshape(Eigen::DSizes<int, 2>(int(batch_size), 1)); + + auto x_exps = EigenMatrix<T>::From(*emission_exps); + x_exps.device(place) = + (x - x_row_max.broadcast(Eigen::DSizes<int, 2>(1, tag_num))).exp(); + + auto w = EigenMatrix<T>::From(*transition_weights); + auto w_exps = EigenMatrix<T>::From(*transition_exps); + w_exps.device(place) = w.exp(); auto* alpha = ctx.Output<LoDTensor>("Alpha"); alpha->mutable_data<T>(ctx.GetPlace()); @@ -214,117 +263,124 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T> // resize the output tensor to the correct dimension. ll->Resize({static_cast<int>(seq_num), 1}); T* log_likelihood = ll->mutable_data<T>(ctx.GetPlace()); - for (size_t i = 0; i < seq_num; ++i) { int start_pos = static_cast<int>(in_lod[level][i]); int end_pos = static_cast<int>(in_lod[level][i + 1]); const Tensor one_seq = emission_weights->Slice<T>(start_pos, end_pos); Tensor one_seq_row_max = emission_row_max.Slice<T>(start_pos, end_pos); - Tensor one_seq_exps = emission_exps.Slice<T>(start_pos, end_pos); + Tensor one_seq_exps = emission_exps->Slice<T>(start_pos, end_pos); const Tensor one_seq_label = label->Slice<T>(start_pos, end_pos); Tensor one_seq_alpha = alpha->Slice<T>(start_pos, end_pos); log_likelihood[i] = ForwardOneSequence( - ctx.device_context(), one_seq, one_seq_row_max, one_seq_exps, - (*transition_weights), transition_exps, one_seq_label, one_seq_alpha); + &one_seq, &one_seq_row_max, &one_seq_exps, transition_weights, + transition_exps, &one_seq_label, &one_seq_alpha); } } protected: - T ForwardOneSequence(const platform::DeviceContext& ctx, - const Tensor& emission, Tensor& emission_row_max, - Tensor& emission_exps, const Tensor& trans_weights, - Tensor& trans_weight_exps, const Tensor& label, - Tensor& alpha) const { - // (TODO caoying) Evaluate and optimize this. - // The Eigen compution kernel will be invoked for multiple times. - // Some computations regardless of sequence inforamtion could be performed - // only one time for the entire batch. This potentially could be optimized. - - auto x_dims = emission.dims(); + T ForwardOneSequence(const Tensor* emission, const Tensor* emission_row_max, + const Tensor* emission_exps, const Tensor* trans_weights, + const Tensor* trans_weight_exps, const Tensor* label, + Tensor* alpha) const { + const T* x = emission->data<T>(); + const T* x_row_max = emission_row_max->data<T>(); + const T* x_exps = emission_exps->data<T>(); + const T* w = trans_weights->data<T>(); + const T* w_exps = trans_weight_exps->data<T>(); + T* alpha_value = alpha->data<T>(); + + auto x_dims = emission->dims(); const size_t seq_length = x_dims[0]; const size_t tag_num = x_dims[1]; - - T* alpha_value = alpha.data<T>(); - - auto x = EigenMatrix<T>::From(emission); - auto x_row_max = EigenMatrix<T>::From(emission_row_max); - const int class_dim = 1; - x_row_max.device(*ctx.GetEigenDevice<platform::CPUPlace>()) = - x.maximum(Eigen::DSizes<int, 1>(class_dim)) - .reshape(Eigen::DSizes<int, 2>(int(seq_length), 1)); - - auto x_exps = EigenMatrix<T>::From(emission_exps); - x_exps.device(*ctx.GetEigenDevice<platform::CPUPlace>()) = - (x - x_row_max.broadcast(Eigen::DSizes<int, 2>(1, tag_num))).exp(); - - auto w = EigenMatrix<T>::From(trans_weights); - auto w_exps = EigenMatrix<T>::From(trans_weight_exps); - w_exps.device(*ctx.GetEigenDevice<platform::CPUPlace>()) = w.exp(); // The 1st row of w are transition weights for start mask. - const size_t start_ridx = 0; // The 2nd row of w are transition weights for end mask. - const size_t end_ridx = 1; // Transition weights among other tags begins from the 3rd row of w. - const size_t state_base_ridx = 2; + const size_t state_trans_base_idx = 2; for (size_t i = 0; i < tag_num; ++i) { - alpha_value[i] = w_exps(start_ridx, i) * x_exps(0, i); + alpha_value[i] = w_exps[i] * x_exps[i]; } - T ll = -x_row_max(0, 1) - std::log(NormalizeL1(alpha_value, tag_num)); + T ll = -x_row_max[0] - std::log(NormalizeL1<T>(alpha_value, tag_num)); for (size_t k = 1; k < seq_length; ++k) { for (size_t i = 0; i < tag_num; ++i) { T sum = 0.; for (size_t j = 0; j < tag_num; ++j) { sum += alpha_value[(k - 1) * tag_num + j] * - w_exps(j + state_base_ridx, i); + w_exps[(j + state_trans_base_idx) * tag_num + i]; } - alpha_value[k * tag_num + i] = x_exps(k, i) * sum; + alpha_value[k * tag_num + i] = x_exps[k * tag_num + i] * sum; } - ll -= x_row_max(k, 1) + - std::log(NormalizeL1(alpha_value + k * tag_num, tag_num)); + ll -= x_row_max[k] + + std::log(NormalizeL1<T>(alpha_value + k * tag_num, tag_num)); } T sum = 0.; for (size_t i = 0; i < tag_num; ++i) { - sum += alpha_value[(seq_length - 1) * tag_num + i] * w_exps(end_ridx, i); + sum += alpha_value[(seq_length - 1) * tag_num + i] * w_exps[tag_num + i]; } ll -= std::log(sum); - const int* lbl = label.data<int>(); + const int* lbl = label->data<int>(); PADDLE_ENFORCE_LT( *std::max_element(lbl, lbl + seq_length), tag_num, "An invalid tag label that execesses the largest tag number."); - // Calculate the nominator part, which depends on the label sequence. - ll += w(start_ridx, lbl[0]) + x(start_ridx, lbl[0]) + - w(end_ridx, lbl[seq_length - 1]); + ll += w[lbl[0]] /*start transition*/ + x[lbl[0]] + + w[tag_num + lbl[seq_length - 1]] /*end transition*/; for (size_t k = 1; k < seq_length; ++k) - ll += x(k, lbl[k]) + w(lbl[k - 1], lbl[k]); + ll += x[k * tag_num + lbl[k]] + w[lbl[k - 1] * tag_num + lbl[k]]; return -ll; } - - private: - T NormalizeL1(T* x, size_t len) const { - T sum = 0.; - for (size_t i = 0; i < len; ++i) sum += x[i]; - // (This comment is from the old LinearChainCRFLayer.) - // Right now, we just bet that sum won't be zero. If this really happens, we - // will figure out what should be done then. - PADDLE_ENFORCE(sum, - "The unnormalized probabilites of all possible unfinished " - "sequences must be greater than 0."); - for (size_t i = 0; i < len; ++i) x[i] /= sum; - return sum; - } }; class LinearChainCrfGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override {} + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("EmissionExps"), + "Input(EmissionExps) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("TransitionExps"), + "Input(TransitionExps) should be not null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("LogLikelihood")), + "Input(LogLikelihood@GRAD) shoudl be not null."); + + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Emission")), + "Output(Emission@GRAD) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Transition")), + "Output(Transition@GRAD) should be not null."); + + auto emission_exps_dims = ctx->GetInputDim("EmissionExps"); + auto transition_exps_dims = + ctx->GetInputDim(framework::GradVarName("TransitionExps")); + auto label_dims = ctx->GetInputDim("Label"); + + PADDLE_ENFORCE_EQ(emission_exps_dims.size(), 2UL, + "The Input(EmissionExps) should be a 2-D tensor."); + PADDLE_ENFORCE_EQ(transition_exps_dims.size(), 2UL, + "The Input(TransitionExps) should be a 2-D tensor."); + PADDLE_ENFORCE_EQ( + transition_exps_dims[0] - 2, transition_exps_dims[1], + "An invalid dimension for the Input(TransitionExps), which should " + "be a 2-D tensor with shape [(D + 2) x D]."); + PADDLE_ENFORCE_EQ( + emission_exps_dims[1], transition_exps_dims[1], + "The 2nd dimension of the Input(EmissionExps) and the " + "Input(TransitionExps) should be equal to the tag number."); + PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL, + "The Input(Label) should be a 2-D tensor with the 2nd " + "dimensions fixed to 1."); + PADDLE_ENFORCE_EQ( + emission_exps_dims[0], label_dims[0], + "The height of Input(EmissionExps) and the height of Input(Label) " + "should be the same."); + + ctx->SetOutputDim(framework::GradVarName("Emission"), emission_exps_dims); + ctx->SetOutputDim(framework::GradVarName("Transition"), + transition_exps_dims); + } }; template <typename T> @@ -334,6 +390,134 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T> void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), "This kernel only runs on CPU."); + auto* ll_grad = + ctx.Input<LoDTensor>(framework::GradVarName("LogLikelihood")); + auto* label = ctx.Input<LoDTensor>("Label"); + auto* emission_exps = ctx.Input<LoDTensor>("EmissionExps"); + auto* transition_exps = ctx.Input<Tensor>("TransitionExps"); + auto* alpha = ctx.Input<Tensor>("Alpha"); + + auto* emission_grad = + ctx.Output<Tensor>(framework::GradVarName("Emission")); + emission_grad->mutable_data<T>(platform::CPUPlace()); + + auto* trans_grad = ctx.Output<Tensor>(framework::GradVarName("Transition")); + if (trans_grad) trans_grad->mutable_data<T>(platform::CPUPlace()); + + auto emission_dims = emission_exps->dims(); + + // Beta is the memo table used in dynamic programming to calculate the + // backwark vectors. For a backward vector i (the i-th row of beta), it + // captures the unnormalized probabilities of partial sequences starting at + // position i. + Tensor beta; + beta.mutable_data<T>(emission_dims, platform::CPUPlace()); + + auto place = ctx.GetEigenDevice<platform::CPUPlace>(); + auto x_grad = EigenMatrix<T>::From(*emission_grad); + auto out_grad = EigenMatrix<T>::From(*ll_grad); + x_grad.device(place) = + x_grad * out_grad.broadcast(Eigen::DSizes<int, 2>(1, emission_dims[1])); + + const size_t level = 0; // currently, only support sequence. + auto lod = emission_exps->lod(); + for (size_t i = 0; i < lod[level].size() - 1; ++i) { + int start_pos = static_cast<int>(lod[level][i]); + int end_pos = static_cast<int>(lod[level][i + 1]); + + const Tensor one_seq_emission_exps = + emission_exps->Slice<T>(start_pos, end_pos); + const Tensor one_seq_label = label->Slice<T>(start_pos, end_pos); + const Tensor one_seq_alpha = alpha->Slice<T>(start_pos, end_pos); + Tensor one_seq_beta = beta.Slice<T>(start_pos, end_pos); + Tensor one_seq_emission_grad = + emission_grad->Slice<T>(start_pos, end_pos); + + BackwardOneSequence(ctx.device_context(), &one_seq_emission_exps, + transition_exps, &one_seq_alpha, &one_seq_label, + &one_seq_beta, trans_grad, &one_seq_emission_grad); + } + } + + protected: + void BackwardOneSequence(const platform::DeviceContext& ctx, + const Tensor* emission_exps, + const Tensor* transition_exps, const Tensor* alpha, + const Tensor* label, Tensor* beta, + Tensor* transition_grad, + Tensor* emission_grad) const { + const T* w_exps = transition_exps->data<T>(); + const T* x_exps = emission_exps->data<T>(); + const int* label_value = label->data<int>(); + T* beta_value = beta->data<T>(); + + auto x_dims = emission_exps->dims(); + const size_t seq_length = x_dims[0]; + const size_t tag_num = x_dims[1]; + const size_t state_trans_base_idx = 2; + + // Calculate the backwark vectors beta. + for (int i = 0; i < tag_num; ++i) + beta_value[(seq_length - 1) * tag_num + i] = w_exps[tag_num + i]; + NormalizeL1<T>(beta_value + (seq_length - 1) * tag_num, tag_num); + + for (int k = seq_length - 2; k >= 0; --k) { + for (int i = 0; i < tag_num; ++i) { + T sum = 0.; + for (int j = 0; j < tag_num; ++j) { + sum += x_exps[(i + state_trans_base_idx) * tag_num + j] * + beta_value[(k + 1) * tag_num + j] * + x_exps[(k + 1) * tag_num + j]; + } + beta_value[k * tag_num + i] = sum; + } + NormalizeL1<T>(beta_value + k * tag_num, tag_num); + } + + auto alpha_mat = EigenMatrix<T>::From(*alpha); + auto beta_mat = EigenMatrix<T>::From(*beta); + auto x_grad_mat = EigenMatrix<T>::From(*emission_grad); + + auto* place = ctx.GetEigenDevice<platform::CPUPlace>(); + x_grad_mat.device(*place) = alpha_mat * beta_mat; + x_grad_mat /= x_grad_mat.sum(Eigen::DSizes<int, 1>(1)) + .reshape(Eigen::DSizes<int, 2>(seq_length, 1)) + .broadcast(Eigen::DSizes<int, 2>(1, tag_num)); + + for (int k = 0; k < seq_length; ++k) + x_grad_mat(k, label_value[k]) -= static_cast<T>(1); + + if (transition_grad) { + T* trans_grad = transition_grad->data<T>(); + for (size_t k = 0; k < tag_num; ++k) { + trans_grad[k] += x_grad_mat(/*from start state*/ 0, k); + trans_grad[tag_num + k] += + x_grad_mat(/*to end state*/ seq_length - 1, k); + } + + auto x_exps_mat = EigenMatrix<T>::From(*emission_exps); + beta_mat = beta_mat * x_exps_mat; + beta_mat /= beta_mat.sum(Eigen::DSizes<int, 1>(1)) + .reshape(Eigen::DSizes<int, 2>(seq_length, 1)) + .broadcast(Eigen::DSizes<int, 2>(1, tag_num)); + + for (int k = 1; k < seq_length; ++k) { + T sum = 0.; + for (int i = 0; i < tag_num; ++i) { + for (int j = 0; j < tag_num; ++j) + sum += x_exps_mat(i, j) * alpha_mat(k - 1, i) * beta_mat(k, j); + } + sum = static_cast<T>(1) / sum; + for (int i = 0; i < tag_num; ++i) { + for (int j = 0; j < tag_num; ++j) { + trans_grad[(i + 2) * tag_num + j] += + sum * x_exps_mat(i, j) * alpha_mat(k - 1, i) * beta_mat(k, j); + } + } + trans_grad[label_value[k - 1] * tag_num + label_value[k]] -= + static_cast<T>(1); + } + } } }; diff --git a/paddle/operators/linear_chain_crf_op.h b/paddle/operators/linear_chain_crf_op.h index a656e233c2c6331affca345283a3c22ee32852e1..e9852de5959fd9ab56bfa62e33fc3acc519c9f3a 100644 --- a/paddle/operators/linear_chain_crf_op.h +++ b/paddle/operators/linear_chain_crf_op.h @@ -30,20 +30,24 @@ class LinearChainCrfOpKernel : public framework::OpKernel<T> { void Compute(const framework::ExecutionContext& ctx) const override; protected: - T ForwardOneSequence(const platform::DeviceContext& ctx, - const Tensor& emission, Tensor& emission_row_max, - Tensor& emission_exps, const Tensor& trans_weights, - Tensor& trans_weight_exps, const Tensor& label, - Tensor& a) const; - - private: - T NormalizeL1(T* x, size_t len) const; + T ForwardOneSequence(const Tensor* emission, const Tensor* emission_row_max, + const Tensor* emission_exps, const Tensor* trans_weights, + const Tensor* trans_weight_exps, const Tensor* label, + Tensor* alpha) const; }; template <typename Place, typename T> class LinearChainCrfGradOpKernel : public framework::OpKernel<T> { public: void Compute(const framework::ExecutionContext& ctx) const override; + + protected: + void BackwardOneSequence(const platform::DeviceContext& ctx, + const Tensor* emission_exps, + const Tensor* transition_exps, const Tensor* alpha, + const Tensor* label, Tensor* beta, + Tensor* transition_grad, + Tensor* emission_grad) const; }; } // namespace operators diff --git a/python/paddle/v2/framework/tests/test_linear_chain_crf_op.py b/python/paddle/v2/framework/tests/test_linear_chain_crf_op.py index 413210e75b8feeaf76710eb3965a007446aba852..9b73e26eb98dbfac65166a83ab570d244362f2d2 100644 --- a/python/paddle/v2/framework/tests/test_linear_chain_crf_op.py +++ b/python/paddle/v2/framework/tests/test_linear_chain_crf_op.py @@ -4,10 +4,12 @@ import numpy as np from op_test import OpTest +import pdb + class LinearChainCrfForward(object): - def __init__(self, seq_start_positions, emission_weights, - transition_weights, labels): + def __init__(self, seq_start_positions, emission_weights, emission_row_max, + emission_exps, transition_weights, transition_exps, labels): self.tag_num = emission_weights.shape[1] self.seq_num = len(seq_start_positions) - 1 @@ -15,25 +17,25 @@ class LinearChainCrfForward(object): self.labels = labels self.x = emission_weights - self.x_row_max = np.amax(self.x, axis=1, keepdims=True) - self.x_exps = np.exp(self.x - self.x_row_max) + self.x_row_max = emission_row_max + self.x_exps = emission_exps # unnormalized logits of the transition weights for the start mark. self.a = transition_weights[0, :] - self.a_exps = np.exp(self.a) + self.a_exps = transition_exps[0, :] # unnormalized logits of the transition weights for the end mark. self.b = transition_weights[1, :] - self.b_exps = np.exp(self.b) + self.b_exps = transition_exps[1, :] # unnormalized logits of the transition weights for all the other tags. self.w = transition_weights[2:, :] - self.w_exps = np.exp(self.w) + self.w_exps = transition_exps[2:, :] # The output of linear chain crf operator. # alpha is a memo table in dynamic programming to caculate # nomalization factor. self.alpha = np.zeros( (seq_start_positions[-1], self.tag_num), dtype="float32") - self.log_likelihood = np.zeros((self.tag_num, 1)) + self.log_likelihood = np.zeros((self.seq_num, 1)) def _l1_norm(self, x): s = np.sum(x) @@ -91,11 +93,15 @@ class TestLinearChainCrfOp(OpTest): lod = [[0]] for i in range(SEQ_NUM): lod[-1].append(lod[-1][-1] + random.randint(1, MAX_SEQ_LEN)) - emission = np.random.uniform(-1, 1, [lod[-1][-1], TAG_NUM]).astype("float32") + emission_row_max = np.amax(emission, axis=1, keepdims=True) + emission_exps = np.exp(emission - emission_row_max) + transition = np.random.uniform(-0.5, 0.5, [TAG_NUM + 2, TAG_NUM]).astype("float32") + transition_exps = np.exp(transition) + labels = np.random.randint( low=0, high=TAG_NUM, size=(lod[-1][-1], 1), dtype="int32") @@ -105,10 +111,17 @@ class TestLinearChainCrfOp(OpTest): "Label": (labels, lod) } - crf = LinearChainCrfForward(lod[0], emission, transition, labels) + crf = LinearChainCrfForward(lod[0], emission, emission_row_max, + emission_exps, transition, transition_exps, + labels) alpha, log_likelihood = crf.crf_forward_compute() - self.outputs = {"Alpha": alpha, "LogLikelihood": log_likelihood} + self.outputs = { + "Alpha": alpha, + "EmissionExps": emission_exps, + "TransitionExps": transition_exps, + "LogLikelihood": log_likelihood + } def setUp(self): self.op_type = "linear_chain_crf" @@ -117,6 +130,13 @@ class TestLinearChainCrfOp(OpTest): def test_check_output(self): self.check_output() + def test_check_grad(self): + self.check_grad(["Emission", "Transition"], "LogLikelihood") + + def test_check_grad_ignore_transition(self): + self.check_grad( + ["Emission"], "LogLikelihood", no_grad_set=set("Transition")) + if __name__ == "__main__": unittest.main()