diff --git a/paddle/fluid/operators/reduce_ops/reduce_amax_op.part.cu b/paddle/fluid/operators/reduce_ops/reduce_amax_op.part.cu index 18c846bc2b4699ab0fd7b91fe43d1c9f3fcd1c14..ed6df1e558bed673e495a4fd455049dad08fc5ee 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_amax_op.part.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_amax_op.part.cu @@ -12,15 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.h" -REGISTER_OP_CUDA_KERNEL( - reduce_amax_grad, - ops::ReduceGradKernel, - ops::ReduceGradKernel, - ops::ReduceGradKernel, - ops::ReduceGradKernel); +template +using CUDAReduceMaxGradKernel = + ops::ReduceCudaAMaxAMinGradKernel; +REGISTER_OP_CUDA_KERNEL(reduce_amax_grad, CUDAReduceMaxGradKernel, + CUDAReduceMaxGradKernel, + CUDAReduceMaxGradKernel, + CUDAReduceMaxGradKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_amin_op.part.cu b/paddle/fluid/operators/reduce_ops/reduce_amin_op.part.cu index c7a26049634ce685cde29fbb7d3c77d72b4ecc22..69854da3c4f2590eb3f148f3674daf06e900d1f8 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_amin_op.part.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_amin_op.part.cu @@ -12,15 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.h" -REGISTER_OP_CUDA_KERNEL( - reduce_amin_grad, - ops::ReduceGradKernel, - ops::ReduceGradKernel, - ops::ReduceGradKernel, - ops::ReduceGradKernel); +template +using CUDAReduceMinGradKernel = + ops::ReduceCudaAMaxAMinGradKernel; +REGISTER_OP_CUDA_KERNEL(reduce_amin_grad, CUDAReduceMinGradKernel, + CUDAReduceMinGradKernel, + CUDAReduceMinGradKernel, + CUDAReduceMinGradKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 322ef1fdff67abd861c6603c3e7c4fc6b5d19f39..ff7429f75ebe3a02e3f75083a1c70240a0de837a 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -24,7 +24,6 @@ limitations under the License. */ #include "paddle/fluid/operators/cast_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op_function.h" #include "paddle/phi/kernels/funcs/math_function.h" - // only can include the headers in paddle/phi/api dirs #include "paddle/fluid/framework/convert_utils.h" #include "paddle/phi/api/lib/utils/tensor_utils.h" @@ -655,6 +654,7 @@ class ReduceCudaGradKernel : public framework::OpKernel { bool reduce_all = context.Attr("reduce_all"); std::vector dims = context.Attr>("dim"); auto* in_x = context.Input("X"); + auto* d_out = context.Input(framework::GradVarName("Out")); auto* d_x = context.Output(framework::GradVarName("X")); @@ -685,12 +685,106 @@ class ReduceCudaGradKernel : public framework::OpKernel { if (out_dtype <= 0) { pt_out_dtype = d_out->dtype(); } + using MPType = typename kps::details::MPTypeTrait::Type; phi::ReduceGrad>( dev_ctx, pt_d_out.get(), pt_d_x.get(), pt_out_dtype, TransformOp(reduce_num)); } }; + +template +struct EqualFunctor { + inline T initial() { return static_cast(0.0f); } + + inline HOSTDEVICE T operator()(const T a, const T b) const { + return static_cast(a == b); + } +}; + +template +struct DivideFunctor { + inline T initial() { return static_cast(1.0f); } + + inline HOSTDEVICE T operator()(const T a, const T b) const { return a / b; } +}; + +template class TransformOp> +class ReduceCudaAMaxAMinGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + bool reduce_all = context.Attr("reduce_all"); + std::vector dims = context.Attr>("dim"); + auto* in_x = context.Input("X"); + auto* out_y = context.Input("Out"); + auto* d_out = + context.Input(framework::GradVarName("Out")); + auto* d_x = context.Output(framework::GradVarName("X")); + auto out_dtype = context.Attr("in_dtype"); + auto pt_out_dtype = framework::TransToPhiDataType( + static_cast(out_dtype)); + // get reduce_dim and reduce_num for reduce_mean_grad + int dim_size = in_x->dims().size(); + std::vector reduce_dims = GetReduceDim(dims, dim_size, reduce_all); + auto update_dims = vectorize(d_x->dims()); + int reduce_num = 1; + for (auto i : reduce_dims) { + reduce_num *= (in_x->dims())[i]; + update_dims[i] = 1; + } + auto& dev_ctx = context.cuda_device_context(); + + // make new tensor reduce_out + phi::DenseTensor new_y(out_y->type()); + new_y.ShareDataWith(*out_y); + new_y.Resize(phi::make_ddim(update_dims)); + + // make new tensor d_out + phi::DenseTensor new_dout(d_out->type()); + new_dout.ShareDataWith(*d_out); + new_dout.Resize(phi::make_ddim(update_dims)); + d_x->mutable_data(dev_ctx.GetPlace(), d_out->dtype()); + + auto new_in = paddle::experimental::MakePhiDenseTensor(*in_x); + auto new_in_tensor = new_in.get(); + + auto new_dx = paddle::experimental::MakePhiDenseTensor(*d_x); + auto new_dx_tensor = new_dx.get(); + + // make equal_out + phi::DenseTensor* equal_out = new phi::DenseTensor(); + equal_out->Resize(in_x->dims()); + dev_ctx.template Alloc(equal_out); + auto equal_out_tensor = *equal_out; + + // make new tensor equal_count + phi::DenseTensor* equal_count = new phi::DenseTensor(); + equal_count->Resize(phi::make_ddim(update_dims)); + dev_ctx.template Alloc(equal_count); + + // compute + // 1. equal_out = Equal(x, y) + std::vector equal_inputs = {&new_y, new_in_tensor}; + std::vector equal_outputs = {&equal_out_tensor}; + phi::funcs::BroadcastKernel( + dev_ctx, equal_inputs, &equal_outputs, 0, EqualFunctor()); + // 2. equal_count = reduceSum(equal_out) + using MPType = typename kps::details::MPTypeTrait::Type; + phi::funcs::ReduceKernel>( + dev_ctx, equal_out_tensor, equal_count, + kps::IdentityFunctor(), reduce_dims, false); + + // 3. dx = Div(dout, equal_out) + std::vector grad_inputs = {&equal_out_tensor, + equal_count}; + std::vector grad_outputs = {new_dx_tensor}; + phi::funcs::BroadcastKernel( + dev_ctx, grad_inputs, &grad_outputs, 0, DivideFunctor()); + delete equal_out; + delete equal_count; + } +}; #endif #endif diff --git a/paddle/phi/kernels/funcs/broadcast_function.h b/paddle/phi/kernels/funcs/broadcast_function.h index 88b87c07c7615ccef3a20e3441854bbb6b940394..74e48f39185485fac9d55e778645686955b6d606 100644 --- a/paddle/phi/kernels/funcs/broadcast_function.h +++ b/paddle/phi/kernels/funcs/broadcast_function.h @@ -605,7 +605,22 @@ void ElementwiseCompute(const GPUContext &dev_ctx, dev_ctx, ins, &outs, axis, func); } -#endif +template +void DefaultElementwiseOperator(const DeviceContext &dev_ctx, + const DenseTensor &x, + const DenseTensor &y, + DenseTensor *z, + int axis = -1) { + auto x_dims = x.dims(); + auto y_dims = y.dims(); + dev_ctx.template Alloc(z); + funcs::ElementwiseCompute(dev_ctx, x, y, axis, Functor(), z); +} + +#else template +void FrobeniusNormKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* out) { + auto out_dtype = x.dtype(); + phi::Reduce( + dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); + std::vector ins = {out}; + std::vector outs = {out}; + auto functor = funcs::CudaSqrtFunctor(); + funcs::ElementwiseKernel(dev_ctx, ins, &outs, functor); +} + +} // namespace phi PD_REGISTER_KERNEL( frobenius_norm, GPU, ALL_LAYOUT, phi::FrobeniusNormKernel, float, double) {}