提交 cc220eec 编写于 作者: C caoying03

add forward computation of crf operator.

上级 cbcf11d9
......@@ -114,16 +114,19 @@ class Tensor {
const platform::DeviceContext& ctx);
/**
* @brief Return the slice of the tensor.
* @brief Return a sub-tensor of the given tensor.
*
* @param[in] begin_idx The begin index of the slice.
* @param[in] end_idx The end index of the slice.
* @param[in] begin_idx The index of the start row(inclusive) to slice.
* The index number begins from 0.
* @param[in] end_idx The index of the end row(exclusive) to slice.
* The index number begins from 0.
*/
template <typename T>
inline Tensor Slice(const int& begin_idx, const int& end_idx) const;
platform::Place place() const {
PADDLE_ENFORCE_NOT_NULL(holder_, "Tensor get place() must contains holder");
PADDLE_ENFORCE_NOT_NULL(
holder_, "A holder must exist when calling the method place().");
return holder_->place();
}
......
......@@ -168,10 +168,11 @@ inline void Tensor::CopyFromVector(const std::vector<T>& src,
template <typename T>
inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
check_memory_size<T>();
PADDLE_ENFORCE_GE(begin_idx, 0, "Slice begin index is less than zero.");
PADDLE_ENFORCE_LE(end_idx, dims_[0], "Slice end index is out of bound.");
PADDLE_ENFORCE_GE(begin_idx, 0,
"The start row index must be greater than 0.");
PADDLE_ENFORCE_LE(end_idx, dims_[0], "The end row index is out of bound.");
PADDLE_ENFORCE_LT(begin_idx, end_idx,
"Begin index must be less than end index.");
"The start row index must be less than the end row index.");
if (dims_[0] == 1) {
return *this;
......
......@@ -49,7 +49,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
ctx->ShareLoD("X", /*->*/ "Y");
}
// Explicitly set data type of output of the cross_entropy operator
// Explicitly set that data type of the output of the cross_entropy operator
// is determined by its input "X".
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
......
......@@ -17,6 +17,9 @@ limitations under the License. */
namespace paddle {
namespace operators {
using framework::LoDTensor;
using framework::LoD;
class LinearChainCrfOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LinearChainCrfOpMaker(framework::OpProto* proto,
......@@ -77,14 +80,14 @@ Please see http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
Equation:
- Denote the first input of this operator (Emission) as \f$x\f$ here.
- The first D values of the second input (Transition) of this operator are for
starting weights, denoted as \f$a\f$ here.
- The next D values of the second input (Transition) of this operator are for
ending weights, denoted as \f$b\f$ here.
- The remaning values of the second input (Transition) are for transition
weights, denoted as \f$w\f$ here.
- Denote the third input of this operator (Label) as \f$s\f$ here.
- Denote Input(Emission) to this operator as \f$x\f$ here.
- The first D values of Input(Transition) to this operator are for starting
weights, denoted as \f$a\f$ here.
- The next D values of Input(Transition) of this operator are for ending
weights, denoted as \f$b\f$ here.
- The remaning values of Input(Transition) are for transition weights,
denoted as \f$w\f$ here.
- Denote Input(Label) as \f$s\f$ here.
The probability of a sequence \f$s\f$ of length \f$L\f$ is defined as:
\f$P(s) = (1/Z) exp(a_{s_1} + b_{s_L}
......@@ -107,8 +110,7 @@ sequences internally, it expects UNSCALED emission feature weights.
Please do not call this op with the emission feature being output of any
nonlinear activation.
3. The 2nd dimension of the first input of this operator (Emission) MUST be
equal to the tag number.
3. The 2nd dimension of Input(Emission) MUST be equal to the tag number.
)DOC");
}
......@@ -136,33 +138,188 @@ class LinearChainCrfOp : public framework::OperatorWithKernel {
auto label_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(emission_dims.size(), 2UL,
"The input Emission should be a 2-D tensor.");
"The Input(Emission) should be a 2-D tensor.");
PADDLE_ENFORCE_EQ(transition_dims.size(), 2UL,
"The input Transition should be a 2-D tensor.");
"The Input(Transition) should be a 2-D tensor.");
PADDLE_ENFORCE_EQ(
transition_dims[0] + 2, transition_dims[1],
"An invalid dimension for the input Transition, which should "
transition_dims[0] - 2, transition_dims[1],
"An invalid dimension for the Input(Transition), which should "
"be a 2-D tensor with shape [D + 2 x D].");
PADDLE_ENFORCE_EQ(
emission_dims[1], transition_dims[1],
"The 2nd dimension of the input Emission and the input Transition "
"The 2nd dimension of the Input(Emission) and the Input(Transition) "
"should be equal to the tag number.");
PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL,
"The input Label should be a 2-D tensor "
"with the 2nd dimensions fixed to 1.");
"The Input(Label) should be a 2-D tensor with the 2nd "
"dimensions fixed to 1.");
PADDLE_ENFORCE_EQ(
emission_dims[0], label_dims[0],
"The height of Input(Emission) and the height of Input(Label) "
"should be the same.");
ctx->SetOutputDim("Alpha", emission_dims);
// (TODO caoying) This is tricky. The 1st dimension of Output(LogLikelihood)
// is the sequence number in a mini-batch. The dimension set here should be
// resized to its correct size in the function Compute.
ctx->SetOutputDim("LogLikelihood", {emission_dims[0], 1});
}
// Explicitly set data type of output of the linear_chain_crf operator
// is determined by its input "Emission".
// Explicitly set that the data type of output of the linear_chain_crf
// operator is determined by its input "Emission".
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("Emission")->type());
}
};
template <typename T>
class LinearChainCrfOpKernel<platform::CPUPlace, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"This kernel only runs on CPU.");
auto* emission_weights = ctx.Input<LoDTensor>("Emission");
auto* transition_weights = ctx.Input<Tensor>("Transition");
auto* label = ctx.Input<LoDTensor>("Label");
auto in_lod = emission_weights->lod();
// TODO(caoying) The checks related to LoD information should be
// moved into InferShape once after the InferShape is refactored.
PADDLE_ENFORCE_EQ(emission_weights->NumLevels(), 1UL,
"The Input(Emission) should be a sequence.");
PADDLE_ENFORCE_EQ(label->NumLevels(), 1UL,
"The Input(Label) should be a sequence.");
const size_t level = 0;
auto emission_dims = emission_weights->dims();
const size_t seq_num = in_lod[level].size() - 1;
// TODO(caoying) These local variables seems to be created and destroied
// every time this function is called. Will this bring additional overhead?
Tensor emission_exps;
Tensor emission_row_max;
Tensor transition_exps;
emission_exps.mutable_data<T>(emission_dims, platform::CPUPlace());
emission_row_max.mutable_data<T>(
framework::make_ddim({emission_dims[0], 1}), platform::CPUPlace());
transition_exps.mutable_data<T>(transition_weights->dims(),
platform::CPUPlace());
auto* alpha = ctx.Output<Tensor>("Alpha");
alpha->mutable_data<T>(ctx.GetPlace());
auto* ll = ctx.Output<Tensor>("LogLikelihood");
// resize the output tensor to the correct dimension.
ll->Resize({static_cast<int>(seq_num), 1});
T* log_likelihood = ll->mutable_data<T>(ctx.GetPlace());
for (size_t i = 0; i < seq_num; ++i) {
int start_pos = static_cast<int>(in_lod[level][i]);
int end_pos = static_cast<int>(in_lod[level][i + 1]);
const Tensor one_seq = emission_weights->Slice<T>(start_pos, end_pos);
Tensor one_seq_row_max = emission_row_max.Slice<T>(start_pos, end_pos);
Tensor one_seq_exps = emission_exps.Slice<T>(start_pos, end_pos);
const Tensor one_seq_label = label->Slice<T>(start_pos, end_pos);
Tensor one_seq_alpha = alpha->Slice<T>(start_pos, end_pos);
log_likelihood[i] = ForwardOneSequence(
ctx.device_context(), one_seq, one_seq_row_max, one_seq_exps,
(*transition_weights), transition_exps, one_seq_label, one_seq_alpha);
}
}
protected:
T ForwardOneSequence(const platform::DeviceContext& ctx,
const Tensor& emission, Tensor& emission_row_max,
Tensor& emission_exps, const Tensor& trans_weights,
Tensor& trans_weight_exps, const Tensor& label,
Tensor& alpha) const {
// (TODO caoying) Evaluate and optimize this.
// The Eigen compution kernel will be invoked for multiple times.
// Some computations regardless of sequence inforamtion could be performed
// only one time for the entire batch. This potentially could be optimized.
auto x_dims = emission.dims();
const size_t seq_length = x_dims[0];
const size_t tag_num = x_dims[1];
T* alpha_value = alpha.data<T>();
auto x = EigenMatrix<T>::From(emission);
auto x_row_max = EigenMatrix<T>::From(emission_row_max);
const int class_dim = 1;
x_row_max.device(*ctx.GetEigenDevice<platform::CPUPlace>()) =
x.maximum(Eigen::DSizes<int, 1>(class_dim))
.reshape(Eigen::DSizes<int, 2>(int(seq_length), 1));
auto x_exps = EigenMatrix<T>::From(emission_exps);
x_exps.device(*ctx.GetEigenDevice<platform::CPUPlace>()) =
(x - x_row_max.broadcast(Eigen::DSizes<int, 2>(1, tag_num))).exp();
auto w = EigenMatrix<T>::From(trans_weights);
auto w_exps = EigenMatrix<T>::From(trans_weight_exps);
w_exps.device(*ctx.GetEigenDevice<platform::CPUPlace>()) = w.exp();
// The 1st row of w are transition weights for start mask.
const size_t start_ridx = 0;
// The 2nd row of w are transition weights for end mask.
const size_t end_ridx = 1;
// Transition weights among other tags begins from the 3rd row of w.
const size_t state_base_ridx = 2;
for (size_t i = 0; i < tag_num; ++i) {
alpha_value[i] = w_exps(start_ridx, i) * x_exps(0, i);
}
T ll = -x_row_max(0, 1) - std::log(NormalizeL1(alpha_value, tag_num));
for (size_t k = 1; k < seq_length; ++k) {
for (size_t i = 0; i < tag_num; ++i) {
T sum = 0.;
for (size_t j = 0; j < tag_num; ++j) {
sum += alpha_value[(k - 1) * tag_num + j] *
w_exps(j + state_base_ridx, i);
}
alpha_value[k * tag_num + i] = x_exps(k, i) * sum;
}
ll -= x_row_max(k, 1) +
std::log(NormalizeL1(alpha_value + k * tag_num, tag_num));
}
T sum = 0.;
for (size_t i = 0; i < tag_num; ++i) {
sum += alpha_value[(seq_length - 1) * tag_num + i] * w_exps(end_ridx, i);
}
ll -= std::log(sum);
const int* lbl = label.data<int>();
PADDLE_ENFORCE_LT(
*std::max_element(lbl, lbl + seq_length), tag_num,
"An invalid tag label that execesses the largest tag number.");
// Calculate the nominator part, which depends on the label sequence.
ll += w(start_ridx, lbl[0]) + x(start_ridx, lbl[0]) +
w(end_ridx, lbl[seq_length - 1]);
for (size_t k = 1; k < seq_length; ++k)
ll += x(k, lbl[k]) + w(lbl[k - 1], lbl[k]);
return -ll;
}
private:
T NormalizeL1(T* x, size_t len) const {
T sum = 0.;
for (size_t i = 0; i < len; ++i) sum += x[i];
// (This comment is from the old LinearChainCRFLayer.)
// Right now, we just bet that sum won't be zero. If this really happens, we
// will figure out what should be done then.
PADDLE_ENFORCE(sum,
"The unnormalized probabilites of all possible unfinished "
"sequences must be greater than 0.");
for (size_t i = 0; i < len; ++i) x[i] /= sum;
return sum;
}
};
class LinearChainCrfGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -171,12 +328,25 @@ class LinearChainCrfGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {}
};
template <typename T>
class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"This kernel only runs on CPU.");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(linear_chain_crf, ops::LinearChainCrfOp, ops::LinearChainCrfOpMaker,
linear_chain_crf_grad, ops::LinearChainCrfGradOp);
REGISTER_OP_CPU_KERNEL(linear_chain_crf, ops::LinearChainCrfOpKernel<float>);
REGISTER_OP_CPU_KERNEL(linear_chain_crf_grad,
ops::LinearChainCrfGradOpKernel<float>);
REGISTER_OP_CPU_KERNEL(
linear_chain_crf,
ops::LinearChainCrfOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
linear_chain_crf_grad,
ops::LinearChainCrfGradOpKernel<paddle::platform::CPUPlace, float>);
......@@ -19,27 +19,31 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T>
template <typename Place, typename T>
class LinearChainCrfOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"This kernel only runs on CPU.");
}
void Compute(const framework::ExecutionContext& ctx) const override;
protected:
T ForwardOneSequence(const platform::DeviceContext& ctx,
const Tensor& emission, Tensor& emission_row_max,
Tensor& emission_exps, const Tensor& trans_weights,
Tensor& trans_weight_exps, const Tensor& label,
Tensor& a) const;
private:
T NormalizeL1(T* x, size_t len) const;
};
template <typename T>
template <typename Place, typename T>
class LinearChainCrfGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"This kernel only runs on CPU.");
}
void Compute(const framework::ExecutionContext& ctx) const override;
};
} // namespace operators
......
......@@ -60,19 +60,23 @@ Because this operators performs a softmax on logits internally, it expects
unscaled logits. Please do not call this op with the output of softmax operator,
which will produce incorrect results.
This operators expects mutually exclusive hard labels, each sample in a batch
is in exactly one class with probabilities 1. Each sample in the batch with one
and only one label.
When the attribute softLabel is set false, this operators expects mutually
exclusive hard labels, each sample in a batch is in exactly one class with
probabilities 1. Each sample in the batch with one and only one label.
Equation:
1) hard label (one-hot label)
Loss_j = -\text{Logit}_{Label_j} + \log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right), j = 1, ..., K
Loss_j = \f$ -\text{Logit}_{Label_j} +
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right),
j = 1, ..., K $\f
2) soft label (a distribution over all classes)
Loss_j = -\sum_{i=0}^{K}\text{Label}_i\left(\text{Logit}_i-\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right), j = 1,...,K
Loss_j = \f$ -\sum_{i=0}^{K}\text{Label}_i\left(\text{Logit}_i -
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right),
j = 1,...,K $\f
)DOC");
}
......
......@@ -61,13 +61,13 @@ class LinearChainCrfForward(object):
s += alpha[-1, i] * self.b_exps[i]
log_likelihood -= np.log(s)
# calculate the noninator part.
# calculate the nominator part.
log_likelihood += (
self.a[label[0]] + self.x[0, label[0]] + self.b[label[-1]])
for k in range(1, seq_len):
log_likelihood += (
self.x[k, label[k]] + self.w[label[k - 1], label[k]])
return log_likelihood
return -log_likelihood
def crf_forward_compute(self):
for i in range(self.seq_num):
......@@ -102,7 +102,7 @@ class TestLinearChainCrfOp(OpTest):
self.inputs = {
"Emission": (emission, lod),
"Transition": transition,
"label": (labels, lod)
"Label": (labels, lod)
}
crf = LinearChainCrfForward(lod[0], emission, transition, labels)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册