diff --git a/paddle/fluid/operators/mean_op.cc b/paddle/fluid/operators/mean_op.cc index 785b16ae283b9c5472ff6797a9faa6b3e287c6f5..44fe4f5193420670c21b62f71d82e7e8f5b868a6 100644 --- a/paddle/fluid/operators/mean_op.cc +++ b/paddle/fluid/operators/mean_op.cc @@ -12,11 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/mean_op.h" #include #include #include +#include "paddle/fluid/framework/op_registry.h" + namespace paddle { namespace operators { @@ -94,21 +95,3 @@ REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanOpInferVarType, ops::MeanGradMaker); REGISTER_OPERATOR(mean_grad, ops::MeanGradOp, ops::MeanGradNoNeedBufferVarsInferer); -REGISTER_OP_CPU_KERNEL( - mean, ops::MeanKernel, - ops::MeanKernel, - ops::MeanKernel, - ops::MeanKernel>, - ops::MeanKernel>); -REGISTER_OP_CPU_KERNEL( - mean_grad, ops::MeanGradKernel, - ops::MeanGradKernel, - ops::MeanGradKernel, - ops::MeanGradKernel>, - ops::MeanGradKernel>); diff --git a/paddle/fluid/operators/mean_op.cu b/paddle/fluid/operators/mean_op.cu deleted file mode 100644 index 813dce6080130c0e4894f085c8c199e147e275bb..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/mean_op.cu +++ /dev/null @@ -1,119 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#ifdef __NVCC__ -#include "cub/cub.cuh" -#endif -#ifdef __HIPCC__ -#include -namespace cub = hipcub; -#endif -#include "paddle/fluid/operators/amp/fp16_type_traits.h" -#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h" -#include "paddle/fluid/operators/mean_op.h" -#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" -#include "paddle/fluid/platform/float16.h" - -namespace paddle { -namespace operators { - -template -__global__ void MeanRunKernel(const T* in_data, T* out_data, int N) { - using MT = typename details::MPTypeTrait::Type; - int idx = blockDim.x * blockIdx.x + threadIdx.x; - auto data = static_cast(in_data[0]); - for (; idx < N; idx += blockDim.x * gridDim.x) { - out_data[idx] = static_cast(data / (static_cast(N))); - } -} - -template -class MeanCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* input = context.Input("X"); - auto* output = context.Output("Out"); - - const T* in_data = input->data(); - T* out_data = output->mutable_data(context.GetPlace()); - auto numel = input->numel(); - auto rank = input->dims().size(); - auto place = context.GetPlace(); - auto stream = context.cuda_device_context().stream(); - - if (rank == 0) { // scalar - auto gpu_place = place; - memory::Copy(gpu_place, out_data, gpu_place, in_data, numel * sizeof(T), - stream); - return; - } - - using Div = kernel_primitives::DivideFunctor; - std::vector reduce_dims; - reduce_dims.reserve(rank); - for (decltype(rank) i = 0; i < rank; ++i) { - reduce_dims.push_back(i); - } - TensorReduceImpl>( - context.cuda_device_context(), *input, output, - kps::IdentityFunctor(), reduce_dims, stream, true); - } -}; - -template -class MeanCUDAGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto OG = context.Input(framework::GradVarName("Out")); - PADDLE_ENFORCE_EQ(OG->numel(), 1, - platform::errors::InvalidArgument( - "Mean Gradient Input Tensor len should be 1. But " - "received Out@Grad's elements num is %d.", - OG->numel())); - auto IG = context.Output(framework::GradVarName("X")); - IG->mutable_data(context.GetPlace()); - - auto in_data = OG->data(); - auto size_prob = IG->numel(); - auto out_data = IG->data(); - int threads = 512; - int grid = (size_prob + threads - 1) / threads; - auto stream = context.cuda_device_context().stream(); - MeanRunKernel<<>>(in_data, out_data, - size_prob); - } -}; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL( - mean, ops::MeanCUDAKernel, - ops::MeanCUDAKernel, - ops::MeanCUDAKernel, - ops::MeanCUDAKernel>, - ops::MeanCUDAKernel>); -REGISTER_OP_CUDA_KERNEL( - mean_grad, - ops::MeanCUDAGradKernel, - ops::MeanCUDAGradKernel, - ops::MeanCUDAGradKernel, - ops::MeanCUDAGradKernel>, - ops::MeanCUDAGradKernel>); diff --git a/paddle/fluid/operators/mean_op.h b/paddle/fluid/operators/mean_op.h deleted file mode 100644 index 4780150751bf66c3d53e5eebb1ad1080a48a7420..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/mean_op.h +++ /dev/null @@ -1,70 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -template -using EigenScalar = framework::EigenScalar; -template -using EigenVector = framework::EigenVector; - -template -class MeanKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* input = context.Input("X"); - auto* output = context.Output("Out"); - - output->mutable_data(context.GetPlace()); - - auto X = EigenVector::Flatten(*input); - auto y = EigenScalar::From(*output); - auto& place = - *context.template device_context().eigen_device(); - - y.device(place) = X.mean(); - } -}; - -template -class MeanGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto OG = context.Input(framework::GradVarName("Out")); - PADDLE_ENFORCE_EQ(OG->numel(), 1UL, - platform::errors::InvalidArgument( - "Mean Gradient should be scalar. But received " - "Out@Grad's elements num is %d.", - OG->numel())); - auto IG = context.Output(framework::GradVarName("X")); - IG->mutable_data(context.GetPlace()); - - T ig_size = static_cast(IG->numel()); - Eigen::DSizes bcast(static_cast(ig_size)); - EigenVector::Flatten(*IG).device( - *context.template device_context().eigen_device()) = - (EigenVector::From(*OG) / ig_size).broadcast(bcast); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/mean_op_mlu.cc b/paddle/fluid/operators/mean_op_mlu.cc index f8246165c550111c3dbc43c0eef16d1fd0299eb2..1fed01194c1a6c4f5743d98e09db1993c8c8e998 100644 --- a/paddle/fluid/operators/mean_op_mlu.cc +++ b/paddle/fluid/operators/mean_op_mlu.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/mean_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h" #include "paddle/fluid/platform/device/mlu/device_context.h" #include "paddle/fluid/platform/float16.h" @@ -20,6 +20,8 @@ namespace paddle { namespace operators { +using Tensor = framework::Tensor; + template class MeanMLUKernel : public framework::OpKernel { public: diff --git a/paddle/fluid/operators/mean_op_npu.cc b/paddle/fluid/operators/mean_op_npu.cc index d81594658044a3492c1a453102138d4f0ce16486..7e15a793fd1b2e1f536e8c912fe4b596489f0ff5 100644 --- a/paddle/fluid/operators/mean_op_npu.cc +++ b/paddle/fluid/operators/mean_op_npu.cc @@ -9,13 +9,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/mean_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { +using Tensor = framework::Tensor; + template class MeanNPUKernel : public framework::OpKernel { public: diff --git a/paddle/fluid/operators/mean_op_xpu.cc b/paddle/fluid/operators/mean_op_xpu.cc index 53bc658af61b263942da63f061c4d7545adde2b8..ef96fe2f03ba41b7599cb9324eef16ee5b37e944 100644 --- a/paddle/fluid/operators/mean_op_xpu.cc +++ b/paddle/fluid/operators/mean_op_xpu.cc @@ -12,15 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/mean_op.h" #ifdef PADDLE_WITH_XPU #include #include #include +#include "paddle/fluid/framework/op_registry.h" + namespace paddle { namespace operators { +using Tensor = framework::Tensor; + template class MeanXPUKernel : public framework::OpKernel { using XPUType = typename XPUTypeTrait::Type; diff --git a/paddle/fluid/operators/mkldnn/gaussian_random_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/gaussian_random_mkldnn_op.cc index ab92a165c76d124e27c2635863846e52815c3d61..503d3ec33762fb45f4132d133c3c5ad5540eca1a 100644 --- a/paddle/fluid/operators/mkldnn/gaussian_random_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/gaussian_random_mkldnn_op.cc @@ -16,7 +16,6 @@ limitations under the License. */ #include "paddle/fluid/framework/generator.h" #include "paddle/fluid/operators/fill_constant_op.h" -#include "paddle/fluid/operators/mean_op.h" namespace paddle { namespace operators { diff --git a/paddle/phi/core/compat/op_utils.h b/paddle/phi/core/compat/op_utils.h index d9cff03e89ca212a4bdbde84dbc031ca68f8be6f..b4616c8c1b721bdbaf6b77ce0583b158b7a7bb44 100644 --- a/paddle/phi/core/compat/op_utils.h +++ b/paddle/phi/core/compat/op_utils.h @@ -50,8 +50,6 @@ const std::unordered_set deprecated_op_names({"diag", "matmul", "matmul_grad", "matmul_grad_grad", - "mean", - "mean_grad", "max", "max_grad", "min", diff --git a/paddle/phi/kernels/cpu/mean_all_grad_kernel.cc b/paddle/phi/kernels/cpu/mean_all_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..2472c4412f6eacb7d887a7a39055507d2d5b91c5 --- /dev/null +++ b/paddle/phi/kernels/cpu/mean_all_grad_kernel.cc @@ -0,0 +1,51 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/mean_all_grad_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { + +template +void MeanAllGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + DenseTensor* x_grad) { + PADDLE_ENFORCE_EQ(out_grad.numel(), + 1UL, + phi::errors::InvalidArgument( + "Mean Gradient should be scalar. But received " + "Out@Grad's elements num is %d.", + out_grad.numel())); + dev_ctx.template Alloc(x_grad); + + T ig_size = static_cast(x_grad->numel()); + Eigen::DSizes bcast(static_cast(ig_size)); + EigenVector::Flatten(*x_grad).device(*dev_ctx.eigen_device()) = + (EigenVector::From(out_grad) / ig_size).broadcast(bcast); +} + +} // namespace phi + +PD_REGISTER_KERNEL(mean_all_grad, + CPU, + ALL_LAYOUT, + phi::MeanAllGradKernel, + float, + double, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/mean_all_kernel.cc b/paddle/phi/kernels/cpu/mean_all_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..8321194984756ead20ad4b38afaf365dd89b2386 --- /dev/null +++ b/paddle/phi/kernels/cpu/mean_all_kernel.cc @@ -0,0 +1,45 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/mean_all_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { + +template +void MeanAllKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + dev_ctx.template Alloc(out); + + auto X = EigenVector::Flatten(x); + auto y = EigenScalar::From(*out); + auto& place = *dev_ctx.eigen_device(); + + y.device(place) = X.mean(); +} + +} // namespace phi + +PD_REGISTER_KERNEL(mean_all, + CPU, + ALL_LAYOUT, + phi::MeanAllKernel, + float, + double, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/mean_all_grad_kernel.cu b/paddle/phi/kernels/gpu/mean_all_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..83d4e3a57735f08c18f5320c9f65623f31c657c4 --- /dev/null +++ b/paddle/phi/kernels/gpu/mean_all_grad_kernel.cu @@ -0,0 +1,64 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/mean_all_kernel.h" + +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +__global__ void MeanRunKernel(const T* in_data, T* out_data, int N) { + using MT = typename dtype::MPTypeTrait::Type; + int idx = blockDim.x * blockIdx.x + threadIdx.x; + auto data = static_cast(in_data[0]); + for (; idx < N; idx += blockDim.x * gridDim.x) { + out_data[idx] = static_cast(data / (static_cast(N))); + } +} + +template +void MeanAllGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + DenseTensor* x_grad) { + PADDLE_ENFORCE_EQ(out_grad.numel(), + 1, + phi::errors::InvalidArgument( + "Mean Gradient Input Tensor len should be 1. But " + "received Out@Grad's elements num is %d.", + out_grad.numel())); + dev_ctx.template Alloc(x_grad); + + auto in_data = out_grad.data(); + auto size_prob = x_grad->numel(); + auto out_data = x_grad->data(); + int threads = 512; + int grid = (size_prob + threads - 1) / threads; + auto stream = dev_ctx.stream(); + MeanRunKernel<<>>(in_data, out_data, size_prob); +} + +} // namespace phi + +PD_REGISTER_KERNEL(mean_all_grad, + GPU, + ALL_LAYOUT, + phi::MeanAllGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/mean_all_kernel.cu b/paddle/phi/kernels/gpu/mean_all_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..799865be26e24412a62fc5136ef9a4a00c2fd2ee --- /dev/null +++ b/paddle/phi/kernels/gpu/mean_all_kernel.cu @@ -0,0 +1,66 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/mean_all_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/reduce_function.h" +#include "paddle/phi/kernels/primitive/functor_primitives.h" + +#include "paddle/fluid/memory/memcpy.h" + +namespace phi { + +template +void MeanAllKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + const T* in_data = x.data(); + T* out_data = dev_ctx.template Alloc(out); + auto numel = x.numel(); + auto rank = x.dims().size(); + auto place = dev_ctx.GetPlace(); + auto stream = dev_ctx.stream(); + + if (rank == 0) { // scalar + paddle::memory::Copy( + place, out_data, place, in_data, numel * sizeof(T), stream); + return; + } + + std::vector reduce_dims; + reduce_dims.reserve(rank); + for (decltype(rank) i = 0; i < rank; ++i) { + reduce_dims.push_back(i); + } + funcs::ReduceKernel>( + dev_ctx, + x, + out, + kps::IdentityFunctor(), + reduce_dims, + /*is_mean=*/true); +} + +} // namespace phi + +PD_REGISTER_KERNEL(mean_all, + GPU, + ALL_LAYOUT, + phi::MeanAllKernel, + float, + double, + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/mean_all_grad_kernel.h b/paddle/phi/kernels/mean_all_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..bede8ac9e049a4f9ca28b33426b90abb612fd47f --- /dev/null +++ b/paddle/phi/kernels/mean_all_grad_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void MeanAllGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/mean_all_kernel.h b/paddle/phi/kernels/mean_all_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..a332526643710a49104296e6be339d9548498b14 --- /dev/null +++ b/paddle/phi/kernels/mean_all_kernel.h @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +// In order to be compatible with `mean` op in fluid, +// it is no longer used in 2.x API. It can not implement by call +// ReduceMeanKernel because ReduceMeanKernel doesn't support bfloat16 now, +// maybe we can unify this kernel to ReduceMeanKernel series in the future +template +void MeanAllKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/mean_sig.cc b/paddle/phi/ops/compat/mean_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..6decd0da0b08698f942ccef1b25f070098f7a501 --- /dev/null +++ b/paddle/phi/ops/compat/mean_sig.cc @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature MeanOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("mean_all", {"X"}, {}, {"Out"}); +} + +KernelSignature MeanGradOpGradArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "mean_all_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_BASE_KERNEL_NAME(mean, mean_all); +PD_REGISTER_BASE_KERNEL_NAME(mean_grad, mean_all_grad); + +PD_REGISTER_ARG_MAPPING_FN(mean, phi::MeanOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(mean_grad, phi::MeanGradOpGradArgumentMapping);