未验证 提交 8b5307bf 编写于 作者: S ShenLiang 提交者: GitHub

fix flatten api grad (#30426) (#30441)

上级 35c8eaf5
...@@ -68,7 +68,9 @@ class FlattenGradKernel : public framework::OpKernel<T> { ...@@ -68,7 +68,9 @@ class FlattenGradKernel : public framework::OpKernel<T> {
auto in_dims = ctx.Input<framework::LoDTensor>("X")->dims(); auto in_dims = ctx.Input<framework::LoDTensor>("X")->dims();
d_x->mutable_data(ctx.GetPlace(), d_out->type()); d_x->mutable_data(ctx.GetPlace(), d_out->type());
framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x); framework::TensorCopy(
*d_out, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), d_x);
d_x->Resize(in_dims); d_x->Resize(in_dims);
} }
}; };
...@@ -107,7 +109,9 @@ class Flatten2GradKernel : public framework::OpKernel<T> { ...@@ -107,7 +109,9 @@ class Flatten2GradKernel : public framework::OpKernel<T> {
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
d_x->mutable_data(ctx.GetPlace(), d_out->type()); d_x->mutable_data(ctx.GetPlace(), d_out->type());
framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x); framework::TensorCopy(
*d_out, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), d_x);
d_x->Resize(x_dims); d_x->Resize(x_dims);
} }
}; };
...@@ -175,7 +179,9 @@ class FlattenContiguousRangeGradKernel : public framework::OpKernel<T> { ...@@ -175,7 +179,9 @@ class FlattenContiguousRangeGradKernel : public framework::OpKernel<T> {
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
d_x->mutable_data(ctx.GetPlace(), d_out->type()); d_x->mutable_data(ctx.GetPlace(), d_out->type());
framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x); framework::TensorCopy(
*d_out, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), d_x);
d_x->Resize(x_dims); d_x->Resize(x_dims);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册