You need to sign in or sign up before continuing.
提交 0e402989 编写于 作者: P phlrain

fix matmul shape check; test=develop

上级 b7baeed7
...@@ -290,8 +290,10 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -290,8 +290,10 @@ class MatMulOp : public framework::OperatorWithKernel {
context->Attrs().Get<bool>("transpose_Y")); context->Attrs().Get<bool>("transpose_Y"));
PADDLE_ENFORCE_EQ(mat_dim_x.width_, mat_dim_y.height_); PADDLE_ENFORCE_EQ(mat_dim_x.width_, mat_dim_y.height_);
PADDLE_ENFORCE(mat_dim_x.batch_size_ == mat_dim_y.batch_size_ || if (context->IsRuntime()) {
mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0); PADDLE_ENFORCE(mat_dim_x.batch_size_ == mat_dim_y.batch_size_ ||
mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0);
}
std::vector<int64_t> dim_out; std::vector<int64_t> dim_out;
if (mat_dim_x.batch_size_ != 0) { if (mat_dim_x.batch_size_ != 0) {
dim_out = framework::vectorize(dim_x); dim_out = framework::vectorize(dim_x);
......
...@@ -4901,6 +4901,9 @@ def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None): ...@@ -4901,6 +4901,9 @@ def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None):
if len(y_shape) > 2 and len(x_shape) > 2: if len(y_shape) > 2 and len(x_shape) > 2:
for i, dim_x in enumerate(x_shape[:-2]): for i, dim_x in enumerate(x_shape[:-2]):
# don't check neg shape
if dim_x < 0 or y_shape[i] < 0:
continue
if dim_x != y_shape[i]: if dim_x != y_shape[i]:
raise ValueError("Invalid inputs for matmul. x(%s), y(%s)" % raise ValueError("Invalid inputs for matmul. x(%s), y(%s)" %
(x.shape, y.shape)) (x.shape, y.shape))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册