From ec101919307695deefad6d99dbda0610be856b0a Mon Sep 17 00:00:00 2001 From: ForFishes <2282912238@qq.com> Date: Sun, 27 Sep 2020 16:00:40 +0800 Subject: [PATCH] fix the memory copy --- paddle/fluid/operators/matmul_v2_op.h | 52 +++++++++++---------------- 1 file changed, 21 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index 70a4cdccf3c..c5ecec60a96 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -44,11 +44,7 @@ template void ReduceSumForMatmulGrad(const Tensor* input, Tensor* output, const std::vector& reduce_dims, const paddle::framework::ExecutionContext& ctx) { - if (reduce_dims.empty()) { - // FIXME maybe reduce this copy operation - framework::TensorCopySync(*input, ctx.GetPlace(), output); - return; - } + if (reduce_dims.empty()) return; #ifdef __NVCC__ auto stream = ctx.cuda_device_context().stream(); TensorReduce>( @@ -577,48 +573,47 @@ class MatMulV2GradKernel : public framework::OpKernel { // So we should avoid the case in reality. VLOG(3) << "It need cost much time to reduce sum for the broadcast and " "wastes the memory. So we should avoid the case in reality"; - Tensor dx_help, dy_help; if (transpose_x) { if (transpose_y) { // X'Y': dA = Y'G', dB = G'X' if (dx) - MatMulFunction(&y, &dout, y_dims, dout_dims, - &dx_help, true, true, ctx); + MatMulFunction(&y, &dout, y_dims, dout_dims, dx, + true, true, ctx); if (dy) - MatMulFunction(&dout, &x, dout_dims, x_dims, - &dy_help, true, true, ctx); + MatMulFunction(&dout, &x, dout_dims, x_dims, dy, + true, true, ctx); } else { // X'Y: dX = YG', dY = XG if (dx) - MatMulFunction(&y, &dout, y_dims, dout_dims, - &dx_help, false, true, ctx); + MatMulFunction(&y, &dout, y_dims, dout_dims, dx, + false, true, ctx); if (dy) - MatMulFunction(&x, &dout, x_dims, dout_dims, - &dy_help, false, false, ctx); + MatMulFunction(&x, &dout, x_dims, dout_dims, dy, + false, false, ctx); } } else { if (transpose_y) { // XY': dX = GY, dY = G'X if (dx) - MatMulFunction(&dout, &y, dout_dims, y_dims, - &dx_help, false, false, ctx); + MatMulFunction(&dout, &y, dout_dims, y_dims, dx, + false, false, ctx); if (dy) - MatMulFunction(&dout, &x, dout_dims, x_dims, - &dy_help, true, false, ctx); + MatMulFunction(&dout, &x, dout_dims, x_dims, dy, + true, false, ctx); } else { // XY: dX = GY', dY = X'G if (dx) - MatMulFunction(&dout, &y, dout_dims, y_dims, - &dx_help, false, true, ctx); + MatMulFunction(&dout, &y, dout_dims, y_dims, dx, + false, true, ctx); if (dy) - MatMulFunction(&x, &dout, x_dims, dout_dims, - &dy_help, true, false, ctx); + MatMulFunction(&x, &dout, x_dims, dout_dims, dy, + true, false, ctx); } } // get help dims - const std::vector dx_help_dims = vectorize(dx_help.dims()); - const std::vector dy_help_dims = vectorize(dy_help.dims()); + const std::vector dx_help_dims = vectorize(dx->dims()); + const std::vector dy_help_dims = vectorize(dy->dims()); std::vector dx_broadcast_dims(ndim); std::vector dy_broadcast_dims(ndim); @@ -642,18 +637,13 @@ class MatMulV2GradKernel : public framework::OpKernel { dy_reduce_dims.push_back(idx); } } - // reduce sum to get grad by ReduceSum if (dx) { - dx->Resize(dx_help.dims()); - ReduceSumForMatmulGrad(&dx_help, dx, dx_reduce_dims, - ctx); + ReduceSumForMatmulGrad(dx, dx, dx_reduce_dims, ctx); dx->Resize(x.dims()); } if (dy) { - dy->Resize(dy_help.dims()); - ReduceSumForMatmulGrad(&dy_help, dy, dy_reduce_dims, - ctx); + ReduceSumForMatmulGrad(dy, dy, dy_reduce_dims, ctx); dy->Resize(y.dims()); } } -- GitLab