提交 50931dee 编写于 作者: T tensor-tang

refine seq enum op

test=develop
上级 8e4ad008
...@@ -30,13 +30,6 @@ class SequenceEnumerateOp : public framework::OperatorWithKernel { ...@@ -30,13 +30,6 @@ class SequenceEnumerateOp : public framework::OperatorWithKernel {
"Output(X) of SequenceEnumerate operator should not be null."); "Output(X) of SequenceEnumerate operator should not be null.");
const auto x_dims = ctx->GetInputDim("X"); const auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(
x_dims.size(), 2,
"Input(X) of SequenceEnumerate operator's rank should be 2.");
PADDLE_ENFORCE_EQ(x_dims[1], 1,
"Input(X) of SequenceEnumerate operator's 2nd "
"dimension should be 1.");
const auto win_size = ctx->Attrs().Get<int>("win_size"); const auto win_size = ctx->Attrs().Get<int>("win_size");
ctx->SetOutputDim("Out", {x_dims[0], win_size}); ctx->SetOutputDim("Out", {x_dims[0], win_size});
ctx->ShareLoD("X", "Out"); ctx->ShareLoD("X", "Out");
......
...@@ -27,30 +27,47 @@ class SequenceEnumerateKernel : public framework::OpKernel<T> { ...@@ -27,30 +27,47 @@ class SequenceEnumerateKernel : public framework::OpKernel<T> {
auto* in = context.Input<LoDTensor>("X"); auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out"); auto* out = context.Output<LoDTensor>("Out");
int win_size = context.Attr<int>("win_size"); int win_size = context.Attr<int>("win_size");
int pad_value = context.Attr<int>("pad_value"); auto pad_value = static_cast<T>(context.Attr<int>("pad_value"));
auto in_dims = in->dims(); auto in_dims = in->dims();
auto in_lod = in->lod(); auto lod0 = in->lod()[0];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
static_cast<uint64_t>(in_dims[0]), in_lod[0].back(), static_cast<uint64_t>(in_dims[0]), lod0.back(),
"The actual input data's size mismatched with LoD information."); "The actual input data's size mismatched with LoD information.");
PADDLE_ENFORCE_EQ(
in_dims.size(), 2UL,
"Input(X) of SequenceEnumerate operator's rank should be 2.");
PADDLE_ENFORCE_EQ(in_dims[1], 1,
"Input(X) of SequenceEnumerate operator's 2nd "
"dimension should be 1.");
// Generate enumerate sequence set // Generate enumerate sequence set
auto lod0 = in_lod[0];
auto in_data = in->data<T>(); auto in_data = in->data<T>();
out->Resize({in_dims[0], win_size}); out->Resize({in_dims[0], win_size});
out->set_lod(in->lod());
auto out_data = out->mutable_data<T>(context.GetPlace()); auto out_data = out->mutable_data<T>(context.GetPlace());
for (size_t i = 0; i < lod0.size() - 1; ++i) { for (size_t i = 0; i < lod0.size() - 1; ++i) {
for (size_t idx = lod0[i]; idx < lod0[i + 1]; ++idx) { int start = lod0[i];
for (int word_idx = 0; word_idx < win_size; ++word_idx) { int end = lod0[i + 1];
size_t word_pos = idx + word_idx; int copy_size = win_size < end - start + 1 ? win_size : end - start + 1;
out_data[win_size * idx + word_idx] = int mid = end + 1 - copy_size;
word_pos < lod0[i + 1] ? in_data[word_pos] : pad_value; int pad_num = win_size - copy_size;
copy_size *= sizeof(T);
for (int idx = start; idx < mid; ++idx) {
std::memcpy(out_data, in_data + idx, copy_size);
out_data += win_size;
} }
for (int idx = mid; idx < end; ++idx) {
copy_size -= sizeof(T);
pad_num++;
std::memcpy(out_data, in_data + idx, copy_size);
T* pdata = out_data + copy_size / sizeof(T);
for (int i = 0; i < pad_num; ++i) {
pdata[i] = pad_value;
}
out_data += win_size;
} }
} }
out->set_lod(in->lod());
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册