Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
f5597d9a
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
396
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
f5597d9a
编写于
7月 19, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mgb): make error infomation of input channel mismatch more readable
GitOrigin-RevId: 6f95260070bf826dc78cd23a3e62548c0d1cb9a8
上级
38bd5999
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
53 addition
and
13 deletion
+53
-13
dnn/src/common/convolution.cpp
dnn/src/common/convolution.cpp
+38
-7
imperative/src/impl/ops/convolution.cpp
imperative/src/impl/ops/convolution.cpp
+15
-6
未找到文件。
dnn/src/common/convolution.cpp
浏览文件 @
f5597d9a
...
...
@@ -573,8 +573,15 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(
filter
.
param
<
dtype
::
QuantizedS8
>
().
scale
));
}
else
{
megdnn_throw
(
ssprintf
(
"unsupported input / filter DType: %s x %s"
,
src
.
name
(),
filter
.
name
()));
"runtime does not support input / filter DType: %s x %s"
"now support case list: FLOAT x FLOAT
\n
"
" Int8 x Int8
\n
"
" QuantizedS8 x QuantizedS8
\n
"
" Quantized8Asymm x Quantized8Asymm
\n
"
" QuantizedS4 x QuantizedS4
\n
"
" Quantized4Asymm x Quantized4Asymm
\n
"
" QuantizedS1 x QuantizedS1
\n
"
,
src
.
name
(),
filter
.
name
()));
}
if
(
!
dst
.
valid
())
{
dst
=
supported_dst_dtype
.
at
(
0
);
...
...
@@ -588,8 +595,21 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(
}
MEGDNN_MARK_USED_VAR
(
dst_supported
);
megdnn_assert
(
dst_supported
,
"unsupported Conv(%s, %s) -> %s"
,
src
.
name
(),
filter
.
name
(),
dst
.
name
());
dst_supported
,
"runtime does not support Conv(%s, %s) -> %s"
"now support case list: Conv(FLOAT x FLOAT) -> FLOAT
\n
"
" Conv(Int8 x Int8) -> Int32
\n
"
" Conv(QuantizedS8 x QuantizedS8) -> "
"QuantizedS32
\n
"
" Conv(Quantized8Asymm x Quantized8Asymm) -> "
"Quantized32Asymm
\n
"
" Conv(QuantizedS4 x QuantizedS4) -> "
"QuantizedS32
\n
"
" Conv(Quantized4Asymm x Quantized4Asymm) -> "
"Quantized32Asymm
\n
"
" Conv(QuantizedS1 x QuantizedS1) -> "
"QuantizedS32
\n
"
,
src
.
name
(),
filter
.
name
(),
dst
.
name
());
}
megdnn_assert
(
(
param
().
compute_mode
==
Param
::
ComputeMode
::
FLOAT32
||
...
...
@@ -1098,15 +1118,26 @@ void ConvolutionBackwardData::deduce_dtype(DType filter, DType diff, DType& grad
}
}
else
{
megdnn_throw
(
ssprintf
(
"unsupported input / diff DType: %s x %s"
,
filter
.
name
(),
diff
.
name
()));
"runtime does not support input / diff DType: %s x %s"
"now support case list: FLOAT x FLOAT
\n
"
" Int8 x Int8
\n
"
" QuantizedS8 x QuantizedS8
\n
"
" Quantized8Asymm x Quantized8Asymm
\n
"
,
filter
.
name
(),
diff
.
name
()));
}
if
(
!
grad
.
valid
())
{
grad
=
supported_dst_dtype
.
at
(
0
);
}
else
{
megdnn_assert
(
vec_contains
(
supported_dst_dtype
,
grad
),
"unsupported ConvBwd(%s, %s) -> %s"
,
filter
.
name
(),
diff
.
name
(),
grad
.
name
());
"runtime does not support ConvBwd(%s, %s) -> %s"
"now support case list: ConvBwd(FLOAT x FLOAT) -> FLOAT
\n
"
" ConvBwd(Int8 x Int8) -> Int32
\n
"
" ConvBwd(QuantizedS8 x QuantizedS8) -> "
"QuantizedS32
\n
"
" ConvBwd(Quantized8Asymm x Quantized8Asymm) -> "
"Quantized32Asymm
\n
"
,
filter
.
name
(),
diff
.
name
(),
grad
.
name
());
}
megdnn_assert
(
param
().
compute_mode
!=
Param
::
ComputeMode
::
FLOAT32
...
...
imperative/src/impl/ops/convolution.cpp
浏览文件 @
f5597d9a
...
...
@@ -95,8 +95,11 @@ TensorLayout do_shape_infer(
dilated_spatial
[
i
]
=
(
filter
[
i
+
flt_start
+
flt_spatial_start
]
-
1
)
*
dilation
[
i
]
+
1
;
}
mgb_assert
(
icpg
*
group
==
src
[
src_or_dst_c_pos
],
"group conv invalid"
);
mgb_assert
(
icpg
*
group
==
src
[
src_or_dst_c_pos
],
"group conv invalid: input channel of Conv expect %zu, but got %zu
\n
"
"hint: weight may be changed by mistake
\n
"
,
icpg
*
group
,
src
[
src_or_dst_c_pos
]);
TensorLayout
dst
{
src
.
dtype
};
dst
.
ndim
=
src_ndim
;
dst
[
0
]
=
src
[
0
];
...
...
@@ -310,8 +313,11 @@ TensorLayout convbwd_do_shape_infer(
dilated_spatial
[
i
]
=
(
filter
[
i
+
flt_start
+
flt_spatial_start
]
-
1
)
*
dilation
[
i
]
+
1
;
}
mgb_assert
(
ocpg
*
group
==
diff
[
src_or_dst_c_pos
],
"group conv invalid"
);
mgb_assert
(
ocpg
*
group
==
diff
[
src_or_dst_c_pos
],
"group conv invalid: input channel of Conv expect %zu, but got %zu
\n
"
"hint: weight may be changed by mistake
\n
"
,
ocpg
*
group
,
diff
[
src_or_dst_c_pos
]);
auto
deduce
=
[](
size_t
out
,
size_t
filter
,
size_t
stride
,
size_t
pad
)
{
auto
i
=
(
out
-
1
)
*
stride
+
filter
;
mgb_assert
(
i
>
pad
*
2
);
...
...
@@ -479,8 +485,11 @@ TensorLayout do_shape_infer(
dilated_spatial
[
i
]
=
(
filter
[
i
+
flt_start
+
flt_spatial_start
]
-
1
)
*
dilation
[
i
]
+
1
;
}
mgb_assert
(
icpg
*
group
==
src
[
src_or_dst_c_pos
],
"group conv invalid"
);
mgb_assert
(
icpg
*
group
==
src
[
src_or_dst_c_pos
],
"group conv invalid: input channel of Conv expect %zu, but got %zu
\n
"
"hint: weight may be changed by mistake
\n
"
,
icpg
*
group
,
src
[
src_or_dst_c_pos
]);
TensorLayout
dst
{
src
.
dtype
};
dst
.
ndim
=
src_ndim
;
dst
[
0
]
=
src
[
0
];
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录