未验证 提交 36a102f8 编写于 作者: L Lijunhui 提交者: GitHub

optimize elementwise_mul_grad using new interfaces (#37728)

* init commit: new elem_mul_grad

* add template speciallization for complex in multiply

* reply review comments

* correct dx and dy computation when T is complex

* reply review comments

* update to new ReduceRunctor

* mul-output broadcast

* call functions

* call functions with comments

* remove comments
上级 905c8022
...@@ -194,5 +194,47 @@ struct FMinFunctor<paddle::platform::float16> { ...@@ -194,5 +194,47 @@ struct FMinFunctor<paddle::platform::float16> {
} }
}; };
template <typename T>
struct MulGradFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a * b; }
};
template <typename T>
struct MulGradFunctor<Complex<T>> {
inline HOSTDEVICE Complex<T> operator()(const Complex<T>& a,
const Complex<T>& b) const {
Complex<T> b_conj(b.real, -b.imag);
return a * b_conj;
}
};
template <typename InT, typename OutT>
struct MulGradXYFunctor {
inline HOSTDEVICE paddle::framework::Array<OutT, 2> operator()(const InT& a,
const InT& b,
const InT& c) {
paddle::framework::Array<OutT, 2> outs;
// dx = dout * y
outs[0] = a * b;
// dy = dout * x
outs[1] = a * c;
return outs;
}
};
template <typename InT, typename OutT>
struct MulGradXYFunctor<Complex<InT>, Complex<OutT>> {
inline HOSTDEVICE paddle::framework::Array<Complex<OutT>, 2> operator()(
const Complex<InT>& a, const Complex<InT>& b, const Complex<InT>& c) {
paddle::framework::Array<Complex<OutT>, 2> outs;
// dx = dout * y
Complex<InT> b_conj(b.real, -b.imag);
outs[0] = a * b_conj;
// dy = dout * x
Complex<InT> c_conj(c.real, -c.imag);
outs[1] = a * c_conj;
return outs;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -68,71 +69,43 @@ class ElementwiseMulKernel<platform::CUDADeviceContext, T> ...@@ -68,71 +69,43 @@ class ElementwiseMulKernel<platform::CUDADeviceContext, T>
} }
}; };
template <typename T> template <typename DeviceContext, typename T>
static __global__ void SimpleElemwiseMulGradCUDAKernel(const T* x, const T* y, typename std::enable_if<
const T* out, std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
const T* dout, ElementwiseMulGrad(const framework::ExecutionContext& ctx,
int64_t size, T* dx, const framework::Tensor* x, const framework::Tensor* y,
T* dy) { const framework::Tensor* out, const framework::Tensor* dout,
int col = blockIdx.x * blockDim.x + threadIdx.x; framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
const auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
const auto place = ctx.GetPlace();
while (col < size) { if (dx != nullptr && dy != nullptr) {
T o = dout[col]; dx->mutable_data<T>(place);
dx[col] = y[col] * o; if (dx->IsSharedBufferWith(*dout)) {
dy[col] = x[col] * o; dx->clear();
col += blockDim.x * gridDim.x; dx->mutable_data<T>(x->dims(), place);
} }
} std::vector<const framework::Tensor*> ins = {dout, y, x};
GetGradXAndYOut<ElementwiseType::kBinary, T>(
template <> dev_ctx, place, axis, ins, dout, dx, dy, MulGradXYFunctor<T, T>());
__global__ void SimpleElemwiseMulGradCUDAKernel<plat::complex<float>>( } else if (dx != nullptr && dy == nullptr) {
const plat::complex<float>* x, const plat::complex<float>* y, dx->mutable_data<T>(place);
const plat::complex<float>* out, const plat::complex<float>* dout, if (dx->IsSharedBufferWith(*dout)) {
int64_t size, plat::complex<float>* dx, plat::complex<float>* dy) { dx->clear();
int col = blockIdx.x * blockDim.x + threadIdx.x; dx->mutable_data<T>(x->dims(), place);
while (col < size) {
plat::complex<float> o = dout[col];
dx[col] = plat::complex<float>(y[col].real, -y[col].imag) * o;
dy[col] = plat::complex<float>(x[col].real, -x[col].imag) * o;
col += blockDim.x * gridDim.x;
} }
} std::vector<const framework::Tensor*> ins = {dout, y};
GetGradXOrYOut<ElementwiseType::kBinary, T>(dev_ctx, place, axis, ins, dout,
template <> dx, MulGradFunctor<T>());
__global__ void SimpleElemwiseMulGradCUDAKernel<plat::complex<double>>( } else if (dx == nullptr && dy != nullptr) {
const plat::complex<double>* x, const plat::complex<double>* y, std::vector<const framework::Tensor*> ins = {dout, x};
const plat::complex<double>* out, const plat::complex<double>* dout, GetGradXOrYOut<ElementwiseType::kBinary, T>(dev_ctx, place, axis, ins, dout,
int64_t size, plat::complex<double>* dx, plat::complex<double>* dy) { dy, MulGradFunctor<T>());
int col = blockIdx.x * blockDim.x + threadIdx.x;
while (col < size) {
plat::complex<double> o = dout[col];
dx[col] = plat::complex<double>(y[col].real, -y[col].imag) * o;
dy[col] = plat::complex<double>(x[col].real, -x[col].imag) * o;
col += blockDim.x * gridDim.x;
} }
} }
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, plat::CUDADeviceContext>::value>::type
elementwise_mul_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout, framework::Tensor* dx,
framework::Tensor* dy) {
dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1);
auto size = x->numel();
dim3 grid_size =
dim3((size + ELEMENTWISE_BLOCK_SIZE - 1) / ELEMENTWISE_BLOCK_SIZE, 1);
SimpleElemwiseMulGradCUDAKernel<
T><<<grid_size, block_size, 0,
ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
x->data<T>(), y->data<T>(), out->data<T>(), dout->data<T>(), size,
dx->mutable_data<T>(ctx.GetPlace()), dy->mutable_data<T>(ctx.GetPlace()));
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -174,26 +174,23 @@ struct MulGradDY<paddle::platform::complex<T>> { ...@@ -174,26 +174,23 @@ struct MulGradDY<paddle::platform::complex<T>> {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
typename std::enable_if< typename std::enable_if<
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_mul_grad(const framework::ExecutionContext& ctx, ElementwiseMulGrad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y, const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* out, const framework::Tensor* dout,
const framework::Tensor* dout, framework::Tensor* dx, framework::Tensor* dx, framework::Tensor* dy) {
framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, MulGradDX<T>, MulGradDY<T>>( ElemwiseGradCompute<DeviceContext, T, MulGradDX<T>, MulGradDY<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, MulGradDX<T>(), MulGradDY<T>()); ctx, *x, *y, *out, *dout, axis, dx, dy, MulGradDX<T>(), MulGradDY<T>());
} }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// cuda definition
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
typename std::enable_if< typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
elementwise_mul_grad(const framework::ExecutionContext& ctx, ElementwiseMulGrad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y, const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* out, const framework::Tensor* dout,
const framework::Tensor* dout, framework::Tensor* dx, framework::Tensor* dx, framework::Tensor* dy);
framework::Tensor* dy);
#endif #endif
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -209,14 +206,8 @@ class ElementwiseMulGradKernel : public ElemwiseGradKernel<T> { ...@@ -209,14 +206,8 @@ class ElementwiseMulGradKernel : public ElemwiseGradKernel<T> {
auto* out = dout; // out is not necessary auto* out = dout; // out is not necessary
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { ElementwiseMulGrad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
elementwise_mul_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
} else {
ElemwiseGradCompute<DeviceContext, T, MulGradDX<T>, MulGradDY<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, MulGradDX<T>(),
MulGradDY<T>());
}
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册