提交 80a5ee00 编写于 作者: C caoying03

fix forward and add backward.

上级 3123e3cf
...@@ -17,6 +17,22 @@ limitations under the License. */ ...@@ -17,6 +17,22 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace {
template <typename T>
T NormalizeL1(T* x, size_t len) {
T sum = 0.;
for (size_t i = 0; i < len; ++i) sum += x[i];
// (This comment is from the old LinearChainCRFLayer.)
// Right now, we just bet that sum won't be zero. If this really happens, we
// will figure out what should be done then.
PADDLE_ENFORCE(sum,
"The unnormalized probabilites of all possible unfinished "
"sequences must be greater than 0.");
for (size_t i = 0; i < len; ++i) x[i] /= sum;
return sum;
}
} // namespace
using framework::LoDTensor; using framework::LoDTensor;
using framework::LoD; using framework::LoD;
...@@ -54,13 +70,25 @@ class LinearChainCrfOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -54,13 +70,25 @@ class LinearChainCrfOpMaker : public framework::OpProtoAndCheckerMaker {
"each tag value \f$v$\f. This vector is called a forward vecotr and " "each tag value \f$v$\f. This vector is called a forward vecotr and "
"will also be used in backward computations.") "will also be used in backward computations.")
.AsIntermediate(); .AsIntermediate();
AddOutput("EmissionExps",
"The exponentials of Input(Emission). This is an intermediate "
"computational result in forward computation, and will be reused "
"in backward computation.")
.AsIntermediate();
AddOutput("TransitionExps",
"The exponentials of Input(Transition). This is an intermediate "
"computational result in forward computation, and will be reused "
"in backward computation.")
.AsIntermediate();
AddOutput( AddOutput(
"LogLikelihood", "LogLikelihood",
"(Tensor, default: Tensor<float>). The logarithm of the conditional " "(Tensor, default: Tensor<float>). The logarithm of the "
"conditional "
"likelihood of each training sample in a mini-batch. This is a 2-D " "likelihood of each training sample in a mini-batch. This is a 2-D "
"tensor with shape [S x 1], where S is the sequence number in a " "tensor with shape [S x 1], where S is the sequence number in a "
"mini-batch. " "mini-batch. "
"Note: S is equal to the sequence number in a mini-batch. The output " "Note: S is equal to the sequence number in a mini-batch. The "
"output "
"is no longer a LoDTensor."); "is no longer a LoDTensor.");
AddComment(R"DOC( AddComment(R"DOC(
Conditional Random Field defines an undirected probabilistic graph with nodes Conditional Random Field defines an undirected probabilistic graph with nodes
...@@ -129,6 +157,10 @@ class LinearChainCrfOp : public framework::OperatorWithKernel { ...@@ -129,6 +157,10 @@ class LinearChainCrfOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Alpha"), PADDLE_ENFORCE(ctx->HasOutput("Alpha"),
"Output(Alpha) should be not null."); "Output(Alpha) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("EmissionExps"),
"Output(EmissionExps) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("TransitionExps"),
"Output(TransitionExps) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("LogLikelihood"), PADDLE_ENFORCE(ctx->HasOutput("LogLikelihood"),
"Output(LogLikelihood) should be not null."); "Output(LogLikelihood) should be not null.");
...@@ -143,7 +175,7 @@ class LinearChainCrfOp : public framework::OperatorWithKernel { ...@@ -143,7 +175,7 @@ class LinearChainCrfOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
transition_dims[0] - 2, transition_dims[1], transition_dims[0] - 2, transition_dims[1],
"An invalid dimension for the Input(Transition), which should " "An invalid dimension for the Input(Transition), which should "
"be a 2-D tensor with shape [D + 2 x D]."); "be a 2-D tensor with shape [(D + 2) x D].");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
emission_dims[1], transition_dims[1], emission_dims[1], transition_dims[1],
"The 2nd dimension of the Input(Emission) and the Input(Transition) " "The 2nd dimension of the Input(Emission) and the Input(Transition) "
...@@ -157,11 +189,14 @@ class LinearChainCrfOp : public framework::OperatorWithKernel { ...@@ -157,11 +189,14 @@ class LinearChainCrfOp : public framework::OperatorWithKernel {
"should be the same."); "should be the same.");
ctx->SetOutputDim("Alpha", emission_dims); ctx->SetOutputDim("Alpha", emission_dims);
ctx->SetOutputDim("EmissionExps", emission_dims);
ctx->SetOutputDim("TransitionExps", transition_dims);
// (TODO caoying) This is tricky. The 1st dimension of Output(LogLikelihood) // (TODO caoying) This is tricky. The 1st dimension of Output(LogLikelihood)
// is the sequence number in a mini-batch. The dimension set here should be // is the sequence number in a mini-batch. The dimension set here should be
// resized to its correct size in the function Compute. // resized to its correct size in the function Compute.
ctx->SetOutputDim("LogLikelihood", {emission_dims[0], 1}); ctx->SetOutputDim("LogLikelihood", {emission_dims[0], 1});
ctx->ShareLoD("Emission", /*->*/ "EmissionExps");
} }
protected: protected:
...@@ -180,9 +215,12 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T> ...@@ -180,9 +215,12 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"This kernel only runs on CPU."); "This kernel only runs on CPU.");
auto* emission_weights = ctx.Input<LoDTensor>("Emission"); auto* emission_weights = ctx.Input<LoDTensor>("Emission");
auto* transition_weights = ctx.Input<Tensor>("Transition"); auto* transition_weights = ctx.Input<Tensor>("Transition");
auto* emission_exps = ctx.Output<LoDTensor>("EmissionExps");
emission_exps->mutable_data<T>(platform::CPUPlace());
auto* transition_exps = ctx.Output<Tensor>("TransitionExps");
transition_exps->mutable_data<T>(platform::CPUPlace());
auto* label = ctx.Input<LoDTensor>("Label"); auto* label = ctx.Input<LoDTensor>("Label");
auto in_lod = emission_weights->lod(); auto in_lod = emission_weights->lod();
...@@ -195,18 +233,29 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T> ...@@ -195,18 +233,29 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
const size_t level = 0; const size_t level = 0;
auto emission_dims = emission_weights->dims(); auto emission_dims = emission_weights->dims();
const size_t batch_size = emission_dims[0];
const size_t tag_num = emission_dims[1];
const size_t seq_num = in_lod[level].size() - 1; const size_t seq_num = in_lod[level].size() - 1;
// TODO(caoying) These local variables seems to be created and destroied
// every time this function is called. Will this bring additional overhead?
Tensor emission_exps;
Tensor emission_row_max; Tensor emission_row_max;
Tensor transition_exps;
emission_exps.mutable_data<T>(emission_dims, platform::CPUPlace());
emission_row_max.mutable_data<T>( emission_row_max.mutable_data<T>(
framework::make_ddim({emission_dims[0], 1}), platform::CPUPlace()); framework::make_ddim({static_cast<int>(batch_size), 1}),
transition_exps.mutable_data<T>(transition_weights->dims(), platform::CPUPlace());
platform::CPUPlace());
auto place = ctx.GetEigenDevice<platform::CPUPlace>();
auto x = EigenMatrix<T>::From(*emission_weights);
auto x_row_max = EigenMatrix<T>::From(emission_row_max);
x_row_max.device(place) =
x.maximum(Eigen::DSizes<int, 1>(1))
.reshape(Eigen::DSizes<int, 2>(int(batch_size), 1));
auto x_exps = EigenMatrix<T>::From(*emission_exps);
x_exps.device(place) =
(x - x_row_max.broadcast(Eigen::DSizes<int, 2>(1, tag_num))).exp();
auto w = EigenMatrix<T>::From(*transition_weights);
auto w_exps = EigenMatrix<T>::From(*transition_exps);
w_exps.device(place) = w.exp();
auto* alpha = ctx.Output<LoDTensor>("Alpha"); auto* alpha = ctx.Output<LoDTensor>("Alpha");
alpha->mutable_data<T>(ctx.GetPlace()); alpha->mutable_data<T>(ctx.GetPlace());
...@@ -214,117 +263,124 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T> ...@@ -214,117 +263,124 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
// resize the output tensor to the correct dimension. // resize the output tensor to the correct dimension.
ll->Resize({static_cast<int>(seq_num), 1}); ll->Resize({static_cast<int>(seq_num), 1});
T* log_likelihood = ll->mutable_data<T>(ctx.GetPlace()); T* log_likelihood = ll->mutable_data<T>(ctx.GetPlace());
for (size_t i = 0; i < seq_num; ++i) { for (size_t i = 0; i < seq_num; ++i) {
int start_pos = static_cast<int>(in_lod[level][i]); int start_pos = static_cast<int>(in_lod[level][i]);
int end_pos = static_cast<int>(in_lod[level][i + 1]); int end_pos = static_cast<int>(in_lod[level][i + 1]);
const Tensor one_seq = emission_weights->Slice<T>(start_pos, end_pos); const Tensor one_seq = emission_weights->Slice<T>(start_pos, end_pos);
Tensor one_seq_row_max = emission_row_max.Slice<T>(start_pos, end_pos); Tensor one_seq_row_max = emission_row_max.Slice<T>(start_pos, end_pos);
Tensor one_seq_exps = emission_exps.Slice<T>(start_pos, end_pos); Tensor one_seq_exps = emission_exps->Slice<T>(start_pos, end_pos);
const Tensor one_seq_label = label->Slice<T>(start_pos, end_pos); const Tensor one_seq_label = label->Slice<T>(start_pos, end_pos);
Tensor one_seq_alpha = alpha->Slice<T>(start_pos, end_pos); Tensor one_seq_alpha = alpha->Slice<T>(start_pos, end_pos);
log_likelihood[i] = ForwardOneSequence( log_likelihood[i] = ForwardOneSequence(
ctx.device_context(), one_seq, one_seq_row_max, one_seq_exps, &one_seq, &one_seq_row_max, &one_seq_exps, transition_weights,
(*transition_weights), transition_exps, one_seq_label, one_seq_alpha); transition_exps, &one_seq_label, &one_seq_alpha);
} }
} }
protected: protected:
T ForwardOneSequence(const platform::DeviceContext& ctx, T ForwardOneSequence(const Tensor* emission, const Tensor* emission_row_max,
const Tensor& emission, Tensor& emission_row_max, const Tensor* emission_exps, const Tensor* trans_weights,
Tensor& emission_exps, const Tensor& trans_weights, const Tensor* trans_weight_exps, const Tensor* label,
Tensor& trans_weight_exps, const Tensor& label, Tensor* alpha) const {
Tensor& alpha) const { const T* x = emission->data<T>();
// (TODO caoying) Evaluate and optimize this. const T* x_row_max = emission_row_max->data<T>();
// The Eigen compution kernel will be invoked for multiple times. const T* x_exps = emission_exps->data<T>();
// Some computations regardless of sequence inforamtion could be performed const T* w = trans_weights->data<T>();
// only one time for the entire batch. This potentially could be optimized. const T* w_exps = trans_weight_exps->data<T>();
T* alpha_value = alpha->data<T>();
auto x_dims = emission.dims();
auto x_dims = emission->dims();
const size_t seq_length = x_dims[0]; const size_t seq_length = x_dims[0];
const size_t tag_num = x_dims[1]; const size_t tag_num = x_dims[1];
T* alpha_value = alpha.data<T>();
auto x = EigenMatrix<T>::From(emission);
auto x_row_max = EigenMatrix<T>::From(emission_row_max);
const int class_dim = 1;
x_row_max.device(*ctx.GetEigenDevice<platform::CPUPlace>()) =
x.maximum(Eigen::DSizes<int, 1>(class_dim))
.reshape(Eigen::DSizes<int, 2>(int(seq_length), 1));
auto x_exps = EigenMatrix<T>::From(emission_exps);
x_exps.device(*ctx.GetEigenDevice<platform::CPUPlace>()) =
(x - x_row_max.broadcast(Eigen::DSizes<int, 2>(1, tag_num))).exp();
auto w = EigenMatrix<T>::From(trans_weights);
auto w_exps = EigenMatrix<T>::From(trans_weight_exps);
w_exps.device(*ctx.GetEigenDevice<platform::CPUPlace>()) = w.exp();
// The 1st row of w are transition weights for start mask. // The 1st row of w are transition weights for start mask.
const size_t start_ridx = 0;
// The 2nd row of w are transition weights for end mask. // The 2nd row of w are transition weights for end mask.
const size_t end_ridx = 1;
// Transition weights among other tags begins from the 3rd row of w. // Transition weights among other tags begins from the 3rd row of w.
const size_t state_base_ridx = 2; const size_t state_trans_base_idx = 2;
for (size_t i = 0; i < tag_num; ++i) { for (size_t i = 0; i < tag_num; ++i) {
alpha_value[i] = w_exps(start_ridx, i) * x_exps(0, i); alpha_value[i] = w_exps[i] * x_exps[i];
} }
T ll = -x_row_max(0, 1) - std::log(NormalizeL1(alpha_value, tag_num)); T ll = -x_row_max[0] - std::log(NormalizeL1<T>(alpha_value, tag_num));
for (size_t k = 1; k < seq_length; ++k) { for (size_t k = 1; k < seq_length; ++k) {
for (size_t i = 0; i < tag_num; ++i) { for (size_t i = 0; i < tag_num; ++i) {
T sum = 0.; T sum = 0.;
for (size_t j = 0; j < tag_num; ++j) { for (size_t j = 0; j < tag_num; ++j) {
sum += alpha_value[(k - 1) * tag_num + j] * sum += alpha_value[(k - 1) * tag_num + j] *
w_exps(j + state_base_ridx, i); w_exps[(j + state_trans_base_idx) * tag_num + i];
} }
alpha_value[k * tag_num + i] = x_exps(k, i) * sum; alpha_value[k * tag_num + i] = x_exps[k * tag_num + i] * sum;
} }
ll -= x_row_max(k, 1) + ll -= x_row_max[k] +
std::log(NormalizeL1(alpha_value + k * tag_num, tag_num)); std::log(NormalizeL1<T>(alpha_value + k * tag_num, tag_num));
} }
T sum = 0.; T sum = 0.;
for (size_t i = 0; i < tag_num; ++i) { for (size_t i = 0; i < tag_num; ++i) {
sum += alpha_value[(seq_length - 1) * tag_num + i] * w_exps(end_ridx, i); sum += alpha_value[(seq_length - 1) * tag_num + i] * w_exps[tag_num + i];
} }
ll -= std::log(sum); ll -= std::log(sum);
const int* lbl = label.data<int>(); const int* lbl = label->data<int>();
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
*std::max_element(lbl, lbl + seq_length), tag_num, *std::max_element(lbl, lbl + seq_length), tag_num,
"An invalid tag label that execesses the largest tag number."); "An invalid tag label that execesses the largest tag number.");
// Calculate the nominator part, which depends on the label sequence. // Calculate the nominator part, which depends on the label sequence.
ll += w(start_ridx, lbl[0]) + x(start_ridx, lbl[0]) + ll += w[lbl[0]] /*start transition*/ + x[lbl[0]] +
w(end_ridx, lbl[seq_length - 1]); w[tag_num + lbl[seq_length - 1]] /*end transition*/;
for (size_t k = 1; k < seq_length; ++k) for (size_t k = 1; k < seq_length; ++k)
ll += x(k, lbl[k]) + w(lbl[k - 1], lbl[k]); ll += x[k * tag_num + lbl[k]] + w[lbl[k - 1] * tag_num + lbl[k]];
return -ll; return -ll;
} }
private:
T NormalizeL1(T* x, size_t len) const {
T sum = 0.;
for (size_t i = 0; i < len; ++i) sum += x[i];
// (This comment is from the old LinearChainCRFLayer.)
// Right now, we just bet that sum won't be zero. If this really happens, we
// will figure out what should be done then.
PADDLE_ENFORCE(sum,
"The unnormalized probabilites of all possible unfinished "
"sequences must be greater than 0.");
for (size_t i = 0; i < len; ++i) x[i] /= sum;
return sum;
}
}; };
class LinearChainCrfGradOp : public framework::OperatorWithKernel { class LinearChainCrfGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {} void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("EmissionExps"),
"Input(EmissionExps) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("TransitionExps"),
"Input(TransitionExps) should be not null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("LogLikelihood")),
"Input(LogLikelihood@GRAD) shoudl be not null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Emission")),
"Output(Emission@GRAD) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Transition")),
"Output(Transition@GRAD) should be not null.");
auto emission_exps_dims = ctx->GetInputDim("EmissionExps");
auto transition_exps_dims =
ctx->GetInputDim(framework::GradVarName("TransitionExps"));
auto label_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(emission_exps_dims.size(), 2UL,
"The Input(EmissionExps) should be a 2-D tensor.");
PADDLE_ENFORCE_EQ(transition_exps_dims.size(), 2UL,
"The Input(TransitionExps) should be a 2-D tensor.");
PADDLE_ENFORCE_EQ(
transition_exps_dims[0] - 2, transition_exps_dims[1],
"An invalid dimension for the Input(TransitionExps), which should "
"be a 2-D tensor with shape [(D + 2) x D].");
PADDLE_ENFORCE_EQ(
emission_exps_dims[1], transition_exps_dims[1],
"The 2nd dimension of the Input(EmissionExps) and the "
"Input(TransitionExps) should be equal to the tag number.");
PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL,
"The Input(Label) should be a 2-D tensor with the 2nd "
"dimensions fixed to 1.");
PADDLE_ENFORCE_EQ(
emission_exps_dims[0], label_dims[0],
"The height of Input(EmissionExps) and the height of Input(Label) "
"should be the same.");
ctx->SetOutputDim(framework::GradVarName("Emission"), emission_exps_dims);
ctx->SetOutputDim(framework::GradVarName("Transition"),
transition_exps_dims);
}
}; };
template <typename T> template <typename T>
...@@ -334,6 +390,134 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T> ...@@ -334,6 +390,134 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"This kernel only runs on CPU."); "This kernel only runs on CPU.");
auto* ll_grad =
ctx.Input<LoDTensor>(framework::GradVarName("LogLikelihood"));
auto* label = ctx.Input<LoDTensor>("Label");
auto* emission_exps = ctx.Input<LoDTensor>("EmissionExps");
auto* transition_exps = ctx.Input<Tensor>("TransitionExps");
auto* alpha = ctx.Input<Tensor>("Alpha");
auto* emission_grad =
ctx.Output<Tensor>(framework::GradVarName("Emission"));
emission_grad->mutable_data<T>(platform::CPUPlace());
auto* trans_grad = ctx.Output<Tensor>(framework::GradVarName("Transition"));
if (trans_grad) trans_grad->mutable_data<T>(platform::CPUPlace());
auto emission_dims = emission_exps->dims();
// Beta is the memo table used in dynamic programming to calculate the
// backwark vectors. For a backward vector i (the i-th row of beta), it
// captures the unnormalized probabilities of partial sequences starting at
// position i.
Tensor beta;
beta.mutable_data<T>(emission_dims, platform::CPUPlace());
auto place = ctx.GetEigenDevice<platform::CPUPlace>();
auto x_grad = EigenMatrix<T>::From(*emission_grad);
auto out_grad = EigenMatrix<T>::From(*ll_grad);
x_grad.device(place) =
x_grad * out_grad.broadcast(Eigen::DSizes<int, 2>(1, emission_dims[1]));
const size_t level = 0; // currently, only support sequence.
auto lod = emission_exps->lod();
for (size_t i = 0; i < lod[level].size() - 1; ++i) {
int start_pos = static_cast<int>(lod[level][i]);
int end_pos = static_cast<int>(lod[level][i + 1]);
const Tensor one_seq_emission_exps =
emission_exps->Slice<T>(start_pos, end_pos);
const Tensor one_seq_label = label->Slice<T>(start_pos, end_pos);
const Tensor one_seq_alpha = alpha->Slice<T>(start_pos, end_pos);
Tensor one_seq_beta = beta.Slice<T>(start_pos, end_pos);
Tensor one_seq_emission_grad =
emission_grad->Slice<T>(start_pos, end_pos);
BackwardOneSequence(ctx.device_context(), &one_seq_emission_exps,
transition_exps, &one_seq_alpha, &one_seq_label,
&one_seq_beta, trans_grad, &one_seq_emission_grad);
}
}
protected:
void BackwardOneSequence(const platform::DeviceContext& ctx,
const Tensor* emission_exps,
const Tensor* transition_exps, const Tensor* alpha,
const Tensor* label, Tensor* beta,
Tensor* transition_grad,
Tensor* emission_grad) const {
const T* w_exps = transition_exps->data<T>();
const T* x_exps = emission_exps->data<T>();
const int* label_value = label->data<int>();
T* beta_value = beta->data<T>();
auto x_dims = emission_exps->dims();
const size_t seq_length = x_dims[0];
const size_t tag_num = x_dims[1];
const size_t state_trans_base_idx = 2;
// Calculate the backwark vectors beta.
for (int i = 0; i < tag_num; ++i)
beta_value[(seq_length - 1) * tag_num + i] = w_exps[tag_num + i];
NormalizeL1<T>(beta_value + (seq_length - 1) * tag_num, tag_num);
for (int k = seq_length - 2; k >= 0; --k) {
for (int i = 0; i < tag_num; ++i) {
T sum = 0.;
for (int j = 0; j < tag_num; ++j) {
sum += x_exps[(i + state_trans_base_idx) * tag_num + j] *
beta_value[(k + 1) * tag_num + j] *
x_exps[(k + 1) * tag_num + j];
}
beta_value[k * tag_num + i] = sum;
}
NormalizeL1<T>(beta_value + k * tag_num, tag_num);
}
auto alpha_mat = EigenMatrix<T>::From(*alpha);
auto beta_mat = EigenMatrix<T>::From(*beta);
auto x_grad_mat = EigenMatrix<T>::From(*emission_grad);
auto* place = ctx.GetEigenDevice<platform::CPUPlace>();
x_grad_mat.device(*place) = alpha_mat * beta_mat;
x_grad_mat /= x_grad_mat.sum(Eigen::DSizes<int, 1>(1))
.reshape(Eigen::DSizes<int, 2>(seq_length, 1))
.broadcast(Eigen::DSizes<int, 2>(1, tag_num));
for (int k = 0; k < seq_length; ++k)
x_grad_mat(k, label_value[k]) -= static_cast<T>(1);
if (transition_grad) {
T* trans_grad = transition_grad->data<T>();
for (size_t k = 0; k < tag_num; ++k) {
trans_grad[k] += x_grad_mat(/*from start state*/ 0, k);
trans_grad[tag_num + k] +=
x_grad_mat(/*to end state*/ seq_length - 1, k);
}
auto x_exps_mat = EigenMatrix<T>::From(*emission_exps);
beta_mat = beta_mat * x_exps_mat;
beta_mat /= beta_mat.sum(Eigen::DSizes<int, 1>(1))
.reshape(Eigen::DSizes<int, 2>(seq_length, 1))
.broadcast(Eigen::DSizes<int, 2>(1, tag_num));
for (int k = 1; k < seq_length; ++k) {
T sum = 0.;
for (int i = 0; i < tag_num; ++i) {
for (int j = 0; j < tag_num; ++j)
sum += x_exps_mat(i, j) * alpha_mat(k - 1, i) * beta_mat(k, j);
}
sum = static_cast<T>(1) / sum;
for (int i = 0; i < tag_num; ++i) {
for (int j = 0; j < tag_num; ++j) {
trans_grad[(i + 2) * tag_num + j] +=
sum * x_exps_mat(i, j) * alpha_mat(k - 1, i) * beta_mat(k, j);
}
}
trans_grad[label_value[k - 1] * tag_num + label_value[k]] -=
static_cast<T>(1);
}
}
} }
}; };
......
...@@ -30,20 +30,24 @@ class LinearChainCrfOpKernel : public framework::OpKernel<T> { ...@@ -30,20 +30,24 @@ class LinearChainCrfOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override; void Compute(const framework::ExecutionContext& ctx) const override;
protected: protected:
T ForwardOneSequence(const platform::DeviceContext& ctx, T ForwardOneSequence(const Tensor* emission, const Tensor* emission_row_max,
const Tensor& emission, Tensor& emission_row_max, const Tensor* emission_exps, const Tensor* trans_weights,
Tensor& emission_exps, const Tensor& trans_weights, const Tensor* trans_weight_exps, const Tensor* label,
Tensor& trans_weight_exps, const Tensor& label, Tensor* alpha) const;
Tensor& a) const;
private:
T NormalizeL1(T* x, size_t len) const;
}; };
template <typename Place, typename T> template <typename Place, typename T>
class LinearChainCrfGradOpKernel : public framework::OpKernel<T> { class LinearChainCrfGradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override; void Compute(const framework::ExecutionContext& ctx) const override;
protected:
void BackwardOneSequence(const platform::DeviceContext& ctx,
const Tensor* emission_exps,
const Tensor* transition_exps, const Tensor* alpha,
const Tensor* label, Tensor* beta,
Tensor* transition_grad,
Tensor* emission_grad) const;
}; };
} // namespace operators } // namespace operators
......
...@@ -4,10 +4,12 @@ import numpy as np ...@@ -4,10 +4,12 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
import pdb
class LinearChainCrfForward(object): class LinearChainCrfForward(object):
def __init__(self, seq_start_positions, emission_weights, def __init__(self, seq_start_positions, emission_weights, emission_row_max,
transition_weights, labels): emission_exps, transition_weights, transition_exps, labels):
self.tag_num = emission_weights.shape[1] self.tag_num = emission_weights.shape[1]
self.seq_num = len(seq_start_positions) - 1 self.seq_num = len(seq_start_positions) - 1
...@@ -15,25 +17,25 @@ class LinearChainCrfForward(object): ...@@ -15,25 +17,25 @@ class LinearChainCrfForward(object):
self.labels = labels self.labels = labels
self.x = emission_weights self.x = emission_weights
self.x_row_max = np.amax(self.x, axis=1, keepdims=True) self.x_row_max = emission_row_max
self.x_exps = np.exp(self.x - self.x_row_max) self.x_exps = emission_exps
# unnormalized logits of the transition weights for the start mark. # unnormalized logits of the transition weights for the start mark.
self.a = transition_weights[0, :] self.a = transition_weights[0, :]
self.a_exps = np.exp(self.a) self.a_exps = transition_exps[0, :]
# unnormalized logits of the transition weights for the end mark. # unnormalized logits of the transition weights for the end mark.
self.b = transition_weights[1, :] self.b = transition_weights[1, :]
self.b_exps = np.exp(self.b) self.b_exps = transition_exps[1, :]
# unnormalized logits of the transition weights for all the other tags. # unnormalized logits of the transition weights for all the other tags.
self.w = transition_weights[2:, :] self.w = transition_weights[2:, :]
self.w_exps = np.exp(self.w) self.w_exps = transition_exps[2:, :]
# The output of linear chain crf operator. # The output of linear chain crf operator.
# alpha is a memo table in dynamic programming to caculate # alpha is a memo table in dynamic programming to caculate
# nomalization factor. # nomalization factor.
self.alpha = np.zeros( self.alpha = np.zeros(
(seq_start_positions[-1], self.tag_num), dtype="float32") (seq_start_positions[-1], self.tag_num), dtype="float32")
self.log_likelihood = np.zeros((self.tag_num, 1)) self.log_likelihood = np.zeros((self.seq_num, 1))
def _l1_norm(self, x): def _l1_norm(self, x):
s = np.sum(x) s = np.sum(x)
...@@ -91,11 +93,15 @@ class TestLinearChainCrfOp(OpTest): ...@@ -91,11 +93,15 @@ class TestLinearChainCrfOp(OpTest):
lod = [[0]] lod = [[0]]
for i in range(SEQ_NUM): for i in range(SEQ_NUM):
lod[-1].append(lod[-1][-1] + random.randint(1, MAX_SEQ_LEN)) lod[-1].append(lod[-1][-1] + random.randint(1, MAX_SEQ_LEN))
emission = np.random.uniform(-1, 1, emission = np.random.uniform(-1, 1,
[lod[-1][-1], TAG_NUM]).astype("float32") [lod[-1][-1], TAG_NUM]).astype("float32")
emission_row_max = np.amax(emission, axis=1, keepdims=True)
emission_exps = np.exp(emission - emission_row_max)
transition = np.random.uniform(-0.5, 0.5, transition = np.random.uniform(-0.5, 0.5,
[TAG_NUM + 2, TAG_NUM]).astype("float32") [TAG_NUM + 2, TAG_NUM]).astype("float32")
transition_exps = np.exp(transition)
labels = np.random.randint( labels = np.random.randint(
low=0, high=TAG_NUM, size=(lod[-1][-1], 1), dtype="int32") low=0, high=TAG_NUM, size=(lod[-1][-1], 1), dtype="int32")
...@@ -105,10 +111,17 @@ class TestLinearChainCrfOp(OpTest): ...@@ -105,10 +111,17 @@ class TestLinearChainCrfOp(OpTest):
"Label": (labels, lod) "Label": (labels, lod)
} }
crf = LinearChainCrfForward(lod[0], emission, transition, labels) crf = LinearChainCrfForward(lod[0], emission, emission_row_max,
emission_exps, transition, transition_exps,
labels)
alpha, log_likelihood = crf.crf_forward_compute() alpha, log_likelihood = crf.crf_forward_compute()
self.outputs = {"Alpha": alpha, "LogLikelihood": log_likelihood} self.outputs = {
"Alpha": alpha,
"EmissionExps": emission_exps,
"TransitionExps": transition_exps,
"LogLikelihood": log_likelihood
}
def setUp(self): def setUp(self):
self.op_type = "linear_chain_crf" self.op_type = "linear_chain_crf"
...@@ -117,6 +130,13 @@ class TestLinearChainCrfOp(OpTest): ...@@ -117,6 +130,13 @@ class TestLinearChainCrfOp(OpTest):
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self):
self.check_grad(["Emission", "Transition"], "LogLikelihood")
def test_check_grad_ignore_transition(self):
self.check_grad(
["Emission"], "LogLikelihood", no_grad_set=set("Transition"))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册