提交 acd1aaea 编写于 作者: W wanghaoshuang

fix issues

上级 901b0411
......@@ -28,7 +28,7 @@ class SeqExpandOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SeqExpandOp should not be null.");
int repeat = ctx->Attrs().Get<int>("repeat");
DDim out_dim;
framework::DDim out_dim;
if (repeat == 0) {
PADDLE_ENFORCE(
ctx->HasInput("Y"),
......@@ -38,7 +38,6 @@ class SeqExpandOp : public framework::OperatorWithKernel {
} else {
out_dim = ctx->GetInputDim("X");
out_dim[0] = out_dim[0] * repeat;
ctx->SetOutputDim("Out", y_dim);
}
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of PadOp should not be null.");
......
......@@ -21,7 +21,6 @@ namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using LoD = paddle::framework::LoD;
template <typename Place, typename T>
class SeqExpandKernel : public framework::OpKernel<T> {
......@@ -35,11 +34,11 @@ class SeqExpandKernel : public framework::OpKernel<T> {
if (repeat != 0) {
if (x->lod().size() == 0) {
std::vector<size_t> level0(x->dims()[0]);
std::vector<size_t> level0;
for (size_t i = 0; i <= x->dims()[0]; i++) {
level0.push_back(i * repeat);
}
const LoD out_lod;
framework::LoD out_lod;
out_lod.push_back(level0);
out->set_lod(out_lod);
}
......@@ -55,14 +54,15 @@ class SeqExpandKernel : public framework::OpKernel<T> {
}
}
}
if (paddle::platform::CPUPlace() == Place) {
if (platform::is_cpu_place(context.GetPlace())) {
for (int i = 0; i < out_dim[0]; ++i) {
memcpy(out_data + element_len * i, x_data + element_len * cpy_map[i],
sizeof(T) * element_len);
}
} else {
for (int i = 0; i < out_dim[0]; ++i) {
hl_memcpy(out_data + element_len * i, x_data + element_len * cpy_map[i],
hl_memcpy(out_data + element_len * i,
const_cast<T*>(x_data) + element_len * cpy_map[i],
sizeof(T) * element_len);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册