未验证 提交 7a1e1193 编写于 作者: Y YuanRisheng 提交者: GitHub

refactor elementwise sub grad (#39225)

上级 5631da9c
...@@ -16,11 +16,26 @@ limitations under the License. */ ...@@ -16,11 +16,26 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T>
void default_elementwise_sub(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y, framework::Tensor* z) {
int axis = ctx.Attr<int>("axis");
auto x_dims = x->dims();
auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
SubFunctor<T>(), z);
} else {
ElementwiseComputeEx<InverseSubFunctor<T>, DeviceContext, T>(
ctx, x, y, axis, InverseSubFunctor<T>(), z);
}
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void default_elementwise_div(const framework::ExecutionContext& ctx, void default_elementwise_div(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* x,
......
...@@ -17,103 +17,6 @@ limitations under the License. */ ...@@ -17,103 +17,6 @@ limitations under the License. */
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
namespace paddle {
namespace operators {
template <typename T>
static __global__ void SimpleElemwiseSubGradCUDAKernel(const T* dout,
int64_t size, T* dx,
T* dy) {
int col = blockIdx.x * blockDim.x + threadIdx.x;
while (col < size) {
if (dx != nullptr) {
dx[col] = dout[col];
}
dy[col] = -dout[col];
col += blockDim.x * gridDim.x;
}
}
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
default_elementwise_sub_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
auto* dout_data = dout->data<T>();
// dx
if (dx != nullptr) {
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
if (dx->dims() == dout->dims()) {
if (dx_data != dout_data) {
framework::TensorCopy(
*dout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dx);
}
} else {
// For inplace strategy, dx will be stored in addr of dout, which makes
// the result of dy wrong.
if (dx->IsSharedBufferWith(*dout)) {
dx->clear();
dx->mutable_data<T>(x->dims(), ctx.GetPlace());
}
std::vector<int> reduce_dims = GetReduceDim(x->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
*dout, dx, kps::IdentityFunctor<T>(), reduce_dims, stream);
}
}
// dy
if (dy != nullptr) {
auto* dy_data = dy->mutable_data<T>(ctx.GetPlace());
if (dy->dims() == dout->dims()) {
if (dy_data != dout_data) {
dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1);
auto size = dy->numel();
dim3 grid_size =
dim3((size + PREDEFINED_BLOCK_SIZE - 1) / PREDEFINED_BLOCK_SIZE, 1);
SimpleElemwiseSubGradCUDAKernel<T><<<
grid_size, block_size, 0,
ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
dout->data<T>(), size, nullptr,
dy->mutable_data<T>(ctx.GetPlace()));
}
} else {
std::vector<int> reduce_dims = GetReduceDim(y->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::InverseFunctor<T>>(
*dout, dy, kps::InverseFunctor<T>(), reduce_dims, stream);
}
}
}
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, plat::CUDADeviceContext>::value>::type
elementwise_sub_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout, framework::Tensor* dx,
framework::Tensor* dy) {
dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1);
auto size = x->numel();
dim3 grid_size =
dim3((size + PREDEFINED_BLOCK_SIZE - 1) / PREDEFINED_BLOCK_SIZE, 1);
SimpleElemwiseSubGradCUDAKernel<
T><<<grid_size, block_size, 0,
ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
dout->data<T>(), size, dx->mutable_data<T>(ctx.GetPlace()),
dy->mutable_data<T>(ctx.GetPlace()));
}
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_sub, elementwise_sub,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, float>, ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, float>,
......
...@@ -17,26 +17,11 @@ limitations under the License. */ ...@@ -17,26 +17,11 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/pten/kernels/elementwise_grad_kernel.h"
#include "paddle/pten/kernels/math_kernel.h" #include "paddle/pten/kernels/math_kernel.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T>
void default_elementwise_sub(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y, framework::Tensor* z) {
int axis = ctx.Attr<int>("axis");
auto x_dims = x->dims();
auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
SubFunctor<T>(), z);
} else {
ElementwiseComputeEx<InverseSubFunctor<T>, DeviceContext, T>(
ctx, x, y, axis, InverseSubFunctor<T>(), z);
}
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseSubKernel : public framework::OpKernel<T> { class ElementwiseSubKernel : public framework::OpKernel<T> {
public: public:
...@@ -48,76 +33,13 @@ class ElementwiseSubKernel : public framework::OpKernel<T> { ...@@ -48,76 +33,13 @@ class ElementwiseSubKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.device_context<DeviceContext>(); auto& dev_ctx = ctx.device_context<DeviceContext>();
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z);
pten::SubtractRawKernel<T>( pten::SubtractRawKernel<T>(
static_cast<const typename framework::ConvertToPtenContext< static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx), DeviceContext>::TYPE&>(dev_ctx),
*pt_x.get(), *pt_y.get(), axis, pt_z.get()); *x, *y, axis, z);
} }
}; };
template <typename T>
struct SubGradDX {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; }
};
template <typename T>
struct SubGradDY {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return -dout; }
};
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
default_elementwise_sub_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
const auto& dev_ctx =
ctx.template device_context<platform::CPUDeviceContext>();
pten::ElemwiseExplicitGradCompute<T, SubGradDX<T>, SubGradDY<T>>(
dev_ctx, *x, *y, *out, *dout, axis, dx, dy, SubGradDX<T>(),
SubGradDY<T>());
}
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_sub_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout, framework::Tensor* dx,
framework::Tensor* dy) {
default_elementwise_sub_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// cuda definition
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
default_elementwise_sub_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy);
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
elementwise_sub_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout, framework::Tensor* dx,
framework::Tensor* dy);
#endif
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseSubGradKernel : public ElemwiseGradKernel<T> { class ElementwiseSubGradKernel : public ElemwiseGradKernel<T> {
public: public:
...@@ -130,14 +52,13 @@ class ElementwiseSubGradKernel : public ElemwiseGradKernel<T> { ...@@ -130,14 +52,13 @@ class ElementwiseSubGradKernel : public ElemwiseGradKernel<T> {
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
// skip out int axis = ctx.Attr<int>("axis");
auto* out = dout; auto& dev_ctx = ctx.device_context<DeviceContext>();
if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
elementwise_sub_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy); pten::SubtractGradKernel<T>(
} else { static_cast<const typename framework::ConvertToPtenContext<
default_elementwise_sub_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, DeviceContext>::TYPE&>(dev_ctx),
dy); *x, *y, *dout, axis, dx, dy);
}
} }
}; };
...@@ -153,18 +74,21 @@ class ElementwiseSubDoubleGradKernel : public framework::OpKernel<T> { ...@@ -153,18 +74,21 @@ class ElementwiseSubDoubleGradKernel : public framework::OpKernel<T> {
auto* ddy = ctx.Input<Tensor>("DDY"); auto* ddy = ctx.Input<Tensor>("DDY");
auto* ddout = ctx.Output<Tensor>("DDOut"); auto* ddout = ctx.Output<Tensor>("DDOut");
int axis = ctx.Attr<int>("axis");
auto& dev_ctx = ctx.device_context<DeviceContext>();
// DDOut = ddx - ddy paddle::optional<const pten::DenseTensor&> ddx_optional = paddle::none;
if (ddout) { paddle::optional<const pten::DenseTensor&> ddy_optional = paddle::none;
Tensor ddx_safe, ddy_safe; if (ddx != nullptr) {
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, dout, ddx, &ddx_safe); ddx_optional = *ddx;
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, y, ddy, &ddy_safe);
ddout->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &ddx_safe, &ddy_safe, axis, SubFunctor<T>(), ddout);
} }
if (ddy != nullptr) {
ddy_optional = *ddy;
}
pten::SubtractDoubleGradKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*y, ddx_optional, ddy_optional, *dout, axis, ddout);
} }
}; };
......
...@@ -25,6 +25,7 @@ const std::unordered_map<std::string, std::string> kernel_alias_name_map = { ...@@ -25,6 +25,7 @@ const std::unordered_map<std::string, std::string> kernel_alias_name_map = {
{"elementwise_div", "divide_raw"}, {"elementwise_div", "divide_raw"},
{"elementwise_mul", "muliply_raw"}, {"elementwise_mul", "muliply_raw"},
{"elementwise_sub", "subtract_raw"}, {"elementwise_sub", "subtract_raw"},
{"elementwise_sub_grad", "subtract_grad"},
{"fill_any_like", "full_like"}, {"fill_any_like", "full_like"},
{"fill_constant", "full"}, {"fill_constant", "full"},
{"flatten_contiguous_range", "flatten"}, {"flatten_contiguous_range", "flatten"},
......
...@@ -743,8 +743,11 @@ void ElemwiseExplicitGradCompute(const CPUContext& dev_ctx, ...@@ -743,8 +743,11 @@ void ElemwiseExplicitGradCompute(const CPUContext& dev_ctx,
} }
} }
// Add Grad /*
******************************
Add Grad
******************************
*/
template <typename T> template <typename T>
struct IdentityGrad { struct IdentityGrad {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; } HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; }
...@@ -786,4 +789,33 @@ elementwise_add_grad(const CPUContext& ctx, ...@@ -786,4 +789,33 @@ elementwise_add_grad(const CPUContext& ctx,
ctx, x, y, out, dout, axis, dx, dy, IdentityGrad<T>(), IdentityGrad<T>()); ctx, x, y, out, dout, axis, dx, dy, IdentityGrad<T>(), IdentityGrad<T>());
} }
/*
******************************
Sub Grad
******************************
*/
template <typename T>
struct SubGradDX {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; }
};
template <typename T>
struct SubGradDY {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return -dout; }
};
template <typename T>
void elementwise_sub_grad(const CPUContext& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out,
const DenseTensor& dout,
DenseTensor* dx,
DenseTensor* dy,
int axis = -1) {
ElemwiseExplicitGradCompute<T, SubGradDX<T>, SubGradDY<T>>(
ctx, x, y, out, dout, axis, dx, dy, SubGradDX<T>(), SubGradDY<T>());
}
} // namespace pten } // namespace pten
...@@ -92,6 +92,38 @@ void AddTripleGradKernel(const Context& dev_ctx, ...@@ -92,6 +92,38 @@ void AddTripleGradKernel(const Context& dev_ctx,
dev_ctx, ddx, ddy, d_ddout, axis, d_ddx, d_ddy, AddGradFunc<T>); dev_ctx, ddx, ddy, d_ddout, axis, d_ddx, d_ddy, AddGradFunc<T>);
} }
template <typename T, typename Context>
void SubtractGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy) {
// skip out
auto* out = &dout;
elementwise_sub_grad<T>(dev_ctx, x, y, *out, dout, dx, dy, axis);
}
template <typename T, typename Context>
void SubtractDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& y,
paddle::optional<const DenseTensor&> ddx,
paddle::optional<const DenseTensor&> ddy,
const DenseTensor& dout,
int axis,
DenseTensor* ddout) {
pten::SubtractDoubleGradImpl<T>(
dev_ctx,
y,
ddx,
ddy,
dout,
axis,
ddout,
ElementwiseCompute<funcs::SubtractFunctor<T>, T>);
}
} // namespace pten } // namespace pten
PT_REGISTER_KERNEL(add_grad, PT_REGISTER_KERNEL(add_grad,
...@@ -126,3 +158,25 @@ PT_REGISTER_KERNEL(add_triple_grad, ...@@ -126,3 +158,25 @@ PT_REGISTER_KERNEL(add_triple_grad,
int64_t, int64_t,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL(subtract_grad,
CPU,
ALL_LAYOUT,
pten::SubtractGradKernel,
float,
double,
int,
int64_t,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL(subtract_double_grad,
CPU,
ALL_LAYOUT,
pten::SubtractDoubleGradKernel,
float,
double,
int,
int64_t,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
...@@ -46,4 +46,22 @@ void AddTripleGradKernel(const Context& dev_ctx, ...@@ -46,4 +46,22 @@ void AddTripleGradKernel(const Context& dev_ctx,
DenseTensor* d_ddx, DenseTensor* d_ddx,
DenseTensor* d_ddy); DenseTensor* d_ddy);
template <typename T, typename Context>
void SubtractGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy);
template <typename T, typename Context>
void SubtractDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& y,
paddle::optional<const DenseTensor&> ddx,
paddle::optional<const DenseTensor&> ddy,
const DenseTensor& dout,
int axis,
DenseTensor* ddout);
} // namespace pten } // namespace pten
...@@ -1952,6 +1952,12 @@ void ElemwiseGradComputeWithBroadcast(const GPUContext &ctx, ...@@ -1952,6 +1952,12 @@ void ElemwiseGradComputeWithBroadcast(const GPUContext &ctx,
} }
} }
/*
******************************
Add Grad
******************************
*/
template <typename T> template <typename T>
static __global__ void SimpleElemwiseAddGradCUDAKernel( static __global__ void SimpleElemwiseAddGradCUDAKernel(
const T *__restrict__ dout, int size, int vec_size, T *dx, T *dy) { const T *__restrict__ dout, int size, int vec_size, T *dx, T *dy) {
...@@ -2078,4 +2084,106 @@ void elementwise_add_grad(const GPUContext &ctx, ...@@ -2078,4 +2084,106 @@ void elementwise_add_grad(const GPUContext &ctx,
} }
} }
/*
******************************
Sub Grad
******************************
*/
template <typename T>
static __global__ void SimpleElemwiseSubGradCUDAKernel(const T *dout,
int64_t size,
T *dx,
T *dy) {
int col = blockIdx.x * blockDim.x + threadIdx.x;
while (col < size) {
if (dx != nullptr) {
dx[col] = dout[col];
}
dy[col] = -dout[col];
col += blockDim.x * gridDim.x;
}
}
template <typename T>
void default_elementwise_sub_grad(const GPUContext &ctx,
const DenseTensor &x,
const DenseTensor &y,
const DenseTensor &out,
const DenseTensor &dout,
DenseTensor *dx,
DenseTensor *dy,
int axis = -1) {
auto *dout_data = dout.data<T>();
// dx
if (dx != nullptr) {
auto *dx_data = dx->mutable_data<T>(ctx.GetPlace());
if (dx->dims() == dout.dims()) {
if (dx_data != dout_data) {
pten::Copy(ctx, dout, false, dx);
}
} else {
// For inplace strategy, dx will be stored in addr of dout, which makes
// the result of dy wrong.
if (dx->IsSharedBufferWith(dout)) {
dx->clear();
dx->mutable_data<T>(x.dims(), ctx.GetPlace());
}
std::vector<int> reduce_dims =
funcs::GetReduceDim(x.dims(), out.dims(), axis);
gpuStream_t stream = ctx.stream();
kernels::TensorReduceFunctorImpl<T,
T,
kps::AddFunctor,
kps::IdentityFunctor<T>>(
dout, dx, kps::IdentityFunctor<T>(), reduce_dims, stream);
}
}
// dy
if (dy != nullptr) {
auto *dy_data = dy->mutable_data<T>(ctx.GetPlace());
if (dy->dims() == dout.dims()) {
if (dy_data != dout_data) {
dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1);
auto size = dy->numel();
dim3 grid_size =
dim3((size + PREDEFINED_BLOCK_SIZE - 1) / PREDEFINED_BLOCK_SIZE, 1);
SimpleElemwiseSubGradCUDAKernel<
T><<<grid_size, block_size, 0, ctx.stream()>>>(
dout.data<T>(), size, nullptr, dy->mutable_data<T>(ctx.GetPlace()));
}
} else {
std::vector<int> reduce_dims =
funcs::GetReduceDim(y.dims(), out.dims(), axis);
gpuStream_t stream = ctx.stream();
kernels::TensorReduceFunctorImpl<T,
T,
kps::AddFunctor,
kps::InverseFunctor<T>>(
dout, dy, kps::InverseFunctor<T>(), reduce_dims, stream);
}
}
}
template <typename T>
void elementwise_sub_grad(const GPUContext &ctx,
const DenseTensor &x,
const DenseTensor &y,
const DenseTensor &out,
const DenseTensor &dout,
DenseTensor *dx,
DenseTensor *dy) {
dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1);
auto size = x.numel();
dim3 grid_size =
dim3((size + PREDEFINED_BLOCK_SIZE - 1) / PREDEFINED_BLOCK_SIZE, 1);
SimpleElemwiseSubGradCUDAKernel<
T><<<grid_size, block_size, 0, ctx.stream()>>>(
dout.data<T>(),
size,
dx->mutable_data<T>(ctx.GetPlace()),
dy->mutable_data<T>(ctx.GetPlace()));
}
} // namespace pten } // namespace pten
...@@ -82,6 +82,42 @@ void AddTripleGradKernel(const Context& dev_ctx, ...@@ -82,6 +82,42 @@ void AddTripleGradKernel(const Context& dev_ctx,
dev_ctx, ddx, ddy, d_ddout, axis, d_ddx, d_ddy, AddGradFunc<T>); dev_ctx, ddx, ddy, d_ddout, axis, d_ddx, d_ddy, AddGradFunc<T>);
} }
template <typename T, typename Context>
void SubtractGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy) {
// skip out
auto* out = &dout;
if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
elementwise_sub_grad<T>(dev_ctx, x, y, *out, dout, dx, dy);
} else {
default_elementwise_sub_grad<T>(dev_ctx, x, y, *out, dout, dx, dy, axis);
}
}
template <typename T, typename Context>
void SubtractDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& y,
paddle::optional<const DenseTensor&> ddx,
paddle::optional<const DenseTensor&> ddy,
const DenseTensor& dout,
int axis,
DenseTensor* ddout) {
pten::SubtractDoubleGradImpl<T>(
dev_ctx,
y,
ddx,
ddy,
dout,
axis,
ddout,
ElementwiseCompute<funcs::SubtractFunctor<T>, T>);
}
} // namespace pten } // namespace pten
PT_REGISTER_KERNEL(add_grad, PT_REGISTER_KERNEL(add_grad,
...@@ -119,3 +155,27 @@ PT_REGISTER_KERNEL(add_triple_grad, ...@@ -119,3 +155,27 @@ PT_REGISTER_KERNEL(add_triple_grad,
paddle::platform::float16, paddle::platform::float16,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL(subtract_grad,
GPU,
ALL_LAYOUT,
pten::SubtractGradKernel,
float,
double,
int,
int64_t,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL(subtract_double_grad,
GPU,
ALL_LAYOUT,
pten::SubtractDoubleGradKernel,
float,
double,
int,
int64_t,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
...@@ -85,4 +85,27 @@ void AddDoubleGradImpl(const Context& dev_ctx, ...@@ -85,4 +85,27 @@ void AddDoubleGradImpl(const Context& dev_ctx,
} }
} }
template <typename T, typename Context, typename GradFunc>
void SubtractDoubleGradImpl(const Context& dev_ctx,
const DenseTensor& y,
const paddle::optional<const DenseTensor&>& ddx,
const paddle::optional<const DenseTensor&>& ddy,
const DenseTensor& dout,
int axis,
DenseTensor* ddout,
GradFunc grad_func) {
// DDOut = ddx - ddy
if (ddout) {
DenseTensor ddx_safe, ddy_safe;
funcs::GetDoubleGradSafeTensor<Context, T>(
dev_ctx, dout, ddx.get_ptr(), &ddx_safe);
funcs::GetDoubleGradSafeTensor<Context, T>(
dev_ctx, y, ddy.get_ptr(), &ddy_safe);
ddout->mutable_data<T>(dev_ctx.GetPlace());
grad_func(
dev_ctx, ddx_safe, ddy_safe, axis, funcs::SubtractFunctor<T>(), ddout);
}
}
} // namespace pten } // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册