提交 84e3adbe 编写于 作者: C chengduo 提交者: ceci3

Fix reshape bug (#16069)

* In some case, the input may have one than one negative value.
test=develop

* fix matmul bug
test=develop
上级 ab19d92e
...@@ -56,6 +56,9 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -56,6 +56,9 @@ class ReshapeOp : public framework::OperatorWithKernel {
static framework::DDim ValidateShape(const std::vector<int> shape, static framework::DDim ValidateShape(const std::vector<int> shape,
const framework::DDim &in_dims) { const framework::DDim &in_dims) {
const int64_t in_size = framework::product(in_dims); const int64_t in_size = framework::product(in_dims);
auto in_dims_vec = framework::vectorize(in_dims);
bool all_positive = std::all_of(in_dims_vec.cbegin(), in_dims_vec.cend(),
[](int64_t i) { return i > 0; });
// only one dimension can be set to -1, whose size will be automatically // only one dimension can be set to -1, whose size will be automatically
// infered. // infered.
const int64_t unk_dim_val = -1; const int64_t unk_dim_val = -1;
...@@ -88,7 +91,7 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -88,7 +91,7 @@ class ReshapeOp : public framework::OperatorWithKernel {
} }
if (unk_dim_idx != -1) { if (unk_dim_idx != -1) {
if (in_size > 0) { if (all_positive) {
// in_size < 0 and is un-determinate in compile time, skip the check, // in_size < 0 and is un-determinate in compile time, skip the check,
// for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8], // for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8],
// capacity = -24, in_size = -8, output_shape[0] = 0 // capacity = -24, in_size = -8, output_shape[0] = 0
......
...@@ -4834,11 +4834,6 @@ def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None): ...@@ -4834,11 +4834,6 @@ def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None):
""" """
def __check_input(x, y): def __check_input(x, y):
if len(y.shape) > len(x.shape):
raise ValueError(
"Invalid inputs for matmul. "
"x's rank should be always greater than or equal to y'rank.")
x_shape = list(x.shape) x_shape = list(x.shape)
y_shape = list(y.shape) y_shape = list(y.shape)
if len(x_shape) == 1: if len(x_shape) == 1:
...@@ -4854,10 +4849,11 @@ def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None): ...@@ -4854,10 +4849,11 @@ def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None):
if x_shape[-1] != y_shape[-2]: if x_shape[-1] != y_shape[-2]:
raise ValueError("Invalid inputs for matmul.") raise ValueError("Invalid inputs for matmul.")
if len(y_shape) > 2: if len(y_shape) > 2 and len(x_shape) > 2:
for i, dim_x in enumerate(x_shape[:-2]): for i, dim_x in enumerate(x_shape[:-2]):
if dim_x != y_shape[i]: if dim_x != y_shape[i]:
raise ValueError("Invalid inputs for matmul.") raise ValueError("Invalid inputs for matmul. x(%s), y(%s)" %
(x.shape, y.shape))
__check_input(x, y) __check_input(x, y)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册