未验证 提交 1efef8ba 编写于 作者: S ShenLiang 提交者: GitHub

Fix bug of matmul_v2 for broadcast case (#29599)

* fix bug of matmul_v2 for broadcast
上级 a9082082
...@@ -44,7 +44,6 @@ template <typename DeviceContext, typename T> ...@@ -44,7 +44,6 @@ template <typename DeviceContext, typename T>
void ReduceSumForMatmulGrad(const Tensor* input, Tensor* output, void ReduceSumForMatmulGrad(const Tensor* input, Tensor* output,
const std::vector<int>& reduce_dims, const std::vector<int>& reduce_dims,
const paddle::framework::ExecutionContext& ctx) { const paddle::framework::ExecutionContext& ctx) {
if (reduce_dims.empty()) return;
#ifdef __NVCC__ #ifdef __NVCC__
auto stream = ctx.cuda_device_context().stream(); auto stream = ctx.cuda_device_context().stream();
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>( TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
...@@ -602,47 +601,48 @@ class MatMulV2GradKernel : public framework::OpKernel<T> { ...@@ -602,47 +601,48 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
// So we should avoid the case in reality. // So we should avoid the case in reality.
VLOG(3) << "It need cost much time to reduce sum for the broadcast and " VLOG(3) << "It need cost much time to reduce sum for the broadcast and "
"wastes the memory. So we should avoid the case in reality"; "wastes the memory. So we should avoid the case in reality";
Tensor dx_help, dy_help;
if (transpose_x) { if (transpose_x) {
if (transpose_y) { if (transpose_y) {
// X'Y': dA = Y'G', dB = G'X' // X'Y': dA = Y'G', dB = G'X'
if (dx) if (dx)
MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims, dx, MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims,
true, true, ctx); &dx_help, true, true, ctx);
if (dy) if (dy)
MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims, dy, MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims,
true, true, ctx); &dy_help, true, true, ctx);
} else { } else {
// X'Y: dX = YG', dY = XG // X'Y: dX = YG', dY = XG
if (dx) if (dx)
MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims, dx, MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims,
false, true, ctx); &dx_help, false, true, ctx);
if (dy) if (dy)
MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims, dy, MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims,
false, false, ctx); &dy_help, false, false, ctx);
} }
} else { } else {
if (transpose_y) { if (transpose_y) {
// XY': dX = GY, dY = G'X // XY': dX = GY, dY = G'X
if (dx) if (dx)
MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims, dx, MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims,
false, false, ctx); &dx_help, false, false, ctx);
if (dy) if (dy)
MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims, dy, MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims,
true, false, ctx); &dy_help, true, false, ctx);
} else { } else {
// XY: dX = GY', dY = X'G // XY: dX = GY', dY = X'G
if (dx) if (dx)
MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims, dx, MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims,
false, true, ctx); &dx_help, false, true, ctx);
if (dy) if (dy)
MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims, dy, MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims,
true, false, ctx); &dy_help, true, false, ctx);
} }
} }
// get help dims // get help dims
const std::vector<std::int64_t> dx_help_dims = vectorize(dx->dims()); const std::vector<std::int64_t> dx_help_dims = vectorize(dx_help.dims());
const std::vector<std::int64_t> dy_help_dims = vectorize(dy->dims()); const std::vector<std::int64_t> dy_help_dims = vectorize(dy_help.dims());
std::vector<std::int64_t> dx_broadcast_dims(ndim); std::vector<std::int64_t> dx_broadcast_dims(ndim);
std::vector<std::int64_t> dy_broadcast_dims(ndim); std::vector<std::int64_t> dy_broadcast_dims(ndim);
...@@ -668,11 +668,21 @@ class MatMulV2GradKernel : public framework::OpKernel<T> { ...@@ -668,11 +668,21 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
} }
// reduce sum to get grad by ReduceSum // reduce sum to get grad by ReduceSum
if (dx) { if (dx) {
ReduceSumForMatmulGrad<DeviceContext, T>(dx, dx, dx_reduce_dims, ctx); if (dx_reduce_dims.empty()) {
*dx = std::move(dx_help);
} else {
ReduceSumForMatmulGrad<DeviceContext, T>(&dx_help, dx, dx_reduce_dims,
ctx);
}
dx->Resize(x.dims()); dx->Resize(x.dims());
} }
if (dy) { if (dy) {
ReduceSumForMatmulGrad<DeviceContext, T>(dy, dy, dy_reduce_dims, ctx); if (dy_reduce_dims.empty()) {
*dy = std::move(dy_help);
} else {
ReduceSumForMatmulGrad<DeviceContext, T>(&dy_help, dy, dy_reduce_dims,
ctx);
}
dy->Resize(y.dims()); dy->Resize(y.dims());
} }
} }
......
...@@ -286,6 +286,30 @@ class TestMatMuklOp17(TestMatMulV2Op): ...@@ -286,6 +286,30 @@ class TestMatMuklOp17(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOpBroadcast1(TestMatMulV2Op):
"""
case 14_3
"""
def config(self):
self.x_shape = (3, 1, 10, 10)
self.y_shape = (1, 2, 10, 10)
self.trans_x = True
self.trans_y = True
class TestMatMuklOpBroadcast2(TestMatMulV2Op):
"""
case 14_4
"""
def config(self):
self.x_shape = (3, 1, 10, 10)
self.y_shape = (1, 2, 10, 10)
self.trans_x = False
self.trans_y = True
#--------------------test matmul fp16-------------------- #--------------------test matmul fp16--------------------
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册