提交 12f2b8eb 编写于 作者: L Liu Yiqun

Correct the forward of sequence_softmax_op.

上级 4d929394
......@@ -42,8 +42,7 @@ class ReshapeOp : public framework::OperatorWithKernel {
int64_t capacity =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
auto *in = ctx.Input<framework::Tensor>("X");
int64_t in_size = framework::product(in->dims());
PADDLE_ENFORCE_EQ(capacity, in_size,
PADDLE_ENFORCE_EQ(capacity, in->numel(),
"The size of Input(X) mismatches with Attr(shape).");
// resize output
std::vector<int64_t> shape_int64(shape.size(), 0);
......
......@@ -30,18 +30,20 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
"Output(Out) of SequenceSoftmaxOp should not be null.");
auto *x = ctx.Input<framework::LoDTensor>("X");
auto dims = x->dims();
auto lod = x->lod();
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
auto dims = x->dims();
PADDLE_ENFORCE_GE(
dims[0],
/* batch_size */ static_cast<int64_t>(lod[0].size() - 1),
"The first dimension of Input(X) should be larger than batch size.");
PADDLE_ENFORCE_EQ(x->numel(), static_cast<int64_t>(lod[0].size() - 1),
const size_t level = lod.size() - 1;
PADDLE_ENFORCE_EQ(x->numel(), static_cast<int64_t>(lod[level].back()),
"The width of each timestep in Input(X) of "
"SequenceSoftmaxOp should be 1.");
dims[0] = lod[0].size() - 1;
std::cout << DebugString() << std::endl;
ctx.Output<framework::LoDTensor>("Out")->Resize({dims});
}
};
......
......@@ -38,7 +38,7 @@ class SequenceSoftmaxKernel : public framework::OpKernel {
auto* out = ctx.Output<LoDTensor>("Out");
auto lod = x->lod();
const size_t level = lod.size();
const size_t level = lod.size() - 1;
out->mutable_data<T>(ctx.GetPlace());
for (int i = 0; i < static_cast<int>(lod[level].size()) - 1; ++i) {
......@@ -47,6 +47,10 @@ class SequenceSoftmaxKernel : public framework::OpKernel {
Tensor x_i = x->Slice<T>(start_pos, end_pos);
Tensor out_i = out->Slice<T>(start_pos, end_pos);
// Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos)
framework::DDim dims = framework::make_ddim({1UL, end_pos - start_pos});
x_i.Resize(dims);
out_i.Resize(dims);
math::SoftmaxFunctor<Place, T>()(&x_i, &out_i, ctx);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册