提交 ec101919 编写于 作者: F ForFishes

fix the memory copy

上级 e239302e
......@@ -44,11 +44,7 @@ template <typename DeviceContext, typename T>
void ReduceSumForMatmulGrad(const Tensor* input, Tensor* output,
const std::vector<int>& reduce_dims,
const paddle::framework::ExecutionContext& ctx) {
if (reduce_dims.empty()) {
// FIXME maybe reduce this copy operation
framework::TensorCopySync(*input, ctx.GetPlace(), output);
return;
}
if (reduce_dims.empty()) return;
#ifdef __NVCC__
auto stream = ctx.cuda_device_context().stream();
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
......@@ -577,48 +573,47 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
// So we should avoid the case in reality.
VLOG(3) << "It need cost much time to reduce sum for the broadcast and "
"wastes the memory. So we should avoid the case in reality";
Tensor dx_help, dy_help;
if (transpose_x) {
if (transpose_y) {
// X'Y': dA = Y'G', dB = G'X'
if (dx)
MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims,
&dx_help, true, true, ctx);
MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims, dx,
true, true, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims,
&dy_help, true, true, ctx);
MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims, dy,
true, true, ctx);
} else {
// X'Y: dX = YG', dY = XG
if (dx)
MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims,
&dx_help, false, true, ctx);
MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims, dx,
false, true, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims,
&dy_help, false, false, ctx);
MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims, dy,
false, false, ctx);
}
} else {
if (transpose_y) {
// XY': dX = GY, dY = G'X
if (dx)
MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims,
&dx_help, false, false, ctx);
MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims, dx,
false, false, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims,
&dy_help, true, false, ctx);
MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims, dy,
true, false, ctx);
} else {
// XY: dX = GY', dY = X'G
if (dx)
MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims,
&dx_help, false, true, ctx);
MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims, dx,
false, true, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims,
&dy_help, true, false, ctx);
MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims, dy,
true, false, ctx);
}
}
// get help dims
const std::vector<std::int64_t> dx_help_dims = vectorize(dx_help.dims());
const std::vector<std::int64_t> dy_help_dims = vectorize(dy_help.dims());
const std::vector<std::int64_t> dx_help_dims = vectorize(dx->dims());
const std::vector<std::int64_t> dy_help_dims = vectorize(dy->dims());
std::vector<std::int64_t> dx_broadcast_dims(ndim);
std::vector<std::int64_t> dy_broadcast_dims(ndim);
......@@ -642,18 +637,13 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
dy_reduce_dims.push_back(idx);
}
}
// reduce sum to get grad by ReduceSum
if (dx) {
dx->Resize(dx_help.dims());
ReduceSumForMatmulGrad<DeviceContext, T>(&dx_help, dx, dx_reduce_dims,
ctx);
ReduceSumForMatmulGrad<DeviceContext, T>(dx, dx, dx_reduce_dims, ctx);
dx->Resize(x.dims());
}
if (dy) {
dy->Resize(dy_help.dims());
ReduceSumForMatmulGrad<DeviceContext, T>(&dy_help, dy, dy_reduce_dims,
ctx);
ReduceSumForMatmulGrad<DeviceContext, T>(dy, dy, dy_reduce_dims, ctx);
dy->Resize(y.dims());
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册