未验证 提交 5af81f83 编写于 作者: W wangchaochaohu 提交者: GitHub

fix gpu kernel for numel Op (#27085)

上级 39d5bb6d
...@@ -53,7 +53,7 @@ REGISTER_OPERATOR( ...@@ -53,7 +53,7 @@ REGISTER_OPERATOR(
size, ops::SizeOp, ops::SizeOpMaker, size, ops::SizeOp, ops::SizeOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(size, ops::SizeKernel<int>, ops::SizeKernel<int32_t>, REGISTER_OP_CPU_KERNEL(size, ops::SizeKernel<int>, ops::SizeKernel<int64_t>,
ops::SizeKernel<paddle::platform::float16>, ops::SizeKernel<paddle::platform::float16>,
ops::SizeKernel<float>, ops::SizeKernel<double>, ops::SizeKernel<float>, ops::SizeKernel<double>,
ops::SizeKernel<bool>); ops::SizeKernel<bool>);
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
size, paddle::operators::SizeKernel<int>, size, paddle::operators::SizeKernel<int>,
paddle::operators::SizeKernel<int32_t>, paddle::operators::SizeKernel<int64_t>,
paddle::operators::SizeKernel<paddle::platform::float16>, paddle::operators::SizeKernel<paddle::platform::float16>,
paddle::operators::SizeKernel<float>, paddle::operators::SizeKernel<bool>, paddle::operators::SizeKernel<float>, paddle::operators::SizeKernel<bool>,
paddle::operators::SizeKernel<double>); paddle::operators::SizeKernel<double>);
...@@ -26,8 +26,18 @@ class SizeKernel : public framework::OpKernel<T> { ...@@ -26,8 +26,18 @@ class SizeKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_t = ctx.Input<Tensor>("Input"); auto* in_t = ctx.Input<Tensor>("Input");
auto* out_t = ctx.Output<Tensor>("Out"); auto* out_t = ctx.Output<Tensor>("Out");
auto out_data = out_t->mutable_data<int64_t>(platform::CPUPlace()); auto place = ctx.GetPlace();
out_data[0] = in_t->numel(); auto out_data = out_t->mutable_data<int64_t>(place);
auto cpu_place = platform::CPUPlace();
if (place == cpu_place) {
out_data[0] = in_t->numel();
} else {
Tensor cpu_tensor;
auto cpu_data =
cpu_tensor.mutable_data<int64_t>(out_t->dims(), cpu_place);
cpu_data[0] = in_t->numel();
TensorCopy(cpu_tensor, place, out_t);
}
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -1001,7 +1001,7 @@ def chunk(x, chunks, axis=0, name=None): ...@@ -1001,7 +1001,7 @@ def chunk(x, chunks, axis=0, name=None):
x_np = np.random.random([3, 9, 5]).astype("int32") x_np = np.random.random([3, 9, 5]).astype("int32")
x = paddle.to_tensor(x_np) x = paddle.to_tensor(x_np)
out0, out1, out22 = paddle.chunk(x, chunks=3, axis=1) out0, out1, out2 = paddle.chunk(x, chunks=3, axis=1)
# out0.shape [3, 3, 5] # out0.shape [3, 3, 5]
# out1.shape [3, 3, 5] # out1.shape [3, 3, 5]
# out2.shape [3, 3, 5] # out2.shape [3, 3, 5]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册