未验证 提交 1efef8ba 编写于 作者: S ShenLiang 提交者: GitHub

Fix bug of matmul_v2 for broadcast case (#29599)

* fix bug of matmul_v2 for broadcast
上级 a9082082
......@@ -44,7 +44,6 @@ template <typename DeviceContext, typename T>
void ReduceSumForMatmulGrad(const Tensor* input, Tensor* output,
const std::vector<int>& reduce_dims,
const paddle::framework::ExecutionContext& ctx) {
if (reduce_dims.empty()) return;
#ifdef __NVCC__
auto stream = ctx.cuda_device_context().stream();
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
......@@ -602,47 +601,48 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
// So we should avoid the case in reality.
VLOG(3) << "It need cost much time to reduce sum for the broadcast and "
"wastes the memory. So we should avoid the case in reality";
Tensor dx_help, dy_help;
if (transpose_x) {
if (transpose_y) {
// X'Y': dA = Y'G', dB = G'X'
if (dx)
MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims, dx,
true, true, ctx);
MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims,
&dx_help, true, true, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims, dy,
true, true, ctx);
MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims,
&dy_help, true, true, ctx);
} else {
// X'Y: dX = YG', dY = XG
if (dx)
MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims, dx,
false, true, ctx);
MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims,
&dx_help, false, true, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims, dy,
false, false, ctx);
MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims,
&dy_help, false, false, ctx);
}
} else {
if (transpose_y) {
// XY': dX = GY, dY = G'X
if (dx)
MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims, dx,
false, false, ctx);
MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims,
&dx_help, false, false, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims, dy,
true, false, ctx);
MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims,
&dy_help, true, false, ctx);
} else {
// XY: dX = GY', dY = X'G
if (dx)
MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims, dx,
false, true, ctx);
MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims,
&dx_help, false, true, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims, dy,
true, false, ctx);
MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims,
&dy_help, true, false, ctx);
}
}
// get help dims
const std::vector<std::int64_t> dx_help_dims = vectorize(dx->dims());
const std::vector<std::int64_t> dy_help_dims = vectorize(dy->dims());
const std::vector<std::int64_t> dx_help_dims = vectorize(dx_help.dims());
const std::vector<std::int64_t> dy_help_dims = vectorize(dy_help.dims());
std::vector<std::int64_t> dx_broadcast_dims(ndim);
std::vector<std::int64_t> dy_broadcast_dims(ndim);
......@@ -668,11 +668,21 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
}
// reduce sum to get grad by ReduceSum
if (dx) {
ReduceSumForMatmulGrad<DeviceContext, T>(dx, dx, dx_reduce_dims, ctx);
if (dx_reduce_dims.empty()) {
*dx = std::move(dx_help);
} else {
ReduceSumForMatmulGrad<DeviceContext, T>(&dx_help, dx, dx_reduce_dims,
ctx);
}
dx->Resize(x.dims());
}
if (dy) {
ReduceSumForMatmulGrad<DeviceContext, T>(dy, dy, dy_reduce_dims, ctx);
if (dy_reduce_dims.empty()) {
*dy = std::move(dy_help);
} else {
ReduceSumForMatmulGrad<DeviceContext, T>(&dy_help, dy, dy_reduce_dims,
ctx);
}
dy->Resize(y.dims());
}
}
......
......@@ -286,6 +286,30 @@ class TestMatMuklOp17(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOpBroadcast1(TestMatMulV2Op):
"""
case 14_3
"""
def config(self):
self.x_shape = (3, 1, 10, 10)
self.y_shape = (1, 2, 10, 10)
self.trans_x = True
self.trans_y = True
class TestMatMuklOpBroadcast2(TestMatMulV2Op):
"""
case 14_4
"""
def config(self):
self.x_shape = (3, 1, 10, 10)
self.y_shape = (1, 2, 10, 10)
self.trans_x = False
self.trans_y = True
#--------------------test matmul fp16--------------------
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册