未验证 提交 3bf3a6ee 编写于 作者: Y YuanRisheng 提交者: GitHub

[Pten]Refactor elementwise_add grad / double grad / triple grad Kernel and...

[Pten]Refactor elementwise_add grad / double grad / triple grad Kernel and move them to pten (#39048)

* refactor elementwise add grad

* fix compile bugs

* fix unit test bugs

* fix file conflicts

* fix bugs when buildPtenContext
上级 43919d0a
......@@ -369,6 +369,10 @@ static void BuildDygraphPtenKernelContext(
size_t end_idx = start_idx + outs_vector.size();
for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
if (outs_vector[offset] == nullptr) {
kernel_ctx->EmplaceBackOutputWithoutSetRange({nullptr});
continue;
}
auto* var = outs_vector[offset]->MutableVar();
framework::Tensor* tensor_out = nullptr;
if (var->template IsType<framework::LoDTensor>()) {
......
......@@ -33,34 +33,6 @@ class CPUDeviceContext;
namespace paddle {
namespace operators {
template <typename T>
struct SameDimsElemwiseAdd<
platform::CPUDeviceContext, T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
void operator()(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z) {
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
blas.VADD(x->numel(), x->data<T>(), y->data<T>(), z->data<T>());
}
};
template <typename T>
struct SameDimsElemwiseAdd<
platform::CPUDeviceContext, T,
typename std::enable_if<!std::is_floating_point<T>::value>::type> {
void operator()(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z) {
auto eigen_x = framework::EigenVector<T>::Flatten(*x);
auto eigen_y = framework::EigenVector<T>::Flatten(*y);
auto eigen_z = framework::EigenVector<T>::Flatten(*z);
auto &place = *ctx.template device_context<platform::CPUDeviceContext>()
.eigen_device();
eigen_z.device(place) = eigen_x + eigen_y;
}
};
class ElementwiseAddOpMaker : public ElementwiseOpMaker {
protected:
std::string GetName() const override { return "Add"; }
......
......@@ -13,139 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/pten/kernels/gpu/elementwise.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
namespace paddle {
namespace operators {
template <typename T>
static __global__ void SimpleElemwiseAddGradCUDAKernel(
const T* __restrict__ dout, int size, int vec_size, T* dx, T* dy) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
int loop = size / vec_size;
int remainder = size % vec_size;
const float4* dout_vec = reinterpret_cast<const float4*>(dout);
float4* dx_vec = reinterpret_cast<float4*>(dx);
float4* dy_vec = reinterpret_cast<float4*>(dy);
float4 tmp_loop;
for (int i = tid; i < loop; i += stride) {
tmp_loop = dout_vec[i];
dx_vec[i] = tmp_loop;
dy_vec[i] = tmp_loop;
}
if (tid == loop && remainder != 0) {
T tmp_rem;
while (remainder) {
int idx = size - remainder;
remainder--;
tmp_rem = dout[idx];
dx[idx] = tmp_rem;
dy[idx] = tmp_rem;
}
}
}
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
default_elementwise_add_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
auto* dout_data = dout->data<T>();
// dx
if (dx != nullptr) {
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
if (dx->dims() == dout->dims()) {
if (dx_data != dout_data) {
framework::TensorCopy(
*dout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dx);
}
} else {
// For inplace strategy, dx will be stored in addr of dout, which makes
// the result of dy wrong.
if (dx->IsSharedBufferWith(*dout)) {
dx->clear();
dx->mutable_data<T>(x->dims(), ctx.GetPlace());
}
std::vector<int> reduce_dims = GetReduceDim(x->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
*dout, dx, kps::IdentityFunctor<T>(), reduce_dims, stream);
}
}
// dy
if (dy != nullptr) {
auto* dy_data = dy->mutable_data<T>(ctx.GetPlace());
if (dy->dims() == dout->dims()) {
if (dy_data != dout_data) {
framework::TensorCopy(
*dout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dy);
}
} else {
std::vector<int> reduce_dims = GetReduceDim(y->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
*dout, dy, kps::IdentityFunctor<T>(), reduce_dims, stream);
}
}
}
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, plat::CUDADeviceContext>::value>::type
elementwise_add_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout, framework::Tensor* dx,
framework::Tensor* dy) {
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
auto* dy_data = dy->mutable_data<T>(ctx.GetPlace());
auto* dout_data = dout->data<T>();
if (dx_data == dout_data && dy_data != dout_data) {
VLOG(4) << "Special case when dx_data is the same as dout_data, "
"only need copy dout to dy";
framework::TensorCopy(
*dout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dy);
} else if (dx_data != dout_data && dy_data == dout_data) {
VLOG(4) << "Special case when dy_data is the same as dout_data, "
"only need copy dout to dx";
framework::TensorCopy(
*dout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dx);
} else if (dx_data != dout_data && dy_data != dout_data) {
auto size = x->numel();
int vec_size = max(static_cast<int>(sizeof(float4) / sizeof(T)), 1);
dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1);
dim3 grid_size =
dim3(((size + vec_size - 1) / vec_size + PREDEFINED_BLOCK_SIZE - 1) /
PREDEFINED_BLOCK_SIZE,
1);
SimpleElemwiseAddGradCUDAKernel<
T><<<grid_size, block_size, 0,
ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
dout->data<T>(), size, vec_size, dx->mutable_data<T>(ctx.GetPlace()),
dy->mutable_data<T>(ctx.GetPlace()));
} else {
VLOG(4) << "Special case when dy_data is the same as dout_data, "
"and dx_data is the same as dout_data, do not need "
"any operator";
}
}
} // namespace operators
namespace operators {} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(
elementwise_add, ops::ElementwiseAddKernel<plat::CUDADeviceContext, float>,
......
......@@ -18,35 +18,13 @@ limitations under the License. */
#include <utility>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
// only can include the headers in paddle/pten/include dirs
#include "paddle/pten/kernels/elementwise_grad_kernel.h"
#include "paddle/pten/kernels/math_kernel.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
void LaunchBroadcastElementwiseCpuKernel(const framework::ExecutionContext &ctx,
const framework::Tensor *x,
const framework::Tensor *y,
framework::Tensor *z) {
int axis = ctx.Attr<int>("axis");
auto x_dims = x->dims();
auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
AddFunctor<T>(), z);
} else {
ElementwiseComputeEx<InverseAddFunctor<T>, DeviceContext, T>(
ctx, x, y, axis, InverseAddFunctor<T>(), z);
}
}
template <typename DeviceContext, typename T, class Enable = void>
struct SameDimsElemwiseAdd {
void operator()(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z);
};
template <typename DeviceContext, typename T>
class ElementwiseAddKernel : public framework::OpKernel<T> {
public:
......@@ -58,128 +36,29 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
auto &dev_ctx = ctx.device_context<DeviceContext>();
int axis = ctx.Attr<int>("axis");
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z);
pten::AddRawKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE &>(dev_ctx),
*pt_x.get(), *pt_y.get(), axis, pt_z.get());
*x, *y, axis, z);
}
};
template <typename T>
struct IdentityGrad {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; }
};
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
default_elementwise_add_grad(const framework::ExecutionContext &ctx,
const framework::Tensor *x,
const framework::Tensor *y,
const framework::Tensor *out,
const framework::Tensor *dout,
framework::Tensor *dx, framework::Tensor *dy) {
int axis = ctx.Attr<int>("axis");
ElemwiseExplicitGradCompute<DeviceContext, T, IdentityGrad<T>,
IdentityGrad<T>>(ctx, *x, *y, *out, *dout, axis,
dx, dy, IdentityGrad<T>(),
IdentityGrad<T>());
}
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_floating_point<T>::value &&
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add_grad(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
const framework::Tensor *out,
const framework::Tensor *dout, framework::Tensor *dx,
framework::Tensor *dy) {
auto blas = math::GetBlas<DeviceContext, T>(ctx);
if (dx) {
blas.VCOPY(dout->numel(), dout->data<T>(),
dx->mutable_data<T>(ctx.GetPlace()));
}
if (dy) {
blas.VCOPY(dout->numel(), dout->data<T>(),
dy->mutable_data<T>(ctx.GetPlace()));
}
}
template <typename DeviceContext, typename T>
typename std::enable_if<
!std::is_floating_point<T>::value &&
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add_grad(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
const framework::Tensor *out,
const framework::Tensor *dout, framework::Tensor *dx,
framework::Tensor *dy) {
default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// cuda definition
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
elementwise_add_grad(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
const framework::Tensor *out,
const framework::Tensor *dout, framework::Tensor *dx,
framework::Tensor *dy);
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
default_elementwise_add_grad(const framework::ExecutionContext &ctx,
const framework::Tensor *x,
const framework::Tensor *y,
const framework::Tensor *out,
const framework::Tensor *dout,
framework::Tensor *dx, framework::Tensor *dy);
#endif
template <typename DeviceContext, typename T>
class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
auto *x = ctx.Input<Tensor>("X");
auto *y = ctx.Input<Tensor>("Y");
auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
// skip out
auto *out = dout;
// Special case when dy is not needed and dx doesn't reduce
if (dx != nullptr && dy == nullptr && dx->dims() == dout->dims()) {
VLOG(4) << "Special case when dy is not needed and dx doesn't "
"reduce";
framework::TensorCopy(
*dout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dx);
} else if (dx == nullptr && dy != nullptr && dy->dims() == dout->dims()) {
VLOG(4) << "Special case when dx is not needed and dy doesn't "
"reduce";
framework::TensorCopy(
*dout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dy);
} else if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
} else {
default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx,
dy);
}
const auto &dev_ctx = ctx.template device_context<DeviceContext>();
int axis = ctx.Attr<int>("axis");
pten::AddGradKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE &>(dev_ctx),
*x, *y, *dout, axis, dx, dy);
}
};
......@@ -195,17 +74,20 @@ class ElementwiseAddDoubleGradKernel : public framework::OpKernel<T> {
auto *ddy = ctx.Input<Tensor>("DDY");
auto *ddout = ctx.Output<Tensor>("DDOut");
// ddOut = ddx + ddy
if (ddout) {
Tensor ddx_safe, ddy_safe;
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, dout, ddx, &ddx_safe);
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, y, ddy, &ddy_safe);
ddout->mutable_data<T>(ctx.GetPlace());
LaunchBroadcastElementwiseCpuKernel<DeviceContext, T>(ctx, &ddx_safe,
&ddy_safe, ddout);
const auto &dev_ctx = ctx.template device_context<DeviceContext>();
int axis = ctx.Attr<int>("axis");
paddle::optional<const pten::DenseTensor &> ddx_optional = paddle::none;
paddle::optional<const pten::DenseTensor &> ddy_optional = paddle::none;
if (ddx != nullptr) {
ddx_optional = *ddx;
}
if (ddy != nullptr) {
ddy_optional = *ddy;
}
pten::AddDoubleGradKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE &>(dev_ctx),
*y, ddx_optional, ddy_optional, *dout, axis, ddout);
}
};
......@@ -219,32 +101,13 @@ class ElementwiseAddTripleGradKernel : public framework::OpKernel<T> {
auto *d_ddout = ctx.Input<Tensor>("D_DDOut");
auto *d_ddx = ctx.Output<Tensor>("D_DDX");
auto *d_ddy = ctx.Output<Tensor>("D_DDY");
// skip out
auto *out = d_ddout;
// Special case when d_ddy is not needed and d_ddx doesn't reduce
if (d_ddx != nullptr && d_ddy == nullptr &&
d_ddx->dims() == d_ddout->dims()) {
VLOG(4) << "Special case when d_ddy is not needed and d_ddx doesn't "
"reduce";
framework::TensorCopy(
*d_ddout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), d_ddx);
} else if (d_ddx == nullptr && d_ddy != nullptr &&
d_ddy->dims() == d_ddout->dims()) {
VLOG(4) << "Special case when d_ddx is not needed and d_ddy doesn't "
"reduce";
framework::TensorCopy(
*d_ddout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), d_ddy);
} else if (d_ddx != nullptr && d_ddy != nullptr &&
(d_ddx->dims() == d_ddy->dims())) {
elementwise_add_grad<DeviceContext, T>(ctx, ddx, ddy, out, d_ddout, d_ddx,
d_ddy);
} else {
default_elementwise_add_grad<DeviceContext, T>(ctx, ddx, ddy, out,
d_ddout, d_ddx, d_ddy);
}
const auto &dev_ctx = ctx.template device_context<DeviceContext>();
int axis = ctx.Attr<int>("axis");
pten::AddTripleGradKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE &>(dev_ctx),
*ddx, *ddy, *d_ddout, axis, d_ddx, d_ddy);
}
};
......
......@@ -354,6 +354,18 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
tensor.place(), tensor.layout());
}
}
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
if (Type() == "elementwise_add_grad") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
return framework::KernelSignature(
"add_grad", {"X", "Y", framework::GradVarName("Out")}, {"axis"},
{framework::GradVarName("X"), framework::GradVarName("Y")});
}
}
return framework::KernelSignature("None", {"X"}, {}, {"Out"});
}
};
class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel {
......@@ -522,11 +534,9 @@ class ElemwiseGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext &context) const override {
auto *dx =
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
if (dx != nullptr) {
auto &dout =
*context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
dx->set_lod(dout.lod());
}
pten::funcs::ElementwiseGradPreProcess(dout, dx);
}
};
......
......@@ -158,32 +158,6 @@ void ElemwiseGradCompute(const framework::ExecutionContext &ctx,
}
}
// NOTE(dzhwinter): Only used in elementwise_add, elementwise_sub.
// explicit gradient can cut off X, Y, Out from gradient op
// In elementwise_add, elementwise_sub, we use dout as fake X, Y, Out to reuse
// elementwise code.
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseExplicitGradCompute(const framework::ExecutionContext &ctx,
const framework::Tensor &x,
const framework::Tensor &y,
const framework::Tensor &out,
const framework::Tensor &dout, int axis,
framework::Tensor *dx, framework::Tensor *dy,
DX_OP dx_op, DY_OP dy_op) {
const framework::DDim &x_dim = x.dims();
const framework::DDim &y_dim = y.dims();
const auto &dev_ctx = ctx.template device_context<DeviceContext>();
if (x.dims() == y.dims()) {
pten::funcs::ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
dev_ctx, x_dim, y_dim, dout, dout, out, dout, axis, dx, dy, dx_op,
dy_op);
} else {
pten::ElemwiseGradComputeWithBroadcast<T, DX_OP, DY_OP>(
dev_ctx, x_dim, y_dim, dout, dout, out, dout, axis, dx, dy, dx_op,
dy_op);
}
}
// It is a common implementation to compute binary calculation with the support
// of broadcast, supporting both CPU and GPU.
// - CPU implementation cannot support the case when x needs broadcast, thus
......@@ -199,30 +173,20 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
const framework::Tensor *x,
const framework::Tensor *y, int axis, Functor func,
framework::Tensor *z) {
z->mutable_data<OutType>(ctx.GetPlace());
if (platform::is_gpu_place(ctx.GetPlace())) {
#if defined(__NVCC__) || defined(__HIPCC__)
std::vector<const framework::Tensor *> ins = {x, y};
std::vector<framework::Tensor *> outs = {z};
z->mutable_data<OutType>(ctx.GetPlace());
const auto &dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
paddle::operators::LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T,
OutType>(dev_ctx, ins, &outs,
axis, func);
pten::ElementwiseCompute<Functor, T, OutType>(dev_ctx, *x, *y, axis, func,
z);
#endif
return;
}
z->mutable_data<OutType>(ctx.GetPlace());
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z);
const auto &dev_ctx =
ctx.template device_context<platform::CPUDeviceContext>();
pten::ElementwiseCompute<Functor, T, OutType>(
dev_ctx, *pt_x.get(), *pt_y.get(), axis, func, pt_z.get());
pten::ElementwiseCompute<Functor, T, OutType>(dev_ctx, *x, *y, axis, func, z);
}
// FusedElemwiseAndAct
......@@ -1207,36 +1171,16 @@ template <typename DeviceContext, typename T>
static inline void GetDoubleGradSafeTensor(
const framework::ExecutionContext &ctx, const framework::Tensor *x,
const framework::Tensor *ddx, framework::Tensor *ddx_safe) {
if (ddx) {
*ddx_safe = *ddx;
} else {
auto &dev_ctx = ctx.template device_context<DeviceContext>();
*ddx_safe = ctx.AllocateTmpTensor<T, DeviceContext>(x->dims(), dev_ctx);
math::SetConstant<DeviceContext, T> set_zero;
set_zero(ctx.template device_context<DeviceContext>(), ddx_safe,
static_cast<T>(0));
}
const auto &dev_ctx = ctx.template device_context<DeviceContext>();
pten::funcs::GetDoubleGradSafeTensor<DeviceContext, T>(dev_ctx, *x, ddx,
ddx_safe);
}
// for broadcast backwards
static inline std::vector<int> GetReduceDim(const framework::DDim &in,
const framework::DDim &out,
int axis) {
axis =
(axis == -1 ? std::abs(static_cast<int>(out.size() - in.size())) : axis);
std::vector<int> dims;
for (int i = 0; i < axis; ++i) {
dims.push_back(i);
}
for (int i = 0; i < in.size(); ++i) {
if (out[i + axis] != in[i]) {
dims.push_back(i + axis);
}
}
for (int i = axis + in.size(); i < out.size(); ++i) {
dims.push_back(i);
}
return dims;
return pten::funcs::GetReduceDim(in, out, axis);
}
#if defined(__NVCC__) || defined(__HIPCC__)
......
......@@ -78,9 +78,11 @@ default_elementwise_sub_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
ElemwiseExplicitGradCompute<DeviceContext, T, SubGradDX<T>, SubGradDY<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, SubGradDX<T>(), SubGradDY<T>());
const auto& dev_ctx =
ctx.template device_context<platform::CPUDeviceContext>();
pten::ElemwiseExplicitGradCompute<T, SubGradDX<T>, SubGradDY<T>>(
dev_ctx, *x, *y, *out, *dout, axis, dx, dy, SubGradDX<T>(),
SubGradDY<T>());
}
template <typename DeviceContext, typename T>
......
......@@ -29,6 +29,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function_impl.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/kernels/funcs/eigen/common.h"
#include "unsupported/Eigen/CXX11/Tensor"
......@@ -52,6 +53,18 @@ template struct SetConstant<platform::CPUDeviceContext,
template struct SetConstant<platform::CPUDeviceContext,
platform::complex<double>>;
template struct SetConstant<pten::CPUContext, platform::float16>;
template struct SetConstant<pten::CPUContext, platform::bfloat16>;
template struct SetConstant<pten::CPUContext, float>;
template struct SetConstant<pten::CPUContext, double>;
template struct SetConstant<pten::CPUContext, int16_t>;
template struct SetConstant<pten::CPUContext, int>;
template struct SetConstant<pten::CPUContext, int64_t>;
template struct SetConstant<pten::CPUContext, bool>;
template struct SetConstant<pten::CPUContext, uint8_t>;
template struct SetConstant<pten::CPUContext, platform::complex<float>>;
template struct SetConstant<pten::CPUContext, platform::complex<double>>;
#ifdef PADDLE_WITH_XPU
template struct SetConstant<platform::XPUDeviceContext, platform::float16>;
template struct SetConstant<platform::XPUDeviceContext, platform::bfloat16>;
......
......@@ -21,6 +21,7 @@ namespace pten {
// the key is sorted by key's alphabet
const std::unordered_map<std::string, std::string> kernel_alias_name_map = {
{"elementwise_add", "add_raw"},
{"elementwise_add_grad", "add_grad"},
{"elementwise_div", "divide_raw"},
{"elementwise_mul", "muliply_raw"},
{"elementwise_sub", "subtract_raw"},
......
......@@ -9,7 +9,7 @@ add_subdirectory(funcs)
set_property(GLOBAL PROPERTY PTEN_KERNELS "")
set(COMMON_KERNEL_DEPS dense_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function)
# remove this dep after removing fluid deps on tensor creation
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} pten_api_utils)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta)
......
......@@ -706,4 +706,94 @@ void ElemwiseGradComputeWithBroadcast(const CPUContext& ctx,
}
}
// NOTE(dzhwinter): Only used in elementwise_add, elementwise_sub.
// explicit gradient can cut off X, Y, Out from gradient op
// In elementwise_add, elementwise_sub, we use dout as fake X, Y, Out to reuse
// elementwise code.
template <typename T, typename DX_OP, typename DY_OP>
void ElemwiseExplicitGradCompute(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy,
DX_OP dx_op,
DY_OP dy_op) {
const DDim& x_dim = x.dims();
const DDim& y_dim = y.dims();
if (x.dims() == y.dims()) {
pten::funcs::ElemwiseGradComputeNoBroadcast<CPUContext, T, DX_OP, DY_OP>(
dev_ctx,
x_dim,
y_dim,
dout,
dout,
out,
dout,
axis,
dx,
dy,
dx_op,
dy_op);
} else {
ElemwiseGradComputeWithBroadcast<T, DX_OP, DY_OP>(dev_ctx,
x_dim,
y_dim,
dout,
dout,
out,
dout,
axis,
dx,
dy,
dx_op,
dy_op);
}
}
// Add Grad
template <typename T>
struct IdentityGrad {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; }
};
template <typename T>
typename std::enable_if<std::is_floating_point<T>::value>::type
elementwise_add_grad(const CPUContext& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out,
const DenseTensor& dout,
DenseTensor* dx,
DenseTensor* dy,
int axis = -1) {
auto blas = paddle::operators::math::GetBlas<CPUContext, T>(ctx);
if (dx) {
blas.VCOPY(
dout.numel(), dout.data<T>(), dx->mutable_data<T>(ctx.GetPlace()));
}
if (dy) {
blas.VCOPY(
dout.numel(), dout.data<T>(), dy->mutable_data<T>(ctx.GetPlace()));
}
}
template <typename T>
typename std::enable_if<!std::is_floating_point<T>::value>::type
elementwise_add_grad(const CPUContext& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out,
const DenseTensor& dout,
DenseTensor* dx,
DenseTensor* dy,
int axis = -1) {
ElemwiseExplicitGradCompute<T, IdentityGrad<T>, IdentityGrad<T>>(
ctx, x, y, out, dout, axis, dx, dy, IdentityGrad<T>(), IdentityGrad<T>());
}
} // namespace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/elementwise_grad_kernel.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/copy_kernel.h"
#include "paddle/pten/kernels/cpu/elementwise.h"
#include "paddle/pten/kernels/funcs/elementwise_functor.h"
#include "paddle/pten/kernels/impl/elementwise_grad_kernel_impl.h"
namespace pten {
template <typename T>
void AddGradFunc(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out,
const DenseTensor& dout,
DenseTensor* dx,
DenseTensor* dy,
int axis = -1) {
if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
elementwise_add_grad<T>(dev_ctx, x, y, out, dout, dx, dy);
} else {
ElemwiseExplicitGradCompute<T, IdentityGrad<T>, IdentityGrad<T>>(
dev_ctx,
x,
y,
out,
dout,
axis,
dx,
dy,
IdentityGrad<T>(),
IdentityGrad<T>());
}
}
template <typename T, typename Context>
void AddGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy) {
pten::AddGradImpl<T>(dev_ctx, x, y, dout, axis, dx, dy, AddGradFunc<T>);
}
template <typename T, typename Context>
void AddDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& y,
paddle::optional<const DenseTensor&> ddx,
paddle::optional<const DenseTensor&> ddy,
const DenseTensor& dout,
int axis,
DenseTensor* ddout) {
pten::AddDoubleGradImpl<T>(
dev_ctx,
y,
ddx,
ddy,
dout,
axis,
ddout,
ElementwiseCompute<funcs::AddFunctor<T>, T>,
ElementwiseCompute<funcs::InverseAddFunctor<T>, T>);
}
template <typename T, typename Context>
void AddTripleGradKernel(const Context& dev_ctx,
const DenseTensor& ddx,
const DenseTensor& ddy,
const DenseTensor& d_ddout,
int axis,
DenseTensor* d_ddx,
DenseTensor* d_ddy) {
pten::AddGradImpl<T>(
dev_ctx, ddx, ddy, d_ddout, axis, d_ddx, d_ddy, AddGradFunc<T>);
}
} // namespace pten
PT_REGISTER_KERNEL(add_grad,
CPU,
ALL_LAYOUT,
pten::AddGradKernel,
float,
double,
int,
int64_t,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL(add_double_grad,
CPU,
ALL_LAYOUT,
pten::AddDoubleGradKernel,
float,
double,
int,
int64_t,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL(add_triple_grad,
CPU,
ALL_LAYOUT,
pten::AddTripleGradKernel,
float,
double,
int,
int64_t,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/utils/optional.h"
namespace pten {
template <typename T, typename Context>
void AddGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy);
template <typename T, typename Context>
void AddDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& y,
paddle::optional<const DenseTensor&> ddx,
paddle::optional<const DenseTensor&> ddy,
const DenseTensor& dout,
int axis,
DenseTensor* ddout);
template <typename T, typename Context>
void AddTripleGradKernel(const Context& dev_ctx,
const DenseTensor& ddx,
const DenseTensor& ddy,
const DenseTensor& d_ddout,
int axis,
DenseTensor* d_ddx,
DenseTensor* d_ddy);
} // namespace pten
......@@ -14,10 +14,12 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/transform.h"
#include "paddle/pten/backends/all_context.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/empty_kernel.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
......@@ -360,6 +362,43 @@ inline void get_mid_dims(const DDim &x_dims,
}
}
// for broadcast backwards
static inline std::vector<int> GetReduceDim(const paddle::framework::DDim &in,
const paddle::framework::DDim &out,
int axis) {
axis =
(axis == -1 ? std::abs(static_cast<int>(out.size() - in.size())) : axis);
std::vector<int> dims;
for (int i = 0; i < axis; ++i) {
dims.push_back(i);
}
for (int i = 0; i < in.size(); ++i) {
if (out[i + axis] != in[i]) {
dims.push_back(i + axis);
}
}
for (int i = axis + in.size(); i < out.size(); ++i) {
dims.push_back(i);
}
return dims;
}
template <typename DeviceContext, typename T>
static inline void GetDoubleGradSafeTensor(const DeviceContext &dev_ctx,
const DenseTensor &x,
const DenseTensor *ddx,
DenseTensor *ddx_safe) {
if (ddx) {
*ddx_safe = *ddx;
} else {
auto meta = pten::DenseTensorMeta(x.dtype(), x.dims(), x.layout());
*ddx_safe = pten::Empty<T, DeviceContext>(dev_ctx, std::move(meta));
ddx_safe->mutable_data(dev_ctx.GetPlace());
paddle::operators::math::SetConstant<DeviceContext, T> set_zero;
set_zero(dev_ctx, ddx_safe, static_cast<T>(0));
}
}
template <typename DeviceContext,
typename T,
typename DX_OP,
......@@ -390,6 +429,13 @@ void ElemwiseGradComputeNoBroadcast(const DeviceContext &dev_ctx,
dy == nullptr ? nullptr : dy->mutable_data<T>(dev_ctx.GetPlace())});
}
inline void ElementwiseGradPreProcess(const DenseTensor &dout,
DenseTensor *dx) {
if (dx != nullptr) {
dx->set_lod(dout.lod());
}
}
#if defined(__NVCC__) || defined(__HIPCC__)
template <typename InT, typename OutT>
......
......@@ -14,9 +14,11 @@ limitations under the License. */
#pragma once
#include "paddle/pten/kernels/copy_kernel.h"
#include "paddle/pten/kernels/funcs/common_shape.h"
#include "paddle/pten/kernels/funcs/cuda_kernel_config.h"
#include "paddle/pten/kernels/funcs/elementwise_base.h"
#include "paddle/pten/kernels/gpu/reduce.h"
#ifdef __HIPCC__
constexpr int ELEMWISE_MAX_BLOCK_DIM = 256;
......@@ -578,6 +580,20 @@ void LaunchElementwiseCudaKernel(const KPDevice &ctx,
}
}
template <typename Functor, typename T, typename OutType = T>
void ElementwiseCompute(const GPUContext &dev_ctx,
const DenseTensor &x,
const DenseTensor &y,
int axis,
Functor func,
DenseTensor *z) {
std::vector<const DenseTensor *> ins = {&x, &y};
std::vector<DenseTensor *> outs = {z};
z->mutable_data<OutType>(dev_ctx.GetPlace());
pten::LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, OutType>(
dev_ctx, ins, &outs, axis, func);
}
// BACKWARD CODE
// Suppose only has contiguous dims
......@@ -1938,4 +1954,130 @@ void ElemwiseGradComputeWithBroadcast(const GPUContext &ctx,
}
}
template <typename T>
static __global__ void SimpleElemwiseAddGradCUDAKernel(
const T *__restrict__ dout, int size, int vec_size, T *dx, T *dy) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
int loop = size / vec_size;
int remainder = size % vec_size;
const float4 *dout_vec = reinterpret_cast<const float4 *>(dout);
float4 *dx_vec = reinterpret_cast<float4 *>(dx);
float4 *dy_vec = reinterpret_cast<float4 *>(dy);
float4 tmp_loop;
for (int i = tid; i < loop; i += stride) {
tmp_loop = dout_vec[i];
dx_vec[i] = tmp_loop;
dy_vec[i] = tmp_loop;
}
if (tid == loop && remainder != 0) {
T tmp_rem;
while (remainder) {
int idx = size - remainder;
remainder--;
tmp_rem = dout[idx];
dx[idx] = tmp_rem;
dy[idx] = tmp_rem;
}
}
}
template <typename T>
void default_elementwise_add_grad(const GPUContext &ctx,
const DenseTensor &x,
const DenseTensor &y,
const DenseTensor &out,
const DenseTensor &dout,
DenseTensor *dx,
DenseTensor *dy,
int axis = -1) {
auto *dout_data = dout.data<T>();
// dx
if (dx != nullptr) {
auto *dx_data = dx->mutable_data<T>(ctx.GetPlace());
if (dx->dims() == dout.dims()) {
if (dx_data != dout_data) {
pten::Copy(ctx, dout, false, dx);
}
} else {
// For inplace strategy, dx will be stored in addr of dout, which makes
// the result of dy wrong.
if (dx->IsSharedBufferWith(dout)) {
dx->clear();
dx->mutable_data<T>(x.dims(), ctx.GetPlace());
}
std::vector<int> reduce_dims =
funcs::GetReduceDim(x.dims(), out.dims(), axis);
gpuStream_t stream = ctx.stream();
kernels::TensorReduceFunctorImpl<T,
T,
kps::AddFunctor,
kps::IdentityFunctor<T>>(
dout, dx, kps::IdentityFunctor<T>(), reduce_dims, stream);
}
}
// dy
if (dy != nullptr) {
auto *dy_data = dy->mutable_data<T>(ctx.GetPlace());
if (dy->dims() == dout.dims()) {
if (dy_data != dout_data) {
pten::Copy(ctx, dout, false, dy);
}
} else {
std::vector<int> reduce_dims =
funcs::GetReduceDim(y.dims(), out.dims(), axis);
gpuStream_t stream = ctx.stream();
kernels::TensorReduceFunctorImpl<T,
T,
kps::AddFunctor,
kps::IdentityFunctor<T>>(
dout, dy, kps::IdentityFunctor<T>(), reduce_dims, stream);
}
}
}
template <typename T>
void elementwise_add_grad(const GPUContext &ctx,
const DenseTensor &x,
const DenseTensor &y,
const DenseTensor &out,
const DenseTensor &dout,
DenseTensor *dx,
DenseTensor *dy) {
auto *dx_data = dx->mutable_data<T>(ctx.GetPlace());
auto *dy_data = dy->mutable_data<T>(ctx.GetPlace());
auto *dout_data = dout.data<T>();
if (dx_data == dout_data && dy_data != dout_data) {
VLOG(4) << "Special case when dx_data is the same as dout_data, "
"only need copy dout to dy";
pten::Copy(ctx, dout, false, dy);
} else if (dx_data != dout_data && dy_data == dout_data) {
VLOG(4) << "Special case when dy_data is the same as dout_data, "
"only need copy dout to dx";
pten::Copy(ctx, dout, false, dx);
} else if (dx_data != dout_data && dy_data != dout_data) {
auto size = x.numel();
int vec_size = max(static_cast<int>(sizeof(float4) / sizeof(T)), 1);
dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1);
dim3 grid_size =
dim3(((size + vec_size - 1) / vec_size + PREDEFINED_BLOCK_SIZE - 1) /
PREDEFINED_BLOCK_SIZE,
1);
SimpleElemwiseAddGradCUDAKernel<
T><<<grid_size, block_size, 0, ctx.stream()>>>(
dout.data<T>(),
size,
vec_size,
dx->mutable_data<T>(ctx.GetPlace()),
dy->mutable_data<T>(ctx.GetPlace()));
} else {
VLOG(4) << "Special case when dy_data is the same as dout_data, "
"and dx_data is the same as dout_data, do not need "
"any operator";
}
}
} // namespace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/elementwise_grad_kernel.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/copy_kernel.h"
#include "paddle/pten/kernels/funcs/elementwise_functor.h"
#include "paddle/pten/kernels/gpu/elementwise.h"
#include "paddle/pten/kernels/impl/elementwise_grad_kernel_impl.h"
namespace pten {
template <typename T>
void AddGradFunc(const GPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out,
const DenseTensor& dout,
DenseTensor* dx,
DenseTensor* dy,
int axis = -1) {
if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
elementwise_add_grad<T>(dev_ctx, x, y, out, dout, dx, dy);
} else {
default_elementwise_add_grad<T>(dev_ctx, x, y, out, dout, dx, dy, axis);
}
}
template <typename T, typename Context>
void AddGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy) {
pten::AddGradImpl<T>(dev_ctx, x, y, dout, axis, dx, dy, AddGradFunc<T>);
}
template <typename T, typename Context>
void AddDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& y,
paddle::optional<const DenseTensor&> ddx,
paddle::optional<const DenseTensor&> ddy,
const DenseTensor& dout,
int axis,
DenseTensor* ddout) {
pten::AddDoubleGradImpl<T>(
dev_ctx,
y,
ddx,
ddy,
dout,
axis,
ddout,
ElementwiseCompute<funcs::AddFunctor<T>, T>,
ElementwiseCompute<funcs::InverseAddFunctor<T>, T>);
}
template <typename T, typename Context>
void AddTripleGradKernel(const Context& dev_ctx,
const DenseTensor& ddx,
const DenseTensor& ddy,
const DenseTensor& d_ddout,
int axis,
DenseTensor* d_ddx,
DenseTensor* d_ddy) {
pten::AddGradImpl<T>(
dev_ctx, ddx, ddy, d_ddout, axis, d_ddx, d_ddy, AddGradFunc<T>);
}
} // namespace pten
PT_REGISTER_KERNEL(add_grad,
GPU,
ALL_LAYOUT,
pten::AddGradKernel,
float,
double,
int,
int64_t,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL(add_double_grad,
GPU,
ALL_LAYOUT,
pten::AddDoubleGradKernel,
float,
double,
int,
int64_t,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL(add_triple_grad,
GPU,
ALL_LAYOUT,
pten::AddTripleGradKernel,
float,
double,
int,
int64_t,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/funcs/elementwise_base.h"
#include "paddle/pten/kernels/funcs/elementwise_functor.h"
namespace pten {
template <typename T, typename Context, typename GradFunc>
void AddGradImpl(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
int axis,
DenseTensor* x_grad,
DenseTensor* y_grad,
GradFunc grad_func) {
pten::funcs::ElementwiseGradPreProcess(out_grad, x_grad);
auto* out = &out_grad;
// Special case when y_grad is not needed and x_grad doesn't reduce
if (x_grad != nullptr && y_grad == nullptr &&
x_grad->dims() == out_grad.dims()) {
VLOG(4) << "Special case when y_grad is not needed and x_grad doesn't "
"reduce";
pten::Copy(dev_ctx, out_grad, false, x_grad);
} else if (x_grad == nullptr && y_grad != nullptr &&
y_grad->dims() == out_grad.dims()) {
VLOG(4) << "Special case when x_grad is not needed and y_grad doesn't "
"reduce";
pten::Copy(dev_ctx, out_grad, false, y_grad);
} else {
grad_func(dev_ctx, x, y, *out, out_grad, x_grad, y_grad, axis);
}
}
template <typename T,
typename Context,
typename GradFunc,
typename GradInverseFunc>
void AddDoubleGradImpl(const Context& dev_ctx,
const DenseTensor& y,
const paddle::optional<const DenseTensor&>& ddx,
const paddle::optional<const DenseTensor&>& ddy,
const DenseTensor& dout,
int axis,
DenseTensor* ddout,
GradFunc grad_func,
GradInverseFunc grad_inverse_func) {
// ddOut = ddx + ddy
if (ddout) {
DenseTensor ddx_safe, ddy_safe;
funcs::GetDoubleGradSafeTensor<Context, T>(
dev_ctx, dout, ddx.get_ptr(), &ddx_safe);
funcs::GetDoubleGradSafeTensor<Context, T>(
dev_ctx, y, ddy.get_ptr(), &ddy_safe);
ddout->mutable_data<T>(dev_ctx.GetPlace());
auto ddx_dims = ddx_safe.dims();
auto ddy_dims = ddy_safe.dims();
if (ddx_dims.size() >= ddy_dims.size()) {
grad_func(
dev_ctx, ddx_safe, ddy_safe, axis, funcs::AddFunctor<T>(), ddout);
} else {
grad_inverse_func(dev_ctx,
ddx_safe,
ddy_safe,
axis,
funcs::InverseAddFunctor<T>(),
ddout);
}
}
}
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册