From d37bd5cc6306e77397501ade14918fb56d0035f5 Mon Sep 17 00:00:00 2001 From: Hongyu Liu <43953930+phlrain@users.noreply.github.com> Date: Thu, 21 Mar 2019 20:14:57 +0800 Subject: [PATCH] Merge pull request #16347 from phlrain/fix_matmul_check fix matmul shape check --- paddle/fluid/operators/matmul_op.cc | 6 ++++-- python/paddle/fluid/layers/nn.py | 3 +++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index 242a1b9ae9..f182827452 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -290,8 +290,10 @@ class MatMulOp : public framework::OperatorWithKernel { context->Attrs().Get("transpose_Y")); PADDLE_ENFORCE_EQ(mat_dim_x.width_, mat_dim_y.height_); - PADDLE_ENFORCE(mat_dim_x.batch_size_ == mat_dim_y.batch_size_ || - mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0); + if (context->IsRuntime()) { + PADDLE_ENFORCE(mat_dim_x.batch_size_ == mat_dim_y.batch_size_ || + mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0); + } std::vector dim_out; if (mat_dim_x.batch_size_ != 0) { dim_out = framework::vectorize(dim_x); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index cdd64e89b4..bc108e9e34 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4732,6 +4732,9 @@ def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None): if len(y_shape) > 2: for i, dim_x in enumerate(x_shape[:-2]): + # don't check neg shape + if dim_x < 0 or y_shape[i] < 0: + continue if dim_x != y_shape[i]: raise ValueError("Invalid inputs for matmul.") -- GitLab