diff --git a/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cc b/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cc index f5d6060bc365ad5d252127d1e6806ae7f06f93f6..cc4eedbf4de2272caac75eb1e5a1d51feaf8cb38 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cc @@ -30,13 +30,6 @@ class SequenceEnumerateOp : public framework::OperatorWithKernel { "Output(X) of SequenceEnumerate operator should not be null."); const auto x_dims = ctx->GetInputDim("X"); - PADDLE_ENFORCE_EQ( - x_dims.size(), 2, - "Input(X) of SequenceEnumerate operator's rank should be 2."); - PADDLE_ENFORCE_EQ(x_dims[1], 1, - "Input(X) of SequenceEnumerate operator's 2nd " - "dimension should be 1."); - const auto win_size = ctx->Attrs().Get("win_size"); ctx->SetOutputDim("Out", {x_dims[0], win_size}); ctx->ShareLoD("X", "Out"); diff --git a/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.h b/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.h index 18da69993b2ad5879dd4678ec0d4b06d7e30cb0a..6a1eb6e625b6990506ba554de4e2398daeb64451 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.h +++ b/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.h @@ -27,30 +27,47 @@ class SequenceEnumerateKernel : public framework::OpKernel { auto* in = context.Input("X"); auto* out = context.Output("Out"); int win_size = context.Attr("win_size"); - int pad_value = context.Attr("pad_value"); + auto pad_value = static_cast(context.Attr("pad_value")); auto in_dims = in->dims(); - auto in_lod = in->lod(); - + auto lod0 = in->lod()[0]; PADDLE_ENFORCE_EQ( - static_cast(in_dims[0]), in_lod[0].back(), + static_cast(in_dims[0]), lod0.back(), "The actual input data's size mismatched with LoD information."); + PADDLE_ENFORCE_EQ( + in_dims.size(), 2UL, + "Input(X) of SequenceEnumerate operator's rank should be 2."); + PADDLE_ENFORCE_EQ(in_dims[1], 1, + "Input(X) of SequenceEnumerate operator's 2nd " + "dimension should be 1."); // Generate enumerate sequence set - auto lod0 = in_lod[0]; auto in_data = in->data(); out->Resize({in_dims[0], win_size}); + out->set_lod(in->lod()); auto out_data = out->mutable_data(context.GetPlace()); for (size_t i = 0; i < lod0.size() - 1; ++i) { - for (size_t idx = lod0[i]; idx < lod0[i + 1]; ++idx) { - for (int word_idx = 0; word_idx < win_size; ++word_idx) { - size_t word_pos = idx + word_idx; - out_data[win_size * idx + word_idx] = - word_pos < lod0[i + 1] ? in_data[word_pos] : pad_value; + int start = lod0[i]; + int end = lod0[i + 1]; + int copy_size = win_size < end - start + 1 ? win_size : end - start + 1; + int mid = end + 1 - copy_size; + int pad_num = win_size - copy_size; + copy_size *= sizeof(T); + for (int idx = start; idx < mid; ++idx) { + std::memcpy(out_data, in_data + idx, copy_size); + out_data += win_size; + } + for (int idx = mid; idx < end; ++idx) { + copy_size -= sizeof(T); + pad_num++; + std::memcpy(out_data, in_data + idx, copy_size); + T* pdata = out_data + copy_size / sizeof(T); + for (int i = 0; i < pad_num; ++i) { + pdata[i] = pad_value; } + out_data += win_size; } } - out->set_lod(in->lod()); } };