diff --git a/paddle/fluid/framework/new_executor/standalone_executor_test.cc b/paddle/fluid/framework/new_executor/standalone_executor_test.cc index 2c3359ffa8e46f0d30a01d73fccb95d88771480a..62d87b6917e4059f08dcdf5ccb4eed6434211e43 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor_test.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor_test.cc @@ -54,7 +54,7 @@ USE_OP(slice_grad); USE_OP(lookup_table_grad); USE_OP(sqrt); USE_OP(elementwise_max); -USE_OP(elementwise_div); +USE_OP_ITSELF(elementwise_div); USE_OP(sgd); USE_OP(squared_l2_norm); USE_OP(memcpy_h2d); diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cc b/paddle/fluid/operators/elementwise/elementwise_div_op.cc index 38cd232e4d1d2237cb5da014d11ba69a91cbe917..13fd9b81a8765aea140ad6ca2fc0383151a51dc7 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cc @@ -102,42 +102,6 @@ REGISTER_OPERATOR( REGISTER_OPERATOR(elementwise_div_grad_grad, ops::ElementwiseDivOpDoubleGrad, ops::ElementwiseDoubleGradOpInplaceInferer); -REGISTER_OP_CPU_KERNEL( - elementwise_div, - ops::ElementwiseDivKernel, - ops::ElementwiseDivKernel, - ops::ElementwiseDivKernel, - ops::ElementwiseDivKernel, - ops::ElementwiseDivKernel>, - ops::ElementwiseDivKernel>); -REGISTER_OP_CPU_KERNEL( - elementwise_div_grad, - ops::ElementwiseDivGradKernel, - ops::ElementwiseDivGradKernel, - ops::ElementwiseDivGradKernel, - ops::ElementwiseDivGradKernel, - ops::ElementwiseDivGradKernel>, - ops::ElementwiseDivGradKernel>); - -REGISTER_OP_CPU_KERNEL( - elementwise_div_grad_grad, - ops::ElementwiseDivDoubleGradKernel, - ops::ElementwiseDivDoubleGradKernel, - ops::ElementwiseDivDoubleGradKernel, - ops::ElementwiseDivDoubleGradKernel, - ops::ElementwiseDivDoubleGradKernel>, - ops::ElementwiseDivDoubleGradKernel>); - REGISTER_OP_VERSION(elementwise_div) .AddCheckpoint( R"ROC(Register elementwise_div for adding the attribute of Scale_y)ROC", diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cu b/paddle/fluid/operators/elementwise/elementwise_div_op.cu deleted file mode 100644 index 9eb4b0352e5337e3fdd758d2e95cfa61d1d62724..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cu +++ /dev/null @@ -1,96 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/elementwise/elementwise_div_op.h" - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -namespace paddle { -namespace operators { - -template -typename std::enable_if< - std::is_same::value>::type -ElementwiseDivGrad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - const framework::Tensor* out, const framework::Tensor* dout, - framework::Tensor* dx, framework::Tensor* dy) { - int axis = ctx.Attr("axis"); - const auto& dev_ctx = ctx.template device_context(); - const auto place = ctx.GetPlace(); - if (dx != nullptr && dy != nullptr) { - std::vector ins = {dout, out, y}; - GetGradXAndYOut( - dev_ctx, place, axis, ins, dout, dx, dy, DivGradXYFunctor()); - } else if (dx != nullptr && dy == nullptr) { - std::vector ins = {dout, y}; - GetGradXOrYOut(dev_ctx, place, axis, ins, dout, - dx, DivGradXFunctor()); - } else if (dy != nullptr && dx == nullptr) { - std::vector ins = {dout, out, y}; - GetGradXOrYOut( - dev_ctx, place, axis, ins, dout, dy, DivGradYFunctor()); - } -} - -} // namespace operators -} // namespace paddle - -REGISTER_OP_CUDA_KERNEL( - elementwise_div, - ops::ElementwiseDivKernel, - ops::ElementwiseDivKernel, - ops::ElementwiseDivKernel, - ops::ElementwiseDivKernel, - ops::ElementwiseDivKernel, - ops::ElementwiseDivKernel, - ops::ElementwiseDivKernel>, - ops::ElementwiseDivKernel>); -REGISTER_OP_CUDA_KERNEL( - elementwise_div_grad, - ops::ElementwiseDivGradKernel, - ops::ElementwiseDivGradKernel, - ops::ElementwiseDivGradKernel, - ops::ElementwiseDivGradKernel, - ops::ElementwiseDivGradKernel, - ops::ElementwiseDivGradKernel, - ops::ElementwiseDivGradKernel>, - ops::ElementwiseDivGradKernel>); -REGISTER_OP_CUDA_KERNEL( - elementwise_div_grad_grad, - ops::ElementwiseDivDoubleGradKernel, - ops::ElementwiseDivDoubleGradKernel, - ops::ElementwiseDivDoubleGradKernel, - ops::ElementwiseDivDoubleGradKernel, - ops::ElementwiseDivDoubleGradKernel, - ops::ElementwiseDivDoubleGradKernel, - ops::ElementwiseDivDoubleGradKernel>, - ops::ElementwiseDivDoubleGradKernel>); diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index c58a7f36548a57a1c8e7770fa282470fba4cc140..e9adb9abdb528c187817be641b81ffb6f64833b0 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -20,142 +20,6 @@ limitations under the License. */ namespace paddle { namespace operators { -template -void default_elementwise_sub(const framework::ExecutionContext& ctx, - const framework::Tensor* x, - const framework::Tensor* y, framework::Tensor* z) { - int axis = ctx.Attr("axis"); - auto x_dims = x->dims(); - auto y_dims = y->dims(); - if (x_dims.size() >= y_dims.size()) { - ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, - SubFunctor(), z); - } else { - ElementwiseComputeEx, DeviceContext, T>( - ctx, x, y, axis, InverseSubFunctor(), z); - } -} - -template -void default_elementwise_div(const framework::ExecutionContext& ctx, - const framework::Tensor* x, - const framework::Tensor* y, framework::Tensor* z) { - int axis = ctx.Attr("axis"); - auto x_dims = x->dims(); - auto y_dims = y->dims(); - if (x_dims.size() >= y_dims.size()) { - ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, - DivFunctor(), z); - } else { - ElementwiseComputeEx, DeviceContext, T>( - ctx, x, y, axis, InverseDivFunctor(), z); - } -} - -template -class ElementwiseDivKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* z = ctx.Output("Out"); - z->mutable_data(ctx.GetPlace()); - - auto& dev_ctx = ctx.device_context(); - int axis = ctx.Attr("axis"); - auto pt_x = paddle::experimental::MakePhiDenseTensor(*x); - auto pt_y = paddle::experimental::MakePhiDenseTensor(*y); - auto pt_z = paddle::experimental::MakePhiDenseTensor(*z); - phi::DivideRawKernel( - static_cast::TYPE&>(dev_ctx), - *pt_x.get(), *pt_y.get(), axis, pt_z.get()); - } -}; - -template -struct DivGradDX { - HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout / y; } -}; - -template -struct DivGradDX> { - HOSTDEVICE paddle::platform::complex operator()( - paddle::platform::complex x, paddle::platform::complex y, - paddle::platform::complex out, - paddle::platform::complex dout) const { - paddle::platform::complex y_conj(y.real, -y.imag); - return dout / y_conj; - } -}; - -template -struct DivGradDY { - HOSTDEVICE T operator()(T x, T y, T out, T dout) const { - return -dout * out / y; - } -}; - -template -struct DivGradDY> { - HOSTDEVICE paddle::platform::complex operator()( - paddle::platform::complex x, paddle::platform::complex y, - paddle::platform::complex out, - paddle::platform::complex dout) const { - paddle::platform::complex out_div_y_conj((out / y).real, - -(out / y).imag); - return -dout * out_div_y_conj; - } -}; - -template -struct DivDoubleDY { - HOSTDEVICE T operator()(T x, T y, T out, T dout) const { - return y * out * dout - x * dout; - } -}; - -template -typename std::enable_if< - std::is_same::value>::type -ElementwiseDivGrad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - const framework::Tensor* out, const framework::Tensor* dout, - framework::Tensor* dx, framework::Tensor* dy) { - int axis = ctx.Attr("axis"); - - ElemwiseGradCompute, DivGradDY>( - ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX(), DivGradDY()); -} - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -template -typename std::enable_if< - std::is_same::value>::type -ElementwiseDivGrad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - const framework::Tensor* out, const framework::Tensor* dout, - framework::Tensor* dx, framework::Tensor* dy); -#endif - -template -class ElementwiseDivGradKernel : public ElemwiseGradKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - ElemwiseGradKernel::Compute(ctx); - using Tensor = framework::Tensor; - - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* out = ctx.Input("Out"); - auto* dout = ctx.Input(framework::GradVarName("Out")); - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dy = ctx.Output(framework::GradVarName("Y")); - - ElementwiseDivGrad(ctx, x, y, out, dout, dx, dy); - } -}; - class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -206,80 +70,5 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel { } }; -template -class ElementwiseDivDoubleGradKernel : public framework::OpKernel { - using Tensor = framework::Tensor; - - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* Y = ctx.Input("Y"); - auto* Out = ctx.Input("Out"); - auto* ddX = ctx.Input("DDX"); - auto* ddY = ctx.Input("DDY"); - auto* dX = ctx.Input("DX"); - - auto* dY = ctx.Output(framework::GradVarName("Y")); - auto* dOut = ctx.Output("DOut"); - auto* ddOut = ctx.Output("DDOut"); - - int axis = ctx.Attr("axis"); - - if (dY) dY->mutable_data(Y->dims(), ctx.GetPlace()); - if (dOut) dOut->mutable_data(Out->dims(), ctx.GetPlace()); - if (ddOut) ddOut->mutable_data(Out->dims(), ctx.GetPlace()); - - // ddX_safe == null ? 0 : ddX - // ddY_safe == null ? 0 : ddY - Tensor ddX_safe, ddY_safe; - GetDoubleGradSafeTensor(ctx, dX, ddX, &ddX_safe); - GetDoubleGradSafeTensor(ctx, Y, ddY, &ddY_safe); - - // ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y - // dY = Out * dX * ddY / Y - dX * ddX / Y - // dOut = - dX * ddY - // To save memory, (1) dout can be used as 'tmp' tensor, (2) ddout can - // inplace ddx - Tensor tmp; - if (dOut) { - tmp = *dOut; - } else { - auto& dev_ctx = ctx.template device_context(); - tmp = ctx.AllocateTmpTensor(Out->dims(), dev_ctx); - } - if (dY) { - // dX_div_Y = dX / Y; - Tensor dX_div_Y = tmp; - default_elementwise_div(ctx, dX, Y, &dX_div_Y); - - // NOTE(dengkaipeng): in the following ElemwiseGradCompute, for the - // first output tensor is nullptr, the branch to calculate first - // output tensor will not be activated, DivGradDx function will not - // be called and can be ignored, the first branch has little effect - // on running speed. - - // dY = Out * dX * ddY / Y - dX * ddX / Y - ElemwiseGradCompute, DivDoubleDY>( - ctx, ddX_safe, ddY_safe, *Out, dX_div_Y, axis, nullptr, dY, - DivGradDX(), DivDoubleDY()); - } - - if (ddOut) { - // ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y - default_elementwise_mul(ctx, Out, &ddY_safe, &tmp); - default_elementwise_sub(ctx, &ddX_safe, &tmp, &tmp); - default_elementwise_div(ctx, &tmp, Y, ddOut); - } - - if (dOut) { - // dOut = - dX * ddY - default_elementwise_mul(ctx, dX, &ddY_safe, dOut); - auto& place = - *ctx.template device_context().eigen_device(); - auto dout = framework::EigenVector::Flatten(*dOut); - dout.device(place) = static_cast(-1) * dout; - } - } -}; - } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_functor.h b/paddle/fluid/operators/elementwise/elementwise_functor.h index 86f5be3071c2d1a84f13da1cef74787003e633bb..8e0bf78e9b7f9c08052c1463acf20c4493ffc9e1 100644 --- a/paddle/fluid/operators/elementwise/elementwise_functor.h +++ b/paddle/fluid/operators/elementwise/elementwise_functor.h @@ -90,67 +90,6 @@ struct MinFunctor { template using Complex = paddle::platform::complex; -template -struct DivGradXYFunctor { - inline HOSTDEVICE phi::Array operator()(const InT a, const InT b, - const InT c) { - // dx = dout / y - // dy = - dout * out / y - phi::Array outs; - outs[0] = a / c; - outs[1] = -a * b / c; - return outs; - } -}; - -template -struct DivGradXYFunctor, Complex> { - inline HOSTDEVICE phi::Array, 2> operator()( - const Complex a, const Complex b, const Complex c) { - phi::Array, 2> outs; - Complex c_conj(c.real, -c.imag); - Complex out_div_c_conj((b / c).real, -(b / c).imag); - outs[0] = a / c_conj; - outs[1] = -a * out_div_c_conj; - return outs; - } -}; - -// Float div grad -template -struct DivGradXFunctor { - inline HOSTDEVICE T operator()(const T a, const T b) const { return a / b; } -}; - -// Complex div grad -template -struct DivGradXFunctor> { - inline HOSTDEVICE Complex operator()(const Complex a, - const Complex b) const { - Complex b_conj(b.real, -b.imag); - return a / b_conj; - } -}; - -// Float mul and div -template -struct DivGradYFunctor { - inline HOSTDEVICE T operator()(const T a, const T b, const T c) const { - return -a * b / c; - } -}; - -// Complex mul and div -template -struct DivGradYFunctor> { - inline HOSTDEVICE Complex operator()(const Complex a, - const Complex b, - const Complex c) const { - Complex out_div_c_conj((b / c).real, -(b / c).imag); - return -a * out_div_c_conj; - } -}; - // Fmax template struct FMaxFunctor { diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index 61862aa9f87408048c5d31a13c0be8a013046902..80b07721f0b4d1feb669bfce91127b0887d79391 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -45,6 +45,7 @@ limitations under the License. */ #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/kernels/gpu/elementwise_grad.h" #endif @@ -145,17 +146,9 @@ void ElemwiseGradCompute(const framework::ExecutionContext &ctx, const framework::Tensor &dout, int axis, framework::Tensor *dx, framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) { - const framework::DDim &x_dim = x.dims(); - const framework::DDim &y_dim = y.dims(); const auto &dev_ctx = ctx.template device_context(); - if (x.dims() == y.dims()) { - phi::funcs::ElemwiseGradComputeNoBroadcast( - dev_ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op); - } else { - phi::funcs::ElemwiseGradComputeWithBroadcast( - dev_ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op); - } + phi::funcs::ElemwiseGradCompute( + dev_ctx, x, y, out, dout, axis, dx, dy, dx_op, dy_op); } // It is a common implementation to compute binary calculation with the support @@ -1174,14 +1167,6 @@ static inline std::vector GetReduceDim(const framework::DDim &in, } #if defined(__NVCC__) || defined(__HIPCC__) -template -void ReduceWrapper(const platform::CUDADeviceContext &dev_ctx, int axis, - framework::Tensor *src, framework::Tensor *dst) { - std::vector reduce_dims = GetReduceDim(dst->dims(), src->dims(), axis); - TensorReduceImpl>( - dev_ctx, *src, dst, kps::IdentityFunctor(), reduce_dims, - dev_ctx.stream()); -} template void GetGradXAndYOut(const platform::CUDADeviceContext &dev_ctx, @@ -1189,36 +1174,8 @@ void GetGradXAndYOut(const platform::CUDADeviceContext &dev_ctx, std::vector ins, const framework::Tensor *dout, framework::Tensor *dx, framework::Tensor *dy, Functor func) { - framework::Tensor tmp_dx; - framework::Tensor tmp_dy; - dx->mutable_data(place); - dy->mutable_data(place); - std::vector outs; - if (dx->dims() == dout->dims() && dy->dims() == dout->dims()) { - outs = {dx, dy}; - } else if (dx->dims() != dout->dims() && dy->dims() == dout->dims()) { - tmp_dx.mutable_data(dout->dims(), place); - outs = {&tmp_dx, dy}; - } else if (dx->dims() == dout->dims() && dy->dims() != dout->dims()) { - tmp_dy.mutable_data(dout->dims(), place); - outs = {dx, &tmp_dy}; - } else if (dx->dims() != dout->dims() && dy->dims() != dout->dims()) { - tmp_dy.mutable_data(dout->dims(), place); - tmp_dx.mutable_data(dout->dims(), place); - outs = {&tmp_dx, &tmp_dy}; - } - - paddle::operators::LaunchElementwiseCudaKernel( - dev_ctx, ins, &outs, axis, func); - - if (dx->dims() != dout->dims() && dy->dims() == dout->dims()) { - ReduceWrapper(dev_ctx, axis, &tmp_dx, dx); - } else if (dx->dims() == dout->dims() && dy->dims() != dout->dims()) { - ReduceWrapper(dev_ctx, axis, &tmp_dy, dy); - } else if (dx->dims() != dout->dims() && dy->dims() != dout->dims()) { - ReduceWrapper(dev_ctx, axis, &tmp_dx, dx); - ReduceWrapper(dev_ctx, axis, &tmp_dy, dy); - } + phi::GetGradXAndYOut(dev_ctx, place, axis, ins, *dout, dx, dy, + func); } template @@ -1227,22 +1184,8 @@ void GetGradXOrYOut(const platform::CUDADeviceContext &dev_ctx, std::vector ins, const framework::Tensor *dout, framework::Tensor *dxy, Functor func) { - framework::Tensor tmp_dxy; - dxy->mutable_data(place); - - std::vector outs; - if (dxy->dims() != dout->dims()) { - tmp_dxy.mutable_data(dout->dims(), place); - outs = {&tmp_dxy}; - } else { - outs = {dxy}; - } - - paddle::operators::LaunchElementwiseCudaKernel(dev_ctx, ins, &outs, - axis, func); - if (dxy->dims() != dout->dims()) { - ReduceWrapper(dev_ctx, axis, &tmp_dxy, dxy); - } + phi::GetGradXOrYOut(dev_ctx, place, axis, ins, *dout, dxy, + func); } #endif diff --git a/paddle/fluid/operators/elementwise/test_elementwise_div_grad_grad.cc b/paddle/fluid/operators/elementwise/test_elementwise_div_grad_grad.cc index 9aa206efed8c0111f56b6651e0228acc316b1bfe..7890d634e9941718f3420bd50a7ded453379fc69 100644 --- a/paddle/fluid/operators/elementwise/test_elementwise_div_grad_grad.cc +++ b/paddle/fluid/operators/elementwise/test_elementwise_div_grad_grad.cc @@ -28,7 +28,7 @@ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/place.h" -USE_OP(elementwise_div); +USE_OP_ITSELF(elementwise_div); namespace paddle { namespace operators { diff --git a/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc b/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc index e48ee805959088aa8a6a3da7fcb8fa02b642c1c8..c9177f1c46eac29ed8c411f60cd7b820bdced8f2 100644 --- a/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc @@ -18,7 +18,6 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/cpu/elementwise_grad.h" -#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h" @@ -108,6 +107,20 @@ void SubtractDoubleGradKernel(const Context& dev_ctx, phi::SubtractDoubleGradImpl(dev_ctx, y, ddx, ddy, dout, axis, ddout); } +template +void DivideGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out, + const DenseTensor& dout, + int axis, + DenseTensor* dx, + DenseTensor* dy) { + funcs::ElementwiseGradPreProcess(dout, dx); + phi::funcs::ElemwiseGradCompute, DivGradDY>( + dev_ctx, x, y, out, dout, axis, dx, dy, DivGradDX(), DivGradDY()); +} + } // namespace phi PD_REGISTER_KERNEL(add_grad, @@ -171,3 +184,25 @@ PD_REGISTER_KERNEL(subtract_double_grad, phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} + +PD_REGISTER_KERNEL(divide_grad, + CPU, + ALL_LAYOUT, + phi::DivideGradKernel, + float, + double, + int, + int64_t, + paddle::platform::complex, + paddle::platform::complex) {} + +PD_REGISTER_KERNEL(divide_double_grad, + CPU, + ALL_LAYOUT, + phi::DivideDoubleGradKernel, + float, + double, + int, + int64_t, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/phi/kernels/elementwise_grad_kernel.h b/paddle/phi/kernels/elementwise_grad_kernel.h index a1b296e326f2198a352b7864bcad54d822492d4e..bcd5a98f07ee9f33ccc7198904c5bed88f4985fe 100644 --- a/paddle/phi/kernels/elementwise_grad_kernel.h +++ b/paddle/phi/kernels/elementwise_grad_kernel.h @@ -64,4 +64,25 @@ void SubtractDoubleGradKernel(const Context& dev_ctx, int axis, DenseTensor* ddout); +template +void DivideGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out, + const DenseTensor& dout, + int axis, + DenseTensor* dx, + DenseTensor* dy); + +template +void DivideDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& y, + const DenseTensor& out, + const DenseTensor& dx, + paddle::optional ddx, + paddle::optional ddy, + int axis, + DenseTensor* dy, + DenseTensor* dout, + DenseTensor* ddout); } // namespace phi diff --git a/paddle/phi/kernels/funcs/broadcast_function.h b/paddle/phi/kernels/funcs/broadcast_function.h index aab31cfbd55b64a957ca75840cc6c0bb41e3f8c0..7634c2462738b2b3bdb622e851723aef23045dfd 100644 --- a/paddle/phi/kernels/funcs/broadcast_function.h +++ b/paddle/phi/kernels/funcs/broadcast_function.h @@ -592,5 +592,25 @@ void ElementwiseCompute(const GPUContext &dev_ctx, #endif +template +void DefaultElementwiseOperator(const DeviceContext &dev_ctx, + const DenseTensor &x, + const DenseTensor &y, + DenseTensor *z, + int axis = -1) { + auto x_dims = x.dims(); + auto y_dims = y.dims(); + dev_ctx.template Alloc(z); + if (x_dims.size() >= y_dims.size()) { + funcs::ElementwiseCompute(dev_ctx, x, y, axis, Functor(), z); + } else { + funcs::ElementwiseCompute( + dev_ctx, x, y, axis, InverseFunctor(), z); + } +} + } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/elementwise_functor.h b/paddle/phi/kernels/funcs/elementwise_functor.h index c0a3985cd1713b3bf66b278de0fd17cb408e0f63..5615a450b5c5425e02879f36d3e12608d475d4c1 100644 --- a/paddle/phi/kernels/funcs/elementwise_functor.h +++ b/paddle/phi/kernels/funcs/elementwise_functor.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/phi/common/complex.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/hostdevice.h" @@ -92,5 +93,72 @@ struct InverseDivideFunctor { inline HOSTDEVICE T operator()(const T a, const T b) const { return b / a; } }; +template +using ComplexType = phi::dtype::complex; + +template +struct DivGradXYFunctor { + inline HOSTDEVICE phi::Array operator()(const InT a, + const InT b, + const InT c) { + // dx = dout / y + // dy = - dout * out / y + phi::Array outs; + outs[0] = a / c; + outs[1] = -a * b / c; + return outs; + } +}; + +template +struct DivGradXYFunctor, ComplexType> { + inline HOSTDEVICE phi::Array, 2> operator()( + const ComplexType a, + const ComplexType b, + const ComplexType c) { + phi::Array, 2> outs; + ComplexType c_conj(c.real, -c.imag); + ComplexType out_div_c_conj((b / c).real, -(b / c).imag); + outs[0] = a / c_conj; + outs[1] = -a * out_div_c_conj; + return outs; + } +}; + +// Float div grad +template +struct DivGradXFunctor { + inline HOSTDEVICE T operator()(const T a, const T b) const { return a / b; } +}; + +// ComplexType div grad +template +struct DivGradXFunctor> { + inline HOSTDEVICE ComplexType operator()(const ComplexType a, + const ComplexType b) const { + ComplexType b_conj(b.real, -b.imag); + return a / b_conj; + } +}; + +// Float mul and div +template +struct DivGradYFunctor { + inline HOSTDEVICE T operator()(const T a, const T b, const T c) const { + return -a * b / c; + } +}; + +// ComplexType mul and div +template +struct DivGradYFunctor> { + inline HOSTDEVICE ComplexType operator()(const ComplexType a, + const ComplexType b, + const ComplexType c) const { + ComplexType out_div_c_conj((b / c).real, -(b / c).imag); + return -a * out_div_c_conj; + } +}; + } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/elementwise_grad_base.h b/paddle/phi/kernels/funcs/elementwise_grad_base.h index dff0cfe5b8b90155f62aa6e465c1fd6450d64807..17bf873587381c54efa23883f10e830c296365ec 100644 --- a/paddle/phi/kernels/funcs/elementwise_grad_base.h +++ b/paddle/phi/kernels/funcs/elementwise_grad_base.h @@ -24,6 +24,7 @@ limitations under the License. */ // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" +#include "paddle/phi/kernels/primitive/kernel_primitives.h" #endif @@ -1758,5 +1759,31 @@ void ElemwiseGradComputeWithBroadcast(const GPUContext &ctx, #endif +template +void ElemwiseGradCompute(const DeviceContext &dev_ctx, + const DenseTensor &x, + const DenseTensor &y, + const DenseTensor &out, + const DenseTensor &dout, + int axis, + DenseTensor *dx, + DenseTensor *dy, + DX_OP dx_op, + DY_OP dy_op) { + const DDim &x_dim = x.dims(); + const DDim &y_dim = y.dims(); + if (x.dims() == y.dims()) { + ElemwiseGradComputeNoBroadcast( + dev_ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op); + } else { + ElemwiseGradComputeWithBroadcast( + dev_ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op); + } +} + } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/gpu/elementwise_grad.h b/paddle/phi/kernels/gpu/elementwise_grad.h index 20799f4e37b3bda787f54ed3687120696e0d4537..b356f19555fc426dd5ce184ebe0f5ff39213aa8e 100644 --- a/paddle/phi/kernels/gpu/elementwise_grad.h +++ b/paddle/phi/kernels/gpu/elementwise_grad.h @@ -14,12 +14,101 @@ limitations under the License. */ #pragma once +#include "paddle/phi/common/place.h" #include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/elementwise_grad_base.h" #include "paddle/phi/kernels/funcs/reduce_function.h" namespace phi { +template +void ReduceWrapper(const GPUContext &dev_ctx, + int axis, + DenseTensor *src, + DenseTensor *dst) { + std::vector reduce_dims = + funcs::GetReduceDim(dst->dims(), src->dims(), axis); + funcs::TensorReduceImpl>( + dev_ctx, + *src, + dst, + kps::IdentityFunctor(), + reduce_dims, + dev_ctx.stream()); +} + +template +void GetGradXAndYOut(const GPUContext &dev_ctx, + const Place &place, + int axis, + std::vector ins, + const DenseTensor &dout, + DenseTensor *dx, + DenseTensor *dy, + Functor func) { + DenseTensor tmp_dx; + DenseTensor tmp_dy; + dev_ctx.Alloc(dx); + dev_ctx.Alloc(dy); + std::vector outs; + if (dx->dims() == dout.dims() && dy->dims() == dout.dims()) { + outs = {dx, dy}; + } else if (dx->dims() != dout.dims() && dy->dims() == dout.dims()) { + tmp_dx.Resize(dout.dims()); + dev_ctx.Alloc(&tmp_dx); + outs = {&tmp_dx, dy}; + } else if (dx->dims() == dout.dims() && dy->dims() != dout.dims()) { + tmp_dy.Resize(dout.dims()); + dev_ctx.Alloc(&tmp_dy); + outs = {dx, &tmp_dy}; + } else if (dx->dims() != dout.dims() && dy->dims() != dout.dims()) { + tmp_dy.Resize(dout.dims()); + dev_ctx.Alloc(&tmp_dy); + tmp_dx.Resize(dout.dims()); + dev_ctx.Alloc(&tmp_dx); + outs = {&tmp_dx, &tmp_dy}; + } + + funcs::BroadcastKernel( + dev_ctx, ins, &outs, axis, func); + + if (dx->dims() != dout.dims() && dy->dims() == dout.dims()) { + ReduceWrapper(dev_ctx, axis, &tmp_dx, dx); + } else if (dx->dims() == dout.dims() && dy->dims() != dout.dims()) { + ReduceWrapper(dev_ctx, axis, &tmp_dy, dy); + } else if (dx->dims() != dout.dims() && dy->dims() != dout.dims()) { + ReduceWrapper(dev_ctx, axis, &tmp_dx, dx); + ReduceWrapper(dev_ctx, axis, &tmp_dy, dy); + } +} + +template +void GetGradXOrYOut(const GPUContext &dev_ctx, + const Place &place, + int axis, + std::vector ins, + const DenseTensor &dout, + DenseTensor *dxy, + Functor func) { + DenseTensor tmp_dxy; + dev_ctx.Alloc(dxy); + + std::vector outs; + if (dxy->dims() != dout.dims()) { + tmp_dxy.Resize(dout.dims()); + dev_ctx.Alloc(&tmp_dxy); + outs = {&tmp_dxy}; + } else { + outs = {dxy}; + } + + funcs::BroadcastKernel(dev_ctx, ins, &outs, axis, func); + if (dxy->dims() != dout.dims()) { + ReduceWrapper(dev_ctx, axis, &tmp_dxy, dxy); + } +} + /* ****************************** Add Grad @@ -243,4 +332,41 @@ void elementwise_sub_grad(const GPUContext &ctx, dx->mutable_data(ctx.GetPlace()), dy->mutable_data(ctx.GetPlace())); } +/* +****************************** + Div Grad +****************************** +*/ +template +void ElementwiseDivGrad(const GPUContext &dev_ctx, + const DenseTensor &x, + const DenseTensor &y, + const DenseTensor &out, + const DenseTensor &dout, + DenseTensor *dx, + DenseTensor *dy, + int axis = -1) { + const auto place = dev_ctx.GetPlace(); + if (dx != nullptr && dy != nullptr) { + std::vector ins = {&dout, &out, &y}; + GetGradXAndYOut( + dev_ctx, + place, + axis, + ins, + dout, + dx, + dy, + funcs::DivGradXYFunctor()); + } else if (dx != nullptr && dy == nullptr) { + std::vector ins = {&dout, &y}; + GetGradXOrYOut( + dev_ctx, place, axis, ins, dout, dx, funcs::DivGradXFunctor()); + } else if (dy != nullptr && dx == nullptr) { + std::vector ins = {&dout, &out, &y}; + GetGradXOrYOut( + dev_ctx, place, axis, ins, dout, dy, funcs::DivGradYFunctor()); + } +} + } // namespace phi diff --git a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu index d00888aee67019befa99d49cd093a4e171f941c2..45c8b9a21639ffdba1e5ab98415017ba8a91efca 100644 --- a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu @@ -15,9 +15,11 @@ #include "paddle/phi/kernels/elementwise_grad_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/complex.h" +#include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/copy_kernel.h" -#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/gpu/elementwise_grad.h" #include "paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h" @@ -102,6 +104,38 @@ void SubtractDoubleGradKernel(const Context& dev_ctx, phi::SubtractDoubleGradImpl(dev_ctx, y, ddx, ddy, dout, axis, ddout); } +template +void DivideGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out, + const DenseTensor& dout, + int axis, + DenseTensor* dx, + DenseTensor* dy) { + const auto place = dev_ctx.GetPlace(); + if (dx != nullptr && dy != nullptr) { + std::vector ins = {&dout, &out, &y}; + GetGradXAndYOut( + dev_ctx, + place, + axis, + ins, + dout, + dx, + dy, + funcs::DivGradXYFunctor()); + } else if (dx != nullptr && dy == nullptr) { + std::vector ins = {&dout, &y}; + GetGradXOrYOut( + dev_ctx, place, axis, ins, dout, dx, funcs::DivGradXFunctor()); + } else if (dy != nullptr && dx == nullptr) { + std::vector ins = {&dout, &out, &y}; + GetGradXOrYOut( + dev_ctx, place, axis, ins, dout, dy, funcs::DivGradYFunctor()); + } +} + } // namespace phi PD_REGISTER_KERNEL(add_grad, @@ -168,3 +202,29 @@ PD_REGISTER_KERNEL(subtract_double_grad, phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} + +PD_REGISTER_KERNEL(divide_grad, + GPU, + ALL_LAYOUT, + phi::DivideGradKernel, + float, + phi::dtype::float16, + phi::dtype::bfloat16, + double, + int, + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} + +PD_REGISTER_KERNEL(divide_double_grad, + GPU, + ALL_LAYOUT, + phi::DivideDoubleGradKernel, + float, + phi::dtype::float16, + phi::dtype::bfloat16, + double, + int, + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h index ac7d6fd1a0e9ca18ef528399581a30bfe0be5597..e8831f90213b625365c03823aa59db92be818594 100644 --- a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h @@ -14,8 +14,11 @@ limitations under the License. */ #pragma once +#include "paddle/phi/common/complex.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" namespace phi { @@ -103,4 +106,157 @@ void SubtractDoubleGradImpl(const Context& dev_ctx, } } +/* +****************************** + Divide Grad +****************************** +*/ + +template +struct DivGradDX { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout / y; } +}; + +template +struct DivGradDX> { + HOSTDEVICE phi::dtype::complex operator()( + phi::dtype::complex x, + phi::dtype::complex y, + phi::dtype::complex out, + phi::dtype::complex dout) const { + phi::dtype::complex y_conj(y.real, -y.imag); + return dout / y_conj; + } +}; + +template +struct DivGradDY { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { + return -dout * out / y; + } +}; + +template +struct DivGradDY> { + HOSTDEVICE phi::dtype::complex operator()( + phi::dtype::complex x, + phi::dtype::complex y, + phi::dtype::complex out, + phi::dtype::complex dout) const { + phi::dtype::complex out_div_y_conj((out / y).real, -(out / y).imag); + return -dout * out_div_y_conj; + } +}; + +template +struct DivDoubleDY { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { + return y * out * dout - x * dout; + } +}; + +template +void DivideDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& y, + const DenseTensor& out, + const DenseTensor& dx, + paddle::optional ddx, + paddle::optional ddy, + int axis, + DenseTensor* dy, + DenseTensor* dout, + DenseTensor* ddout) { + if (dy) { + dy->Resize(y.dims()); + dev_ctx.template Alloc(dy); + } + if (dout) { + dout->Resize(out.dims()); + dev_ctx.template Alloc(dout); + } + if (ddout) { + ddout->Resize(out.dims()); + dev_ctx.template Alloc(ddout); + } + // ddX_safe == null ? 0 : ddX + // ddY_safe == null ? 0 : ddY + DenseTensor ddX_safe, ddY_safe; + phi::funcs::GetDoubleGradSafeTensor( + dev_ctx, dx, ddx.get_ptr(), &ddX_safe); + phi::funcs::GetDoubleGradSafeTensor( + dev_ctx, y, ddy.get_ptr(), &ddY_safe); + + // ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y + // dY = Out * dX * ddY / Y - dX * ddX / Y + // dOut = - dX * ddY + // To save memory, (1) dout can be used as 'tmp' tensor, (2) ddout can + // inplace ddx + DenseTensor tmp; + if (dout) { + tmp = *dout; + } else { + tmp.Resize(out.dims()); + dev_ctx.template Alloc(&tmp); + } + if (dy) { + // dX_div_Y = dX / Y; + DenseTensor dX_div_Y = tmp; + funcs::DefaultElementwiseOperator, + funcs::InverseDivideFunctor>( + dev_ctx, dx, y, &dX_div_Y, axis); + + // NOTE(dengkaipeng): in the following ElemwiseGradCompute, for the + // first output tensor is nullptr, the branch to calculate first + // output tensor will not be activated, DivGradDx function will not + // be called and can be ignored, the first branch has little effect + // on running speed. + + // dY = Out * dX * ddY / Y - dX * ddX / Y + phi::funcs::ElemwiseGradCompute, DivDoubleDY>( + dev_ctx, + ddX_safe, + ddY_safe, + out, + dX_div_Y, + axis, + nullptr, + dy, + DivGradDX(), + DivDoubleDY()); + } + + if (ddout) { + // ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, out, ddY_safe, &tmp, axis); + funcs::DefaultElementwiseOperator, + funcs::InverseSubtractFunctor>( + dev_ctx, ddX_safe, tmp, &tmp, axis); + funcs::DefaultElementwiseOperator, + funcs::InverseDivideFunctor>( + dev_ctx, tmp, y, ddout, axis); + } + + if (dout) { + // dOut = - dX * ddY + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, dx, ddY_safe, dout, axis); + auto& place = *dev_ctx.eigen_device(); + auto dout_result = phi::EigenVector::Flatten(*dout); + dout_result.device(place) = static_cast(-1) * dout_result; + } +} + } // namespace phi diff --git a/paddle/phi/kernels/math_kernel.cc b/paddle/phi/kernels/math_kernel.cc index 8b17d8bd2506c0fadbdb769c98e72925e31078c7..a5d3f51e5447fa41447c4b59c3beb8c917f8a0e5 100644 --- a/paddle/phi/kernels/math_kernel.cc +++ b/paddle/phi/kernels/math_kernel.cc @@ -208,6 +208,7 @@ PD_REGISTER_KERNEL(divide, int, int64_t, phi::dtype::float16, + phi::dtype::bfloat16, complex64, complex128) {} PD_REGISTER_KERNEL(multiply, diff --git a/paddle/phi/ops/compat/elementwise_sig.cc b/paddle/phi/ops/compat/elementwise_sig.cc index 89846ea0563bb7ab0bde5a14844776c61ee96784..d4a25866907a0a13050b2a8559d59a1810516bf9 100644 --- a/paddle/phi/ops/compat/elementwise_sig.cc +++ b/paddle/phi/ops/compat/elementwise_sig.cc @@ -106,6 +106,22 @@ KernelSignature ElementwiseSubDoubleGradOpArgumentMapping( "subtract_double_grad", {"Y", "DDX", "DDY", "DOut"}, {"axis"}, {"DDOut"}); } +KernelSignature ElementwiseDivGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("divide_grad", + {"X", "Y", "Out", GradVarName("Out")}, + {"axis"}, + {GradVarName("X"), GradVarName("Y")}); +} + +KernelSignature ElementwiseDivDoubleGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("divide_double_grad", + {"Y", "Out", "DX", "DDX", "DDY"}, + {"axis"}, + {GradVarName("Y"), "DOut", "DDOut"}); +} + } // namespace phi PD_REGISTER_BASE_KERNEL_NAME(elementwise_add, add); @@ -117,6 +133,8 @@ PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad_grad, add_double_grad); PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_triple_grad, add_triple_grad); PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad, subtract_grad); PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad_grad, subtract_double_grad); +PD_REGISTER_BASE_KERNEL_NAME(elementwise_div_grad, divide_grad); +PD_REGISTER_BASE_KERNEL_NAME(elementwise_div_grad_grad, divide_double_grad); PD_REGISTER_ARG_MAPPING_FN(elementwise_add, phi::ElementwiseAddOpArgumentMapping); @@ -136,3 +154,7 @@ PD_REGISTER_ARG_MAPPING_FN(elementwise_sub_grad, phi::ElementwiseSubGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(elementwise_sub_grad_grad, phi::ElementwiseSubDoubleGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(elementwise_div_grad, + phi::ElementwiseDivGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(elementwise_div_grad_grad, + phi::ElementwiseDivDoubleGradOpArgumentMapping);