diff --git a/paddle/fluid/operators/reduce_ops/reduce_amax_op.cc b/paddle/fluid/operators/reduce_ops/reduce_amax_op.cc index 36776cebfcd46dc0998e8e5e75793656b6631482..61f238f19d1378c03c0804387f520bcbcf86c46d 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_amax_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_amax_op.cc @@ -11,20 +11,28 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" + +namespace ops = paddle::operators; + +class ReduceAMaxOpMaker : public ops::ReduceOpMaker { + protected: + virtual std::string GetName() const { return "reduce_amax"; } + virtual std::string GetOpType() const { return "Reduce reduce_amax"; } +}; + +DECLARE_INFER_SHAPE_FUNCTOR(reduce_amax, + ReduceAMaxInferShapeFunctor, + PD_INFER_META(phi::ReduceInferMetaBase)); -REGISTER_REDUCE_OP(reduce_amax); -REGISTER_OP_CPU_KERNEL( +REGISTER_OPERATOR( reduce_amax, - ops::ReduceKernel, - ops::ReduceKernel, - ops::ReduceKernel, - ops::ReduceKernel); -REGISTER_OP_CPU_KERNEL( - reduce_amax_grad, - ops::ReduceGradKernel, - ops::ReduceGradKernel, - ops::ReduceGradKernel, - ops:: - ReduceGradKernel); + ops::ReduceOp, + ReduceAMaxOpMaker, + paddle::framework::DefaultGradOpMaker, + paddle::framework::DefaultGradOpMaker, + ReduceAMaxInferShapeFunctor); +REGISTER_OPERATOR(reduce_amax_grad, ops::ReduceGradOp) diff --git a/paddle/fluid/operators/reduce_ops/reduce_amax_op.kps b/paddle/fluid/operators/reduce_ops/reduce_amax_op.kps deleted file mode 100644 index 09987279184694d234bdaf0bee5e0a3478c2ab1a..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/reduce_ops/reduce_amax_op.kps +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef PADDLE_WITH_XPU_KP -#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" -#endif - -#include "paddle/fluid/operators/reduce_ops/reduce_op.h" -#include "paddle/phi/core/kernel_registry.h" - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -#ifdef PADDLE_WITH_XPU_KP -REGISTER_OP_KERNEL( - reduce_amax, KP, plat::XPUPlace, - ops::ReduceCudaKernel); -#else -REGISTER_OP_CUDA_KERNEL( - reduce_amax, - ops::ReduceCudaKernel, - ops::ReduceCudaKernel, - ops::ReduceCudaKernel, - ops::ReduceCudaKernel); -#endif diff --git a/paddle/fluid/operators/reduce_ops/reduce_amax_op.part.cu b/paddle/fluid/operators/reduce_ops/reduce_amax_op.part.cu deleted file mode 100644 index d19819f17dc775338da464bc4c36f1ab6baec5cc..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/reduce_ops/reduce_amax_op.part.cu +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/reduce_ops/reduce_op.h" - -template -using CUDAReduceMaxGradKernel = - ops::ReduceCudaAMaxAMinGradKernel; -REGISTER_OP_CUDA_KERNEL(reduce_amax_grad, - CUDAReduceMaxGradKernel, - CUDAReduceMaxGradKernel, - CUDAReduceMaxGradKernel, - CUDAReduceMaxGradKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_amin_op.cc b/paddle/fluid/operators/reduce_ops/reduce_amin_op.cc index bb99ca9b17e7ea80f8d7b386a14b86b467a78df1..aac8414ac197d16f82b329d811a9ca149ab3cbea 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_amin_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_amin_op.cc @@ -11,20 +11,28 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" + +namespace ops = paddle::operators; + +class ReduceAMinOpMaker : public ops::ReduceOpMaker { + protected: + virtual std::string GetName() const { return "reduce_amin"; } + virtual std::string GetOpType() const { return "Reduce reduce_amin"; } +}; + +DECLARE_INFER_SHAPE_FUNCTOR(reduce_amin, + ReduceAMinInferShapeFunctor, + PD_INFER_META(phi::ReduceInferMetaBase)); -REGISTER_REDUCE_OP(reduce_amin); -REGISTER_OP_CPU_KERNEL( +REGISTER_OPERATOR( reduce_amin, - ops::ReduceKernel, - ops::ReduceKernel, - ops::ReduceKernel, - ops::ReduceKernel); -REGISTER_OP_CPU_KERNEL( - reduce_amin_grad, - ops::ReduceGradKernel, - ops::ReduceGradKernel, - ops::ReduceGradKernel, - ops:: - ReduceGradKernel); + ops::ReduceOp, + ReduceAMinOpMaker, + paddle::framework::DefaultGradOpMaker, + paddle::framework::DefaultGradOpMaker, + ReduceAMinInferShapeFunctor); +REGISTER_OPERATOR(reduce_amin_grad, ops::ReduceGradOp) diff --git a/paddle/fluid/operators/reduce_ops/reduce_amin_op.kps b/paddle/fluid/operators/reduce_ops/reduce_amin_op.kps deleted file mode 100644 index 5e1139396d90cb82d8db72d4fda18806daa597e8..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/reduce_ops/reduce_amin_op.kps +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef PADDLE_WITH_XPU_KP -#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" -#endif - -#include "paddle/fluid/operators/reduce_ops/reduce_op.h" -#include "paddle/phi/core/kernel_registry.h" - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -#ifdef PADDLE_WITH_XPU_KP -REGISTER_OP_KERNEL( - reduce_amin, KP, plat::XPUPlace, - ops::ReduceCudaKernel); -#else -REGISTER_OP_CUDA_KERNEL( - reduce_amin, - ops::ReduceCudaKernel, - ops::ReduceCudaKernel, - ops::ReduceCudaKernel, - ops::ReduceCudaKernel); -#endif diff --git a/paddle/fluid/operators/reduce_ops/reduce_amin_op.part.cu b/paddle/fluid/operators/reduce_ops/reduce_amin_op.part.cu deleted file mode 100644 index f5580d784b5896c128439b5ac7e63a7163d6c6bd..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/reduce_ops/reduce_amin_op.part.cu +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/reduce_ops/reduce_op.h" - -template -using CUDAReduceMinGradKernel = - ops::ReduceCudaAMaxAMinGradKernel; -REGISTER_OP_CUDA_KERNEL(reduce_amin_grad, - CUDAReduceMinGradKernel, - CUDAReduceMinGradKernel, - CUDAReduceMinGradKernel, - CUDAReduceMinGradKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_min_max_op.h b/paddle/fluid/operators/reduce_ops/reduce_min_max_op.h index 3d0f7bd08f9301bbe31441ed3e79747690bf59f8..a458dd09f4aaa4761cb8dac31764b7ea7f7b8c97 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_min_max_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_min_max_op.h @@ -55,120 +55,5 @@ struct MaxOrMinGradFunctor { } }; -#define HANDLE_AXIS_DIM(BROADCAST_DIM, AXIS_DIM) \ - if (broadcast_dim_size == BROADCAST_DIM && rank == AXIS_DIM) { \ - AMaxOrAMinAxisIsListGradFunctor( \ - place, x, y, dx, dy, dim, axis_dim); \ - } - -template -void AMaxOrAMinAxisIsListGradFunctor(const DeviceContext& place, - X* x, - Y* y, - DX* dx, - DY* dy, - const Dim& dim, - const std::vector& axis_dim) { - // R is x->dimensions().size(); - // D is axis_dim->dimensions().size(); - auto axis = Eigen::array(); - auto reshape_x = Eigen::array(); - auto reshape_y = Eigen::array(); - - for (int i = 0; i < D; i++) axis[i] = axis_dim[i]; - for (int i = 0; i < R; i++) { - reshape_x[i] = x->dimensions()[i]; - reshape_y[i] = y->dimensions()[i]; - } - - auto equals = (*x) == y->broadcast(dim); - auto ones = dx->constant(1); - auto zeros = dx->constant(0); - auto mask = equals.select(ones, zeros); - dx->device(place) = - dy->broadcast(dim) * mask / - mask.reshape(reshape_x).sum(axis).reshape(reshape_y).broadcast(dim); -} - -struct AMaxOrAMinGradFunctor { - template - void operator()(const DeviceContext& place, - X* x, - Y* y, - DX* dx, - DY* dy, - const Dim& dim, - int size) { - auto equals = (*x) == y->broadcast(dim); - auto ones = dx->constant(1); - auto zeros = dx->constant(0); - auto mask = equals.select(ones, zeros); - - // If there are multiple minimum or maximum elements, - // we evenly distribute gradient between these equal values - size_t x_numel = 1; - for (size_t i = 0; i < x->dimensions().size(); i++) - x_numel *= x->dimensions()[i]; - // reduce_all - if (size == static_cast(x_numel)) { - auto equal_number = mask.sum() - .reshape(Eigen::array({1})) - .broadcast(Eigen::array({size})); - dx->device(place) = dy->broadcast(dim) * mask / equal_number; - return; - } - - // compute forward reduce axis_dim by dim (which is broadcast_dim) - std::vector axis_dim; - int broadcast_dim_size = static_cast(dim.size()); - for (int i = 0; i < broadcast_dim_size; i++) { - if (dim[i] > 1) { - axis_dim.push_back(i); - } - } - - int rank = static_cast(axis_dim.size()); - // axis is a int element - if (rank == 1) { - auto axis = Eigen::array({axis_dim[0]}); - dx->device(place) = - dy->broadcast(dim) * mask / - mask.sum(axis).reshape(dy->dimensions()).broadcast(dim); - return; - } - // axis is list, HANDLE_AXIS_DIM(broadcast_dim_size, rank) - HANDLE_AXIS_DIM(3, 2); - HANDLE_AXIS_DIM(4, 2); - HANDLE_AXIS_DIM(4, 3); - // comments for accelerating compiling temporarily. - // HANDLE_AXIS_DIM(5, 2); - // HANDLE_AXIS_DIM(5, 3); - // HANDLE_AXIS_DIM(5, 4); - // HANDLE_AXIS_DIM(6, 2); - // HANDLE_AXIS_DIM(6, 3); - // HANDLE_AXIS_DIM(6, 4); - // HANDLE_AXIS_DIM(6, 5); - } -}; - } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index e9bc3905a22ee30cfb5dae3efa5a5ee53e463fbb..9e53a6b56de5cab3e62a315523079a407b23d189 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -838,87 +838,6 @@ struct DivideFunctor { inline HOSTDEVICE T operator()(const T a, const T b) const { return a / b; } }; - -template class TransformOp> -class ReduceCudaAMaxAMinGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - bool reduce_all = context.Attr("reduce_all"); - std::vector dims = context.Attr>("dim"); - auto* in_x = context.Input("X"); - auto* out_y = context.Input("Out"); - auto* d_out = - context.Input(framework::GradVarName("Out")); - auto* d_x = context.Output(framework::GradVarName("X")); - auto out_dtype = context.Attr("in_dtype"); - auto pt_out_dtype = framework::TransToPhiDataType( - static_cast(out_dtype)); - // get reduce_dim and reduce_num for reduce_mean_grad - int dim_size = in_x->dims().size(); - std::vector reduce_dims = GetReduceDim(dims, dim_size, reduce_all); - auto update_dims = vectorize(d_x->dims()); - int reduce_num = 1; - for (auto i : reduce_dims) { - reduce_num *= (in_x->dims())[i]; - update_dims[i] = 1; - } - auto& dev_ctx = context.cuda_device_context(); - - // make new tensor reduce_out - phi::DenseTensor new_y(out_y->type()); - new_y.ShareDataWith(*out_y); - new_y.Resize(phi::make_ddim(update_dims)); - - // make new tensor d_out - phi::DenseTensor new_dout(d_out->type()); - new_dout.ShareDataWith(*d_out); - new_dout.Resize(phi::make_ddim(update_dims)); - d_x->mutable_data(dev_ctx.GetPlace(), d_out->dtype()); - - auto new_in = paddle::experimental::MakePhiDenseTensor(*in_x); - auto new_in_tensor = new_in.get(); - - auto new_dx = paddle::experimental::MakePhiDenseTensor(*d_x); - auto new_dx_tensor = new_dx.get(); - - // make equal_out - phi::DenseTensor* equal_out = new phi::DenseTensor(); - equal_out->Resize(in_x->dims()); - dev_ctx.template Alloc(equal_out); - auto equal_out_tensor = *equal_out; - - // make new tensor equal_count - phi::DenseTensor* equal_count = new phi::DenseTensor(); - equal_count->Resize(phi::make_ddim(update_dims)); - dev_ctx.template Alloc(equal_count); - - // compute - // 1. equal_out = Equal(x, y) - std::vector equal_inputs = {&new_y, new_in_tensor}; - std::vector equal_outputs = {&equal_out_tensor}; - phi::funcs::BroadcastKernel( - dev_ctx, equal_inputs, &equal_outputs, 0, EqualFunctor()); - // 2. equal_count = reduceSum(equal_out) - using MPType = typename kps::details::MPTypeTrait::Type; - phi::funcs:: - ReduceKernel>( - dev_ctx, - equal_out_tensor, - equal_count, - kps::IdentityFunctor(), - reduce_dims, - false); - - // 3. dx = Div(dout, equal_out) - std::vector grad_inputs = {&equal_out_tensor, - equal_count}; - std::vector grad_outputs = {new_dx_tensor}; - phi::funcs::BroadcastKernel( - dev_ctx, grad_inputs, &grad_outputs, 0, DivideFunctor()); - delete equal_out; - delete equal_count; - } -}; #endif #endif diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index d123e00f975f6ffaafe345c9ea9a027de4b463ba..25cdd37ddea9d90d3b5757eee2330c34cccd5490 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -129,6 +129,24 @@ kernel : func : allclose +- api : amax + args : (Tensor x, int64_t[] dims={}, bool keep_dim=false) + output : Tensor(out) + infer_meta : + func : ReduceInferMeta + kernel : + func : amax + backward : amax_grad + +- api : amin + args : (Tensor x, int64_t[] dims={}, bool keep_dim=false) + output : Tensor(out) + infer_meta : + func : ReduceInferMeta + kernel : + func : amin + backward : amin_grad + - api : angle args : (Tensor x) output : Tensor diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index f1304e8f8c70cad82292f515c011c8481493f034..c00d9fd9a627b1a4a650447c32f6c9b1885c1306 100644 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -92,6 +92,26 @@ kernel : func : addmm_grad +- backward_api : amax_grad + forward: amax (Tensor x, int64_t[] dims={}, bool keep_dim=false) -> Tensor(out) + args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] dims={}, bool keep_dim=false, bool reduce_all=false) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + func : amax_grad + +- backward_api : amin_grad + forward: amin (Tensor x, int64_t[] dims={}, bool keep_dim=false) -> Tensor(out) + args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] dims={}, bool keep_dim=false, bool reduce_all=false) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + func : amin_grad + - backward_api : angle_grad forward : angle (Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) diff --git a/paddle/phi/kernels/cpu/reduce_amax_grad_kernel.cc b/paddle/phi/kernels/cpu/reduce_amax_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..ffe9133d6d94c9cc284910038666f2bb1d37fb6c --- /dev/null +++ b/paddle/phi/kernels/cpu/reduce_amax_grad_kernel.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/reduce_amax_grad_kernel.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/reduce_functor.h" +#include "paddle/phi/kernels/impl/reduce_grad.h" + +namespace phi { + +template +void ReduceAMaxGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& out_grad, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* x_grad) { + ReduceGradKernel( + dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad); +} + +} // namespace phi + +PD_REGISTER_KERNEL(amax_grad, + CPU, + ALL_LAYOUT, + phi::ReduceAMaxGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/reduce_amax_kernel.cc b/paddle/phi/kernels/cpu/reduce_amax_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..ac3b5ce762e293998788610df6df7ee658d4b4a7 --- /dev/null +++ b/paddle/phi/kernels/cpu/reduce_amax_kernel.cc @@ -0,0 +1,45 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/reduce_amax_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cpu/reduce.h" +#include "paddle/phi/kernels/funcs/reduce_functor.h" + +namespace phi { + +template +void AMaxRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* out) { + auto out_dtype = x.dtype(); + phi::Reduce( + dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); +} + +} // namespace phi + +PD_REGISTER_KERNEL(amax_raw, + CPU, + ALL_LAYOUT, + phi::AMaxRawKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/reduce_amin_grad_kernel.cc b/paddle/phi/kernels/cpu/reduce_amin_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..6bb0e5061cc20a6a30622184436099730c2fb34a --- /dev/null +++ b/paddle/phi/kernels/cpu/reduce_amin_grad_kernel.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/reduce_amin_grad_kernel.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/reduce_functor.h" +#include "paddle/phi/kernels/impl/reduce_grad.h" + +namespace phi { + +template +void ReduceAMinGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& out_grad, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* x_grad) { + ReduceGradKernel( + dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad); +} + +} // namespace phi + +PD_REGISTER_KERNEL(amin_grad, + CPU, + ALL_LAYOUT, + phi::ReduceAMinGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/reduce_amin_kernel.cc b/paddle/phi/kernels/cpu/reduce_amin_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..d8f090f93ffd3a8363e493f76b514107c6504a13 --- /dev/null +++ b/paddle/phi/kernels/cpu/reduce_amin_kernel.cc @@ -0,0 +1,45 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/reduce_amin_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cpu/reduce.h" +#include "paddle/phi/kernels/funcs/reduce_functor.h" + +namespace phi { + +template +void AMinRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* out) { + auto out_dtype = x.dtype(); + phi::Reduce( + dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); +} + +} // namespace phi + +PD_REGISTER_KERNEL(amin_raw, + CPU, + ALL_LAYOUT, + phi::AMinRawKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/funcs/reduce_functor.h b/paddle/phi/kernels/funcs/reduce_functor.h index 9bf1bfecabbf22f5fdc87d9c7426ec7525ac1046..34032e153c0496ca64cfc6ab86cfe5fe64bc37e4 100644 --- a/paddle/phi/kernels/funcs/reduce_functor.h +++ b/paddle/phi/kernels/funcs/reduce_functor.h @@ -14,6 +14,9 @@ #pragma once +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" + namespace phi { namespace funcs { @@ -178,5 +181,120 @@ struct MaxOrMinGradFunctor { } }; +#define HANDLE_AXIS_DIM(BROADCAST_DIM, AXIS_DIM) \ + if (broadcast_dim_size == BROADCAST_DIM && rank == AXIS_DIM) { \ + AMaxOrAMinAxisIsListGradFunctor( \ + place, x, y, dx, dy, dim, axis_dim); \ + } + +template +void AMaxOrAMinAxisIsListGradFunctor(const DeviceContext& place, + X* x, + Y* y, + DX* dx, + DY* dy, + const Dim& dim, + const std::vector& axis_dim) { + // R is x->dimensions().size(); + // D is axis_dim->dimensions().size(); + auto axis = Eigen::array(); + auto reshape_x = Eigen::array(); + auto reshape_y = Eigen::array(); + + for (int i = 0; i < D; i++) axis[i] = axis_dim[i]; + for (int i = 0; i < R; i++) { + reshape_x[i] = x->dimensions()[i]; + reshape_y[i] = y->dimensions()[i]; + } + + auto equals = (*x) == y->broadcast(dim); + auto ones = dx->constant(1); + auto zeros = dx->constant(0); + auto mask = equals.select(ones, zeros); + dx->device(place) = + dy->broadcast(dim) * mask / + mask.reshape(reshape_x).sum(axis).reshape(reshape_y).broadcast(dim); +} + +struct AMaxOrAMinGradFunctor { + template + void operator()(const DeviceContext& place, + X* x, + Y* y, + DX* dx, + DY* dy, + const Dim& dim, + int size) { + auto equals = (*x) == y->broadcast(dim); + auto ones = dx->constant(1); + auto zeros = dx->constant(0); + auto mask = equals.select(ones, zeros); + + // If there are multiple minimum or maximum elements, + // we evenly distribute gradient between these equal values + size_t x_numel = 1; + for (size_t i = 0; i < x->dimensions().size(); i++) + x_numel *= x->dimensions()[i]; + // reduce_all + if (size == static_cast(x_numel)) { + auto equal_number = mask.sum() + .reshape(Eigen::array({1})) + .broadcast(Eigen::array({size})); + dx->device(place) = dy->broadcast(dim) * mask / equal_number; + return; + } + + // compute forward reduce axis_dim by dim (which is broadcast_dim) + std::vector axis_dim; + int broadcast_dim_size = static_cast(dim.size()); + for (int i = 0; i < broadcast_dim_size; i++) { + if (dim[i] > 1) { + axis_dim.push_back(i); + } + } + + int rank = static_cast(axis_dim.size()); + // axis is a int element + if (rank == 1) { + auto axis = Eigen::array({axis_dim[0]}); + dx->device(place) = + dy->broadcast(dim) * mask / + mask.sum(axis).reshape(dy->dimensions()).broadcast(dim); + return; + } + // axis is list, HANDLE_AXIS_DIM(broadcast_dim_size, rank) + HANDLE_AXIS_DIM(3, 2); + HANDLE_AXIS_DIM(4, 2); + HANDLE_AXIS_DIM(4, 3); + // comments for accelerating compiling temporarily. + // HANDLE_AXIS_DIM(5, 2); + // HANDLE_AXIS_DIM(5, 3); + // HANDLE_AXIS_DIM(5, 4); + // HANDLE_AXIS_DIM(6, 2); + // HANDLE_AXIS_DIM(6, 3); + // HANDLE_AXIS_DIM(6, 4); + // HANDLE_AXIS_DIM(6, 5); + } +}; + } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/gpu/reduce_amax_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_amax_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..a75ef42889da2ea90994fe5f41781c9207e26e6a --- /dev/null +++ b/paddle/phi/kernels/gpu/reduce_amax_grad_kernel.cu @@ -0,0 +1,43 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/gpu/reduce_amin_amax_common.h" +#include "paddle/phi/kernels/reduce_max_grad_kernel.h" + +namespace phi { + +template +void ReduceAMaxGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& out_grad, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* x_grad) { + ReduceCudaAMaxAMinGrad( + dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad); +} +} // namespace phi + +PD_REGISTER_KERNEL(amax_grad, + GPU, + ALL_LAYOUT, + phi::ReduceAMaxGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/reduce_amin_amax_common.h b/paddle/phi/kernels/gpu/reduce_amin_amax_common.h new file mode 100644 index 0000000000000000000000000000000000000000..fe3cd89d5bc97490646a14fe29a0a08d01d108a8 --- /dev/null +++ b/paddle/phi/kernels/gpu/reduce_amin_amax_common.h @@ -0,0 +1,103 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/api/lib/utils/tensor_utils.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" +#include "paddle/phi/kernels/funcs/compare_functors.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" +#include "paddle/phi/kernels/funcs/reduce_function.h" + +namespace phi { + +template +void ReduceCudaAMaxAMinGrad(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& out_grad, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* x_grad) { + auto* in_x = &x; + auto* out_y = &out; + auto* d_out = &out_grad; + auto* d_x = x_grad; + // get reduce_dim and reduce_num for reduce_mean_grad + int dim_size = in_x->dims().size(); + auto reduce_dims = funcs::details::GetReduceDim(dims, dim_size, reduce_all); + auto update_dims = vectorize(d_x->dims()); + int reduce_num = 1; + for (auto i : reduce_dims) { + reduce_num *= (in_x->dims())[i]; + update_dims[i] = 1; + } + + // make new tensor reduce_out + phi::DenseTensor new_y(out_y->type()); + new_y.ShareDataWith(*out_y); + new_y.Resize(phi::make_ddim(update_dims)); + + // make new tensor d_out + phi::DenseTensor new_dout(d_out->type()); + new_dout.ShareDataWith(*d_out); + new_dout.Resize(phi::make_ddim(update_dims)); + dev_ctx.Alloc(d_x, d_out->dtype()); + + auto new_in = paddle::experimental::MakePhiDenseTensor(*in_x); + auto new_in_tensor = new_in.get(); + + auto new_dx = paddle::experimental::MakePhiDenseTensor(*d_x); + auto new_dx_tensor = new_dx.get(); + + // make equal_out + phi::DenseTensor* equal_out = new phi::DenseTensor(); + equal_out->Resize(in_x->dims()); + dev_ctx.template Alloc(equal_out); + auto equal_out_tensor = *equal_out; + + // make new tensor equal_count + phi::DenseTensor* equal_count = new phi::DenseTensor(); + equal_count->Resize(phi::make_ddim(update_dims)); + dev_ctx.template Alloc(equal_count); + + // compute + // 1. equal_out = Equal(x, y) + std::vector equal_inputs = {&new_y, new_in_tensor}; + std::vector equal_outputs = {&equal_out_tensor}; + funcs::BroadcastKernel( + dev_ctx, equal_inputs, &equal_outputs, 0, funcs::EqualFunctor()); + // 2. equal_count = reduceSum(equal_out) + using MPType = typename kps::details::MPTypeTrait::Type; + phi::funcs:: + ReduceKernel>( + dev_ctx, + equal_out_tensor, + equal_count, + kps::IdentityFunctor(), + reduce_dims, + false); + + // 3. dx = Div(dout, equal_out) + std::vector grad_inputs = {&equal_out_tensor, + equal_count}; + std::vector grad_outputs = {new_dx_tensor}; + funcs::BroadcastKernel( + dev_ctx, grad_inputs, &grad_outputs, 0, funcs::DivideFunctor()); + delete equal_out; + delete equal_count; +} +} // namespace phi diff --git a/paddle/phi/kernels/gpu/reduce_amin_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_amin_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..152ef494b4c13090f09dafc42d1da3f16229e541 --- /dev/null +++ b/paddle/phi/kernels/gpu/reduce_amin_grad_kernel.cu @@ -0,0 +1,44 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/reduce_amin_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/gpu/reduce_amin_amax_common.h" + +namespace phi { + +template +void ReduceAMinGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& out_grad, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* x_grad) { + ReduceCudaAMaxAMinGrad( + dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad); +} +} // namespace phi + +PD_REGISTER_KERNEL(amin_grad, + GPU, + ALL_LAYOUT, + phi::ReduceAMinGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/kps/reduce_amax_kernel.cu b/paddle/phi/kernels/kps/reduce_amax_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..57197fd9d5b8a24f87d8a41e7a18c4f8f3637656 --- /dev/null +++ b/paddle/phi/kernels/kps/reduce_amax_kernel.cu @@ -0,0 +1,46 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/gpu/reduce.h" +#include "paddle/phi/kernels/reduce_amin_kernel.h" + +namespace phi { + +template +void AMaxRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* out) { + auto out_dtype = x.dtype(); + phi::Reduce( + dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); +} + +} // namespace phi + +#ifdef PADDLE_WITH_XPU_KP +PD_REGISTER_KERNEL(amax_raw, KPS, ALL_LAYOUT, phi::AMaxRawKernel, float) {} +#else +PD_REGISTER_KERNEL(amax_raw, + KPS, + ALL_LAYOUT, + phi::AMaxRawKernel, + float, + double, + int, + int64_t) {} +#endif diff --git a/paddle/phi/kernels/kps/reduce_amin_kernel.cu b/paddle/phi/kernels/kps/reduce_amin_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..230adcc829441824b5b9341da83504d884120b34 --- /dev/null +++ b/paddle/phi/kernels/kps/reduce_amin_kernel.cu @@ -0,0 +1,46 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/reduce_amin_kernel.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/gpu/reduce.h" + +namespace phi { + +template +void AMinRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* out) { + auto out_dtype = x.dtype(); + phi::Reduce( + dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); +} + +} // namespace phi + +#ifdef PADDLE_WITH_XPU_KP +PD_REGISTER_KERNEL(amin_raw, KPS, ALL_LAYOUT, phi::AMinRawKernel, float) {} +#else +PD_REGISTER_KERNEL(amin_raw, + KPS, + ALL_LAYOUT, + phi::AMinRawKernel, + float, + double, + int, + int64_t) {} +#endif diff --git a/paddle/phi/kernels/reduce_amax_grad_kernel.h b/paddle/phi/kernels/reduce_amax_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..82518c11675c39bafef409ab92c6cd605b2ebc3f --- /dev/null +++ b/paddle/phi/kernels/reduce_amax_grad_kernel.h @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void ReduceAMaxGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& out_grad, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/reduce_amax_kernel.cc b/paddle/phi/kernels/reduce_amax_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..acec25d83db6a28f195b1835bf2c14f7cbc9629d --- /dev/null +++ b/paddle/phi/kernels/reduce_amax_kernel.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/reduce_amax_kernel.h" + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void AMaxKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + DenseTensor* out) { + bool reduce_all = false; + AMaxRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + amax, CPU, ALL_LAYOUT, phi::AMaxKernel, float, double, int, int64_t) {} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PD_REGISTER_KERNEL( + amax, GPU, ALL_LAYOUT, phi::AMaxKernel, float, double, int, int64_t) {} +#endif + +#if defined(PADDLE_WITH_XPU_KP) +PD_REGISTER_KERNEL(amax, KPS, ALL_LAYOUT, phi::AMaxKernel, float) {} +#endif diff --git a/paddle/phi/kernels/reduce_amax_kernel.h b/paddle/phi/kernels/reduce_amax_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..79a287b4871364acb1cf9a77006b6967044eb58e --- /dev/null +++ b/paddle/phi/kernels/reduce_amax_kernel.h @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void AMaxRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* out); + +template +void AMaxKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/reduce_amin_grad_kernel.h b/paddle/phi/kernels/reduce_amin_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..96f157e2038628773875ed816ae01bef670c183f --- /dev/null +++ b/paddle/phi/kernels/reduce_amin_grad_kernel.h @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void ReduceAMinGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& out_grad, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/reduce_amin_kernel.cc b/paddle/phi/kernels/reduce_amin_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..28e6e587f40201f6d63f95c9197214d31238a441 --- /dev/null +++ b/paddle/phi/kernels/reduce_amin_kernel.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/reduce_amin_kernel.h" + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void AMinKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + DenseTensor* out) { + bool reduce_all = false; + AMinRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + amin, CPU, ALL_LAYOUT, phi::AMinKernel, float, double, int, int64_t) {} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PD_REGISTER_KERNEL( + amin, GPU, ALL_LAYOUT, phi::AMinKernel, float, double, int, int64_t) {} +#endif + +#if defined(PADDLE_WITH_XPU_KP) +PD_REGISTER_KERNEL(amin, KPS, ALL_LAYOUT, phi::AMinKernel, float) {} +#endif diff --git a/paddle/phi/kernels/reduce_amin_kernel.h b/paddle/phi/kernels/reduce_amin_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..b36351dd5258f8a50ca18525037e3d13810200e8 --- /dev/null +++ b/paddle/phi/kernels/reduce_amin_kernel.h @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void AMinRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* out); + +template +void AMinKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/reduce_sig.cc b/paddle/phi/ops/compat/reduce_sig.cc index a0ba07f5e8e2cf4c9b21819681092a70c1357a7c..e796307c0c9b3aa95c18f038dc23c1051e89bcaf 100644 --- a/paddle/phi/ops/compat/reduce_sig.cc +++ b/paddle/phi/ops/compat/reduce_sig.cc @@ -83,6 +83,22 @@ KernelSignature ReduceMaxOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("unregistered", {}, {}, {}); } +KernelSignature ReduceAMaxOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.IsDenseTensorInput("X")) { + bool reduce_all = paddle::any_cast(ctx.Attr("reduce_all")); + // When ctx is InferShapeArgumentMappingContext, the reduce_all is used in + // InferShape, so we must return the "max_raw" KernelSignature. + // And the InferMeta function(i.e. ReduceInferMetaBase) is accordance with + // the "max_raw" KernelSignature + if (ctx.IsForInferShape() || reduce_all) { + return KernelSignature( + "amax_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"}); + } + return KernelSignature("amax", {"X"}, {"dim", "keep_dim"}, {"Out"}); + } + return KernelSignature("unregistered", {}, {}, {}); +} + KernelSignature ReduceMinOpArgumentMapping(const ArgumentMappingContext& ctx) { if (ctx.IsDenseTensorInput("X")) { bool reduce_all = paddle::any_cast(ctx.Attr("reduce_all")); @@ -99,6 +115,22 @@ KernelSignature ReduceMinOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("unregistered", {}, {}, {}); } +KernelSignature ReduceAMinOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.IsDenseTensorInput("X")) { + bool reduce_all = paddle::any_cast(ctx.Attr("reduce_all")); + // When ctx is InferShapeArgumentMappingContext, the reduce_all is used in + // InferShape, so we must return the "min_raw" KernelSignature. + // And the InferMeta function(i.e. ReduceInferMetaBase) is accordance with + // the "min_raw" KernelSignature + if (ctx.IsForInferShape() || reduce_all) { + return KernelSignature( + "amin_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"}); + } + return KernelSignature("amin", {"X"}, {"dim", "keep_dim"}, {"Out"}); + } + return KernelSignature("unregistered", {}, {}, {}); +} + KernelSignature ReduceAnyOpArgumentMapping(const ArgumentMappingContext& ctx) { if (ctx.IsDenseTensorInput("X")) { bool reduce_all = paddle::any_cast(ctx.Attr("reduce_all")); @@ -151,6 +183,14 @@ KernelSignature ReduceMaxGradOpArgumentMapping( {"X@GRAD"}); } +KernelSignature ReduceAMaxGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("amax_grad", + {"X", "Out", "Out@GRAD"}, + {"dim", "keep_dim", "reduce_all"}, + {"X@GRAD"}); +} + KernelSignature ReduceMinGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("min_grad", @@ -159,6 +199,14 @@ KernelSignature ReduceMinGradOpArgumentMapping( {"X@GRAD"}); } +KernelSignature ReduceAMinGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("amin_grad", + {"X", "Out", "Out@GRAD"}, + {"dim", "keep_dim", "reduce_all"}, + {"X@GRAD"}); +} + KernelSignature ReduceProdGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("prod_grad", @@ -173,6 +221,8 @@ PD_REGISTER_BASE_KERNEL_NAME(reduce_sum, sum); PD_REGISTER_BASE_KERNEL_NAME(reduce_mean, mean); PD_REGISTER_BASE_KERNEL_NAME(reduce_max, max); PD_REGISTER_BASE_KERNEL_NAME(reduce_min, min); +PD_REGISTER_BASE_KERNEL_NAME(reduce_amax, amax); +PD_REGISTER_BASE_KERNEL_NAME(reduce_amin, amin); PD_REGISTER_BASE_KERNEL_NAME(reduce_prod, prod); PD_REGISTER_BASE_KERNEL_NAME(reduce_all, all); PD_REGISTER_BASE_KERNEL_NAME(reduce_any, any); @@ -182,12 +232,16 @@ PD_REGISTER_BASE_KERNEL_NAME(reduce_mean_grad, mean_grad); PD_REGISTER_BASE_KERNEL_NAME(reduce_prod_grad, prod_grad); PD_REGISTER_BASE_KERNEL_NAME(reduce_max_grad, max_grad); PD_REGISTER_BASE_KERNEL_NAME(reduce_min_grad, min_grad); +PD_REGISTER_BASE_KERNEL_NAME(reduce_amax_grad, amax_grad); +PD_REGISTER_BASE_KERNEL_NAME(reduce_amin_grad, amin_grad); PD_REGISTER_ARG_MAPPING_FN(reduce_sum, phi::ReduceSumOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(reduce_mean, phi::ReduceMeanOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(reduce_prod, phi::ReduceProdOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(reduce_max, phi::ReduceMaxOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(reduce_amax, phi::ReduceAMaxOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(reduce_min, phi::ReduceMinOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(reduce_amin, phi::ReduceAMinOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(reduce_all, phi::ReduceAllOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(reduce_any, phi::ReduceAnyOpArgumentMapping); @@ -199,5 +253,9 @@ PD_REGISTER_ARG_MAPPING_FN(reduce_prod_grad, phi::ReduceProdGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(reduce_max_grad, phi::ReduceMaxGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(reduce_amax_grad, + phi::ReduceAMaxGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(reduce_min_grad, phi::ReduceMinGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(reduce_amin_grad, + phi::ReduceAMinGradOpArgumentMapping); diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index c85d9226e67bd871eb73474b49681384bbdd5f3e..f70c9a0c41011699537f080cc7b0ae7d412784ae 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -29,7 +29,7 @@ from .layer_function_generator import _generate_doc_string_, generate_activation import paddle from ..static import Variable -from ..framework import core, in_dygraph_mode, _non_static_mode, LayerHelper +from ..framework import core, in_dygraph_mode, _non_static_mode, LayerHelper, _in_legacy_dygraph from ..fluid.framework import _in_legacy_dygraph from ..framework import _varbase_creator, convert_np_dtype_to_dtype_ from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype @@ -2334,7 +2334,11 @@ def amax(x, axis=None, keepdim=False, name=None): """ reduce_all, axis = _get_reduce_all_value(axis) - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + if reduce_all: + axis = range(len(x.shape)) + return _C_ops.final_state_amax(x, axis, keepdim) + if _in_legacy_dygraph(): return _C_ops.reduce_amax(x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all) helper = LayerHelper('amax', **locals()) @@ -2446,9 +2450,12 @@ def amin(x, axis=None, keepdim=False, name=None): """ reduce_all, axis = _get_reduce_all_value(axis) - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + if reduce_all: + axis = range(len(x.shape)) + return _C_ops.final_state_amin(x, axis, keepdim) + elif _in_legacy_dygraph(): return _C_ops.reduce_amin(x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all) - helper = LayerHelper('amin', **locals()) check_variable_and_dtype( x, 'x', ['float32', 'float64', 'int32', 'int64'], 'amin')