未验证 提交 452c75b8 编写于 作者: Y YuanRisheng 提交者: GitHub

move elementwise mul grad (#40252)

上级 0604df9e
......@@ -46,7 +46,7 @@ USE_OP(matmul_grad);
USE_OP(square);
USE_OP(transpose2_grad);
USE_OP(concat_grad);
USE_OP(elementwise_mul_grad);
USE_OP_ITSELF(elementwise_mul_grad);
USE_OP(sigmoid_grad);
USE_OP(tanh_grad);
USE_OP(sum);
......
......@@ -196,47 +196,6 @@ struct MinGradXYFunctor {
}
};
template <typename T>
struct MulGradFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const { return a * b; }
};
template <typename T>
struct MulGradFunctor<Complex<T>> {
inline HOSTDEVICE Complex<T> operator()(const Complex<T> a,
const Complex<T> b) const {
Complex<T> b_conj(b.real, -b.imag);
return a * b_conj;
}
};
template <typename InT, typename OutT>
struct MulGradXYFunctor {
inline HOSTDEVICE phi::Array<OutT, 2> operator()(const InT a, const InT b,
const InT c) {
phi::Array<OutT, 2> outs;
// dx = dout * y
outs[0] = a * b;
// dy = dout * x
outs[1] = a * c;
return outs;
}
};
template <typename InT, typename OutT>
struct MulGradXYFunctor<Complex<InT>, Complex<OutT>> {
inline HOSTDEVICE phi::Array<Complex<OutT>, 2> operator()(
const Complex<InT> a, const Complex<InT> b, const Complex<InT> c) {
phi::Array<Complex<OutT>, 2> outs;
// dx = dout * y
Complex<InT> b_conj(b.real, -b.imag);
outs[0] = a * b_conj;
// dy = dout * x
Complex<InT> c_conj(c.real, -c.imag);
outs[1] = a * c_conj;
return outs;
}
};
// Ternary compare
template <typename T>
struct MaxGradXFunctor {
......
......@@ -173,55 +173,6 @@ REGISTER_OP_CPU_KERNEL(
paddle::platform::complex<float>>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
elementwise_mul_grad,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
elementwise_mul_grad_grad,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
float>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
double>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
int>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
bool>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
elementwise_mul_triple_grad,
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
float>,
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
double>,
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
int>,
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>,
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
bool>,
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>,
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_VERSION(elementwise_mul)
.AddCheckpoint(
......
......@@ -63,33 +63,6 @@ class ElementwiseMulKernel<platform::CUDADeviceContext, T>
}
};
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
ElementwiseMulGrad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
const auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
const auto place = ctx.GetPlace();
if (dx != nullptr && dy != nullptr) {
std::vector<const framework::Tensor*> ins = {dout, y, x};
GetGradXAndYOut<ElementwiseType::kTernary, T>(
dev_ctx, place, axis, ins, dout, dx, dy, MulGradXYFunctor<T, T>());
} else if (dx != nullptr && dy == nullptr) {
std::vector<const framework::Tensor*> ins = {dout, y};
GetGradXOrYOut<ElementwiseType::kBinary, T>(dev_ctx, place, axis, ins, dout,
dx, MulGradFunctor<T>());
} else if (dx == nullptr && dy != nullptr) {
std::vector<const framework::Tensor*> ins = {dout, x};
GetGradXOrYOut<ElementwiseType::kBinary, T>(dev_ctx, place, axis, ins, dout,
dy, MulGradFunctor<T>());
}
}
} // namespace operators
} // namespace paddle
......@@ -103,44 +76,3 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::bfloat16>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
elementwise_mul_grad,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, bool>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::bfloat16>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext,
plat::complex<float>>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext,
plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
elementwise_mul_grad_grad,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, bool>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
plat::bfloat16>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
plat::complex<float>>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
elementwise_mul_triple_grad,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, bool>,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext,
plat::bfloat16>,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext,
plat::complex<float>>,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext,
plat::complex<double>>);
......@@ -137,244 +137,6 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
}
}
};
template <typename T>
struct MulGradDX {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * y; }
};
template <typename T>
struct MulGradDX<paddle::platform::complex<T>> {
HOSTDEVICE paddle::platform::complex<T> operator()(
paddle::platform::complex<T> x, paddle::platform::complex<T> y,
paddle::platform::complex<T> out,
paddle::platform::complex<T> dout) const {
paddle::platform::complex<T> y_conj(y.real, -y.imag);
return dout * y_conj;
}
};
template <typename T>
struct MulGradDY {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * x; }
};
template <typename T>
struct MulGradDY<paddle::platform::complex<T>> {
HOSTDEVICE paddle::platform::complex<T> operator()(
paddle::platform::complex<T> x, paddle::platform::complex<T> y,
paddle::platform::complex<T> out,
paddle::platform::complex<T> dout) const {
paddle::platform::complex<T> x_conj(x.real, -x.imag);
return dout * x_conj;
}
};
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
ElementwiseMulGrad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, MulGradDX<T>, MulGradDY<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, MulGradDX<T>(), MulGradDY<T>());
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
ElementwiseMulGrad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy);
#endif
template <typename DeviceContext, typename T>
class ElementwiseMulGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* out = dout; // out is not necessary
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
ElementwiseMulGrad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
}
};
template <typename DeviceContext, typename T>
class ElementwiseMulDoubleGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>("DOut");
auto* ddx = ctx.Input<Tensor>("DDX");
auto* ddy = ctx.Input<Tensor>("DDY");
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* ddout = ctx.Output<Tensor>("DDOut");
if (ddout) ddout->mutable_data<T>(ctx.GetPlace());
Tensor ddx_safe, ddy_safe;
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, x, ddx, &ddx_safe);
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, y, ddy, &ddy_safe);
// dx = dout * ddy
// dy = dout * ddx
// ddout = ddx * y + x * ddy
// change computation sequence to save memory, so ddout can inplace ddx and
// dx can be used as 'tmp' tensor
// (1) dx = x * ddy
// (2) dy = dout * ddx
// (3) ddout = ddx * y
// (4) ddout = ddout + dx
// (5) dx = dout * ddy
if (ddout) {
int axis = ctx.Attr<int>("axis");
auto& place =
*ctx.template device_context<DeviceContext>().eigen_device();
// size(ddout) > size(ddx), ddout can't use memory of ddx using inplace
if (ddout->numel() > ddx->numel()) {
ElemwiseGradCompute<DeviceContext, T, MulGradDX<T>, MulGradDY<T>>(
ctx, ddx_safe, ddy_safe, *dout, *dout, axis, dx, dy, MulGradDX<T>(),
MulGradDY<T>());
Tensor ddout_tmp;
ddout_tmp.mutable_data<T>(ddout->dims(), ctx.GetPlace());
default_elementwise_mul<DeviceContext, T>(ctx, y, &ddx_safe, ddout);
default_elementwise_mul<DeviceContext, T>(ctx, &ddy_safe, x,
&ddout_tmp);
auto ddout_t = framework::EigenVector<T>::Flatten(*ddout);
auto ddout_tmp_t = framework::EigenVector<T>::Flatten(ddout_tmp);
ddout_t.device(place) = ddout_t + ddout_tmp_t;
} else {
// use dx to save memory, other than alloc tmp tensor
Tensor* ddout_tmp = dx;
default_elementwise_mul<DeviceContext, T>(ctx, x, &ddy_safe, ddout_tmp);
// NOTE: in the following ElemwiseGradCompute, for the
// first output tensor is nullptr, the branch to calculate first
// output tensor will not be activated, DivGradDx function will not
// be called and can be ignored, the first branch has little effect
// on running speed.
ElemwiseGradCompute<DeviceContext, T, MulGradDX<T>, MulGradDY<T>>(
ctx, ddx_safe, ddy_safe, *dout, *dout, axis, nullptr, dy,
MulGradDX<T>(), MulGradDY<T>());
default_elementwise_mul<DeviceContext, T>(ctx, &ddx_safe, y, ddout);
auto ddout_t = framework::EigenVector<T>::Flatten(*ddout);
auto ddout_tmp_t = framework::EigenVector<T>::Flatten(*ddout_tmp);
ddout_t.device(place) = ddout_t + ddout_tmp_t;
default_elementwise_mul<DeviceContext, T>(ctx, dout, &ddy_safe, dx);
}
}
}
};
template <typename DeviceContext, typename T>
class ElementwiseMulTripleGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor;
// get input
auto* x = ctx.Input<framework::Tensor>("X");
auto* y = ctx.Input<framework::Tensor>("Y");
auto* dout = ctx.Input<framework::Tensor>("DOut");
auto* ddx = ctx.Input<framework::Tensor>("DDX");
auto* ddy = ctx.Input<framework::Tensor>("DDY");
auto* d_dx = ctx.Input<framework::Tensor>("D_DX");
auto* d_dy = ctx.Input<framework::Tensor>("D_DY");
auto* d_ddout = ctx.Input<framework::Tensor>("D_DDOut");
// get output
auto* out_d_x = ctx.Output<framework::Tensor>("D_X");
auto* out_d_y = ctx.Output<framework::Tensor>("D_Y");
auto* out_d_dout = ctx.Output<framework::Tensor>("D_DOut");
auto* out_d_ddx = ctx.Output<framework::Tensor>("D_DDX");
auto* out_d_ddy = ctx.Output<framework::Tensor>("D_DDY");
if (out_d_x) out_d_x->mutable_data<T>(x->dims(), ctx.GetPlace());
if (out_d_y) out_d_y->mutable_data<T>(y->dims(), ctx.GetPlace());
if (out_d_dout) out_d_dout->mutable_data<T>(dout->dims(), ctx.GetPlace());
if (out_d_ddx) out_d_ddx->mutable_data<T>(x->dims(), ctx.GetPlace());
if (out_d_ddy) out_d_ddy->mutable_data<T>(y->dims(), ctx.GetPlace());
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
Tensor ddx_safe, ddy_safe;
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, x, ddx, &ddx_safe);
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, y, ddy, &ddy_safe);
if (d_ddout) {
if (out_d_x) {
// out_d_x = ddy * d_ddout
default_elementwise_mul<DeviceContext, T>(ctx, &ddy_safe, d_ddout,
out_d_x);
}
if (out_d_y) {
// out_d_y = ddx * d_ddout
default_elementwise_mul<DeviceContext, T>(ctx, &ddx_safe, d_ddout,
out_d_y);
}
}
if (out_d_dout) {
// get out_d_dout
// out_d_dout = ddy * d_dx + d_dy * ddx
Tensor out_d_dout_tmp;
out_d_dout_tmp.mutable_data<T>(dout->dims(), ctx.GetPlace());
default_elementwise_mul<DeviceContext, T>(ctx, d_dy, &ddx_safe,
out_d_dout);
default_elementwise_mul<DeviceContext, T>(ctx, &ddy_safe, d_dx,
&out_d_dout_tmp);
auto out_d_dout_t = framework::EigenVector<T>::Flatten(*out_d_dout);
auto out_d_dout_tmp_t =
framework::EigenVector<T>::Flatten(out_d_dout_tmp);
out_d_dout_t.device(place) = out_d_dout_t + out_d_dout_tmp_t;
}
if (out_d_ddx) {
// get out_d_ddx
// out_d_ddx = dout * d_dy + y * d_ddout
Tensor out_d_ddx_tmp;
out_d_ddx_tmp.mutable_data<T>(ddx->dims(), ctx.GetPlace());
default_elementwise_mul<DeviceContext, T>(ctx, dout, d_dy, out_d_ddx);
default_elementwise_mul<DeviceContext, T>(ctx, y, d_ddout,
&out_d_ddx_tmp);
auto out_d_ddx_t = framework::EigenVector<T>::Flatten(*out_d_ddx);
auto out_d_ddx_tmp_t = framework::EigenVector<T>::Flatten(out_d_ddx_tmp);
out_d_ddx_t.device(place) = out_d_ddx_t + out_d_ddx_tmp_t;
}
if (out_d_ddy) {
// get out_d_ddy
// out_d_ddy = dout * d_dx + x * d_ddout
Tensor out_d_ddy_tmp;
out_d_ddy_tmp.mutable_data<T>(ddy->dims(), ctx.GetPlace());
default_elementwise_mul<DeviceContext, T>(ctx, dout, d_dx, out_d_ddy);
default_elementwise_mul<DeviceContext, T>(ctx, x, d_ddout,
&out_d_ddy_tmp);
auto out_d_ddy_t = framework::EigenVector<T>::Flatten(*out_d_ddy);
auto out_d_ddy_tmp_t = framework::EigenVector<T>::Flatten(out_d_ddy_tmp);
out_d_ddy_t.device(place) = out_d_ddy_t + out_d_ddy_tmp_t;
}
}
};
} // namespace operators
} // namespace paddle
......@@ -121,6 +121,20 @@ void DivideGradKernel(const Context& dev_ctx,
dev_ctx, x, y, out, dout, axis, dx, dy, DivGradDX<T>(), DivGradDY<T>());
}
template <typename T, typename Context>
void MultiplyGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy) {
funcs::ElementwiseGradPreProcess(dout, dx);
auto* out = &dout; // out is not necessary
phi::funcs::ElemwiseGradCompute<Context, T, MulGradDX<T>, MulGradDY<T>>(
dev_ctx, x, y, *out, dout, axis, dx, dy, MulGradDX<T>(), MulGradDY<T>());
}
} // namespace phi
PD_REGISTER_KERNEL(add_grad,
......@@ -193,8 +207,8 @@ PD_REGISTER_KERNEL(divide_grad,
double,
int,
int64_t,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(divide_double_grad,
CPU,
......@@ -204,5 +218,44 @@ PD_REGISTER_KERNEL(divide_double_grad,
double,
int,
int64_t,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(multiply_grad,
CPU,
ALL_LAYOUT,
phi::MultiplyGradKernel,
float,
double,
int,
int64_t,
bool,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(multiply_double_grad,
CPU,
ALL_LAYOUT,
phi::MultiplyDoubleGradKernel,
float,
double,
int,
int64_t,
bool,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(multiply_triple_grad,
CPU,
ALL_LAYOUT,
phi::MultiplyTripleGradKernel,
float,
double,
int,
int64_t,
bool,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -85,4 +85,43 @@ void DivideDoubleGradKernel(const Context& dev_ctx,
DenseTensor* dy,
DenseTensor* dout,
DenseTensor* ddout);
template <typename T, typename Context>
void MultiplyGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy);
template <typename T, typename Context>
void MultiplyDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
paddle::optional<const DenseTensor&> ddx,
paddle::optional<const DenseTensor&> ddy,
int axis,
DenseTensor* dx,
DenseTensor* dy,
DenseTensor* ddout);
template <typename T, typename Context>
void MultiplyTripleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
paddle::optional<const DenseTensor&> ddx,
paddle::optional<const DenseTensor&> ddy,
const DenseTensor& d_dx,
const DenseTensor& d_dy,
paddle::optional<const DenseTensor&> d_ddout,
int axis,
DenseTensor* d_x,
DenseTensor* d_y,
DenseTensor* d_dout,
DenseTensor* d_ddx,
DenseTensor* d_ddy);
} // namespace phi
......@@ -160,5 +160,49 @@ struct DivGradYFunctor<ComplexType<T>> {
}
};
template <typename T>
struct MultiplyGradFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const { return a * b; }
};
template <typename T>
struct MultiplyGradFunctor<ComplexType<T>> {
inline HOSTDEVICE ComplexType<T> operator()(const ComplexType<T> a,
const ComplexType<T> b) const {
ComplexType<T> b_conj(b.real, -b.imag);
return a * b_conj;
}
};
template <typename InT, typename OutT>
struct MultiplyGradXYFunctor {
inline HOSTDEVICE phi::Array<OutT, 2> operator()(const InT a,
const InT b,
const InT c) {
phi::Array<OutT, 2> outs;
// dx = dout * y
outs[0] = a * b;
// dy = dout * x
outs[1] = a * c;
return outs;
}
};
template <typename InT, typename OutT>
struct MultiplyGradXYFunctor<ComplexType<InT>, ComplexType<OutT>> {
inline HOSTDEVICE phi::Array<ComplexType<OutT>, 2> operator()(
const ComplexType<InT> a,
const ComplexType<InT> b,
const ComplexType<InT> c) {
phi::Array<ComplexType<OutT>, 2> outs;
// dx = dout * y
ComplexType<InT> b_conj(b.real, -b.imag);
outs[0] = a * b_conj;
// dy = dout * x
ComplexType<InT> c_conj(c.real, -c.imag);
outs[1] = a * c_conj;
return outs;
}
};
} // namespace funcs
} // namespace phi
......@@ -360,4 +360,41 @@ void ElementwiseDivGrad(const GPUContext &dev_ctx,
}
}
/*
******************************
Mul Grad
******************************
*/
template <typename T>
void ElementwiseMulGrad(const GPUContext &dev_ctx,
const DenseTensor &x,
const DenseTensor &y,
const DenseTensor &dout,
DenseTensor *dx,
DenseTensor *dy,
int axis) {
const auto place = dev_ctx.GetPlace();
if (dx != nullptr && dy != nullptr) {
std::vector<const DenseTensor *> ins = {&dout, &y, &x};
GetGradXAndYOut<ElementwiseType::kTernary, T>(
dev_ctx,
place,
axis,
ins,
dout,
dx,
dy,
funcs::MultiplyGradXYFunctor<T, T>());
} else if (dx != nullptr && dy == nullptr) {
std::vector<const DenseTensor *> ins = {&dout, &y};
GetGradXOrYOut<ElementwiseType::kBinary, T>(
dev_ctx, place, axis, ins, dout, dx, funcs::MultiplyGradFunctor<T>());
} else if (dx == nullptr && dy != nullptr) {
std::vector<const DenseTensor *> ins = {&dout, &x};
GetGradXOrYOut<ElementwiseType::kBinary, T>(
dev_ctx, place, axis, ins, dout, dy, funcs::MultiplyGradFunctor<T>());
}
}
} // namespace phi
......@@ -136,6 +136,18 @@ void DivideGradKernel(const Context& dev_ctx,
}
}
template <typename T, typename Context>
void MultiplyGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy) {
funcs::ElementwiseGradPreProcess(dout, dx);
ElementwiseMulGrad<T>(dev_ctx, x, y, dout, dx, dy, axis);
}
} // namespace phi
PD_REGISTER_KERNEL(add_grad,
......@@ -228,3 +240,45 @@ PD_REGISTER_KERNEL(divide_double_grad,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(multiply_grad,
GPU,
ALL_LAYOUT,
phi::MultiplyGradKernel,
float,
phi::dtype::float16,
double,
int,
int64_t,
bool,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(multiply_double_grad,
GPU,
ALL_LAYOUT,
phi::MultiplyDoubleGradKernel,
float,
phi::dtype::float16,
double,
int,
int64_t,
bool,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(multiply_triple_grad,
GPU,
ALL_LAYOUT,
phi::MultiplyTripleGradKernel,
float,
phi::dtype::float16,
double,
int,
int64_t,
bool,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -259,4 +259,277 @@ void DivideDoubleGradKernel(const Context& dev_ctx,
}
}
template <typename T>
struct MulGradDX {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * y; }
};
template <typename T>
struct MulGradDX<phi::dtype::complex<T>> {
HOSTDEVICE phi::dtype::complex<T> operator()(
phi::dtype::complex<T> x,
phi::dtype::complex<T> y,
phi::dtype::complex<T> out,
phi::dtype::complex<T> dout) const {
phi::dtype::complex<T> y_conj(y.real, -y.imag);
return dout * y_conj;
}
};
/*
******************************
Multiply Grad
******************************
*/
template <typename T>
struct MulGradDY {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * x; }
};
template <typename T>
struct MulGradDY<phi::dtype::complex<T>> {
HOSTDEVICE phi::dtype::complex<T> operator()(
phi::dtype::complex<T> x,
phi::dtype::complex<T> y,
phi::dtype::complex<T> out,
phi::dtype::complex<T> dout) const {
phi::dtype::complex<T> x_conj(x.real, -x.imag);
return dout * x_conj;
}
};
template <typename T, typename Context>
void MultiplyDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
paddle::optional<const DenseTensor&> ddx,
paddle::optional<const DenseTensor&> ddy,
int axis,
DenseTensor* dx,
DenseTensor* dy,
DenseTensor* ddout) {
if (ddout) dev_ctx.template Alloc<T>(ddout);
DenseTensor ddx_safe, ddy_safe;
funcs::GetDoubleGradSafeTensor<Context, T>(
dev_ctx, x, ddx.get_ptr(), &ddx_safe);
funcs::GetDoubleGradSafeTensor<Context, T>(
dev_ctx, y, ddy.get_ptr(), &ddy_safe);
// dx = dout * ddy
// dy = dout * ddx
// ddout = ddx * y + x * ddy
// change computation sequence to save memory, so ddout can inplace ddx and
// dx can be used as 'tmp' tensor
// (1) dx = x * ddy
// (2) dy = dout * ddx
// (3) ddout = ddx * y
// (4) ddout = ddout + dx
// (5) dx = dout * ddy
if (ddout) {
auto& place = *dev_ctx.eigen_device();
// size(ddout) > size(ddx), ddout can't use memory of ddx using inplace
if (ddout->numel() > ddx.get_ptr()->numel()) {
phi::funcs::ElemwiseGradCompute<Context, T, MulGradDX<T>, MulGradDY<T>>(
dev_ctx,
ddx_safe,
ddy_safe,
dout,
dout,
axis,
dx,
dy,
MulGradDX<T>(),
MulGradDY<T>());
DenseTensor ddout_tmp;
ddout_tmp.Resize(ddout->dims());
dev_ctx.template Alloc<T>(&ddout_tmp);
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, y, ddx_safe, ddout, axis);
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, ddy_safe, x, &ddout_tmp, axis);
auto ddout_t = phi::EigenVector<T>::Flatten(*ddout);
auto ddout_tmp_t = phi::EigenVector<T>::Flatten(ddout_tmp);
ddout_t.device(place) = ddout_t + ddout_tmp_t;
} else {
// use dx to save memory, other than alloc tmp tensor
DenseTensor* ddout_tmp = dx;
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, x, ddy_safe, ddout_tmp, axis);
// NOTE: in the following ElemwiseGradCompute, for the
// first output tensor is nullptr, the branch to calculate first
// output tensor will not be activated, DivGradDx function will not
// be called and can be ignored, the first branch has little effect
// on running speed.
phi::funcs::ElemwiseGradCompute<Context, T, MulGradDX<T>, MulGradDY<T>>(
dev_ctx,
ddx_safe,
ddy_safe,
dout,
dout,
axis,
nullptr,
dy,
MulGradDX<T>(),
MulGradDY<T>());
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, ddx_safe, y, ddout, axis);
auto ddout_t = phi::EigenVector<T>::Flatten(*ddout);
auto ddout_tmp_t = phi::EigenVector<T>::Flatten(*ddout_tmp);
ddout_t.device(place) = ddout_t + ddout_tmp_t;
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, dout, ddy_safe, dx, axis);
}
}
}
template <typename T, typename Context>
void MultiplyTripleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
paddle::optional<const DenseTensor&> ddx,
paddle::optional<const DenseTensor&> ddy,
const DenseTensor& d_dx,
const DenseTensor& d_dy,
paddle::optional<const DenseTensor&> d_ddout,
int axis,
DenseTensor* d_x,
DenseTensor* d_y,
DenseTensor* d_dout,
DenseTensor* d_ddx,
DenseTensor* d_ddy) {
if (d_x) {
d_x->Resize(x.dims());
dev_ctx.template Alloc<T>(d_x);
}
if (d_y) {
d_y->Resize(y.dims());
dev_ctx.template Alloc<T>(d_y);
}
if (d_dout) {
d_dout->Resize(dout.dims());
dev_ctx.template Alloc<T>(d_dout);
}
if (d_ddx) {
d_ddx->Resize(x.dims());
dev_ctx.template Alloc<T>(d_ddx);
}
if (d_ddy) {
d_ddy->Resize(y.dims());
dev_ctx.template Alloc<T>(d_ddy);
}
auto& place = *dev_ctx.eigen_device();
DenseTensor ddx_safe, ddy_safe;
funcs::GetDoubleGradSafeTensor<Context, T>(
dev_ctx, x, ddx.get_ptr(), &ddx_safe);
funcs::GetDoubleGradSafeTensor<Context, T>(
dev_ctx, y, ddy.get_ptr(), &ddy_safe);
if (d_ddout.get_ptr()) {
if (d_x) {
// d_x = ddy * d_ddout
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, ddy_safe, *(d_ddout.get_ptr()), d_x, axis);
}
if (d_y) {
// d_y = ddx * d_ddout
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, ddx_safe, *(d_ddout.get_ptr()), d_y, axis);
}
}
if (d_dout) {
// get d_dout
// d_dout = ddy * d_dx + d_dy * ddx
DenseTensor d_dout_tmp;
d_dout_tmp.Resize(dout.dims());
dev_ctx.template Alloc<T>(&d_dout_tmp);
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, d_dy, ddx_safe, d_dout, axis);
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, ddy_safe, d_dx, &d_dout_tmp, axis);
auto d_dout_t = phi::EigenVector<T>::Flatten(*d_dout);
auto d_dout_tmp_t = phi::EigenVector<T>::Flatten(d_dout_tmp);
d_dout_t.device(place) = d_dout_t + d_dout_tmp_t;
}
if (d_ddx) {
// get d_ddx
// d_ddx = dout * d_dy + y * d_ddout
DenseTensor d_ddx_tmp;
d_ddx_tmp.Resize(ddx->dims());
dev_ctx.template Alloc<T>(&d_ddx_tmp);
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, dout, d_dy, d_ddx, axis);
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, y, *(d_ddout.get_ptr()), &d_ddx_tmp, axis);
auto d_ddx_t = phi::EigenVector<T>::Flatten(*d_ddx);
auto d_ddx_tmp_t = phi::EigenVector<T>::Flatten(d_ddx_tmp);
d_ddx_t.device(place) = d_ddx_t + d_ddx_tmp_t;
}
if (d_ddy) {
// get d_ddy
// d_ddy = dout * d_dx + x * d_ddout
DenseTensor d_ddy_tmp;
d_ddy_tmp.Resize(ddy->dims());
dev_ctx.template Alloc<T>(&d_ddy_tmp);
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, dout, d_dx, d_ddy, axis);
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, x, *(d_ddout.get_ptr()), &d_ddy_tmp, axis);
auto d_ddy_t = phi::EigenVector<T>::Flatten(*d_ddy);
auto d_ddy_tmp_t = phi::EigenVector<T>::Flatten(d_ddy_tmp);
d_ddy_t.device(place) = d_ddy_t + d_ddy_tmp_t;
}
}
} // namespace phi
......@@ -122,6 +122,31 @@ KernelSignature ElementwiseDivDoubleGradOpArgumentMapping(
{GradVarName("Y"), "DOut", "DDOut"});
}
KernelSignature ElementwiseMulGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("multiply_grad",
{"X", "Y", GradVarName("Out")},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
}
KernelSignature ElementwiseMulDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("multiply_double_grad",
{"X", "Y", "DOut", "DDX", "DDY"},
{"axis"},
{GradVarName("X"), GradVarName("Y"), "DDOut"});
}
KernelSignature ElementwiseMulTripleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"multiply_triple_grad",
{"X", "Y", "DOut", "DDX", "DDY", "D_DX", "D_DY", "D_DDOut"},
{"axis"},
{"D_X", "D_Y", "D_DOut", "D_DDX", "D_DDY"});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add, add);
......@@ -135,6 +160,9 @@ PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad, subtract_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad_grad, subtract_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_div_grad, divide_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_div_grad_grad, divide_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad, multiply_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad_grad, multiply_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_triple_grad, multiply_triple_grad);
PD_REGISTER_ARG_MAPPING_FN(elementwise_add,
phi::ElementwiseAddOpArgumentMapping);
......@@ -158,3 +186,9 @@ PD_REGISTER_ARG_MAPPING_FN(elementwise_div_grad,
phi::ElementwiseDivGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_div_grad_grad,
phi::ElementwiseDivDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_grad,
phi::ElementwiseMulGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_grad_grad,
phi::ElementwiseMulDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_triple_grad,
phi::ElementwiseMulTripleGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册