/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/framework/pten_utils.h" // only can include the headers in paddle/pten/include dirs #include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/include/core.h" #include "paddle/pten/include/math.h" namespace paddle { namespace operators { template void LaunchBroadcastElementwiseCpuKernel(const framework::ExecutionContext &ctx, const framework::Tensor *x, const framework::Tensor *y, framework::Tensor *z) { int axis = ctx.Attr("axis"); auto x_dims = x->dims(); auto y_dims = y->dims(); if (x_dims.size() >= y_dims.size()) { ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, AddFunctor(), z); } else { ElementwiseComputeEx, DeviceContext, T>( ctx, x, y, axis, InverseAddFunctor(), z); } } template struct SameDimsElemwiseAdd { void operator()(const framework::ExecutionContext &ctx, const framework::Tensor *x, const framework::Tensor *y, framework::Tensor *z); }; template class ElementwiseAddKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { auto *x = ctx.Input("X"); auto *y = ctx.Input("Y"); auto *z = ctx.Output("Out"); z->mutable_data(ctx.GetPlace()); auto &dev_ctx = ctx.device_context(); int axis = ctx.Attr("axis"); auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); auto pt_z = paddle::experimental::MakePtenDenseTensor(*z); pten::Add(dev_ctx, *pt_x.get(), *pt_y.get(), axis, pt_z.get()); } }; template struct IdentityGrad { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; } }; template typename std::enable_if< std::is_same::value>::type default_elementwise_add_grad(const framework::ExecutionContext &ctx, const framework::Tensor *x, const framework::Tensor *y, const framework::Tensor *out, const framework::Tensor *dout, framework::Tensor *dx, framework::Tensor *dy) { int axis = ctx.Attr("axis"); ElemwiseExplicitGradCompute, IdentityGrad>(ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad(), IdentityGrad()); } template typename std::enable_if< std::is_floating_point::value && std::is_same::value>::type elementwise_add_grad(const framework::ExecutionContext &ctx, const framework::Tensor *x, const framework::Tensor *y, const framework::Tensor *out, const framework::Tensor *dout, framework::Tensor *dx, framework::Tensor *dy) { auto blas = math::GetBlas(ctx); if (dx) { blas.VCOPY(dout->numel(), dout->data(), dx->mutable_data(ctx.GetPlace())); } if (dy) { blas.VCOPY(dout->numel(), dout->data(), dy->mutable_data(ctx.GetPlace())); } } template typename std::enable_if< !std::is_floating_point::value && std::is_same::value>::type elementwise_add_grad(const framework::ExecutionContext &ctx, const framework::Tensor *x, const framework::Tensor *y, const framework::Tensor *out, const framework::Tensor *dout, framework::Tensor *dx, framework::Tensor *dy) { default_elementwise_add_grad(ctx, x, y, out, dout, dx, dy); } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) // cuda definition template typename std::enable_if< std::is_same::value>::type elementwise_add_grad(const framework::ExecutionContext &ctx, const framework::Tensor *x, const framework::Tensor *y, const framework::Tensor *out, const framework::Tensor *dout, framework::Tensor *dx, framework::Tensor *dy); template typename std::enable_if< std::is_same::value>::type default_elementwise_add_grad(const framework::ExecutionContext &ctx, const framework::Tensor *x, const framework::Tensor *y, const framework::Tensor *out, const framework::Tensor *dout, framework::Tensor *dx, framework::Tensor *dy); #endif template class ElementwiseAddGradKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { ElemwiseGradKernel::Compute(ctx); using Tensor = framework::Tensor; auto *x = ctx.Input("X"); auto *y = ctx.Input("Y"); auto *dout = ctx.Input(framework::GradVarName("Out")); auto *dx = ctx.Output(framework::GradVarName("X")); auto *dy = ctx.Output(framework::GradVarName("Y")); // skip out auto *out = dout; // Special case when dy is not needed and dx doesn't reduce if (dx != nullptr && dy == nullptr && dx->dims() == dout->dims()) { VLOG(4) << "Special case when dy is not needed and dx doesn't " "reduce"; framework::TensorCopy( *dout, ctx.GetPlace(), ctx.template device_context(), dx); } else if (dx == nullptr && dy != nullptr && dy->dims() == dout->dims()) { VLOG(4) << "Special case when dx is not needed and dy doesn't " "reduce"; framework::TensorCopy( *dout, ctx.GetPlace(), ctx.template device_context(), dy); } else if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { elementwise_add_grad(ctx, x, y, out, dout, dx, dy); } else { default_elementwise_add_grad(ctx, x, y, out, dout, dx, dy); } } }; template class ElementwiseAddDoubleGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { using Tensor = framework::Tensor; auto *y = ctx.Input("Y"); auto *dout = ctx.Input("DOut"); auto *ddx = ctx.Input("DDX"); auto *ddy = ctx.Input("DDY"); auto *ddout = ctx.Output("DDOut"); // ddOut = ddx + ddy if (ddout) { Tensor ddx_safe, ddy_safe; GetDoubleGradSafeTensor(ctx, dout, ddx, &ddx_safe); GetDoubleGradSafeTensor(ctx, y, ddy, &ddy_safe); ddout->mutable_data(ctx.GetPlace()); LaunchBroadcastElementwiseCpuKernel(ctx, &ddx_safe, &ddy_safe, ddout); } } }; template class ElementwiseAddTripleGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { using Tensor = framework::Tensor; auto *ddx = ctx.Input("DDX"); auto *ddy = ctx.Input("DDY"); auto *d_ddout = ctx.Input("D_DDOut"); auto *d_ddx = ctx.Output("D_DDX"); auto *d_ddy = ctx.Output("D_DDY"); // skip out auto *out = d_ddout; // Special case when d_ddy is not needed and d_ddx doesn't reduce if (d_ddx != nullptr && d_ddy == nullptr && d_ddx->dims() == d_ddout->dims()) { VLOG(4) << "Special case when d_ddy is not needed and d_ddx doesn't " "reduce"; framework::TensorCopy( *d_ddout, ctx.GetPlace(), ctx.template device_context(), d_ddx); } else if (d_ddx == nullptr && d_ddy != nullptr && d_ddy->dims() == d_ddout->dims()) { VLOG(4) << "Special case when d_ddx is not needed and d_ddy doesn't " "reduce"; framework::TensorCopy( *d_ddout, ctx.GetPlace(), ctx.template device_context(), d_ddy); } else if (d_ddx != nullptr && d_ddy != nullptr && (d_ddx->dims() == d_ddy->dims())) { elementwise_add_grad(ctx, ddx, ddy, out, d_ddout, d_ddx, d_ddy); } else { default_elementwise_add_grad(ctx, ddx, ddy, out, d_ddout, d_ddx, d_ddy); } } }; } // namespace operators } // namespace paddle