未验证 提交 1e8432f2 编写于 作者: J jakpiase 提交者: GitHub

[CHERRY-PICK] Fix for matmul_v2 oneDNN op broadcasting when inputs dims have...

[CHERRY-PICK] Fix for matmul_v2 oneDNN op broadcasting when inputs dims have different lengths (#38733)

* fix for matmul_v2 broadcasting

* resolved conflicts
上级 457fe72c
...@@ -152,7 +152,7 @@ class MatMulV2MKLDNNKernel ...@@ -152,7 +152,7 @@ class MatMulV2MKLDNNKernel
x_bd_dims[x_bd_dims.size() - 2] = x_dims[0]; x_bd_dims[x_bd_dims.size() - 2] = x_dims[0];
} else { } else {
for (size_t i = 0; i < x_dims.size(); ++i) { for (size_t i = 0; i < x_dims.size(); ++i) {
x_bd_dims[i] = x_dims[i]; x_bd_dims[x_bd_dims.size() - x_dims.size() + i] = x_dims[i];
} }
} }
if (y_dims.size() == 1) { if (y_dims.size() == 1) {
...@@ -162,20 +162,21 @@ class MatMulV2MKLDNNKernel ...@@ -162,20 +162,21 @@ class MatMulV2MKLDNNKernel
y_bd_dims[y_bd_dims.size() - 2] = y_dims[0]; y_bd_dims[y_bd_dims.size() - 2] = y_dims[0];
} else { } else {
for (size_t i = 0; i < y_dims.size(); ++i) { for (size_t i = 0; i < y_dims.size(); ++i) {
y_bd_dims[i] = y_dims[i]; y_bd_dims[y_bd_dims.size() - y_dims.size() + i] = y_dims[i];
} }
} }
if ((y_dims.size() == x_dims.size()) && y_dims.size() > 2) { if (x_dims.size() > 2 && y_dims.size() > 2) {
for (size_t i = 0; i < x_dims.size() - 2; ++i) { for (size_t i = 0; i < x_bd_dims.size() - 2; ++i) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_dims[i] == y_dims[i] || x_dims[i] == 1 || y_dims[i] == 1, true, x_bd_dims[i] == y_bd_dims[i] || x_bd_dims[i] == 1 ||
paddle::platform::errors::InvalidArgument( y_bd_dims[i] == 1,
"Tensor dimensions are incorrect for broadcasting." true, paddle::platform::errors::InvalidArgument(
"Dimensions in X and Y must be same or equal to 1, but " "Tensor dimensions are incorrect for broadcasting."
"received x_dim[%d]=%d and y_dims[%d]= %d", "Dimensions in X and Y must be same or equal to 1, but "
i, x_dims[i], i, y_dims[i])); "received x_dim[%d]=%d and y_dims[%d]= %d",
out_dims[i] = std::max(x_dims[i], y_dims[i]); i, x_bd_dims[i], i, y_bd_dims[i]));
out_dims[i] = std::max(x_bd_dims[i], y_bd_dims[i]);
} }
out->Resize(make_ddim(out_dims)); out->Resize(make_ddim(out_dims));
} }
...@@ -237,11 +238,11 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> { ...@@ -237,11 +238,11 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> {
dy_tmp->mutable_data<T>(ctx.GetPlace()); dy_tmp->mutable_data<T>(ctx.GetPlace());
} }
void ReduceSumForMatmulGradOutput(const ExecutionContext& ctx, void ReduceSumForMatmulGradOutput(
const MKLDNNDeviceContext& dev_ctx, const ExecutionContext& ctx, const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine onednn_engine, const dnnl::engine onednn_engine, const Tensor* dx_tmp, Tensor* dx,
const Tensor* dx_tmp, Tensor* dx, std::vector<int64_t>& dx_dims,
std::vector<int64_t> dx_dims) const { const std::vector<int64_t>& squeezed_dims) const {
paddle::platform::ReductionMKLDNNHandler<T> handler( paddle::platform::ReductionMKLDNNHandler<T> handler(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine, dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine,
ctx.GetPlace(), dx_tmp, dx, dx_dims); ctx.GetPlace(), dx_tmp, dx, dx_dims);
...@@ -257,6 +258,19 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> { ...@@ -257,6 +258,19 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> {
reduction_p->execute(astream, reduction_args); reduction_p->execute(astream, reduction_args);
astream.wait(); astream.wait();
dx->set_format(paddle::platform::GetMKLDNNFormat(
dst_memory_p->get_desc().reshape(squeezed_dims)));
}
std::vector<int64_t> ExtendDimsWithOnes(const std::vector<int64_t>& dims,
int new_size) const {
std::vector<int64_t> new_dims(new_size, 1);
for (size_t i = 0; i < dims.size(); ++i) {
new_dims[new_size - dims.size() + i] = dims[i];
}
return new_dims;
} }
void RunKernel(const ExecutionContext& ctx) const { void RunKernel(const ExecutionContext& ctx) const {
...@@ -295,8 +309,14 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> { ...@@ -295,8 +309,14 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> {
bool trans_y = ctx.Attr<bool>("trans_y"); bool trans_y = ctx.Attr<bool>("trans_y");
auto dout_dims = vectorize(dout->dims()); auto dout_dims = vectorize(dout->dims());
int ndims = std::max(x->dims().size(), y->dims().size()); size_t ndims = std::max(x->dims().size(), y->dims().size());
ndims = std::max(ndims, 3); ndims = std::max<size_t>(ndims, 3);
if (x_dims.size() != ndims) {
x_dims = ExtendDimsWithOnes(x_dims, ndims);
} else if (y_dims.size() != ndims) {
y_dims = ExtendDimsWithOnes(y_dims, ndims);
}
// in broadcasting scenario new memory is required because // in broadcasting scenario new memory is required because
// reduce sum must be calculated upon broadcasted dims // reduce sum must be calculated upon broadcasted dims
...@@ -340,21 +360,21 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> { ...@@ -340,21 +360,21 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> {
if (x_dims != dx_bd_dims) { if (x_dims != dx_bd_dims) {
ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dx_tmp, dx, ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dx_tmp, dx,
x_dims); x_dims,
paddle::framework::vectorize(x->dims()));
} else { } else {
*dx = std::move(dx_tmp); *dx = std::move(dx_tmp);
} }
if (y_dims != dy_bd_dims) { if (y_dims != dy_bd_dims) {
ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dy_tmp, dy, ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dy_tmp, dy,
y_dims); y_dims,
paddle::framework::vectorize(y->dims()));
} else { } else {
*dy = std::move(dy_tmp); *dy = std::move(dy_tmp);
} }
dx->set_layout(paddle::framework::DataLayout::kMKLDNN); dx->Resize(x->dims());
dx->set_format(x->format()); dy->Resize(y->dims());
dy->set_layout(paddle::framework::DataLayout::kMKLDNN);
dy->set_format(y->format());
} }
}; };
} // anonymous namespace } // anonymous namespace
......
...@@ -251,6 +251,41 @@ class TestMatMulV2MatrixXMatrix2Dx5DOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp): ...@@ -251,6 +251,41 @@ class TestMatMulV2MatrixXMatrix2Dx5DOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp):
self.trans_y = False self.trans_y = False
class TestMatMulV2MatrixXMatrix4Dx3DTransposeXOneDNNOp(
TestMatMulV2VectorXVectorOneDNNOp):
def config(self):
self.x_shape = (5, 4, 15, 10)
self.y_shape = (1, 15, 20)
self.trans_x = True
self.trans_y = False
class TestMatMulV2MatrixXMatrix3Dx4DTransposeYOneDNNOp(
TestMatMulV2VectorXVectorOneDNNOp):
def config(self):
self.x_shape = (2, 10, 15)
self.y_shape = (4, 2, 20, 15)
self.trans_x = False
self.trans_y = True
class TestMatMulV2MatrixXMatrix5Dx3DTransposeXTransposeYOneDNNOp(
TestMatMulV2VectorXVectorOneDNNOp):
def config(self):
self.x_shape = (4, 3, 2, 15, 10)
self.y_shape = (1, 20, 15)
self.trans_x = True
self.trans_y = True
class TestMatMulV2MatrixXMatrix3Dx4DOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp):
def config(self):
self.x_shape = (1, 1, 32, 16)
self.y_shape = (16, 16, 16)
self.trans_x = False
self.trans_y = False
# BF16 TESTS # BF16 TESTS
def create_bf16_test_class(parent): def create_bf16_test_class(parent):
@OpTestTool.skip_if_not_cpu_bf16() @OpTestTool.skip_if_not_cpu_bf16()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册