Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
67cfce9f
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
67cfce9f
编写于
5月 31, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(imperative/amp): add is_scalar check in elemwise and concat
GitOrigin-RevId: 61a612e92a716030d5d7ad6f6ee3258f03e35069
上级
d313f926
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
49 addition
and
30 deletion
+49
-30
imperative/python/megengine/amp/convert_format.py
imperative/python/megengine/amp/convert_format.py
+0
-7
imperative/python/megengine/functional/math.py
imperative/python/megengine/functional/math.py
+1
-2
imperative/python/test/unit/amp/test_convert_format.py
imperative/python/test/unit/amp/test_convert_format.py
+0
-7
imperative/src/impl/transformations/format.cpp
imperative/src/impl/transformations/format.cpp
+48
-14
未找到文件。
imperative/python/megengine/amp/convert_format.py
浏览文件 @
67cfce9f
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
copy
import
deepcopy
from
copy
import
deepcopy
from
..
import
functional
as
F
from
..
import
functional
as
F
...
...
imperative/python/megengine/functional/math.py
浏览文件 @
67cfce9f
...
@@ -592,7 +592,6 @@ def matmul(
...
@@ -592,7 +592,6 @@ def matmul(
transpose_a
=
False
,
transpose_a
=
False
,
transpose_b
=
False
,
transpose_b
=
False
,
compute_mode
=
"default"
,
compute_mode
=
"default"
,
format
=
"default"
,
)
->
Tensor
:
)
->
Tensor
:
r
"""Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``.
r
"""Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``.
...
@@ -625,7 +624,7 @@ def matmul(
...
@@ -625,7 +624,7 @@ def matmul(
array([[10., 13.],
array([[10., 13.],
[28., 40.]], dtype=float32)
[28., 40.]], dtype=float32)
"""
"""
return
_matmul
(
inp1
,
inp2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
)
return
_matmul
(
inp1
,
inp2
,
transpose_a
,
transpose_b
,
compute_mode
)
def
dot
(
inp1
:
Tensor
,
inp2
:
Tensor
)
->
Tensor
:
def
dot
(
inp1
:
Tensor
,
inp2
:
Tensor
)
->
Tensor
:
...
...
imperative/python/test/unit/amp/test_convert_format.py
浏览文件 @
67cfce9f
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
...
...
imperative/src/impl/transformations/format.cpp
浏览文件 @
67cfce9f
...
@@ -23,24 +23,42 @@ TypedValueRef<FormattedTensorValue> FormatTransformation::to(
...
@@ -23,24 +23,42 @@ TypedValueRef<FormattedTensorValue> FormatTransformation::to(
if
(
format
==
target
)
if
(
format
==
target
)
return
as
(
tensor
,
target
);
return
as
(
tensor
,
target
);
auto
&&
shape
=
tensor
.
value
().
shape
().
cast
<
ShapeValue
>
();
if
(
format
==
FT
::
NHWC
&&
(
target
==
FT
::
NCHW
||
target
==
FT
::
DEFAULT
))
{
if
(
format
==
FT
::
NHWC
&&
(
target
==
FT
::
NCHW
||
target
==
FT
::
DEFAULT
))
{
// FIXME(czh): temporary fast path for group conv 5D weight.
// FIXME(czh): temporary fast path for group conv 5D weight.
if
(
tensor
.
value
().
shape
().
cast
<
ShapeValue
>
()
.
ndim
==
5
)
{
if
(
shape
.
ndim
==
5
)
{
pattern
=
{
0
,
1
,
4
,
2
,
3
};
pattern
=
{
0
,
1
,
4
,
2
,
3
};
}
else
{
}
else
if
(
shape
.
ndim
==
4
)
{
pattern
=
{
0
,
3
,
1
,
2
};
pattern
=
{
0
,
3
,
1
,
2
};
}
else
{
mgb_throw
(
MegBrainError
,
"Unsupport format conversion for tensor %s(shape=%s) from %s to %s"
,
tensor
.
to_string
().
c_str
(),
shape
.
to_string
().
c_str
(),
format
.
to_string
().
c_str
(),
Format
(
target
).
to_string
().
c_str
());
}
}
}
else
if
((
format
==
FT
::
NCHW
||
format
==
FT
::
DEFAULT
)
&&
target
==
FT
::
NHWC
)
{
}
else
if
((
format
==
FT
::
NCHW
||
format
==
FT
::
DEFAULT
)
&&
target
==
FT
::
NHWC
)
{
if
(
tensor
.
value
().
shape
().
cast
<
ShapeValue
>
()
.
ndim
==
5
)
{
if
(
shape
.
ndim
==
5
)
{
pattern
=
{
0
,
1
,
3
,
4
,
2
};
pattern
=
{
0
,
1
,
3
,
4
,
2
};
}
else
{
}
else
if
(
shape
.
ndim
==
4
)
{
pattern
=
{
0
,
2
,
3
,
1
};
pattern
=
{
0
,
2
,
3
,
1
};
}
else
{
mgb_throw
(
MegBrainError
,
"Unsupport format conversion for tensor %s(shape=%s) from %s to %s"
,
tensor
.
to_string
().
c_str
(),
shape
.
to_string
().
c_str
(),
format
.
to_string
().
c_str
(),
Format
(
target
).
to_string
().
c_str
());
}
}
}
else
{
}
else
{
mgb_throw
(
mgb_throw
(
MegBrainError
,
"Unsupport format conversion from %s to %s"
,
MegBrainError
,
"Unsupport format conversion for tensor %s(shape=%s) from %s to %s"
,
tensor
.
to_string
().
c_str
(),
shape
.
to_string
().
c_str
(),
format
.
to_string
().
c_str
(),
Format
(
target
).
to_string
().
c_str
());
format
.
to_string
().
c_str
(),
Format
(
target
).
to_string
().
c_str
());
}
}
mgb_log_debug
(
"Change tensor %s from %s to %s"
,
tensor
.
to_string
().
c_str
(),
format
.
to_string
().
c_str
(),
Format
(
target
).
to_string
().
c_str
());
auto
output
=
auto
output
=
imperative
::
apply
(
*
Dimshuffle
::
make
(
pattern
,
scope
),
{
tensor
.
value
()})[
0
];
imperative
::
apply
(
*
Dimshuffle
::
make
(
pattern
,
scope
),
{
tensor
.
value
()})[
0
];
return
m_value_type
.
make
(
output
,
target
);
return
m_value_type
.
make
(
output
,
target
);
...
@@ -380,9 +398,7 @@ inline ValueRefList unify_inputs_format(
...
@@ -380,9 +398,7 @@ inline ValueRefList unify_inputs_format(
ValueRefList
unified_inputs
(
inputs
.
size
());
ValueRefList
unified_inputs
(
inputs
.
size
());
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
auto
&&
inp
=
inputs
[
i
].
cast
(
t
.
value_type
());
auto
&&
inp
=
inputs
[
i
].
cast
(
t
.
value_type
());
if
(
inp
.
format
()
!=
dst_fmt
&&
if
(
inp
.
format
()
!=
dst_fmt
)
{
(
inp
.
value
().
shape
().
cast
<
ShapeValue
>
().
ndim
==
4
||
inp
.
value
().
shape
().
cast
<
ShapeValue
>
().
ndim
==
5
))
{
unified_inputs
[
i
]
=
t
.
to
(
inp
,
dst_fmt
,
scope
);
unified_inputs
[
i
]
=
t
.
to
(
inp
,
dst_fmt
,
scope
);
}
else
{
}
else
{
unified_inputs
[
i
]
=
inputs
[
i
];
unified_inputs
[
i
]
=
inputs
[
i
];
...
@@ -396,7 +412,16 @@ ValueRefList elemwise_rule(
...
@@ -396,7 +412,16 @@ ValueRefList elemwise_rule(
const
FormatTransformation
&
t
)
{
const
FormatTransformation
&
t
)
{
FT
format
=
get_inputs_format
(
inputs
,
t
);
FT
format
=
get_inputs_format
(
inputs
,
t
);
if
(
format
==
FT
::
NHWC
&&
auto_convert
)
{
if
(
format
==
FT
::
NHWC
&&
auto_convert
)
{
auto
unified_inputs
=
unify_inputs_format
(
inputs
,
FT
::
NHWC
,
op
.
scope
(),
t
);
ValueRefList
unified_inputs
(
inputs
.
size
());
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
auto
&&
inp
=
inputs
[
i
].
cast
(
t
.
value_type
());
if
(
inp
.
format
()
!=
FT
::
NHWC
&&
inp
.
value
().
is_scalar
())
{
unified_inputs
[
i
]
=
t
.
value_type
().
make
(
inp
.
value
(),
FT
::
NHWC
);
}
else
{
unified_inputs
[
i
]
=
inputs
[
i
];
}
}
unified_inputs
=
unify_inputs_format
(
unified_inputs
,
FT
::
NHWC
,
op
.
scope
(),
t
);
return
t
.
wrap_outputs
(
return
t
.
wrap_outputs
(
imperative
::
apply
(
op
,
t
.
unwrap_inputs
(
unified_inputs
)),
format
);
imperative
::
apply
(
op
,
t
.
unwrap_inputs
(
unified_inputs
)),
format
);
}
}
...
@@ -410,7 +435,16 @@ ValueRefList concat_rule(
...
@@ -410,7 +435,16 @@ ValueRefList concat_rule(
if
(
!
(
format
==
FT
::
NHWC
&&
auto_convert
))
{
if
(
!
(
format
==
FT
::
NHWC
&&
auto_convert
))
{
return
t
.
wrap_outputs
(
imperative
::
apply
(
op
,
t
.
unwrap_inputs
(
inputs
)),
format
);
return
t
.
wrap_outputs
(
imperative
::
apply
(
op
,
t
.
unwrap_inputs
(
inputs
)),
format
);
}
}
auto
unified_inputs
=
unify_inputs_format
(
inputs
,
FT
::
NHWC
,
op
.
scope
(),
t
);
ValueRefList
unified_inputs
(
inputs
.
size
());
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
auto
&&
inp
=
inputs
[
i
].
cast
(
t
.
value_type
());
if
(
inp
.
format
()
!=
FT
::
NHWC
&&
inp
.
value
().
is_scalar
())
{
unified_inputs
[
i
]
=
t
.
value_type
().
make
(
inp
.
value
(),
FT
::
NHWC
);
}
else
{
unified_inputs
[
i
]
=
inputs
[
i
];
}
}
unified_inputs
=
unify_inputs_format
(
unified_inputs
,
FT
::
NHWC
,
op
.
scope
(),
t
);
// TODO: handle 5D NHWC Tensor from group conv
// TODO: handle 5D NHWC Tensor from group conv
auto
axis
=
op
.
axis
;
auto
axis
=
op
.
axis
;
if
(
axis
==
2
||
axis
==
3
)
{
if
(
axis
==
2
||
axis
==
3
)
{
...
@@ -441,7 +475,7 @@ ValueRefList batchnorm_rule(
...
@@ -441,7 +475,7 @@ ValueRefList batchnorm_rule(
const
FormatTransformation
&
t
)
{
const
FormatTransformation
&
t
)
{
auto
&&
inp_format
=
inputs
[
0
].
cast
(
t
.
value_type
()).
format
();
auto
&&
inp_format
=
inputs
[
0
].
cast
(
t
.
value_type
()).
format
();
if
(
inp_format
==
FT
::
NHWC
)
{
if
(
inp_format
==
FT
::
NHWC
)
{
auto
&&
new_param
=
op
.
param
();
auto
new_param
=
op
.
param
();
new_param
.
param_dim
=
BatchNorm
::
ParamDim
::
DIM_111C
;
new_param
.
param_dim
=
BatchNorm
::
ParamDim
::
DIM_111C
;
auto
new_op
=
BatchNorm
::
make
(
new_param
);
auto
new_op
=
BatchNorm
::
make
(
new_param
);
return
identity_rule_helper
(
*
new_op
,
inputs
,
t
);
return
identity_rule_helper
(
*
new_op
,
inputs
,
t
);
...
@@ -454,7 +488,7 @@ ValueRefList adaptive_pooling_rule(
...
@@ -454,7 +488,7 @@ ValueRefList adaptive_pooling_rule(
const
FormatTransformation
&
t
)
{
const
FormatTransformation
&
t
)
{
auto
&&
inp_format
=
inputs
[
0
].
cast
(
t
.
value_type
()).
format
();
auto
&&
inp_format
=
inputs
[
0
].
cast
(
t
.
value_type
()).
format
();
if
(
inp_format
==
FT
::
NHWC
)
{
if
(
inp_format
==
FT
::
NHWC
)
{
auto
&&
new_param
=
op
.
param
();
auto
new_param
=
op
.
param
();
new_param
.
format
=
AdaptivePooling
::
Format
::
NHWC
;
new_param
.
format
=
AdaptivePooling
::
Format
::
NHWC
;
auto
new_op
=
AdaptivePooling
::
make
(
new_param
,
op
.
shape
);
auto
new_op
=
AdaptivePooling
::
make
(
new_param
,
op
.
shape
);
return
identity_rule_helper
(
*
new_op
,
inputs
,
t
);
return
identity_rule_helper
(
*
new_op
,
inputs
,
t
);
...
@@ -518,7 +552,7 @@ FOREACH_IDENTITY_OP(CREATE_IDENTITY_OP_RULE)
...
@@ -518,7 +552,7 @@ FOREACH_IDENTITY_OP(CREATE_IDENTITY_OP_RULE)
const FormatTransformation& t) { \
const FormatTransformation& t) { \
auto&& inp_format = inputs[0].cast(t.value_type()).format(); \
auto&& inp_format = inputs[0].cast(t.value_type()).format(); \
if (inp_format == FT::NHWC) { \
if (inp_format == FT::NHWC) { \
auto
&& new_param = _op.param();
\
auto
new_param = _op.param();
\
new_param.format = Op::Format::NHWC; \
new_param.format = Op::Format::NHWC; \
auto new_op = Op::make(new_param); \
auto new_op = Op::make(new_param); \
return identity_rule_helper(*new_op, inputs, t); \
return identity_rule_helper(*new_op, inputs, t); \
...
@@ -535,7 +569,7 @@ FOREACH_FORMAT_OP(CREATE_FORMAT_OP_RULE)
...
@@ -535,7 +569,7 @@ FOREACH_FORMAT_OP(CREATE_FORMAT_OP_RULE)
const FormatTransformation& t) { \
const FormatTransformation& t) { \
auto&& inp_format = inputs[0].cast(t.value_type()).format(); \
auto&& inp_format = inputs[0].cast(t.value_type()).format(); \
if (inp_format == FT::NHWC) { \
if (inp_format == FT::NHWC) { \
auto
&& new_param = _op.param();
\
auto
new_param = _op.param();
\
new_param.format = Op::Format::NHWC; \
new_param.format = Op::Format::NHWC; \
auto new_op = Op::make(new_param, _op.policy()); \
auto new_op = Op::make(new_param, _op.policy()); \
return identity_rule_helper(*new_op, inputs, t); \
return identity_rule_helper(*new_op, inputs, t); \
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录