Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
d313f926
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
d313f926
编写于
5月 28, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(imperative/amp): fix format transformation for symbol trans
GitOrigin-RevId: 96cc237c67e25c8cb1567eb08325db65adc1c57d
上级
261a5bce
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
47 addition
and
24 deletion
+47
-24
imperative/python/megengine/autodiff/grad_manager.py
imperative/python/megengine/autodiff/grad_manager.py
+0
-2
imperative/python/megengine/functional/tensor.py
imperative/python/megengine/functional/tensor.py
+3
-1
imperative/python/src/transformation.h
imperative/python/src/transformation.h
+1
-1
imperative/python/test/unit/amp/test_convert_format.py
imperative/python/test/unit/amp/test_convert_format.py
+19
-15
imperative/python/test/unit/core/test_formatted_tensor.py
imperative/python/test/unit/core/test_formatted_tensor.py
+5
-1
imperative/src/impl/transformations/format.cpp
imperative/src/impl/transformations/format.cpp
+18
-4
imperative/src/include/megbrain/imperative/transformations/format.h
.../src/include/megbrain/imperative/transformations/format.h
+1
-0
未找到文件。
imperative/python/megengine/autodiff/grad_manager.py
浏览文件 @
d313f926
...
@@ -260,7 +260,6 @@ class GradManager:
...
@@ -260,7 +260,6 @@ class GradManager:
push_scope
(
"backward"
)
push_scope
(
"backward"
)
set_option
(
"record_computing_path"
,
0
)
set_option
(
"record_computing_path"
,
0
)
_origin_auto_format
=
get_auto_format_convert
()
_origin_auto_format
=
get_auto_format_convert
()
set_auto_format_convert
(
False
)
from
..functional
import
ones_like
from
..functional
import
ones_like
global
backwarding_grad_manager
global
backwarding_grad_manager
...
@@ -304,7 +303,6 @@ class GradManager:
...
@@ -304,7 +303,6 @@ class GradManager:
self
.
release
()
self
.
release
()
backwarding_grad_manager
=
cache
backwarding_grad_manager
=
cache
set_option
(
"record_computing_path"
,
1
)
set_option
(
"record_computing_path"
,
1
)
set_auto_format_convert
(
_origin_auto_format
)
pop_scope
(
"backward"
)
pop_scope
(
"backward"
)
def
record
(
self
):
def
record
(
self
):
...
...
imperative/python/megengine/functional/tensor.py
浏览文件 @
d313f926
...
@@ -274,7 +274,9 @@ def full_like(inp: Tensor, value: Union[int, float]) -> Tensor:
...
@@ -274,7 +274,9 @@ def full_like(inp: Tensor, value: Union[int, float]) -> Tensor:
return
x
return
x
# set x's format to use FormatTransformation rule for Broadcast.
# set x's format to use FormatTransformation rule for Broadcast.
return
broadcast_to
(
x
,
inp
.
shape
)
rst
=
broadcast_to
(
x
,
inp
.
shape
)
rst
.
format
=
inp
.
format
return
rst
def
broadcast_to
(
inp
:
Tensor
,
shape
:
Union
[
int
,
Iterable
[
int
]])
->
Tensor
:
def
broadcast_to
(
inp
:
Tensor
,
shape
:
Union
[
int
,
Iterable
[
int
]])
->
Tensor
:
...
...
imperative/python/src/transformation.h
浏览文件 @
d313f926
...
@@ -26,7 +26,7 @@ public:
...
@@ -26,7 +26,7 @@ public:
Eval
,
Eval
,
};
};
std
::
array
<
std
::
vector
<
std
::
shared_ptr
<
Transformation
>>
,
8
>
segments
;
std
::
array
<
std
::
vector
<
std
::
shared_ptr
<
Transformation
>>
,
9
>
segments
;
private:
private:
template
<
Segment
segment
>
template
<
Segment
segment
>
...
...
imperative/python/test/unit/amp/test_convert_format.py
浏览文件 @
d313f926
...
@@ -12,6 +12,7 @@ import megengine.functional as F
...
@@ -12,6 +12,7 @@ import megengine.functional as F
import
megengine.module
as
M
import
megengine.module
as
M
from
megengine
import
Parameter
,
Tensor
,
amp
from
megengine
import
Parameter
,
Tensor
,
amp
from
megengine.core._config
import
set_auto_format_convert
from
megengine.core._config
import
set_auto_format_convert
from
megengine.core._trace_option
import
use_symbolic_shape
class
MyModule
(
M
.
Module
):
class
MyModule
(
M
.
Module
):
...
@@ -41,22 +42,25 @@ class MyModule(M.Module):
...
@@ -41,22 +42,25 @@ class MyModule(M.Module):
def
test_convert_module
(
is_inplace
):
def
test_convert_module
(
is_inplace
):
m
=
MyModule
()
m
=
MyModule
()
expected_shape
=
{
expected_shape
=
{
"i.bn.weight"
:
(
1
,
1
,
1
,
4
),
"i.bn.weight"
:
(
1
,
4
,
1
,
1
),
"i.bn.bias"
:
(
1
,
1
,
1
,
4
),
"i.bn.bias"
:
(
1
,
4
,
1
,
1
),
"i.bn.running_mean"
:
(
1
,
1
,
1
,
4
),
"i.bn.running_mean"
:
(
1
,
4
,
1
,
1
),
"i.bn.running_var"
:
(
1
,
1
,
1
,
4
),
"i.bn.running_var"
:
(
1
,
4
,
1
,
1
),
"conv.weight"
:
(
2
,
2
,
4
,
4
,
2
),
"conv.weight"
:
(
2
,
2
,
2
,
4
,
4
),
"conv.bias"
:
(
1
,
1
,
1
,
4
),
"conv.bias"
:
(
1
,
4
,
1
,
1
),
"bn.weight"
:
(
1
,
1
,
1
,
4
),
"bn.weight"
:
(
1
,
4
,
1
,
1
),
"bn.bias"
:
(
1
,
1
,
1
,
4
),
"bn.bias"
:
(
1
,
4
,
1
,
1
),
"bn.running_mean"
:
(
1
,
1
,
1
,
4
),
"bn.running_mean"
:
(
1
,
4
,
1
,
1
),
"bn.running_var"
:
(
1
,
1
,
1
,
4
),
"bn.running_var"
:
(
1
,
4
,
1
,
1
),
"param"
:
(
1
,
1
,
1
,
3
),
"param"
:
(
1
,
3
,
1
,
1
),
"buff"
:
(
1
,
1
,
1
,
3
),
"buff"
:
(
1
,
3
,
1
,
1
),
}
}
m
=
amp
.
convert_module_format
(
m
,
is_inplace
)
m
=
amp
.
convert_module_format
(
m
,
is_inplace
)
for
name
,
param
in
m
.
named_tensors
():
for
name
,
param
in
m
.
named_tensors
():
assert
param
.
format
==
"nhwc"
assert
param
.
format
==
"nhwc"
set_auto_format_convert
(
False
)
if
use_symbolic_shape
():
assert
param
.
shape
==
expected_shape
[
name
],
name
np
.
testing
.
assert_array_equal
(
set_auto_format_convert
(
True
)
param
.
shape
.
numpy
(),
expected_shape
[
name
],
name
)
else
:
assert
param
.
shape
==
expected_shape
[
name
],
name
imperative/python/test/unit/core/test_formatted_tensor.py
浏览文件 @
d313f926
...
@@ -6,6 +6,7 @@ import megengine.functional as F
...
@@ -6,6 +6,7 @@ import megengine.functional as F
import
megengine.module
as
M
import
megengine.module
as
M
from
megengine
import
tensor
from
megengine
import
tensor
from
megengine.autodiff
import
GradManager
from
megengine.autodiff
import
GradManager
from
megengine.core._trace_option
import
use_symbolic_shape
from
megengine.jit
import
trace
from
megengine.jit
import
trace
...
@@ -121,7 +122,10 @@ def test_repeat(is_symbolic):
...
@@ -121,7 +122,10 @@ def test_repeat(is_symbolic):
@
pytest
.
mark
.
parametrize
(
"is_symbolic"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"is_symbolic"
,
[
None
])
def
test_getshape
(
is_symbolic
):
def
test_getshape
(
is_symbolic
):
def
func
(
x
):
def
func
(
x
):
return
x
.
shape
if
use_symbolic_shape
():
return
x
.
shape
.
numpy
()
else
:
return
x
.
shape
data
=
np
.
arange
(
0
,
24
).
reshape
((
1
,
2
,
3
,
4
))
data
=
np
.
arange
(
0
,
24
).
reshape
((
1
,
2
,
3
,
4
))
_compare_nchw_nhwc
(
data
,
func
,
is_symbolic
)
_compare_nchw_nhwc
(
data
,
func
,
is_symbolic
)
...
...
imperative/src/impl/transformations/format.cpp
浏览文件 @
d313f926
#include "megbrain/imperative/transformations/format.h"
#include "megbrain/imperative/transformations/format.h"
#include "megbrain/imperative/transformations/grad.h"
#include "megbrain/imperative/transformations/grad.h"
#include "megbrain/imperative/transformations/symbol.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/ops/utility.h"
...
@@ -75,6 +76,17 @@ inline ValueRefList FormatTransformation::wrap_outputs(
...
@@ -75,6 +76,17 @@ inline ValueRefList FormatTransformation::wrap_outputs(
}
}
return
wrapped_outputs
;
return
wrapped_outputs
;
}
}
inline
bool
FormatTransformation
::
check_all_format_value
(
const
Span
<
ValueRef
>&
inputs
)
const
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
if
(
!
inputs
[
i
].
as_ref
(
m_value_type
))
{
return
false
;
}
}
return
true
;
}
namespace
{
namespace
{
ValueShape
convert_nhwc2nchw_shape
(
const
ValueShape
&
shape
)
{
ValueShape
convert_nhwc2nchw_shape
(
const
ValueShape
&
shape
)
{
...
@@ -369,7 +381,8 @@ inline ValueRefList unify_inputs_format(
...
@@ -369,7 +381,8 @@ inline ValueRefList unify_inputs_format(
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
auto
&&
inp
=
inputs
[
i
].
cast
(
t
.
value_type
());
auto
&&
inp
=
inputs
[
i
].
cast
(
t
.
value_type
());
if
(
inp
.
format
()
!=
dst_fmt
&&
if
(
inp
.
format
()
!=
dst_fmt
&&
inp
.
value
().
shape
().
cast
<
ShapeValue
>
().
ndim
==
4
)
{
(
inp
.
value
().
shape
().
cast
<
ShapeValue
>
().
ndim
==
4
||
inp
.
value
().
shape
().
cast
<
ShapeValue
>
().
ndim
==
5
))
{
unified_inputs
[
i
]
=
t
.
to
(
inp
,
dst_fmt
,
scope
);
unified_inputs
[
i
]
=
t
.
to
(
inp
,
dst_fmt
,
scope
);
}
else
{
}
else
{
unified_inputs
[
i
]
=
inputs
[
i
];
unified_inputs
[
i
]
=
inputs
[
i
];
...
@@ -568,6 +581,10 @@ struct FormatRuleRegistry {
...
@@ -568,6 +581,10 @@ struct FormatRuleRegistry {
ValueRefList
FormatTransformation
::
apply_transformation
(
ValueRefList
FormatTransformation
::
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
{
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
{
if
(
auto
*
apply_op
=
op
.
as
<
ApplyOp
>
())
{
if
(
auto
*
apply_op
=
op
.
as
<
ApplyOp
>
())
{
// bypass SymbolValue
if
(
!
check_all_format_value
(
inputs
))
{
return
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
));
}
// all inputs should be FormattedTensorValue
// all inputs should be FormattedTensorValue
auto
iter
=
format_rules
.
find
(
apply_op
->
op
().
dyn_typeinfo
());
auto
iter
=
format_rules
.
find
(
apply_op
->
op
().
dyn_typeinfo
());
if
(
iter
!=
format_rules
.
end
())
{
if
(
iter
!=
format_rules
.
end
())
{
...
@@ -628,9 +645,6 @@ ValueRefList FormatTransformation::apply_transformation(
...
@@ -628,9 +645,6 @@ ValueRefList FormatTransformation::apply_transformation(
auto
&&
format
=
inp_ref
->
format
();
auto
&&
format
=
inp_ref
->
format
();
return
wrap_outputs
(
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
)),
format
);
return
wrap_outputs
(
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
)),
format
);
}
else
{
}
else
{
mgb_log_warn
(
"Not FormattedTensorValue input for IdentityLike op: %s, %s"
,
op
.
to_string
().
c_str
(),
inputs
[
0
].
to_string
().
c_str
());
return
imperative
::
apply
(
op
,
inputs
);
return
imperative
::
apply
(
op
,
inputs
);
}
}
}
else
if
(
op
.
is
<
AttachGrad
>
())
{
}
else
if
(
op
.
is
<
AttachGrad
>
())
{
...
...
imperative/src/include/megbrain/imperative/transformations/format.h
浏览文件 @
d313f926
...
@@ -70,6 +70,7 @@ public:
...
@@ -70,6 +70,7 @@ public:
const
ValueRef
&
output
,
Format
format
=
Format
::
Type
::
DEFAULT
)
const
;
const
ValueRef
&
output
,
Format
format
=
Format
::
Type
::
DEFAULT
)
const
;
inline
ValueRefList
wrap_outputs
(
inline
ValueRefList
wrap_outputs
(
const
ValueRefList
&
outputs
,
Format
format
=
Format
::
Type
::
DEFAULT
)
const
;
const
ValueRefList
&
outputs
,
Format
format
=
Format
::
Type
::
DEFAULT
)
const
;
inline
bool
check_all_format_value
(
const
Span
<
ValueRef
>&
inputs
)
const
;
TypedValueRef
<
FormattedTensorValue
>
as
(
TypedValueRef
<
FormattedTensorValue
>
as
(
const
FormattedTensorValue
&
,
const
Format
::
Type
&
target
)
const
;
const
FormattedTensorValue
&
,
const
Format
::
Type
&
target
)
const
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录