未验证 提交 567e6bbc 编写于 作者: C crystal 提交者: GitHub

implementation of broadcast sub backward by reduce (#37754)

* add boardcast_sub

* add boardcast_sub
上级 b4a67491
......@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
......@@ -30,12 +32,69 @@ static __global__ void SimpleElemwiseSubGradCUDAKernel(const T* dout,
int col = blockIdx.x * blockDim.x + threadIdx.x;
while (col < size) {
dx[col] = dout[col];
if (dx != nullptr) {
dx[col] = dout[col];
}
dy[col] = -dout[col];
col += blockDim.x * gridDim.x;
}
}
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
default_elementwise_sub_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
auto* dout_data = dout->data<T>();
// dx
if (dx != nullptr) {
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
if (dx->dims() == dout->dims()) {
if (dx_data != dout_data) {
framework::TensorCopy(
*dout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dx);
}
} else {
// For inplace strategy, dx will be stored in addr of dout, which makes
// the result of dy wrong.
if (dx->IsSharedBufferWith(*dout)) {
dx->clear();
dx->mutable_data<T>(x->dims(), ctx.GetPlace());
}
std::vector<int> reduce_dims = GetReduceDim(x->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, CustomSum>(*dout, dx, reduce_dims, stream);
}
}
// dy
if (dy != nullptr) {
auto* dy_data = dy->mutable_data<T>(ctx.GetPlace());
if (dy->dims() == dout->dims()) {
if (dy_data != dout_data) {
dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1);
auto size = dy->numel();
dim3 grid_size = dim3(
(size + ELEMENTWISE_BLOCK_SIZE - 1) / ELEMENTWISE_BLOCK_SIZE, 1);
SimpleElemwiseSubGradCUDAKernel<T><<<
grid_size, block_size, 0,
ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
dout->data<T>(), size, nullptr,
dy->mutable_data<T>(ctx.GetPlace()));
}
} else {
std::vector<int> reduce_dims = GetReduceDim(y->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, CustomSub>(*dout, dy, reduce_dims, stream);
}
}
}
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, plat::CUDADeviceContext>::value>::type
......
......@@ -71,6 +71,21 @@ struct SubGradDY {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return -dout; }
};
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
default_elementwise_sub_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
ElemwiseExplicitGradCompute<DeviceContext, T, SubGradDX<T>, SubGradDY<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, SubGradDX<T>(), SubGradDY<T>());
}
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
......@@ -79,13 +94,21 @@ elementwise_sub_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* out,
const framework::Tensor* dout, framework::Tensor* dx,
framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
ElemwiseExplicitGradCompute<DeviceContext, T, SubGradDX<T>, SubGradDY<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, SubGradDX<T>(), SubGradDY<T>());
default_elementwise_sub_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// cuda definition
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
default_elementwise_sub_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy);
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
......@@ -108,15 +131,13 @@ class ElementwiseSubGradKernel : public ElemwiseGradKernel<T> {
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
// skip out
auto* out = dout;
if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
elementwise_sub_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
} else {
ElemwiseExplicitGradCompute<DeviceContext, T, SubGradDX<T>, SubGradDY<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, SubGradDX<T>(),
SubGradDY<T>());
default_elementwise_sub_grad<DeviceContext, T>(ctx, x, y, out, dout, dx,
dy);
}
}
};
......
......@@ -86,6 +86,20 @@ struct DivideFunctor {
Tx n_inv;
};
/**
* @brief Default inverse functor
*/
template <typename Tx, typename Ty = Tx>
struct InverseFunctor {
HOSTDEVICE inline InverseFunctor() {}
HOSTDEVICE explicit inline InverseFunctor(int n) {}
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(-x);
}
};
/**
* @brief Default unary square functor
*/
......
......@@ -64,6 +64,17 @@ struct CustomSum {
}
};
template <typename Tx, typename Ty = Tx>
struct CustomSub {
using Transformer = kps::InverseFunctor<Tx>;
inline Ty initial() { return static_cast<Ty>(0.0f); }
__device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const {
return b + a;
}
};
template <typename Tx, typename Ty = Tx>
struct CustomMean {
using Transformer = kps::DivideFunctor<Tx>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册