提交 acd1aaea 编写于 作者: W wanghaoshuang

fix issues

上级 901b0411
...@@ -28,7 +28,7 @@ class SeqExpandOp : public framework::OperatorWithKernel { ...@@ -28,7 +28,7 @@ class SeqExpandOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SeqExpandOp should not be null."); "Input(X) of SeqExpandOp should not be null.");
int repeat = ctx->Attrs().Get<int>("repeat"); int repeat = ctx->Attrs().Get<int>("repeat");
DDim out_dim; framework::DDim out_dim;
if (repeat == 0) { if (repeat == 0) {
PADDLE_ENFORCE( PADDLE_ENFORCE(
ctx->HasInput("Y"), ctx->HasInput("Y"),
...@@ -38,7 +38,6 @@ class SeqExpandOp : public framework::OperatorWithKernel { ...@@ -38,7 +38,6 @@ class SeqExpandOp : public framework::OperatorWithKernel {
} else { } else {
out_dim = ctx->GetInputDim("X"); out_dim = ctx->GetInputDim("X");
out_dim[0] = out_dim[0] * repeat; out_dim[0] = out_dim[0] * repeat;
ctx->SetOutputDim("Out", y_dim);
} }
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of PadOp should not be null."); "Output(Out) of PadOp should not be null.");
......
...@@ -21,7 +21,6 @@ namespace paddle { ...@@ -21,7 +21,6 @@ namespace paddle {
namespace operators { namespace operators {
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
using LoD = paddle::framework::LoD;
template <typename Place, typename T> template <typename Place, typename T>
class SeqExpandKernel : public framework::OpKernel<T> { class SeqExpandKernel : public framework::OpKernel<T> {
...@@ -35,11 +34,11 @@ class SeqExpandKernel : public framework::OpKernel<T> { ...@@ -35,11 +34,11 @@ class SeqExpandKernel : public framework::OpKernel<T> {
if (repeat != 0) { if (repeat != 0) {
if (x->lod().size() == 0) { if (x->lod().size() == 0) {
std::vector<size_t> level0(x->dims()[0]); std::vector<size_t> level0;
for (size_t i = 0; i <= x->dims()[0]; i++) { for (size_t i = 0; i <= x->dims()[0]; i++) {
level0.push_back(i * repeat); level0.push_back(i * repeat);
} }
const LoD out_lod; framework::LoD out_lod;
out_lod.push_back(level0); out_lod.push_back(level0);
out->set_lod(out_lod); out->set_lod(out_lod);
} }
...@@ -55,14 +54,15 @@ class SeqExpandKernel : public framework::OpKernel<T> { ...@@ -55,14 +54,15 @@ class SeqExpandKernel : public framework::OpKernel<T> {
} }
} }
} }
if (paddle::platform::CPUPlace() == Place) { if (platform::is_cpu_place(context.GetPlace())) {
for (int i = 0; i < out_dim[0]; ++i) { for (int i = 0; i < out_dim[0]; ++i) {
memcpy(out_data + element_len * i, x_data + element_len * cpy_map[i], memcpy(out_data + element_len * i, x_data + element_len * cpy_map[i],
sizeof(T) * element_len); sizeof(T) * element_len);
} }
} else { } else {
for (int i = 0; i < out_dim[0]; ++i) { for (int i = 0; i < out_dim[0]; ++i) {
hl_memcpy(out_data + element_len * i, x_data + element_len * cpy_map[i], hl_memcpy(out_data + element_len * i,
const_cast<T*>(x_data) + element_len * cpy_map[i],
sizeof(T) * element_len); sizeof(T) * element_len);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册