Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
0fb18bc2
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0fb18bc2
编写于
12月 02, 2020
作者:
S
ShenLiang
提交者:
GitHub
12月 02, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
enforce the matmul_v2 error message (#29297)
上级
9b59a589
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
51 addition
and
22 deletion
+51
-22
paddle/fluid/operators/matmul_v2_op.h
paddle/fluid/operators/matmul_v2_op.h
+51
-22
未找到文件。
paddle/fluid/operators/matmul_v2_op.h
浏览文件 @
0fb18bc2
...
...
@@ -71,8 +71,14 @@ static void GetBroadcastFromDims(const int x_ndim, const std::int64_t* x_dims,
for
(
int
i
=
0
;
i
<
ndim
;
++
i
)
{
PADDLE_ENFORCE_EQ
(
x_bd_dims
[
i
]
==
y_bd_dims
[
i
]
||
x_bd_dims
[
i
]
<=
1
||
y_bd_dims
[
i
]
<=
1
,
true
,
platform
::
errors
::
InvalidArgument
(
"Input(X) and Input(Y) has error dim."
));
true
,
platform
::
errors
::
InvalidArgument
(
"Input(X) and Input(Y) has error dim."
"X_broadcast's shape[%s] must be equal to Y_broadcast's shape[%s],"
"or X_broadcast's shape[%s] <= 1, or Y_broadcast's shape[%s] <= 1,"
"But received X_broadcast's shape[%s] = [%s]"
"received Y_broadcast's shape[%s] = [%s]"
,
i
,
i
,
i
,
i
,
i
,
x_bd_dims
[
i
],
i
,
y_bd_dims
[
i
]));
if
(
x_bd_dims
[
i
]
==
0
||
y_bd_dims
[
i
]
==
0
)
{
out_bd_dims
[
i
]
=
0
;
}
else
{
...
...
@@ -118,10 +124,13 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
const
T
*
y_data
=
Y
->
data
<
T
>
();
if
(
x_ndim
==
1
&&
y_ndim
==
1
)
{
PADDLE_ENFORCE_EQ
(
X
->
numel
(),
Y
->
numel
(),
platform
::
errors
::
InvalidArgument
(
"X's numbers is not equal to Y's numbers,"
"when X/Y's dims =1"
));
PADDLE_ENFORCE_EQ
(
X
->
numel
(),
Y
->
numel
(),
platform
::
errors
::
InvalidArgument
(
"X's numbers must be equal to Y's numbers,"
"when X/Y's dims =1. But received X has [%d] elements,"
"received Y has [%d] elements"
,
X
->
numel
(),
Y
->
numel
()));
VLOG
(
3
)
<<
"MatMul's case 1"
;
Out
->
Resize
({
1
});
Out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
...
...
@@ -140,13 +149,19 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
if
(
x_ndim
==
1
)
{
const
int
N
=
X
->
numel
();
if
(
trans_y
)
{
PADDLE_ENFORCE_EQ
(
y_dims
[
y_ndim
-
1
],
N
,
platform
::
errors
::
InvalidArgument
(
"Input(Y) has error dim."
));
PADDLE_ENFORCE_EQ
(
y_dims
[
y_ndim
-
1
],
N
,
platform
::
errors
::
InvalidArgument
(
"Input(Y) has error dim."
"Y'dims[%d] must be equal to %d"
"But received Y'dims[%d] is %d"
,
y_ndim
-
1
,
N
,
y_ndim
-
1
,
y_dims
[
y_ndim
-
1
]));
}
else
{
PADDLE_ENFORCE_EQ
(
y_dims
[
y_ndim
-
2
],
N
,
platform
::
errors
::
InvalidArgument
(
"Input(Y) has error dim."
));
PADDLE_ENFORCE_EQ
(
y_dims
[
y_ndim
-
2
],
N
,
platform
::
errors
::
InvalidArgument
(
"Input(Y) has error dim."
"Y'dims[%d] must be equal to %d"
"But received Y'dims[%d] is %d"
,
y_ndim
-
2
,
N
,
y_ndim
-
2
,
y_dims
[
y_ndim
-
2
]));
}
std
::
vector
<
std
::
int64_t
>
out_dims
(
y_ndim
-
1
);
if
(
trans_y
)
{
...
...
@@ -182,13 +197,19 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
if
(
y_ndim
==
1
)
{
const
int
N
=
Y
->
numel
();
if
(
trans_x
)
{
PADDLE_ENFORCE_EQ
(
x_dims
[
x_ndim
-
2
],
N
,
platform
::
errors
::
InvalidArgument
(
"Input(X) has error dim."
));
PADDLE_ENFORCE_EQ
(
x_dims
[
x_ndim
-
2
],
N
,
platform
::
errors
::
InvalidArgument
(
"Input(X) has error dim."
"X'dims[%d] must be equal to %d"
"But received X'dims[%d] is %d"
,
x_ndim
-
2
,
N
,
x_ndim
-
2
,
x_dims
[
x_ndim
-
2
]));
}
else
{
PADDLE_ENFORCE_EQ
(
x_dims
[
x_ndim
-
1
],
N
,
platform
::
errors
::
InvalidArgument
(
"Input(X) has error dim."
));
PADDLE_ENFORCE_EQ
(
x_dims
[
x_ndim
-
1
],
N
,
platform
::
errors
::
InvalidArgument
(
"Input(X) has error dim."
"X'dims[%d] must be equal to %d"
"But received X'dims[%d] is %d"
,
x_ndim
-
1
,
N
,
x_ndim
-
1
,
x_dims
[
x_ndim
-
1
]));
}
std
::
vector
<
std
::
int64_t
>
out_dims
(
x_ndim
-
1
);
if
(
trans_x
)
{
...
...
@@ -225,11 +246,19 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
const
int
M
=
trans_x
?
x_dims
[
x_ndim
-
1
]
:
x_dims
[
x_ndim
-
2
];
const
int
K
=
trans_x
?
x_dims
[
x_ndim
-
2
]
:
x_dims
[
x_ndim
-
1
];
if
(
trans_y
)
{
PADDLE_ENFORCE_EQ
(
y_dims
[
y_ndim
-
1
],
K
,
platform
::
errors
::
InvalidArgument
(
"Input(X) has error dim."
));
PADDLE_ENFORCE_EQ
(
y_dims
[
y_ndim
-
1
],
K
,
platform
::
errors
::
InvalidArgument
(
"Input(Y) has error dim."
"Y'dims[%d] must be equal to %d"
"But received Y'dims[%d] is %d"
,
y_ndim
-
1
,
K
,
y_ndim
-
1
,
y_dims
[
y_ndim
-
1
]));
}
else
{
PADDLE_ENFORCE_EQ
(
y_dims
[
y_ndim
-
2
],
K
,
platform
::
errors
::
InvalidArgument
(
"Input(X) has error dim."
));
PADDLE_ENFORCE_EQ
(
y_dims
[
y_ndim
-
2
],
K
,
platform
::
errors
::
InvalidArgument
(
"Input(Y) has error dim."
"Y'dims[%d] must be equal to %d"
"But received Y'dims[%d] is %d"
,
y_ndim
-
2
,
K
,
y_ndim
-
2
,
y_dims
[
y_ndim
-
2
]));
}
const
int
N
=
trans_y
?
y_dims
[
y_ndim
-
2
]
:
y_dims
[
y_ndim
-
1
];
const
int
ndim
=
(
std
::
max
)(
x_ndim
,
y_ndim
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录