提交 ec101919 编写于 作者: F ForFishes

fix the memory copy

上级 e239302e
...@@ -44,11 +44,7 @@ template <typename DeviceContext, typename T> ...@@ -44,11 +44,7 @@ template <typename DeviceContext, typename T>
void ReduceSumForMatmulGrad(const Tensor* input, Tensor* output, void ReduceSumForMatmulGrad(const Tensor* input, Tensor* output,
const std::vector<int>& reduce_dims, const std::vector<int>& reduce_dims,
const paddle::framework::ExecutionContext& ctx) { const paddle::framework::ExecutionContext& ctx) {
if (reduce_dims.empty()) { if (reduce_dims.empty()) return;
// FIXME maybe reduce this copy operation
framework::TensorCopySync(*input, ctx.GetPlace(), output);
return;
}
#ifdef __NVCC__ #ifdef __NVCC__
auto stream = ctx.cuda_device_context().stream(); auto stream = ctx.cuda_device_context().stream();
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>( TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
...@@ -577,48 +573,47 @@ class MatMulV2GradKernel : public framework::OpKernel<T> { ...@@ -577,48 +573,47 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
// So we should avoid the case in reality. // So we should avoid the case in reality.
VLOG(3) << "It need cost much time to reduce sum for the broadcast and " VLOG(3) << "It need cost much time to reduce sum for the broadcast and "
"wastes the memory. So we should avoid the case in reality"; "wastes the memory. So we should avoid the case in reality";
Tensor dx_help, dy_help;
if (transpose_x) { if (transpose_x) {
if (transpose_y) { if (transpose_y) {
// X'Y': dA = Y'G', dB = G'X' // X'Y': dA = Y'G', dB = G'X'
if (dx) if (dx)
MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims, MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims, dx,
&dx_help, true, true, ctx); true, true, ctx);
if (dy) if (dy)
MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims, MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims, dy,
&dy_help, true, true, ctx); true, true, ctx);
} else { } else {
// X'Y: dX = YG', dY = XG // X'Y: dX = YG', dY = XG
if (dx) if (dx)
MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims, MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims, dx,
&dx_help, false, true, ctx); false, true, ctx);
if (dy) if (dy)
MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims, MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims, dy,
&dy_help, false, false, ctx); false, false, ctx);
} }
} else { } else {
if (transpose_y) { if (transpose_y) {
// XY': dX = GY, dY = G'X // XY': dX = GY, dY = G'X
if (dx) if (dx)
MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims, MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims, dx,
&dx_help, false, false, ctx); false, false, ctx);
if (dy) if (dy)
MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims, MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims, dy,
&dy_help, true, false, ctx); true, false, ctx);
} else { } else {
// XY: dX = GY', dY = X'G // XY: dX = GY', dY = X'G
if (dx) if (dx)
MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims, MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims, dx,
&dx_help, false, true, ctx); false, true, ctx);
if (dy) if (dy)
MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims, MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims, dy,
&dy_help, true, false, ctx); true, false, ctx);
} }
} }
// get help dims // get help dims
const std::vector<std::int64_t> dx_help_dims = vectorize(dx_help.dims()); const std::vector<std::int64_t> dx_help_dims = vectorize(dx->dims());
const std::vector<std::int64_t> dy_help_dims = vectorize(dy_help.dims()); const std::vector<std::int64_t> dy_help_dims = vectorize(dy->dims());
std::vector<std::int64_t> dx_broadcast_dims(ndim); std::vector<std::int64_t> dx_broadcast_dims(ndim);
std::vector<std::int64_t> dy_broadcast_dims(ndim); std::vector<std::int64_t> dy_broadcast_dims(ndim);
...@@ -642,18 +637,13 @@ class MatMulV2GradKernel : public framework::OpKernel<T> { ...@@ -642,18 +637,13 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
dy_reduce_dims.push_back(idx); dy_reduce_dims.push_back(idx);
} }
} }
// reduce sum to get grad by ReduceSum // reduce sum to get grad by ReduceSum
if (dx) { if (dx) {
dx->Resize(dx_help.dims()); ReduceSumForMatmulGrad<DeviceContext, T>(dx, dx, dx_reduce_dims, ctx);
ReduceSumForMatmulGrad<DeviceContext, T>(&dx_help, dx, dx_reduce_dims,
ctx);
dx->Resize(x.dims()); dx->Resize(x.dims());
} }
if (dy) { if (dy) {
dy->Resize(dy_help.dims()); ReduceSumForMatmulGrad<DeviceContext, T>(dy, dy, dy_reduce_dims, ctx);
ReduceSumForMatmulGrad<DeviceContext, T>(&dy_help, dy, dy_reduce_dims,
ctx);
dy->Resize(y.dims()); dy->Resize(y.dims());
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册