未验证 提交 be22021c 编写于 作者: X xiongkun 提交者: GitHub

fix bmm op bugs in static mode with dynamic shape (#56135)

上级 edba06e1
...@@ -271,27 +271,27 @@ void BmmInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { ...@@ -271,27 +271,27 @@ void BmmInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) {
"Input(Y) of BmmOp must be 3-dimensional in BmmOp, " "Input(Y) of BmmOp must be 3-dimensional in BmmOp, "
"but received Y's shape: [%s].", "but received Y's shape: [%s].",
y_ndims)); y_ndims));
PADDLE_ENFORCE_EQ( std::vector<int64_t> dim_out;
auto cal_shape_fn = [](int64_t x, int64_t y, const std::string& error_str) {
if (x == -1) {
return y;
} else if (y == -1) {
return x;
}
PADDLE_ENFORCE_EQ(x, y, phi::errors::InvalidArgument(error_str, x, y));
return x;
};
cal_shape_fn(x_dims[2],
y_dims[1],
"Input(X)'s width must be equal with Input(Y)'s height in BmmOp,"
"but receive X's width: [%s],"
"Y's height: [%s].");
dim_out.push_back(cal_shape_fn(
x_dims[0], x_dims[0],
y_dims[0], y_dims[0],
phi::errors::InvalidArgument(
"Input(X) and Input(Y) must have the same batch size in BmmOp, " "Input(X) and Input(Y) must have the same batch size in BmmOp, "
"but received X's batch size: [%s]," "but received X's batch size: [%s],"
"Y's batch size [%s]", "Y's batch size [%s]"));
x_dims[0],
y_dims[0]));
PADDLE_ENFORCE_EQ(
x_dims[2],
y_dims[1],
phi::errors::InvalidArgument(
"Input(X)'s width must be equal with Input(Y)'s height in BmmOp,"
"but receive X's width: [%s],"
"Y's height: [%s].",
x_dims[2],
y_dims[1]));
std::vector<int64_t> dim_out;
dim_out.push_back(x_dims[0]);
dim_out.push_back(x_dims[1]); dim_out.push_back(x_dims[1]);
dim_out.push_back(y_dims[2]); dim_out.push_back(y_dims[2]);
out->set_dims(phi::make_ddim(dim_out)); out->set_dims(phi::make_ddim(dim_out));
......
...@@ -1593,13 +1593,13 @@ def bmm(x, y, name=None): ...@@ -1593,13 +1593,13 @@ def bmm(x, y, name=None):
x_shape, y_shape x_shape, y_shape
) )
) )
if x_shape[2] != y_shape[1]: if x_shape[2] != -1 and y_shape[1] != -1 and x_shape[2] != y_shape[1]:
raise ValueError( raise ValueError(
"x's width must be equal with y's height. But received x's shape: {}, y's shape: {}".format( "x's width must be equal with y's height. But received x's shape: {}, y's shape: {}".format(
x_shape, y_shape x_shape, y_shape
) )
) )
if x_shape[0] != y_shape[0]: if x_shape[0] != -1 and y_shape[0] != -1 and x_shape[0] != y_shape[0]:
raise ValueError( raise ValueError(
"x's batch (shape[0]) must be equal with y's batch (shape[0]). But received x's shape: {}, y's shape: {}".format( "x's batch (shape[0]) must be equal with y's batch (shape[0]). But received x's shape: {}, y's shape: {}".format(
x_shape, y_shape x_shape, y_shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册