diff --git a/paddle/fluid/eager/tests/performance_tests/benchmark_eager_cuda.cc b/paddle/fluid/eager/tests/performance_tests/benchmark_eager_cuda.cc index 14e7ce8cfcfb4dea0907cd128873223c8e5859a2..9f59f4fc03045aa8a122820e2328a76f33ad9877 100644 --- a/paddle/fluid/eager/tests/performance_tests/benchmark_eager_cuda.cc +++ b/paddle/fluid/eager/tests/performance_tests/benchmark_eager_cuda.cc @@ -186,7 +186,7 @@ TEST(Benchmark, EagerIntermediateMLPCUDA) { USE_OP_ITSELF(scale); USE_OP_ITSELF(matmul_v2); USE_OP_ITSELF(reduce_sum); -USE_OP(reduce_sum_grad); +USE_OP_ITSELF(reduce_sum_grad); USE_OP_ITSELF(elementwise_add); #endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP diff --git a/paddle/fluid/eager/tests/performance_tests/benchmark_fluid_cuda.cc b/paddle/fluid/eager/tests/performance_tests/benchmark_fluid_cuda.cc index e9b7d10070dbf22f10e617d34f143992d19fb659..df77fc1360b4994310ddce242459b435d771ee20 100644 --- a/paddle/fluid/eager/tests/performance_tests/benchmark_fluid_cuda.cc +++ b/paddle/fluid/eager/tests/performance_tests/benchmark_fluid_cuda.cc @@ -248,7 +248,7 @@ TEST(Benchmark, FluidMLPCUDA) { USE_OP_ITSELF(scale); USE_OP_ITSELF(matmul_v2); USE_OP_ITSELF(reduce_sum); -USE_OP(reduce_sum_grad); +USE_OP_ITSELF(reduce_sum_grad); USE_OP_ITSELF(elementwise_add); #endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP diff --git a/paddle/fluid/framework/new_executor/standalone_executor_test.cc b/paddle/fluid/framework/new_executor/standalone_executor_test.cc index a69cc0d6b866d08ae1f5e65e3da41c525e83c47e..219aae71127ed8963b4bfe4e8ee5e7259dbf7d02 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor_test.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor_test.cc @@ -37,7 +37,7 @@ USE_OP(elementwise_mul); USE_OP(softmax_with_cross_entropy); USE_OP_ITSELF(reduce_mean); USE_OP_ITSELF(reduce_sum); -USE_OP(reduce_sum_grad); +USE_OP_ITSELF(reduce_sum_grad); USE_OP(reduce_mean_grad); USE_OP_ITSELF(reshape2_grad); USE_OP(softmax_with_cross_entropy_grad); diff --git a/paddle/fluid/imperative/tests/test_tracer.cc b/paddle/fluid/imperative/tests/test_tracer.cc index d05036f7a12ebdc3db5fbfda5eb50c295c0478e4..0696de908a917cfd7b0ebd1334bb2ad4e12559ee 100644 --- a/paddle/fluid/imperative/tests/test_tracer.cc +++ b/paddle/fluid/imperative/tests/test_tracer.cc @@ -591,5 +591,5 @@ TEST(test_tracer, eager_tracer) { USE_OP(mul); USE_OP(mul_grad); USE_OP_ITSELF(reduce_sum); -USE_OP(reduce_sum_grad); +USE_OP_ITSELF(reduce_sum_grad); USE_OP_ITSELF(elementwise_add); diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc index 6441d53239e955957c3bda85eebeceb5af695e8c..2a78774f3706e73bd8931e80fe020faac58d7ff5 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc @@ -114,16 +114,3 @@ REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp, ops::ReduceSumDoubleOpGradMaker, ops::ReduceSumDoubleOpGradMaker, ops::ReduceSumGradNoNeedBufferVarInferer); - -template -using CPUReduceSumGradKernel = - ops::ReduceSumGradKernel; - -REGISTER_OP_CPU_KERNEL( - reduce_sum_grad, CPUReduceSumGradKernel, - CPUReduceSumGradKernel, CPUReduceSumGradKernel, - CPUReduceSumGradKernel, - CPUReduceSumGradKernel, CPUReduceSumGradKernel, - CPUReduceSumGradKernel>, - CPUReduceSumGradKernel>); diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu b/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu deleted file mode 100644 index 2f6bf127518090916c4b947daf1d1f202fdd5960..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" -#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" - -template -using CUDAReduceSumGradKernel = - ops::ReduceCudaGradKernel; - -REGISTER_OP_CUDA_KERNEL( - reduce_sum_grad, CUDAReduceSumGradKernel, - CUDAReduceSumGradKernel, CUDAReduceSumGradKernel, - CUDAReduceSumGradKernel, - CUDAReduceSumGradKernel, - CUDAReduceSumGradKernel, CUDAReduceSumGradKernel, - CUDAReduceSumGradKernel>, - CUDAReduceSumGradKernel>); diff --git a/paddle/phi/core/compat/op_utils.h b/paddle/phi/core/compat/op_utils.h index f2b7f00cb6b8598fe7736e4cb38f03122f871807..00e9bff9bd5910ceedcca3dfb3a7a64ec88596df 100644 --- a/paddle/phi/core/compat/op_utils.h +++ b/paddle/phi/core/compat/op_utils.h @@ -55,6 +55,7 @@ const std::unordered_set deprecated_op_names({"diag", "expand_grad", "expand_as_grad", "sum", + "sum_grad", "top_k", "top_k_grad"}); diff --git a/paddle/phi/kernels/cpu/reduce_grad.h b/paddle/phi/kernels/cpu/reduce_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..f56d3d3ed50f7e72910115f7ec28914a5eade2e8 --- /dev/null +++ b/paddle/phi/kernels/cpu/reduce_grad.h @@ -0,0 +1,132 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/cast_kernel.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/reduce_grad_functions.h" + +namespace phi { + +template +void ComputeFromInput(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const paddle::optional& out, + const DenseTensor& input2, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DataType in_dtype, + DataType out_dtype, + DenseTensor* x_grad) { + auto* input0 = &x; + auto* input1 = out.get_ptr(); + auto* output = x_grad; + dev_ctx.template Alloc(output); + + // The dims has full dim, set the reduce_all is True + const auto& input_dim_size = x.dims().size(); + std::set dims_set(dims.begin(), dims.end()); + bool full_dim = true; + for (auto i = 0; i < input_dim_size; i++) { + if (dims_set.find(i) == dims_set.end()) { + full_dim = false; + break; + } + } + reduce_all = (reduce_all || full_dim); + // NOTE: EigenTensor::From() uses tensor->data() + // if op has NoNeedBufferVarsInferer, the corresponding kNoNeedBufferX or + // kNoNeedBufferY should set true + // and use fake var that has same dims. + if (kNoNeedBufferX) { + input0 = output; + } + if (kNoNeedBufferY) { + input1 = &input2; + } + + const std::vector const_dims{dims.begin(), dims.end()}; + + // NOTE(dengkaipeng): Out is unnecessary in some reduce kernel and + // not be set as Input in grad Maker, use Out_grad to replace here + if (!input1) input1 = &input2; + Functor functor; + + funcs::LaunchReduceGradKernel(dev_ctx, + input0, + input1, + &input2, + output, + functor, + const_dims, + reduce_all); +} + +template +void ReduceGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const paddle::optional& out, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DataType in_dtype, + DataType out_dtype, + DenseTensor* x_grad) { + if (in_dtype != DataType::UNDEFINED) { + DenseTensorMeta x_grad_meta(out_dtype, x_grad->dims(), x_grad->layout()); + DenseTensor x_grad_tmp = + phi::Empty(dev_ctx, std::move(x_grad_meta)); + ComputeFromInput( + dev_ctx, + x, + out_grad, + out, + out_grad, + dims, + keep_dim, + reduce_all, + in_dtype, + out_dtype, + &x_grad_tmp); + + phi::CastKernel(dev_ctx, x_grad_tmp, in_dtype, x_grad); + } else { + ComputeFromInput( + dev_ctx, + x, + out_grad, + out, + out_grad, + dims, + keep_dim, + reduce_all, + in_dtype, + out_dtype, + x_grad); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc b/paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..efea054555e86be79b5cdb09fe8c4784a1ad0c3b --- /dev/null +++ b/paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc @@ -0,0 +1,139 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/reduce_sum_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cast_kernel.h" +#include "paddle/phi/kernels/cpu/reduce_grad.h" +#include "paddle/phi/kernels/empty_kernel.h" +namespace phi { + +struct SumGradFunctor { + template + void operator()(const DeviceContext& place, + X* x, + Y* y, + DX* dx, + DY* dy, + const Dim& dim, + int size) { + dx->device(place) = dy->broadcast(dim); + } +}; + +template +void ComputeFromInput(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& input2, + const std::vector& dims, + DenseTensor* x_grad) { + auto* input0 = &x; + auto* output = x_grad; + dev_ctx.template Alloc(output); + + const auto* input2_d = input2.data(); + auto* output_d = output->data(); + + // handle reduce_all + if (input2.dims().size() == 1 && input2.dims()[0] == 1) { + for (int64_t i = 0; i < phi::product(input0->dims()); ++i) { + output_d[i] = input2_d[0]; + } + return; + } + + // handle reduce by one dimension + int reduce_dim_index = dims[0]; + if (reduce_dim_index < 0) { + reduce_dim_index += input0->dims().size(); + } + + auto& input_dim = input0->dims(); + int64_t before_dim = 1; + for (int i = 0; i < reduce_dim_index; ++i) { + before_dim *= input_dim[i]; + } + int64_t reduce_dim = input_dim[reduce_dim_index]; + int64_t after_dim = 1; + for (int i = reduce_dim_index + 1; i < input_dim.size(); ++i) { + after_dim *= input_dim[i]; + } + for (int64_t i = 0; i < before_dim; ++i) { + for (int64_t j = 0; j < reduce_dim; ++j) { + for (int64_t k = 0; k < after_dim; ++k) { + output_d[i * reduce_dim * after_dim + j * after_dim + k] = + input2_d[i * after_dim + k]; + } + } + } +} + +template +void ReduceSumGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DataType in_dtype, + DataType out_dtype, + DenseTensor* x_grad) { + if (dims.size() == 1) { + if (out_dtype != DataType::UNDEFINED) { + DenseTensorMeta x_grad_meta(out_dtype, x_grad->dims(), x_grad->layout()); + DenseTensor x_grad_tmp = + phi::Empty(dev_ctx, std::move(x_grad_meta)); + + ComputeFromInput(dev_ctx, x, out_grad, dims, &x_grad_tmp); + + phi::CastKernel(dev_ctx, x_grad_tmp, in_dtype, x_grad); + + } else { + ComputeFromInput(dev_ctx, x, out_grad, dims, x_grad); + } + } + + ReduceGradKernel(dev_ctx, + x, + out_grad, + paddle::none, + dims, + keep_dim, + reduce_all, + in_dtype, + out_dtype, + x_grad); +} + +} // namespace phi + +PD_REGISTER_KERNEL(sum_grad, + CPU, + ALL_LAYOUT, + phi::ReduceSumGradKernel, + bool, + float, + double, + phi::dtype::float16, + int, + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/funcs/reduce_grad_functions.h b/paddle/phi/kernels/funcs/reduce_grad_functions.h new file mode 100644 index 0000000000000000000000000000000000000000..3488b6f2f86b20e0b758f3aa75a6739c40cd81db --- /dev/null +++ b/paddle/phi/kernels/funcs/reduce_grad_functions.h @@ -0,0 +1,177 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/operators/eigen/eigen_function.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/cpu/reduce.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +namespace phi { + +namespace funcs { + +// This ReduceGradFunctor is only the CPU implement. +template +void ReduceGradFunctor(const Context& dev_ctx, + const DenseTensor& input0, + const DenseTensor& input1, + const DenseTensor& input2, + DenseTensor* output, + Functor functor, + const std::vector& dims) { + auto x = phi::EigenTensor::From(input0); + auto x_grad = phi::EigenTensor::From(*output); + auto x_rank = static_cast(x.dimensions().size()); + auto x_dims = input0.dims(); + auto reduced_dims_v = phi::vectorize(x_dims); + std::vector dims_ref = dims; + Eigen::array broadcast_dim; + for (size_t i = 0; i < D; ++i) broadcast_dim[i] = 1; + + int broad_cats_times = 1; + for (size_t i = 0; i < dims_ref.size(); ++i) { + if (dims_ref[i] < 0) { + dims_ref[i] = x_rank + dims_ref[i]; + } + reduced_dims_v[dims_ref[i]] = 1; + broadcast_dim[dims_ref[i]] = x_dims[dims_ref[i]]; + broad_cats_times *= x_dims[dims_ref[i]]; + } + auto reduced_dims = phi::make_ddim(reduced_dims_v); + auto x_reduce = EigenTensor::From(input1, reduced_dims); + auto x_reduce_grad = EigenTensor::From(input2, reduced_dims); + + auto& place = *dev_ctx.eigen_device(); + + functor(place, + &x, + &x_reduce, + &x_grad, + &x_reduce_grad, + broadcast_dim, + broad_cats_times); +} + +inline void GetOriginDimFromShuffled(const DDim& src_dim, + const std::vector& dims, + std::vector* origin_dim) { + DDim shuffled_dims(src_dim); + size_t n = src_dim.size(); + std::vector perm_axis(n); + std::vector dims_64{dims.begin(), dims.end()}; + GetShuffledDim(src_dim, &shuffled_dims, dims_64, &perm_axis); + for (size_t i = 0; i < n; ++i) { + (*origin_dim)[perm_axis[i]] = i; + } +} + +template +void HandleLargeDimGrad(const Context& dev_ctx, + const DenseTensor* x, + const DenseTensor* out, + const DenseTensor* dout, + DenseTensor* dx, + Functor functor, + const std::vector& dims) { + const int64_t unreduced = out->numel(); + const int64_t reduced = x->numel() / unreduced; + DDim out_dim(out->dims()); + DDim x_dim(x->dims()); + // transpose and reshape X + DenseTensor shuffled_x; + std::vector dims_64{dims.begin(), dims.end()}; + GetShuffledInput(dev_ctx, *x, &shuffled_x, dims_64); + DDim shuffled_dim = shuffled_x.dims(); + shuffled_x.Resize({unreduced, reduced}); + // reshape dX {unreduced, reduced} + dx->Resize({unreduced, reduced}); + ReduceGradFunctor( + dev_ctx, shuffled_x, *out, *dout, dx, functor, {1}); + // transpose dX + std::vector origin_axis(x_dim.size()); + GetOriginDimFromShuffled(x_dim, dims, &origin_axis); + DenseTensor dx_tmp; + paddle::framework::TensorCopy(*dx, dev_ctx.GetPlace(), &dx_tmp); + dx_tmp.Resize(shuffled_dim); + dx->Resize(x_dim); + phi::funcs::TransposeNormal trans; + trans(dev_ctx, dx_tmp, dx, origin_axis); +} + +// Only for CPU +template +void LaunchReduceGradKernel(const Context& dev_ctx, + const DenseTensor* input0, + const DenseTensor* input1, + const DenseTensor* input2, + DenseTensor* output, + Functor functor, + const std::vector& dims, + bool reduce_all = false) { + if (reduce_all) { + auto x = phi::EigenVector::Flatten(*input0); + auto x_reduce = phi::EigenVector::Flatten(*input1); + auto x_reduce_grad = phi::EigenVector::Flatten(*input2); + auto x_grad = phi::EigenVector::Flatten(*output); + auto& place = *dev_ctx.eigen_device(); + // *dev_ctx.eigen_device(); + auto broadcast_dim = + Eigen::array({{static_cast(input0->numel())}}); + functor(place, + &x, + &x_reduce, + &x_grad, + &x_reduce_grad, + broadcast_dim, + broadcast_dim[0]); + } else { + int rank = input0->dims().size(); + switch (rank) { + case 1: + ReduceGradFunctor( + dev_ctx, *input0, *input1, *input2, output, functor, dims); + break; + case 2: + ReduceGradFunctor( + dev_ctx, *input0, *input1, *input2, output, functor, dims); + break; + case 3: + ReduceGradFunctor( + dev_ctx, *input0, *input1, *input2, output, functor, dims); + break; + case 4: + ReduceGradFunctor( + dev_ctx, *input0, *input1, *input2, output, functor, dims); + break; + case 5: + ReduceGradFunctor( + dev_ctx, *input0, *input1, *input2, output, functor, dims); + break; + case 6: + ReduceGradFunctor( + dev_ctx, *input0, *input1, *input2, output, functor, dims); + break; + default: + HandleLargeDimGrad( + dev_ctx, input0, input1, input2, output, functor, dims); + break; + } + } +} + +} // namespace funcs + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/reduce_grad.h b/paddle/phi/kernels/gpu/reduce_grad.h index a2b1c8631c7b44fefff5871515b77d9a67d992e2..d21c8a3fa46f81c046c722db50ac62fb57cf64f4 100644 --- a/paddle/phi/kernels/gpu/reduce_grad.h +++ b/paddle/phi/kernels/gpu/reduce_grad.h @@ -23,6 +23,7 @@ #include #include +#include "paddle/phi/api/ext/dispatch.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..9f4ddc3cf37a744355f6f79b7cd18b3d06b80062 --- /dev/null +++ b/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu @@ -0,0 +1,90 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/reduce_sum_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/reduce_function.h" +#include "paddle/phi/kernels/gpu/reduce_grad.h" + +namespace phi { + +template +void ReduceSumGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DataType in_dtype, + DataType out_dtype, + DenseTensor* x_grad) { + auto* in_x = &x; + auto* d_out = &out_grad; + auto* d_x = x_grad; + + auto pt_out_dtype = in_dtype; + + // get reduce_dim and reduce_num for reduce_mean_grad + int dim_size = in_x->dims().size(); + std::vector reduce_dims = + funcs::details::GetReduceDim(dims, dim_size, reduce_all); + + auto update_dims = vectorize(d_x->dims()); + int reduce_num = 1; + for (auto i : reduce_dims) { + reduce_num *= (in_x->dims())[i]; + update_dims[i] = 1; + } + // make new tensor + DenseTensor new_d_out(d_out->dtype()); + new_d_out.ShareDataWith(*d_out); + new_d_out.Resize(phi::make_ddim(update_dims)); + if (in_dtype != DataType::UNDEFINED) { + dev_ctx.Alloc(d_x, in_dtype); + } else { + dev_ctx.Alloc(d_x, d_out->dtype()); + } + + auto pt_d_out = new_d_out; + auto pt_d_x = *d_x; + if (in_dtype == DataType::UNDEFINED) { + pt_out_dtype = d_out->dtype(); + } + using MPType = typename kps::details::MPTypeTrait::Type; + + phi::ReduceGrad>( + dev_ctx, + &pt_d_out, + &pt_d_x, + pt_out_dtype, + kps::IdentityFunctor(reduce_num)); +} + +} // namespace phi + +PD_REGISTER_KERNEL(sum_grad, + GPU, + ALL_LAYOUT, + phi::ReduceSumGradKernel, + bool, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16, + int, + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/reduce_sum_grad_kernel.h b/paddle/phi/kernels/reduce_sum_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..ab4d63297efffc70710e496efa08f4b9c7e5f7ce --- /dev/null +++ b/paddle/phi/kernels/reduce_sum_grad_kernel.h @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/dense_tensor.h" +namespace phi { + +template +void ReduceSumGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DataType in_dtype, + DataType out_dtype, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/ops/compat/reduce_sig.cc b/paddle/phi/ops/compat/reduce_sig.cc index 36798abe4c11b8f57f110bc369f4892c898f8fe9..997f1505bd08d991aa3f13f1ad831c0107664b2f 100644 --- a/paddle/phi/ops/compat/reduce_sig.cc +++ b/paddle/phi/ops/compat/reduce_sig.cc @@ -74,13 +74,25 @@ KernelSignature ReduceMaxOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("unregistered", {}, {}, {}); } +KernelSignature ReduceSumGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "sum_grad", + {"X", GradVarName("Out")}, + {"dim", "keep_dim", "reduce_all", "in_dtype", "out_dtype"}, + {GradVarName("X")}); +} + } // namespace phi PD_REGISTER_BASE_KERNEL_NAME(reduce_sum, sum); PD_REGISTER_BASE_KERNEL_NAME(reduce_mean, mean); PD_REGISTER_BASE_KERNEL_NAME(reduce_max, max); +PD_REGISTER_BASE_KERNEL_NAME(reduce_sum_grad, sum_grad); PD_REGISTER_ARG_MAPPING_FN(reduce_sum, phi::ReduceSumOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(reduce_mean, phi::ReduceMeanOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(reduce_prod, phi::ReduceProdOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(reduce_max, phi::ReduceMaxOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(reduce_sum_grad, + phi::ReduceSumGradOpArgumentMapping);