From 0e40298949516f78a92213becedfaaedbec7b3e9 Mon Sep 17 00:00:00 2001 From: phlrain Date: Thu, 21 Mar 2019 07:23:19 +0000 Subject: [PATCH] fix matmul shape check; test=develop --- paddle/fluid/operators/matmul_op.cc | 6 ++++-- python/paddle/fluid/layers/nn.py | 3 +++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index 242a1b9ae92..f1828274520 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -290,8 +290,10 @@ class MatMulOp : public framework::OperatorWithKernel { context->Attrs().Get("transpose_Y")); PADDLE_ENFORCE_EQ(mat_dim_x.width_, mat_dim_y.height_); - PADDLE_ENFORCE(mat_dim_x.batch_size_ == mat_dim_y.batch_size_ || - mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0); + if (context->IsRuntime()) { + PADDLE_ENFORCE(mat_dim_x.batch_size_ == mat_dim_y.batch_size_ || + mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0); + } std::vector dim_out; if (mat_dim_x.batch_size_ != 0) { dim_out = framework::vectorize(dim_x); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 28655314d30..1c0674fb5a9 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4901,6 +4901,9 @@ def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None): if len(y_shape) > 2 and len(x_shape) > 2: for i, dim_x in enumerate(x_shape[:-2]): + # don't check neg shape + if dim_x < 0 or y_shape[i] < 0: + continue if dim_x != y_shape[i]: raise ValueError("Invalid inputs for matmul. x(%s), y(%s)" % (x.shape, y.shape)) -- GitLab