提交 bce4f7d6 编写于 作者: C caoying03

follow comments.

上级 4c630869
......@@ -228,8 +228,9 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
PADDLE_ENFORCE_GE(begin_idx, 0,
"The start row index must be greater than 0.");
PADDLE_ENFORCE_LE(end_idx, dims_[0], "The end row index is out of bound.");
PADDLE_ENFORCE_LT(begin_idx, end_idx,
"The start row index must be less than the end row index.");
PADDLE_ENFORCE_LT(
begin_idx, end_idx,
"The start row index must be smaller than the end row index.");
if (dims_[0] == 1) {
return *this;
......
......@@ -26,9 +26,10 @@ T NormalizeL1(T* x, size_t len) {
// Right now, we just bet that sum won't be zero. If this really happens, we
// will figure out what should be done then.
PADDLE_ENFORCE(sum,
"The unnormalized probabilites of all possible unfinished "
"The unnormalized probabilities of all possible unfinished "
"sequences must be greater than 0.");
for (size_t i = 0; i < len; ++i) x[i] /= sum;
T s = 1. / sum;
for (size_t i = 0; i < len; ++i) x[i] *= s;
return sum;
}
} // namespace
......@@ -36,9 +37,9 @@ T NormalizeL1(T* x, size_t len) {
using framework::LoDTensor;
using framework::LoD;
class LinearChainCrfOpMaker : public framework::OpProtoAndCheckerMaker {
class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LinearChainCrfOpMaker(framework::OpProto* proto,
LinearChainCRFOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
......@@ -51,11 +52,11 @@ class LinearChainCrfOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput(
"Transition",
"(Tensor, default: Tensor<float>). A Tensor with shape [(D + 2) x D]. "
"The learnable parameter for linear_chain_crf operator. "
"The learnable parameter for the linear_chain_crf operator. "
"See more details in the operator's comments.");
AddInput(
"Label",
"(LoDTensor, default: LoDTensor<int>). The ground truth which is a 2-D "
"(LoDTensor, default: LoDTensor<int>). The groundtruth which is a 2-D "
"LoDTensor with shape [N x 1], where N is the total element number in "
"a mini-batch.");
AddOutput(
......@@ -82,14 +83,11 @@ class LinearChainCrfOpMaker : public framework::OpProtoAndCheckerMaker {
.AsIntermediate();
AddOutput(
"LogLikelihood",
"(Tensor, default: Tensor<float>). The logarithm of the "
"conditional "
"(Tensor, default: Tensor<float>). The logarithm of the conditional "
"likelihood of each training sample in a mini-batch. This is a 2-D "
"tensor with shape [S x 1], where S is the sequence number in a "
"mini-batch. "
"Note: S is equal to the sequence number in a mini-batch. The "
"output "
"is no longer a LoDTensor.");
"mini-batch. Note: S is equal to the sequence number in a mini-batch. "
"The output is no longer a LoDTensor.");
AddComment(R"DOC(
Conditional Random Field defines an undirected probabilistic graph with nodes
denoting random variables and edges denoting dependencies between these
......@@ -100,11 +98,11 @@ variables. CRF learns the conditional probability \f$P(Y|X)\f$, where
Linear chain CRF is a special case of CRF that is useful for sequence labeling
task. Sequence labeling tasks do not assume a lot of conditional
independences among inputs. They only concern about the input and the output
being linear sequences. Thus, the graph model of CRF is a simple chain or
a line, which results in a linear chain CRF.
being linear sequences. Thus, the graph model of such a CRF is a simple chain
or a line, which results in the linear chain CRF.
This operator implements the Forward-Backward algorithm for linear chain CRF.
Please see http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
This operator implements the Forward-Backward algorithm for the linear chain
CRF. Please see http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
Equation:
......@@ -144,7 +142,7 @@ nonlinear activation.
}
};
class LinearChainCrfOp : public framework::OperatorWithKernel {
class LinearChainCRFOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -211,7 +209,7 @@ class LinearChainCrfOp : public framework::OperatorWithKernel {
};
template <typename T>
class LinearChainCrfOpKernel<platform::CPUPlace, T>
class LinearChainCRFOpKernel<platform::CPUPlace, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -262,11 +260,11 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
w_exps.device(place) = w.exp();
auto* alpha = ctx.Output<LoDTensor>("Alpha");
alpha->mutable_data<T>(ctx.GetPlace());
alpha->mutable_data<T>(platform::CPUPlace());
auto* ll = ctx.Output<LoDTensor>("LogLikelihood");
// resize the output tensor to the correct dimension.
ll->Resize({static_cast<int>(seq_num), 1});
T* log_likelihood = ll->mutable_data<T>(ctx.GetPlace());
T* log_likelihood = ll->mutable_data<T>(platform::CPUPlace());
for (size_t i = 0; i < seq_num; ++i) {
int start_pos = static_cast<int>(in_lod[level][i]);
int end_pos = static_cast<int>(in_lod[level][i + 1]);
......@@ -322,6 +320,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
}
alpha_value[k * tag_num + i] = x_exps[k * tag_num + i] * sum;
}
// NormalizeL1 is to avoid underflow or overflow at (*).
ll -= x_row_max[k] +
std::log(NormalizeL1<T>(alpha_value + k * tag_num, tag_num));
}
......@@ -330,6 +329,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
sum += alpha_value[(seq_length - 1) * tag_num + i] * w_exps[tag_num + i];
}
ll -= std::log(sum);
// Now ll is equal to -log(Z).
const int* lbl = label->data<int>();
PADDLE_ENFORCE_LT(
......@@ -347,7 +347,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
}
};
class LinearChainCrfGradOp : public framework::OperatorWithKernel {
class LinearChainCRFGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -407,11 +407,11 @@ class LinearChainCrfGradOp : public framework::OperatorWithKernel {
};
template <typename T>
class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
class LinearChainCRFGradOpKernel<platform::CPUPlace, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
PADDLE_ENFORCE(platform::is_cpu_place(platform::CPUPlace()),
"This kernel only runs on CPU.");
auto* label = ctx.Input<LoDTensor>("Label");
auto* emission_exps = ctx.Input<LoDTensor>("EmissionExps");
......@@ -493,6 +493,7 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
}
beta_value[k * tag_num + i] = sum;
}
// NormalizeL1 is to avoid underflow or overflow at (**).
NormalizeL1<T>(beta_value + k * tag_num, tag_num);
}
......@@ -534,7 +535,7 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
T sum = 0.;
for (size_t i = 0; i < tag_num; ++i) {
for (size_t j = 0; j < tag_num; ++j) {
sum += w_exps[(i + state_trans_base_idx) * tag_num + j] *
sum += w_exps[(i + state_trans_base_idx) * tag_num + j] * // (**)
alpha_mat(k - 1, i) * tmp_mat(k, j);
}
}
......@@ -557,11 +558,11 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(linear_chain_crf, ops::LinearChainCrfOp, ops::LinearChainCrfOpMaker,
linear_chain_crf_grad, ops::LinearChainCrfGradOp);
REGISTER_OP(linear_chain_crf, ops::LinearChainCRFOp, ops::LinearChainCRFOpMaker,
linear_chain_crf_grad, ops::LinearChainCRFGradOp);
REGISTER_OP_CPU_KERNEL(
linear_chain_crf,
ops::LinearChainCrfOpKernel<paddle::platform::CPUPlace, float>);
ops::LinearChainCRFOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
linear_chain_crf_grad,
ops::LinearChainCrfGradOpKernel<paddle::platform::CPUPlace, float>);
ops::LinearChainCRFGradOpKernel<paddle::platform::CPUPlace, float>);
......@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
class LinearChainCrfOpKernel : public framework::OpKernel<T> {
class LinearChainCRFOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override;
......@@ -37,7 +37,7 @@ class LinearChainCrfOpKernel : public framework::OpKernel<T> {
};
template <typename Place, typename T>
class LinearChainCrfGradOpKernel : public framework::OpKernel<T> {
class LinearChainCRFGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册