未验证 提交 e85be1b1 编写于 作者: S ShenLiang 提交者: GitHub

fix flatten api grad (#30426)

上级 c94a4b94
......@@ -68,7 +68,9 @@ class FlattenGradKernel : public framework::OpKernel<T> {
auto in_dims = ctx.Input<framework::LoDTensor>("X")->dims();
d_x->mutable_data(ctx.GetPlace(), d_out->type());
framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x);
framework::TensorCopy(
*d_out, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), d_x);
d_x->Resize(in_dims);
}
};
......@@ -107,7 +109,9 @@ class Flatten2GradKernel : public framework::OpKernel<T> {
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
d_x->mutable_data(ctx.GetPlace(), d_out->type());
framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x);
framework::TensorCopy(
*d_out, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), d_x);
d_x->Resize(x_dims);
}
};
......@@ -175,7 +179,9 @@ class FlattenContiguousRangeGradKernel : public framework::OpKernel<T> {
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
d_x->mutable_data(ctx.GetPlace(), d_out->type());
framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x);
framework::TensorCopy(
*d_out, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), d_x);
d_x->Resize(x_dims);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册