diff --git a/paddle/fluid/framework/new_executor/standalone_executor_test.cc b/paddle/fluid/framework/new_executor/standalone_executor_test.cc index 28e1145db42123b9dacfa9e359e08476d16ab4c0..7fe1852f7396cb8cebe4b83f4cc80a8023421351 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor_test.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor_test.cc @@ -38,7 +38,7 @@ USE_OP(softmax_with_cross_entropy); USE_OP_ITSELF(reduce_mean); USE_OP_ITSELF(reduce_sum); USE_OP_ITSELF(reduce_sum_grad); -USE_OP(reduce_mean_grad); +USE_OP_ITSELF(reduce_mean_grad); USE_OP_ITSELF(reshape2_grad); USE_OP(softmax_with_cross_entropy_grad); USE_OP_ITSELF(elementwise_add_grad); diff --git a/paddle/fluid/operators/reduce_ops/reduce_max_op.cc b/paddle/fluid/operators/reduce_ops/reduce_max_op.cc index 41df8e4a15f093a40a31c70eea98dfb7e575f4cd..15812778e0023e30a29f259bbd14b4c564ea8d46 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_max_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_max_op.cc @@ -35,13 +35,3 @@ REGISTER_OPERATOR( paddle::framework::DefaultGradOpMaker, ReduceMaxInferShapeFunctor); REGISTER_OPERATOR(reduce_max_grad, ops::ReduceGradOp) - -REGISTER_OP_CPU_KERNEL( - reduce_max_grad, ops::ReduceGradKernel, - ops::ReduceGradKernel, - ops::ReduceGradKernel, - ops::ReduceGradKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_max_op.part.cu b/paddle/fluid/operators/reduce_ops/reduce_max_op.part.cu deleted file mode 100644 index 5ee38b8fa46290c86cd44ef1bcc71bd2fcd9bcd4..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/reduce_ops/reduce_max_op.part.cu +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h" - -REGISTER_OP_CUDA_KERNEL( - reduce_max_grad, ops::ReduceGradKernel, - ops::ReduceGradKernel, - ops::ReduceGradKernel, - ops::ReduceGradKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc b/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc index 4a18330913803f822436118a35fb957b7e31b391..dc41979defb9314f2efb942f0f530c3b5da3bb8b 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc @@ -107,12 +107,3 @@ REGISTER_OPERATOR(reduce_mean_grad, ops::ReduceGradOp, ops::ReduceMeanDoubleGradDescMaker, ops::ReduceMeanDoubleGradOpBaseMaker, ops::ReduceMeanGradNoNeedBufferVarInferer); - -template -using CPUReduceMeanGradKernel = - ops::ReduceGradKernel; - -REGISTER_OP_CPU_KERNEL(reduce_mean_grad, CPUReduceMeanGradKernel, - CPUReduceMeanGradKernel, - CPUReduceMeanGradKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_mean_op.part.cu b/paddle/fluid/operators/reduce_ops/reduce_mean_op.part.cu deleted file mode 100644 index a578c9f7d81083c533028b9c8912a24006ed0292..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/reduce_ops/reduce_mean_op.part.cu +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// .part used to speed up nvcc compile -#include "paddle/fluid/operators/reduce_ops/reduce_mean_op.h" - -template -using CUDAReduceMeanGradKernel = - ops::ReduceCudaGradKernel; - -REGISTER_OP_CUDA_KERNEL(reduce_mean_grad, CUDAReduceMeanGradKernel, - CUDAReduceMeanGradKernel, - CUDAReduceMeanGradKernel, - CUDAReduceMeanGradKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_min_op.cc b/paddle/fluid/operators/reduce_ops/reduce_min_op.cc index b9915f2b484f140bfd776b64459a19c6788a55c9..5e5b04d57b002d8e8ecab9ddaf8186118f4bf187 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_min_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_min_op.cc @@ -35,13 +35,3 @@ REGISTER_OPERATOR( paddle::framework::DefaultGradOpMaker, ReduceMinInferShapeFunctor); REGISTER_OPERATOR(reduce_min_grad, ops::ReduceGradOp) - -REGISTER_OP_CPU_KERNEL( - reduce_min_grad, ops::ReduceGradKernel, - ops::ReduceGradKernel, - ops::ReduceGradKernel, - ops::ReduceGradKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_min_op.part.cu b/paddle/fluid/operators/reduce_ops/reduce_min_op.part.cu deleted file mode 100644 index bf886063786a8c36884ed20fef41c99468156c01..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/reduce_ops/reduce_min_op.part.cu +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h" - -REGISTER_OP_CUDA_KERNEL( - reduce_min_grad, ops::ReduceGradKernel, - ops::ReduceGradKernel, - ops::ReduceGradKernel, - ops::ReduceGradKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_prod_op.cc b/paddle/fluid/operators/reduce_ops/reduce_prod_op.cc index eb745ab9c56c5b3cfa62eb36713ebc2485282d6d..b1abdf9e8a758008dff49176c2d6b6682de5b622 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_prod_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_prod_op.cc @@ -14,6 +14,10 @@ #include "paddle/fluid/operators/reduce_ops/reduce_prod_op.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" + namespace paddle { namespace framework { class OpDesc; @@ -26,14 +30,20 @@ class CPUDeviceContext; } // namespace platform } // namespace paddle -REGISTER_REDUCE_OP(reduce_prod); +namespace ops = paddle::operators; + +class ReduceProdOpMaker : public ops::ReduceOpMaker { + protected: + virtual std::string GetName() const { return "reduce_prod"; } + virtual std::string GetOpType() const { return "Reduce reduce_prod"; } +}; + +DECLARE_INFER_SHAPE_FUNCTOR(reduce_prod, ReduceProdInferShapeFunctor, + PD_INFER_META(phi::ReduceInferMetaBase)); -REGISTER_OP_CPU_KERNEL(reduce_prod_grad, - ops::ReduceGradKernel, - ops::ReduceGradKernel, - ops::ReduceGradKernel, - ops::ReduceGradKernel); +REGISTER_OPERATOR( + reduce_prod, ops::ReduceOp, ReduceProdOpMaker, + paddle::framework::DefaultGradOpMaker, + paddle::framework::DefaultGradOpMaker, + ReduceProdInferShapeFunctor); +REGISTER_OPERATOR(reduce_prod_grad, ops::ReduceGradOp); diff --git a/paddle/fluid/operators/reduce_ops/reduce_prod_op.part.cu b/paddle/fluid/operators/reduce_ops/reduce_prod_op.part.cu deleted file mode 100644 index 0610cdd94f89c0371988fac7955d07fc5498a69f..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/reduce_ops/reduce_prod_op.part.cu +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/reduce_ops/reduce_prod_op.h" - -REGISTER_OP_CUDA_KERNEL( - reduce_prod_grad, ops::ReduceGradKernel, - ops::ReduceGradKernel, - ops::ReduceGradKernel, - ops::ReduceGradKernel); diff --git a/paddle/phi/core/compat/op_utils.h b/paddle/phi/core/compat/op_utils.h index b1da573c49f2f20c6b25beae189fe5952efd3cef..946230cb169d20db56a46399552b629348c4783f 100644 --- a/paddle/phi/core/compat/op_utils.h +++ b/paddle/phi/core/compat/op_utils.h @@ -47,8 +47,13 @@ const std::unordered_set deprecated_op_names({"diag", "matmul_grad", "matmul_grad_grad", "mean", + "mean_grad", "max", + "max_grad", "min", + "min_grad", + "prod", + "prod_grad", "any", "all", "reshape", diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 02b5b2d74ad2914f60a1df08e500b06733b95aaa..aa76561c5ce6f272bd5f8096f66e9b3382987de4 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -31,10 +31,11 @@ set(MANUAL_BUILD_KERNELS eigh_kernel gumbel_softmax_kernel gumbel_softmax_grad_k matrix_power_kernel matrix_power_grad_kernel maxout_kernel maxout_grad_kernel pool_kernel put_along_axis_kernel put_along_axis_grad_kernel segment_pool_kernel segment_pool_grad_kernel softmax_kernel softmax_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel - triangular_solve_grad_kernel determinant_grad_kernel) + triangular_solve_grad_kernel determinant_grad_kernel reduce_kernel) kernel_library(eigh_kernel DEPS ${COMMON_KERNEL_DEPS} lapack_function) kernel_library(gumbel_softmax_kernel DEPS ${COMMON_KERNEL_DEPS} softmax) kernel_library(gumbel_softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax) +kernel_library(reduce_kernel DEPS ${COMMON_KERNEL_DEPS} cast_kernel) kernel_library(matrix_power_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_inverse) kernel_library(matrix_power_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_inverse) kernel_library(maxout_kernel DEPS ${COMMON_KERNEL_DEPS} maxouting) diff --git a/paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc b/paddle/phi/kernels/cpu/reduce_grad_kernel.cc similarity index 53% rename from paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc rename to paddle/phi/kernels/cpu/reduce_grad_kernel.cc index efea054555e86be79b5cdb09fe8c4784a1ad0c3b..78a7ae8d415b5d4b18fdf8e469576db50f739e38 100644 --- a/paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_grad_kernel.cc @@ -12,33 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/reduce_sum_grad_kernel.h" +#include "paddle/phi/kernels/reduce_grad_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/cast_kernel.h" -#include "paddle/phi/kernels/cpu/reduce_grad.h" #include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/reduce_functor.h" +#include "paddle/phi/kernels/impl/reduce_grad.h" +#include "paddle/phi/kernels/impl/reduce_max_grad_kernel_impl.h" +#include "paddle/phi/kernels/impl/reduce_min_grad_kernel_impl.h" +#include "paddle/phi/kernels/impl/reduce_prod_grad_kernel_impl.h" namespace phi { -struct SumGradFunctor { - template - void operator()(const DeviceContext& place, - X* x, - Y* y, - DX* dx, - DY* dy, - const Dim& dim, - int size) { - dx->device(place) = dy->broadcast(dim); - } -}; - template void ComputeFromInput(const Context& dev_ctx, const DenseTensor& x, @@ -111,16 +97,38 @@ void ReduceSumGradKernel(const Context& dev_ctx, } } - ReduceGradKernel(dev_ctx, - x, - out_grad, - paddle::none, - dims, - keep_dim, - reduce_all, - in_dtype, - out_dtype, - x_grad); + ReduceGradKernel(dev_ctx, + x, + out_grad, + paddle::none, + dims, + keep_dim, + reduce_all, + in_dtype, + out_dtype, + x_grad); +} + +template +void ReduceMeanGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DataType in_dtype, + DataType out_dtype, + DenseTensor* x_grad) { + ReduceGradKernel(dev_ctx, + x, + out_grad, + paddle::none, + dims, + keep_dim, + reduce_all, + in_dtype, + out_dtype, + x_grad); } } // namespace phi @@ -137,3 +145,38 @@ PD_REGISTER_KERNEL(sum_grad, int64_t, phi::dtype::complex, phi::dtype::complex) {} + +PD_REGISTER_KERNEL(mean_grad, + CPU, + ALL_LAYOUT, + phi::ReduceMeanGradKernel, + bool, + float, + double) {} + +PD_REGISTER_KERNEL(prod_grad, + CPU, + ALL_LAYOUT, + phi::ReduceProdGradKernel, + float, + double, + int, + int64_t) {} + +PD_REGISTER_KERNEL(max_grad, + CPU, + ALL_LAYOUT, + phi::ReduceMaxGradKernel, + float, + double, + int, + int64_t) {} + +PD_REGISTER_KERNEL(min_grad, + CPU, + ALL_LAYOUT, + phi::ReduceMinGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/funcs/reduce_functor.h b/paddle/phi/kernels/funcs/reduce_functor.h index c74880e04322474e28385997b5022ebf52643bf4..b793afb63b1dca9bbd8ad09b83461567de6371ad 100644 --- a/paddle/phi/kernels/funcs/reduce_functor.h +++ b/paddle/phi/kernels/funcs/reduce_functor.h @@ -73,5 +73,82 @@ struct AnyFunctor { } }; +struct MeanGradFunctor { + template + void operator()(const DeviceContext& place, + X* x, + Y* y, + DX* dx, + DY* dy, + const Dim& dim, + int size) { + dx->device(place) = dy->broadcast(dim) / dx->constant(size); + } +}; + +struct SumGradFunctor { + template + void operator()(const DeviceContext& place, + X* x, + Y* y, + DX* dx, + DY* dy, + const Dim& dim, + int size) { + dx->device(place) = dy->broadcast(dim); + } +}; + +struct ProdGradFunctor { + template + void operator()(const DeviceContext& place, + X* x, + Y* y, + DX* dx, + DY* dy, + const Dim& dim, + int size) { + dx->device(place) = dy->broadcast(dim) * y->broadcast(dim) * x->inverse(); + } +}; + +struct MaxOrMinGradFunctor { + template + void operator()(const DeviceContext& place, + X* x, + Y* y, + DX* dx, + DY* dy, + const Dim& dim, + int size) { + auto equals = (*x) == y->broadcast(dim); + auto ones = dx->constant(1); + auto zeros = dx->constant(0); + // If there are multiple minimum or maximum elements, the subgradient of + // each is the set [0, 1], and we pass gradient to all of them here. + dx->device(place) = dy->broadcast(dim) * equals.select(ones, zeros); + } +}; + } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/reduce_grad_functions.h b/paddle/phi/kernels/funcs/reduce_grad_functions.h index 3488b6f2f86b20e0b758f3aa75a6739c40cd81db..11197a52261d7d0fd7618d2c7c0de09b57abe0d8 100644 --- a/paddle/phi/kernels/funcs/reduce_grad_functions.h +++ b/paddle/phi/kernels/funcs/reduce_grad_functions.h @@ -41,14 +41,14 @@ void ReduceGradFunctor(const Context& dev_ctx, Eigen::array broadcast_dim; for (size_t i = 0; i < D; ++i) broadcast_dim[i] = 1; - int broad_cats_times = 1; + int broad_cast_times = 1; for (size_t i = 0; i < dims_ref.size(); ++i) { if (dims_ref[i] < 0) { dims_ref[i] = x_rank + dims_ref[i]; } reduced_dims_v[dims_ref[i]] = 1; broadcast_dim[dims_ref[i]] = x_dims[dims_ref[i]]; - broad_cats_times *= x_dims[dims_ref[i]]; + broad_cast_times *= x_dims[dims_ref[i]]; } auto reduced_dims = phi::make_ddim(reduced_dims_v); auto x_reduce = EigenTensor::From(input1, reduced_dims); @@ -62,7 +62,7 @@ void ReduceGradFunctor(const Context& dev_ctx, &x_grad, &x_reduce_grad, broadcast_dim, - broad_cats_times); + broad_cast_times); } inline void GetOriginDimFromShuffled(const DDim& src_dim, diff --git a/paddle/phi/kernels/gpu/reduce_grad.h b/paddle/phi/kernels/gpu/reduce_grad.h index d21c8a3fa46f81c046c722db50ac62fb57cf64f4..e32101b73728f637da0626d691018842aedd62e7 100644 --- a/paddle/phi/kernels/gpu/reduce_grad.h +++ b/paddle/phi/kernels/gpu/reduce_grad.h @@ -43,5 +43,59 @@ void ReduceGrad(const GPUContext& dev_ctx, })); } +template class TransformOp> +void ReduceGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DataType in_dtype, + DataType out_dtype, + DenseTensor* x_grad) { + auto* in_x = &x; + auto* d_out = &out_grad; + auto* d_x = x_grad; + + auto pt_out_dtype = in_dtype; + + // get reduce_dim and reduce_num for reduce_mean_grad + int dim_size = in_x->dims().size(); + std::vector reduce_dims = + funcs::details::GetReduceDim(dims, dim_size, reduce_all); + + auto update_dims = vectorize(d_x->dims()); + int reduce_num = 1; + for (auto i : reduce_dims) { + reduce_num *= (in_x->dims())[i]; + update_dims[i] = 1; + } + // make new tensor + DenseTensor new_d_out(d_out->dtype()); + new_d_out.ShareDataWith(*d_out); + new_d_out.Resize(phi::make_ddim(update_dims)); + if (in_dtype != DataType::UNDEFINED) { + dev_ctx.Alloc(d_x, in_dtype); + } else { + dev_ctx.Alloc(d_x, d_out->dtype()); + } + + auto pt_d_out = new_d_out; + auto pt_d_x = *d_x; + if (in_dtype == DataType::UNDEFINED) { + pt_out_dtype = d_out->dtype(); + } + using MPType = typename kps::details::MPTypeTrait::Type; + + phi::ReduceGrad>( + dev_ctx, + &pt_d_out, + &pt_d_x, + pt_out_dtype, + TransformOp(reduce_num)); +} + } // namespace phi #endif diff --git a/paddle/phi/kernels/gpu/reduce_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..5256048267ea19a4cb12387ebbc582a2df1bd1b1 --- /dev/null +++ b/paddle/phi/kernels/gpu/reduce_grad_kernel.cu @@ -0,0 +1,119 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/reduce_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/reduce_function.h" +#include "paddle/phi/kernels/gpu/reduce_grad.h" +#include "paddle/phi/kernels/impl/reduce_max_grad_kernel_impl.h" +#include "paddle/phi/kernels/impl/reduce_min_grad_kernel_impl.h" +#include "paddle/phi/kernels/impl/reduce_prod_grad_kernel_impl.h" + +namespace phi { + +template +void ReduceSumGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DataType in_dtype, + DataType out_dtype, + DenseTensor* x_grad) { + ReduceGradKernel(dev_ctx, + x, + out_grad, + dims, + keep_dim, + reduce_all, + in_dtype, + out_dtype, + x_grad); +} + +template +void ReduceMeanGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DataType in_dtype, + DataType out_dtype, + DenseTensor* x_grad) { + ReduceGradKernel(dev_ctx, + x, + out_grad, + dims, + keep_dim, + reduce_all, + in_dtype, + out_dtype, + x_grad); +} + +} // namespace phi + +PD_REGISTER_KERNEL(sum_grad, + GPU, + ALL_LAYOUT, + phi::ReduceSumGradKernel, + bool, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16, + int, + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} + +PD_REGISTER_KERNEL(mean_grad, + GPU, + ALL_LAYOUT, + phi::ReduceMeanGradKernel, + bool, + float, + double, + phi::dtype::float16) {} + +PD_REGISTER_KERNEL(prod_grad, + GPU, + ALL_LAYOUT, + phi::ReduceProdGradKernel, + float, + double, + int, + int64_t) {} + +PD_REGISTER_KERNEL(max_grad, + GPU, + ALL_LAYOUT, + phi::ReduceMaxGradKernel, + float, + double, + int, + int64_t) {} + +PD_REGISTER_KERNEL(min_grad, + GPU, + ALL_LAYOUT, + phi::ReduceMinGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu deleted file mode 100644 index 9f4ddc3cf37a744355f6f79b7cd18b3d06b80062..0000000000000000000000000000000000000000 --- a/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/reduce_sum_grad_kernel.h" - -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/funcs/reduce_function.h" -#include "paddle/phi/kernels/gpu/reduce_grad.h" - -namespace phi { - -template -void ReduceSumGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& out_grad, - const std::vector& dims, - bool keep_dim, - bool reduce_all, - DataType in_dtype, - DataType out_dtype, - DenseTensor* x_grad) { - auto* in_x = &x; - auto* d_out = &out_grad; - auto* d_x = x_grad; - - auto pt_out_dtype = in_dtype; - - // get reduce_dim and reduce_num for reduce_mean_grad - int dim_size = in_x->dims().size(); - std::vector reduce_dims = - funcs::details::GetReduceDim(dims, dim_size, reduce_all); - - auto update_dims = vectorize(d_x->dims()); - int reduce_num = 1; - for (auto i : reduce_dims) { - reduce_num *= (in_x->dims())[i]; - update_dims[i] = 1; - } - // make new tensor - DenseTensor new_d_out(d_out->dtype()); - new_d_out.ShareDataWith(*d_out); - new_d_out.Resize(phi::make_ddim(update_dims)); - if (in_dtype != DataType::UNDEFINED) { - dev_ctx.Alloc(d_x, in_dtype); - } else { - dev_ctx.Alloc(d_x, d_out->dtype()); - } - - auto pt_d_out = new_d_out; - auto pt_d_x = *d_x; - if (in_dtype == DataType::UNDEFINED) { - pt_out_dtype = d_out->dtype(); - } - using MPType = typename kps::details::MPTypeTrait::Type; - - phi::ReduceGrad>( - dev_ctx, - &pt_d_out, - &pt_d_x, - pt_out_dtype, - kps::IdentityFunctor(reduce_num)); -} - -} // namespace phi - -PD_REGISTER_KERNEL(sum_grad, - GPU, - ALL_LAYOUT, - phi::ReduceSumGradKernel, - bool, - float, - double, - phi::dtype::float16, - phi::dtype::bfloat16, - int, - int64_t, - phi::dtype::complex, - phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/reduce_grad.h b/paddle/phi/kernels/impl/reduce_grad.h similarity index 100% rename from paddle/phi/kernels/cpu/reduce_grad.h rename to paddle/phi/kernels/impl/reduce_grad.h diff --git a/paddle/phi/kernels/reduce_sum_grad_kernel.h b/paddle/phi/kernels/impl/reduce_max_grad_kernel_impl.h similarity index 51% rename from paddle/phi/kernels/reduce_sum_grad_kernel.h rename to paddle/phi/kernels/impl/reduce_max_grad_kernel_impl.h index ab4d63297efffc70710e496efa08f4b9c7e5f7ce..4a74416e3916492e6d3a40e09ca347db485fff7c 100644 --- a/paddle/phi/kernels/reduce_sum_grad_kernel.h +++ b/paddle/phi/kernels/impl/reduce_max_grad_kernel_impl.h @@ -14,19 +14,34 @@ #pragma once -#include "paddle/phi/common/data_type.h" -#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/reduce_grad_kernel.h" + +#include "paddle/phi/kernels/funcs/reduce_functor.h" +#include "paddle/phi/kernels/impl/reduce_grad.h" + namespace phi { template -void ReduceSumGradKernel(const Context& dev_ctx, +void ReduceMaxGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& out_grad, + const DenseTensor& out, const std::vector& dims, bool keep_dim, bool reduce_all, DataType in_dtype, DataType out_dtype, - DenseTensor* x_grad); + DenseTensor* x_grad) { + ReduceGradKernel(dev_ctx, + x, + out_grad, + out, + dims, + keep_dim, + reduce_all, + in_dtype, + out_dtype, + x_grad); +} } // namespace phi diff --git a/paddle/phi/kernels/impl/reduce_min_grad_kernel_impl.h b/paddle/phi/kernels/impl/reduce_min_grad_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..baaa544f137366f1e0343c25bc373cc08350f7fd --- /dev/null +++ b/paddle/phi/kernels/impl/reduce_min_grad_kernel_impl.h @@ -0,0 +1,47 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/reduce_grad_kernel.h" + +#include "paddle/phi/kernels/funcs/reduce_functor.h" +#include "paddle/phi/kernels/impl/reduce_grad.h" + +namespace phi { + +template +void ReduceMinGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const DenseTensor& out, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DataType in_dtype, + DataType out_dtype, + DenseTensor* x_grad) { + ReduceGradKernel(dev_ctx, + x, + out_grad, + out, + dims, + keep_dim, + reduce_all, + in_dtype, + out_dtype, + x_grad); +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/reduce_prod_grad_kernel_impl.h b/paddle/phi/kernels/impl/reduce_prod_grad_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..6b93e98cec0168ab55e15e3401a72738f79d3a07 --- /dev/null +++ b/paddle/phi/kernels/impl/reduce_prod_grad_kernel_impl.h @@ -0,0 +1,47 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/reduce_grad_kernel.h" + +#include "paddle/phi/kernels/funcs/reduce_functor.h" +#include "paddle/phi/kernels/impl/reduce_grad.h" + +namespace phi { + +template +void ReduceProdGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const DenseTensor& out, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DataType in_dtype, + DataType out_dtype, + DenseTensor* x_grad) { + ReduceGradKernel(dev_ctx, + x, + out_grad, + out, + dims, + keep_dim, + reduce_all, + in_dtype, + out_dtype, + x_grad); +} + +} // namespace phi diff --git a/paddle/phi/kernels/reduce_grad_kernel.h b/paddle/phi/kernels/reduce_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..ee6f3d19a094d29546e82e7138933eceb96459d0 --- /dev/null +++ b/paddle/phi/kernels/reduce_grad_kernel.h @@ -0,0 +1,79 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/dense_tensor.h" +namespace phi { + +template +void ReduceSumGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DataType in_dtype, + DataType out_dtype, + DenseTensor* x_grad); + +template +void ReduceMeanGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DataType in_dtype, + DataType out_dtype, + DenseTensor* x_grad); + +template +void ReduceProdGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const DenseTensor& out, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DataType in_dtype, + DataType out_dtype, + DenseTensor* x_grad); + +template +void ReduceMaxGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const DenseTensor& out, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DataType in_dtype, + DataType out_dtype, + DenseTensor* x_grad); + +template +void ReduceMinGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const DenseTensor& out, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DataType in_dtype, + DataType out_dtype, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/reduce_kernel.h b/paddle/phi/kernels/reduce_kernel.h index 75f52c36beb76abcd0cc05a7b46935a56d35da64..69bcb47bc98eadd46eeff5c1f92ccf9cf0c9a9d3 100644 --- a/paddle/phi/kernels/reduce_kernel.h +++ b/paddle/phi/kernels/reduce_kernel.h @@ -16,7 +16,6 @@ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/infermeta/unary.h" -#include "paddle/phi/kernels/empty_kernel.h" namespace phi { template diff --git a/paddle/phi/ops/compat/reduce_sig.cc b/paddle/phi/ops/compat/reduce_sig.cc index 789496ccbd01c12504e1aeb9f89b60bf94a091c9..4bca0523801c1a94f90197c93cc495c2c4f56eeb 100644 --- a/paddle/phi/ops/compat/reduce_sig.cc +++ b/paddle/phi/ops/compat/reduce_sig.cc @@ -136,6 +136,42 @@ KernelSignature ReduceSumGradOpArgumentMapping( {GradVarName("X")}); } +KernelSignature ReduceMeanGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "mean_grad", + {"X", GradVarName("Out")}, + {"dim", "keep_dim", "reduce_all", "in_dtype", "out_dtype"}, + {GradVarName("X")}); +} + +KernelSignature ReduceMaxGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "max_grad", + {"X", GradVarName("Out"), "Out"}, + {"dim", "keep_dim", "reduce_all", "in_dtype", "out_dtype"}, + {GradVarName("X")}); +} + +KernelSignature ReduceMinGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "min_grad", + {"X", GradVarName("Out"), "Out"}, + {"dim", "keep_dim", "reduce_all", "in_dtype", "out_dtype"}, + {GradVarName("X")}); +} + +KernelSignature ReduceProdGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "prod_grad", + {"X", GradVarName("Out"), "Out"}, + {"dim", "keep_dim", "reduce_all", "in_dtype", "out_dtype"}, + {GradVarName("X")}); +} + } // namespace phi PD_REGISTER_BASE_KERNEL_NAME(reduce_sum, sum); @@ -147,6 +183,10 @@ PD_REGISTER_BASE_KERNEL_NAME(reduce_all, all); PD_REGISTER_BASE_KERNEL_NAME(reduce_any, any); PD_REGISTER_BASE_KERNEL_NAME(reduce_sum_grad, sum_grad); +PD_REGISTER_BASE_KERNEL_NAME(reduce_mean_grad, mean_grad); +PD_REGISTER_BASE_KERNEL_NAME(reduce_prod_grad, prod_grad); +PD_REGISTER_BASE_KERNEL_NAME(reduce_max_grad, max_grad); +PD_REGISTER_BASE_KERNEL_NAME(reduce_min_grad, min_grad); PD_REGISTER_ARG_MAPPING_FN(reduce_sum, phi::ReduceSumOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(reduce_mean, phi::ReduceMeanOpArgumentMapping); @@ -158,3 +198,11 @@ PD_REGISTER_ARG_MAPPING_FN(reduce_any, phi::ReduceAnyOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(reduce_sum_grad, phi::ReduceSumGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(reduce_mean_grad, + phi::ReduceMeanGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(reduce_prod_grad, + phi::ReduceProdGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(reduce_max_grad, + phi::ReduceMaxGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(reduce_min_grad, + phi::ReduceMinGradOpArgumentMapping);