From 560b4323495a0de9fc36a77a6c2f99d20a21d68f Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Tue, 8 Dec 2020 17:23:45 +0800 Subject: [PATCH] Revert "improve elementwise_add_grad perf (#29277)" (#29464) This reverts commit befd6d53383b160cac492a92f9358fd59f0861c7. --- .../elementwise/elementwise_add_op.cu | 309 +----------------- .../elementwise/elementwise_add_op.h | 60 ++-- 2 files changed, 37 insertions(+), 332 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cu b/paddle/fluid/operators/elementwise/elementwise_add_op.cu index e460a96cbf..8de6416065 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cu @@ -11,16 +11,12 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include #include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex64.h" #include "paddle/fluid/platform/float16.h" -#define WARPSIZE 32 - namespace ops = paddle::operators; namespace plat = paddle::platform; @@ -78,10 +74,11 @@ static __global__ void SimpleElemwiseAddGradCUDAKernel(const T* dout, template typename std::enable_if< std::is_same::value>::type -ElementwiseAddGrad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - const framework::Tensor* out, const framework::Tensor* dout, - framework::Tensor* dx, framework::Tensor* dy) { +elementwise_add_grad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + const framework::Tensor* out, + const framework::Tensor* dout, framework::Tensor* dx, + framework::Tensor* dy) { dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); auto size = x->numel(); dim3 grid_size = @@ -93,302 +90,6 @@ ElementwiseAddGrad(const framework::ExecutionContext& ctx, dy->mutable_data(ctx.GetPlace())); } -inline static bool UseReduceFirstAxisRank1(const framework::DDim& dout_dims, - const framework::DDim& x_dims, - const framework::DDim& y_dims, - const int axis) { - int start_axis = - (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); - - if (y_dims[y_dims.size() - 1] == 1) { - return false; - } - - if (y_dims.size() > 1) { - for (int i = 0; i < y_dims.size() - 1; ++i) { - if (y_dims[i] != 1) { - return false; - } - } - return true; - } else if (start_axis == x_dims.size() - 1) { - return true; - } - return false; -} - -inline static bool UseReduceFirstAxisRank2(const framework::DDim& dout_dims, - const framework::DDim& x_dims, - const framework::DDim& y_dims, - const int axis) { - int start_axis = - (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); - - if (y_dims.size() < 2 || - x_dims[x_dims.size() - 2] != y_dims[y_dims.size() - 2] || - x_dims[x_dims.size() - 1] != y_dims[y_dims.size() - 1]) { - return false; - } - - if (start_axis == x_dims.size() - 2) { - return true; - } else if (start_axis == 0) { - for (int i = 0; i < y_dims.size() - 2; ++i) { - if (y_dims[i] != 1) { - return false; - } - } - return true; - } - return false; -} - -inline static bool UseReduceSecondAxisRank2(const framework::DDim& dout_dims, - const framework::DDim& x_dims, - const framework::DDim& y_dims, - const int axis, int* start, - int* end) { - if (x_dims.size() != y_dims.size() || y_dims.size() < 3) { - return false; - } - - auto y_dims_vec = framework::vectorize(y_dims); - auto start_iter = std::find(y_dims_vec.begin(), y_dims_vec.end(), 1); - auto end_iter = std::find(y_dims_vec.rbegin(), y_dims_vec.rend(), 1); - if (start_iter == y_dims_vec.end() || start_iter == y_dims_vec.end() - 1) { - return false; - } else { - *start = std::distance(y_dims_vec.begin(), start_iter); - *end = y_dims_vec.size() - 1 - std::distance(y_dims_vec.rbegin(), end_iter); - for (int i = *start; i <= *end; ++i) { - if (y_dims[i] != 1) { - return false; - } - } - return true; - } -} - -template -__global__ __launch_bounds__(1024) void ReduceFirstAixsKernel( - const T* in, T* out, const int64_t num_rows, const int64_t num_cols, OP op, - T init) { - int row = blockIdx.y * blockDim.y + threadIdx.y; - int col = blockIdx.x * blockDim.x + threadIdx.x; - - T sum = init; - if (row < num_rows && col < num_cols) sum = in[row * num_cols + col]; - - __shared__ __align__( - alignof(T)) char partial_sums_raw[WARPSIZE * (WARPSIZE + 1) * sizeof(T)]; - T* partial_sums = reinterpret_cast(partial_sums_raw); - - row += gridDim.y * blockDim.y; - - if (col < num_cols) { - for (; row < num_rows; row += gridDim.y * blockDim.y) { - sum = op(sum, in[row * num_cols + col]); - } - } - - partial_sums[threadIdx.x * (WARPSIZE + 1) + threadIdx.y] = sum; - - __syncthreads(); - - if (threadIdx.y == 0 && col < num_cols) { - T s = partial_sums[threadIdx.x * (WARPSIZE + 1)]; - - const int numRowsThisBlock = min(static_cast(blockDim.y), - num_rows - blockIdx.y * blockDim.y); - - for (int row = 1; row < numRowsThisBlock; ++row) { - T t = partial_sums[threadIdx.x * (WARPSIZE + 1) + row]; - s = op(s, t); - } - - out[col * gridDim.y + blockIdx.y] = s; - } -} - -template -static void ElemwiseYGradRank1CUDA(const framework::ExecutionContext& ctx, - const framework::Tensor& dout, - const int rows, const int cols, - framework::Tensor* dx, - framework::Tensor* dy) { - dim3 block_dim(WARPSIZE, std::min(rows, 1024 / WARPSIZE)); - dim3 grid_dim((cols + (WARPSIZE - 1)) / WARPSIZE, 1, 1); - - if (dx) { - dx->mutable_data(ctx.GetPlace()); - framework::TensorCopy( - dout, ctx.GetPlace(), - ctx.template device_context(), dx); - } - if (dy) { - dy->mutable_data(ctx.GetPlace()); - const T* dout_data = dout.data(); - T* dy_data = dy->data(); - auto stream = ctx.template device_context().stream(); - ReduceFirstAixsKernel<<>>( - dout_data, dy_data, rows, cols, AddFunctor(), static_cast(0)); - } -} - -template -__global__ __launch_bounds__(1024) void ReduceFirstOrSecondAxisKernel( - const T* in, T* out, const int num_planes, const int num_rows, - const int num_cols, OP op, T init) { - const int gid = threadIdx.x + blockIdx.x * blockDim.x; - const int elems_per_plane = num_rows * num_cols; - - const int plane = gid / num_cols; - const int col = gid % num_cols; - - if (plane >= num_planes) return; - - if (num_rows == 1) { - out[plane * elems_per_plane + col] = in[plane * elems_per_plane + col]; - return; - } - - T sum = op(in[plane * elems_per_plane + col], - in[plane * elems_per_plane + num_cols + col]); - for (int row = 2; row < num_rows; ++row) { - sum = op(sum, in[plane * elems_per_plane + row * num_cols + col]); - } - - out[plane * num_cols + col] = sum; -} - -template -static void ElemwiseYGradRank2CUDA(const framework::ExecutionContext& ctx, - const framework::Tensor& dout, - const int planes, const int rows, - const int cols, framework::Tensor* dx, - framework::Tensor* dy) { - int num_threads = 128; - int num_blocks = (rows + num_threads - 1) / num_threads; - - if (planes != 1) { - num_blocks = (planes * cols + num_threads - 1) / num_threads; - } - - if (dx) { - dx->mutable_data(ctx.GetPlace()); - framework::TensorCopy( - dout, ctx.GetPlace(), - ctx.template device_context(), dx); - } - if (dy) { - dy->mutable_data(ctx.GetPlace()); - const T* dout_data = dout.data(); - T* dy_data = dy->data(); - auto stream = ctx.template device_context().stream(); - ReduceFirstOrSecondAxisKernel<<>>( - dout_data, dy_data, planes, rows, cols, AddFunctor(), - static_cast(0)); - } -} - -template -static bool ElemwiseGradUseReduce(const framework::ExecutionContext& ctx, - const int axis, const framework::DDim x_dims, - const framework::DDim y_dims, - const framework::Tensor& dout, - framework::Tensor* dx, - framework::Tensor* dy) { - int start = 0; - int end = 0; - auto x_dims_vec = framework::vectorize(x_dims); - if (UseReduceFirstAxisRank1(dout.dims(), x_dims, y_dims, axis)) { - int rows = std::accumulate(x_dims_vec.begin(), x_dims_vec.end() - 1, 1, - std::multiplies()); - int cols = dx->dims()[dx->dims().size() - 1]; - if (cols > 512 && cols < 4096) { - ElemwiseYGradRank1CUDA(ctx, dout, rows, cols, dx, dy); - return true; - } - } - - if (UseReduceFirstAxisRank2(dout.dims(), x_dims, y_dims, axis)) { - int rows = std::accumulate(x_dims_vec.begin(), x_dims_vec.end() - 2, 1, - std::multiplies()); - int cols = - dx->dims()[dx->dims().size() - 1] * dx->dims()[dx->dims().size() - 2]; - if (cols > 4096) { - ElemwiseYGradRank2CUDA(ctx, dout, 1, rows, cols, dx, - dy); - return true; - } - } - - if (UseReduceSecondAxisRank2(dout.dims(), x_dims, y_dims, axis, &start, - &end)) { - int planes = std::accumulate(x_dims_vec.begin(), x_dims_vec.begin() + start, - 1, std::multiplies()); - int rows = std::accumulate(x_dims_vec.begin() + start, - x_dims_vec.begin() + end + 1, 1, - std::multiplies()); - int cols = std::accumulate(x_dims_vec.begin() + end + 1, x_dims_vec.end(), - 1, std::multiplies()); - if (rows / (planes * cols) < 16) { - ElemwiseYGradRank2CUDA(ctx, dout, planes, rows, cols, - dx, dy); - return true; - } - } - - return false; -} - -template -class ElementwiseAddGradKernel - : public ElemwiseGradKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - ElemwiseGradKernel::Compute(ctx); - - using Tensor = framework::Tensor; - - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* dout = ctx.Input(framework::GradVarName("Out")); - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dy = ctx.Output(framework::GradVarName("Y")); - // skip out - auto* out = dout; - int axis = ctx.Attr("axis"); - - // Special case when dy is not needed and dx doesn't reduce - if (dx != nullptr && dy == nullptr && dx->dims() == dout->dims()) { - VLOG(4) << "Special case when dy is not needed and dx doesn't " - "reduce"; - framework::TensorCopy( - *dout, ctx.GetPlace(), - ctx.template device_context(), dx); - } else if (dx == nullptr && dy != nullptr && dy->dims() == dout->dims()) { - VLOG(4) << "Special case when dx is not needed and dy doesn't " - "reduce"; - framework::TensorCopy( - *dout, ctx.GetPlace(), - ctx.template device_context(), dy); - } else if (dx && dy && (dx->dims() == dy->dims())) { - ElementwiseAddGrad(ctx, x, y, out, dout, - dx, dy); - } else if (dx && dx->dims() == dout->dims() && - ElemwiseGradUseReduce( - ctx, axis, x->dims(), y->dims(), *dout, dx, dy)) { - } else if (dy && dy->dims() == dout->dims() && - ElemwiseGradUseReduce( - ctx, axis, x->dims(), y->dims(), *dout, dy, dx)) { - } else { - DefaultElementwiseAddGrad(ctx, x, y, out, - dout, dx, dy); - } - } -}; - } // namespace operators } // namespace paddle REGISTER_OP_CUDA_KERNEL( diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.h b/paddle/fluid/operators/elementwise/elementwise_add_op.h index 23223fc06d..acda31e0f2 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.h @@ -22,10 +22,9 @@ namespace paddle { namespace operators { template -void DefaultElementwiseAddGrad(const framework::ExecutionContext &ctx, - const framework::Tensor *x, - const framework::Tensor *y, - framework::Tensor *z) { +void default_elementwise_add(const framework::ExecutionContext &ctx, + const framework::Tensor *x, + const framework::Tensor *y, framework::Tensor *z) { int axis = ctx.Attr("axis"); auto x_dims = x->dims(); auto y_dims = y->dims(); @@ -58,7 +57,7 @@ class ElementwiseAddKernel : public framework::OpKernel { SameDimsElemwiseAdd same_dims_add; same_dims_add(ctx, x, y, z); } else { - DefaultElementwiseAddGrad(ctx, x, y, z); + default_elementwise_add(ctx, x, y, z); } } }; @@ -69,12 +68,13 @@ struct IdentityGrad { }; template -void DefaultElementwiseAddGrad(const framework::ExecutionContext &ctx, - const framework::Tensor *x, - const framework::Tensor *y, - const framework::Tensor *out, - const framework::Tensor *dout, - framework::Tensor *dx, framework::Tensor *dy) { +void default_elementwise_add_grad(const framework::ExecutionContext &ctx, + const framework::Tensor *x, + const framework::Tensor *y, + const framework::Tensor *out, + const framework::Tensor *dout, + framework::Tensor *dx, + framework::Tensor *dy) { int axis = ctx.Attr("axis"); ElemwiseExplicitGradCompute, @@ -87,10 +87,11 @@ template typename std::enable_if< std::is_floating_point::value && std::is_same::value>::type -ElementwiseAddGrad(const framework::ExecutionContext &ctx, - const framework::Tensor *x, const framework::Tensor *y, - const framework::Tensor *out, const framework::Tensor *dout, - framework::Tensor *dx, framework::Tensor *dy) { +elementwise_add_grad(const framework::ExecutionContext &ctx, + const framework::Tensor *x, const framework::Tensor *y, + const framework::Tensor *out, + const framework::Tensor *dout, framework::Tensor *dx, + framework::Tensor *dy) { auto blas = math::GetBlas(ctx); if (dx) { blas.VCOPY(dout->numel(), dout->data(), @@ -107,11 +108,12 @@ template typename std::enable_if< !std::is_floating_point::value && std::is_same::value>::type -ElementwiseAddGrad(const framework::ExecutionContext &ctx, - const framework::Tensor *x, const framework::Tensor *y, - const framework::Tensor *out, const framework::Tensor *dout, - framework::Tensor *dx, framework::Tensor *dy) { - DefaultElementwiseAddGrad(ctx, x, y, out, dout, dx, dy); +elementwise_add_grad(const framework::ExecutionContext &ctx, + const framework::Tensor *x, const framework::Tensor *y, + const framework::Tensor *out, + const framework::Tensor *dout, framework::Tensor *dx, + framework::Tensor *dy) { + default_elementwise_add_grad(ctx, x, y, out, dout, dx, dy); } #ifdef PADDLE_WITH_CUDA @@ -119,10 +121,11 @@ ElementwiseAddGrad(const framework::ExecutionContext &ctx, template typename std::enable_if< std::is_same::value>::type -ElementwiseAddGrad(const framework::ExecutionContext &ctx, - const framework::Tensor *x, const framework::Tensor *y, - const framework::Tensor *out, const framework::Tensor *dout, - framework::Tensor *dx, framework::Tensor *dy); +elementwise_add_grad(const framework::ExecutionContext &ctx, + const framework::Tensor *x, const framework::Tensor *y, + const framework::Tensor *out, + const framework::Tensor *dout, framework::Tensor *dx, + framework::Tensor *dy); #endif template @@ -155,9 +158,10 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel { *dout, ctx.GetPlace(), ctx.template device_context(), dy); } else if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { - ElementwiseAddGrad(ctx, x, y, out, dout, dx, dy); + elementwise_add_grad(ctx, x, y, out, dout, dx, dy); } else { - DefaultElementwiseAddGrad(ctx, x, y, out, dout, dx, dy); + default_elementwise_add_grad(ctx, x, y, out, dout, dx, + dy); } } }; @@ -182,8 +186,8 @@ class ElementwiseAddDoubleGradKernel : public framework::OpKernel { GetDoubleGradSafeTensor(ctx, y, ddy, &ddy_safe); ddout->mutable_data(ctx.GetPlace()); - DefaultElementwiseAddGrad(ctx, &ddx_safe, &ddy_safe, - ddout); + default_elementwise_add(ctx, &ddx_safe, &ddy_safe, + ddout); } } }; -- GitLab