未验证 提交 55cd9cb8 编写于 作者: C crystal 提交者: GitHub

implementation of broadcast div backward by reduce (#38044)

* add elementwise div

* move mul and div grad functor

* Combine multiple CUDA kernels

* Update the reduce interface call

* add multi-output

* add multi-output div

* add branch judge

* Package branch

* Combine the x and y functions into one
上级 d1dc677a
......@@ -13,9 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
......@@ -23,83 +20,39 @@ namespace plat = paddle::platform;
namespace paddle {
namespace operators {
template <typename T>
static __global__ void SimpleElemwiseDivGradCUDAKernel(const T* x, const T* y,
const T* out,
const T* dout,
int64_t size, T* dx,
T* dy) {
int col = blockIdx.x * blockDim.x + threadIdx.x;
while (col < size) {
T o = dout[col];
dx[col] = o / y[col];
dy[col] = -o * out[col] / y[col];
col += blockDim.x * gridDim.x;
}
}
template <>
__global__ void
SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex<float>>(
const paddle::platform::complex<float>* x,
const paddle::platform::complex<float>* y,
const paddle::platform::complex<float>* out,
const paddle::platform::complex<float>* dout, int64_t size,
paddle::platform::complex<float>* dx,
paddle::platform::complex<float>* dy) {
int col = blockIdx.x * blockDim.x + threadIdx.x;
while (col < size) {
paddle::platform::complex<float> o = dout[col];
paddle::platform::complex<float> y_conj(y[col].real, -y[col].imag);
paddle::platform::complex<float> out_div_y_conj((out[col] / y[col]).real,
-(out[col] / y[col]).imag);
dx[col] = o / y_conj;
dy[col] = -o * out_div_y_conj;
col += blockDim.x * gridDim.x;
}
}
template <>
__global__ void
SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex<double>>(
const paddle::platform::complex<double>* x,
const paddle::platform::complex<double>* y,
const paddle::platform::complex<double>* out,
const paddle::platform::complex<double>* dout, int64_t size,
paddle::platform::complex<double>* dx,
paddle::platform::complex<double>* dy) {
int col = blockIdx.x * blockDim.x + threadIdx.x;
while (col < size) {
paddle::platform::complex<double> o = dout[col];
paddle::platform::complex<double> y_conj(y[col].real, -y[col].imag);
paddle::platform::complex<double> out_div_y_conj((out[col] / y[col]).real,
-(out[col] / y[col]).imag);
dx[col] = o / y_conj;
dy[col] = -o * out_div_y_conj;
col += blockDim.x * gridDim.x;
}
}
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, plat::CUDADeviceContext>::value>::type
elementwise_div_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout, framework::Tensor* dx,
framework::Tensor* dy) {
dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1);
auto size = x->numel();
dim3 grid_size =
dim3((size + ELEMENTWISE_BLOCK_SIZE - 1) / ELEMENTWISE_BLOCK_SIZE, 1);
SimpleElemwiseDivGradCUDAKernel<
T><<<grid_size, block_size, 0,
ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
x->data<T>(), y->data<T>(), out->data<T>(), dout->data<T>(), size,
dx->mutable_data<T>(ctx.GetPlace()), dy->mutable_data<T>(ctx.GetPlace()));
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
ElementwiseDivGrad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
const auto& dev_ctx = ctx.template device_context<DeviceContext>();
const auto place = ctx.GetPlace();
if (dx != nullptr && dy != nullptr) {
dx->mutable_data<T>(place);
if (dx->IsSharedBufferWith(*dout)) {
dx->clear();
dx->mutable_data<T>(x->dims(), place);
}
std::vector<const framework::Tensor*> ins = {dout, out, y};
GetGradXAndYOut<ElementwiseType::kTernary, T>(
dev_ctx, place, axis, ins, dout, dx, dy, DivGradXYFunctor<T, T>());
} else if (dx != nullptr && dy == nullptr) {
dx->mutable_data<T>(place);
if (dx->IsSharedBufferWith(*dout)) {
dx->clear();
dx->mutable_data<T>(x->dims(), place);
}
std::vector<const framework::Tensor*> ins = {dout, y};
GetGradXOrYOut<ElementwiseType::kBinary, T>(dev_ctx, place, axis, ins, dout,
dx, DivGradXFunctor<T>());
} else if (dy != nullptr && dx == nullptr) {
std::vector<const framework::Tensor*> ins = {dout, out, y};
GetGradXOrYOut<ElementwiseType::kTernary, T>(
dev_ctx, place, axis, ins, dout, dy, DivGradYFunctor<T>());
}
}
} // namespace operators
......
......@@ -111,26 +111,24 @@ struct DivDoubleDY {
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_div_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout, framework::Tensor* dx,
framework::Tensor* dy) {
ElementwiseDivGrad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, DivGradDX<T>, DivGradDY<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX<T>(), DivGradDY<T>());
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// cuda definition
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
elementwise_div_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout, framework::Tensor* dx,
framework::Tensor* dy);
ElementwiseDivGrad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy);
#endif
template <typename DeviceContext, typename T>
......@@ -146,15 +144,8 @@ class ElementwiseDivGradKernel : public ElemwiseGradKernel<T> {
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
elementwise_div_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
} else {
ElemwiseGradCompute<DeviceContext, T, DivGradDX<T>, DivGradDY<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX<T>(),
DivGradDY<T>());
}
ElementwiseDivGrad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
}
};
......
......@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/framework/array.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/hostdevice.h"
......@@ -87,6 +89,71 @@ struct MinFunctor {
}
};
template <typename T>
using Complex = paddle::platform::complex<T>;
template <typename InT, typename OutT>
struct DivGradXYFunctor {
inline HOSTDEVICE paddle::framework::Array<OutT, 2> operator()(const InT a,
const InT b,
const InT c) {
// dx = dout / y
// dy = - dout * out / y
paddle::framework::Array<OutT, 2> outs;
outs[0] = a / c;
outs[1] = -a * b / c;
return outs;
}
};
template <typename InT, typename OutT>
struct DivGradXYFunctor<Complex<InT>, Complex<OutT>> {
inline HOSTDEVICE paddle::framework::Array<Complex<OutT>, 2> operator()(
const Complex<InT> a, const Complex<InT> b, const Complex<InT> c) {
paddle::framework::Array<Complex<OutT>, 2> outs;
Complex<InT> c_conj(c.real, -c.imag);
Complex<InT> out_div_c_conj((b / c).real, -(b / c).imag);
outs[0] = a / c_conj;
outs[1] = -a * out_div_c_conj;
return outs;
}
};
// Float div grad
template <typename T>
struct DivGradXFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a / b; }
};
// Complex div grad
template <typename T>
struct DivGradXFunctor<Complex<T>> {
inline HOSTDEVICE Complex<T> operator()(const Complex<T>& a,
const Complex<T>& b) const {
Complex<T> b_conj(b.real, -b.imag);
return a / b_conj;
}
};
// Float mul and div
template <typename T>
struct DivGradYFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b, const T& c) const {
return -a * b / c;
}
};
// Complex mul and div
template <typename T>
struct DivGradYFunctor<Complex<T>> {
inline HOSTDEVICE Complex<T> operator()(const Complex<T>& a,
const Complex<T>& b,
const Complex<T>& c) const {
Complex<T> out_div_c_conj((b / c).real, -(b / c).imag);
return -a * out_div_c_conj;
}
};
// Fmax
template <typename T>
struct FMaxFunctor {
......
......@@ -42,6 +42,7 @@ limitations under the License. */
#include <thrust/iterator/iterator_adaptor.h>
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
......@@ -2556,5 +2557,77 @@ static inline std::vector<int> GetReduceDim(const framework::DDim &in,
}
return dims;
}
#if defined(__NVCC__) || defined(__HIPCC__)
template <typename T>
void ReduceWrapper(const platform::CUDADeviceContext &dev_ctx, int axis,
framework::Tensor *src, framework::Tensor *dst) {
std::vector<int> reduce_dims = GetReduceDim(dst->dims(), src->dims(), axis);
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
*src, dst, kps::IdentityFunctor<T>(), reduce_dims, dev_ctx.stream());
}
template <ElementwiseType ET, typename T, typename Functor>
void GetGradXAndYOut(const platform::CUDADeviceContext &dev_ctx,
const platform::Place &place, int axis,
std::vector<const framework::Tensor *> ins,
const framework::Tensor *dout, framework::Tensor *dx,
framework::Tensor *dy, Functor func) {
framework::Tensor tmp_dx;
framework::Tensor tmp_dy;
dy->mutable_data<T>(place);
std::vector<framework::Tensor *> outs;
if (dx->dims() == dout->dims() && dy->dims() == dout->dims()) {
outs = {dx, dy};
} else if (dx->dims() != dout->dims() && dy->dims() == dout->dims()) {
tmp_dx.mutable_data<T>(dout->dims(), place);
outs = {&tmp_dx, dy};
} else if (dx->dims() == dout->dims() && dy->dims() != dout->dims()) {
tmp_dy.mutable_data<T>(dout->dims(), place);
outs = {dx, &tmp_dy};
} else if (dx->dims() != dout->dims() && dy->dims() != dout->dims()) {
tmp_dy.mutable_data<T>(dout->dims(), place);
tmp_dx.mutable_data<T>(dout->dims(), place);
outs = {&tmp_dx, &tmp_dy};
}
LaunchElementwiseCudaKernel<ET, T, T, decltype(func), 2>(dev_ctx, ins, &outs,
axis, func);
if (dx->dims() != dout->dims() && dy->dims() == dout->dims()) {
ReduceWrapper<T>(dev_ctx, axis, &tmp_dx, dx);
} else if (dx->dims() == dout->dims() && dy->dims() != dout->dims()) {
ReduceWrapper<T>(dev_ctx, axis, &tmp_dy, dy);
} else if (dx->dims() != dout->dims() && dy->dims() != dout->dims()) {
ReduceWrapper<T>(dev_ctx, axis, &tmp_dx, dx);
ReduceWrapper<T>(dev_ctx, axis, &tmp_dy, dy);
}
}
template <ElementwiseType ET, typename T, typename Functor>
void GetGradXOrYOut(const platform::CUDADeviceContext &dev_ctx,
const platform::Place &place, int axis,
std::vector<const framework::Tensor *> ins,
const framework::Tensor *dout, framework::Tensor *dxy,
Functor func) {
framework::Tensor tmp_dxy;
dxy->mutable_data<T>(place);
std::vector<framework::Tensor *> outs;
if (dxy->dims() != dout->dims()) {
tmp_dxy.mutable_data<T>(dout->dims(), place);
outs = {&tmp_dxy};
} else {
outs = {dxy};
}
LaunchElementwiseCudaKernel<ET, T, T>(dev_ctx, ins, &outs, axis, func);
if (dxy->dims() != dout->dims()) {
ReduceWrapper<T>(dev_ctx, axis, &tmp_dxy, dxy);
}
}
#endif
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册