From be22021c0bb0db02f2302d28031425f5187a5bce Mon Sep 17 00:00:00 2001 From: xiongkun Date: Wed, 16 Aug 2023 15:00:18 +0800 Subject: [PATCH] fix bmm op bugs in static mode with dynamic shape (#56135) --- paddle/phi/infermeta/binary.cc | 38 +++++++++++++++++----------------- python/paddle/tensor/linalg.py | 4 ++-- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index fee5882787e..08a1e0e1e45 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -271,27 +271,27 @@ void BmmInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { "Input(Y) of BmmOp must be 3-dimensional in BmmOp, " "but received Y's shape: [%s].", y_ndims)); - PADDLE_ENFORCE_EQ( + std::vector dim_out; + auto cal_shape_fn = [](int64_t x, int64_t y, const std::string& error_str) { + if (x == -1) { + return y; + } else if (y == -1) { + return x; + } + PADDLE_ENFORCE_EQ(x, y, phi::errors::InvalidArgument(error_str, x, y)); + return x; + }; + cal_shape_fn(x_dims[2], + y_dims[1], + "Input(X)'s width must be equal with Input(Y)'s height in BmmOp," + "but receive X's width: [%s]," + "Y's height: [%s]."); + dim_out.push_back(cal_shape_fn( x_dims[0], y_dims[0], - phi::errors::InvalidArgument( - "Input(X) and Input(Y) must have the same batch size in BmmOp, " - "but received X's batch size: [%s]," - "Y's batch size [%s]", - x_dims[0], - y_dims[0])); - PADDLE_ENFORCE_EQ( - x_dims[2], - y_dims[1], - phi::errors::InvalidArgument( - "Input(X)'s width must be equal with Input(Y)'s height in BmmOp," - "but receive X's width: [%s]," - "Y's height: [%s].", - x_dims[2], - y_dims[1])); - - std::vector dim_out; - dim_out.push_back(x_dims[0]); + "Input(X) and Input(Y) must have the same batch size in BmmOp, " + "but received X's batch size: [%s]," + "Y's batch size [%s]")); dim_out.push_back(x_dims[1]); dim_out.push_back(y_dims[2]); out->set_dims(phi::make_ddim(dim_out)); diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 8d9db61f929..375d414015d 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1593,13 +1593,13 @@ def bmm(x, y, name=None): x_shape, y_shape ) ) - if x_shape[2] != y_shape[1]: + if x_shape[2] != -1 and y_shape[1] != -1 and x_shape[2] != y_shape[1]: raise ValueError( "x's width must be equal with y's height. But received x's shape: {}, y's shape: {}".format( x_shape, y_shape ) ) - if x_shape[0] != y_shape[0]: + if x_shape[0] != -1 and y_shape[0] != -1 and x_shape[0] != y_shape[0]: raise ValueError( "x's batch (shape[0]) must be equal with y's batch (shape[0]). But received x's shape: {}, y's shape: {}".format( x_shape, y_shape -- GitLab