未验证 提交 5af81f83 编写于 作者: W wangchaochaohu 提交者: GitHub

fix gpu kernel for numel Op (#27085)

上级 39d5bb6d
......@@ -53,7 +53,7 @@ REGISTER_OPERATOR(
size, ops::SizeOp, ops::SizeOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(size, ops::SizeKernel<int>, ops::SizeKernel<int32_t>,
REGISTER_OP_CPU_KERNEL(size, ops::SizeKernel<int>, ops::SizeKernel<int64_t>,
ops::SizeKernel<paddle::platform::float16>,
ops::SizeKernel<float>, ops::SizeKernel<double>,
ops::SizeKernel<bool>);
......@@ -16,7 +16,7 @@ limitations under the License. */
REGISTER_OP_CUDA_KERNEL(
size, paddle::operators::SizeKernel<int>,
paddle::operators::SizeKernel<int32_t>,
paddle::operators::SizeKernel<int64_t>,
paddle::operators::SizeKernel<paddle::platform::float16>,
paddle::operators::SizeKernel<float>, paddle::operators::SizeKernel<bool>,
paddle::operators::SizeKernel<double>);
......@@ -26,8 +26,18 @@ class SizeKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_t = ctx.Input<Tensor>("Input");
auto* out_t = ctx.Output<Tensor>("Out");
auto out_data = out_t->mutable_data<int64_t>(platform::CPUPlace());
out_data[0] = in_t->numel();
auto place = ctx.GetPlace();
auto out_data = out_t->mutable_data<int64_t>(place);
auto cpu_place = platform::CPUPlace();
if (place == cpu_place) {
out_data[0] = in_t->numel();
} else {
Tensor cpu_tensor;
auto cpu_data =
cpu_tensor.mutable_data<int64_t>(out_t->dims(), cpu_place);
cpu_data[0] = in_t->numel();
TensorCopy(cpu_tensor, place, out_t);
}
}
};
} // namespace operators
......
......@@ -1001,7 +1001,7 @@ def chunk(x, chunks, axis=0, name=None):
x_np = np.random.random([3, 9, 5]).astype("int32")
x = paddle.to_tensor(x_np)
out0, out1, out22 = paddle.chunk(x, chunks=3, axis=1)
out0, out1, out2 = paddle.chunk(x, chunks=3, axis=1)
# out0.shape [3, 3, 5]
# out1.shape [3, 3, 5]
# out2.shape [3, 3, 5]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册