Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
fd41302c
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
fd41302c
编写于
3月 25, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(imperative/amp): add set_format
GitOrigin-RevId: 91de6f49de7dc334cbe17f2a29e11a8f40ee79d6
上级
fc633ce4
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
78 addition
and
21 deletion
+78
-21
imperative/python/megengine/amp/convert_format.py
imperative/python/megengine/amp/convert_format.py
+4
-1
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+2
-0
imperative/python/megengine/module/normalization.py
imperative/python/megengine/module/normalization.py
+9
-2
imperative/python/megengine/tensor.py
imperative/python/megengine/tensor.py
+5
-1
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+8
-1
imperative/python/src/tensor.h
imperative/python/src/tensor.h
+6
-0
imperative/python/test/unit/core/test_formatted_tensor.py
imperative/python/test/unit/core/test_formatted_tensor.py
+3
-0
imperative/src/impl/basic_operators.cpp
imperative/src/impl/basic_operators.cpp
+8
-1
imperative/src/impl/transformations/format.cpp
imperative/src/impl/transformations/format.cpp
+16
-12
imperative/src/include/megbrain/imperative/basic_operators.h
imperative/src/include/megbrain/imperative/basic_operators.h
+13
-1
imperative/src/include/megbrain/imperative/transformations/format.h
.../src/include/megbrain/imperative/transformations/format.h
+4
-2
未找到文件。
imperative/python/megengine/amp/convert_format.py
浏览文件 @
fd41302c
...
...
@@ -29,10 +29,13 @@ def convert_tensor_format(x: Tensor, inplace: bool = True):
# TODO: use initialization from tensor after fixing format setting
if
x
.
format
!=
"nhwc"
:
if
inplace
:
# reset will destroy backward grad
data
=
x
.
numpy
().
transpose
(
*
pattern
)
x
[...]
=
Tensor
(
data
,
format
=
"nhwc"
)
else
:
x
=
Tensor
(
x
.
numpy
().
transpose
(
*
pattern
),
format
=
"nhwc"
)
# use mge interface to maintain grad
x
=
F
.
transpose
(
x
,
pattern
)
x
.
format
=
"nhwc"
return
x
...
...
imperative/python/megengine/functional/nn.py
浏览文件 @
fd41302c
...
...
@@ -245,6 +245,8 @@ def conv2d(
sparse_type
=
"dense"
if
groups
==
1
else
"group"
compute_mode
=
_config
.
_get_actual_op_param
(
compute_mode
,
_config
.
__compute_mode
)
with
_config
.
_override
(
auto_format_convert
=
False
):
print
(
compute_mode
,
inp
.
shape
,
inp
.
format
,
weight
.
shape
,
weight
.
format
)
op
=
builtin
.
Convolution
(
stride_h
=
stride_h
,
stride_w
=
stride_w
,
...
...
imperative/python/megengine/module/normalization.py
浏览文件 @
fd41302c
import
numpy
as
np
import
megengine
as
mge
import
megengine.functional
as
F
from
megengine
import
Parameter
...
...
@@ -34,6 +35,7 @@ class GroupNorm(Module):
def
forward
(
self
,
x
):
N
,
C
,
H
,
W
=
x
.
shape
format
=
x
.
format
assert
C
==
self
.
num_channels
x
=
x
.
reshape
(
N
,
self
.
num_groups
,
-
1
)
...
...
@@ -44,7 +46,9 @@ class GroupNorm(Module):
x
=
x
.
reshape
(
N
,
C
,
H
,
W
)
if
self
.
affine
:
x
=
self
.
weight
.
reshape
(
1
,
-
1
,
1
,
1
)
*
x
+
self
.
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
# FIXME(czh): remove this after making it a builtin op.
if
format
==
"nhwc"
:
x
=
mge
.
amp
.
convert_tensor_format
(
x
,
inplace
=
False
)
return
x
def
_module_info_string
(
self
)
->
str
:
...
...
@@ -81,6 +85,7 @@ class InstanceNorm(Module):
def
forward
(
self
,
x
):
N
,
C
,
H
,
W
=
x
.
shape
format
=
x
.
format
assert
C
==
self
.
num_channels
x
=
x
.
reshape
(
N
,
C
,
-
1
)
mean
=
x
.
mean
(
axis
=
2
,
keepdims
=
True
)
...
...
@@ -90,7 +95,9 @@ class InstanceNorm(Module):
x
=
x
.
reshape
(
N
,
C
,
H
,
W
)
if
self
.
affine
:
x
=
self
.
weight
.
reshape
(
1
,
-
1
,
1
,
1
)
*
x
+
self
.
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
# FIXME(czh): remove this after making it a builtin op.
if
format
==
"nhwc"
:
x
=
mge
.
amp
.
convert_tensor_format
(
x
,
inplace
=
False
)
return
x
def
_module_info_string
(
self
)
->
str
:
...
...
imperative/python/megengine/tensor.py
浏览文件 @
fd41302c
...
...
@@ -122,7 +122,11 @@ class Tensor(_Tensor, ArrayMethodMixin):
@
property
def
format
(
self
)
->
str
:
return
super
().
format
return
super
().
format
()
@
format
.
setter
def
format
(
self
,
format
):
super
().
_set_format
(
format
)
@
property
def
qparams
(
self
):
...
...
imperative/python/src/tensor.cpp
浏览文件 @
fd41302c
...
...
@@ -584,6 +584,12 @@ void TensorWrapper::set_module_trace_info(PyObject* obj) {
module_trace_info_map
[
m_tensor
->
data
()]
=
py
::
reinterpret_borrow
<
py
::
object
>
(
obj
);
}
void
TensorWrapper
::
_set_format
(
PyObject
*
dest
)
{
auto
py_dest
=
py
::
reinterpret_borrow
<
py
::
object
>
(
dest
);
auto
format
=
py_dest
.
cast
<
std
::
string
>
();
m_tensor
->
set_format
(
format
);
}
void
TensorWrapper
::
_set_name
(
PyObject
*
dest
)
{
auto
py_dest
=
py
::
reinterpret_borrow
<
py
::
object
>
(
dest
);
auto
name
=
py_dest
.
cast
<
std
::
string
>
();
...
...
@@ -812,7 +818,7 @@ void init_tensor(py::module m) {
.
def_getset
<&
TensorWrapper
::
shape
>
(
"shape"
)
.
def_getset
<&
TensorWrapper
::
dtype
>
(
"dtype"
)
.
def_getset
<&
TensorWrapper
::
device
>
(
"device"
)
.
def
_getset
<&
TensorWrapper
::
format
>
(
"format"
)
.
def
<&
TensorWrapper
::
format
>
(
"format"
)
.
def
<&
TensorWrapper
::
reset
>
(
"_reset"
)
.
def
<&
TensorWrapper
::
isscalar
>
(
"_isscalar"
)
.
def
<&
TensorWrapper
::
detach
>
(
"detach"
)
...
...
@@ -820,6 +826,7 @@ void init_tensor(py::module m) {
.
def
<&
TensorWrapper
::
_dev_tensor
>
(
"_dev_tensor"
)
.
def
<&
TensorWrapper
::
_drop
>
(
"_drop"
)
.
def
<&
TensorWrapper
::
_detail
>
(
"_detail"
)
.
def
<&
TensorWrapper
::
_set_format
>
(
"_set_format"
)
.
def
<&
TensorWrapper
::
_set_name
>
(
"_set_name"
)
.
def
<&
TensorWrapper
::
_watch
>
(
"_watch"
)
.
def
<&
TensorWrapper
::
_var
>
(
"var"
)
...
...
imperative/python/src/tensor.h
浏览文件 @
fd41302c
...
...
@@ -59,6 +59,11 @@ public:
return
*
shape
;
}
inline
Format
format
()
{
return
*
data
().
format
();
}
inline
void
set_format
(
std
::
string
format
)
{
if
(
!
format
.
empty
())
{
m_data
=
imperative
::
apply
(
SetFormat
(
format
),
m_data
)[
0
];
}
}
inline
HostValue
::
ref_t
numpy
()
{
return
data
().
numpy
();
}
inline
void
reset
(
ValueRef
value
)
{
m_data
=
value
;
...
...
@@ -130,6 +135,7 @@ public:
PyObject
*
copied
();
PyObject
*
module_trace_info
();
void
set_module_trace_info
(
PyObject
*
);
void
_set_format
(
PyObject
*
);
void
_set_name
(
PyObject
*
);
PyObject
*
_detail
();
PyObject
*
_var
();
...
...
imperative/python/test/unit/core/test_formatted_tensor.py
浏览文件 @
fd41302c
...
...
@@ -31,6 +31,9 @@ def test_basic():
b
[...]
=
tensor
(
data
,
format
=
"nchw"
)
assert
b
.
format
==
"nchw"
# set tensor's format
b
.
format
=
"nhwc"
assert
b
.
format
==
"nhwc"
def
_compare_nchw_nhwc
(
data
,
func
,
is_symbolic
=
None
):
x1
=
tensor
(
data
)
...
...
imperative/src/impl/basic_operators.cpp
浏览文件 @
fd41302c
...
...
@@ -105,9 +105,16 @@ std::string IsScalar::to_string() const {
return
"IsScalar"
;
}
std
::
string
GetFormat
::
to_string
()
const
{
return
"GetFormat{}"
;
}
std
::
string
SetFormat
::
to_string
()
const
{
return
ssprintf
(
"SetFormat{format=%s}"
,
m_format
.
to_string
().
c_str
());
}
std
::
string
GetVarVal
::
to_string
()
const
{
return
"GetVarVal"
;
}
}
// namespace imperative
}
// namespace mgb
imperative/src/impl/transformations/format.cpp
浏览文件 @
fd41302c
...
...
@@ -57,15 +57,15 @@ inline ValueRefList FormatTransformation::unwrap_inputs(
}
inline
ValueRef
FormatTransformation
::
wrap_output
(
const
ValueRef
&
output
,
F
T
type
)
const
{
return
m_value_type
.
make
(
output
,
type
);
const
ValueRef
&
output
,
F
ormat
format
)
const
{
return
m_value_type
.
make
(
output
,
format
);
}
inline
ValueRefList
FormatTransformation
::
wrap_outputs
(
const
ValueRefList
&
outputs
,
F
T
type
)
const
{
const
ValueRefList
&
outputs
,
F
ormat
format
)
const
{
ValueRefList
wrapped_outputs
(
outputs
.
size
());
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
wrapped_outputs
[
i
]
=
wrap_output
(
outputs
[
i
],
type
);
wrapped_outputs
[
i
]
=
wrap_output
(
outputs
[
i
],
format
);
}
return
wrapped_outputs
;
}
...
...
@@ -241,7 +241,7 @@ ValueRefList subtensor_rule(
if
(
!
(
auto_convert
&&
src
.
format
()
==
FT
::
NHWC
))
{
return
{
t
.
wrap_output
(
imperative
::
apply
(
op
,
t
.
unwrap_inputs
(
inputs
))[
0
],
src
.
format
()
.
type
()
)};
src
.
format
())};
}
auto
nhwc_items
=
convert_nchw2nhwc_idx_items
(
op
.
items
);
auto
outputs
=
imperative
::
apply
(
...
...
@@ -264,7 +264,7 @@ ValueRefList setsubtensor_rule(
if
(
!
(
auto_convert
&&
src
.
format
()
==
FT
::
NHWC
))
{
return
{
t
.
wrap_output
(
imperative
::
apply
(
op
,
t
.
unwrap_inputs
(
inputs
))[
0
],
src
.
format
()
.
type
()
)};
src
.
format
())};
}
// value has been broadcasted to src's fake NCHW shape.
auto
&
value
=
inputs
[
1
].
cast
(
t
.
value_type
());
...
...
@@ -330,7 +330,7 @@ ValueRefList identity_rule_helper(
// mgb_assert(inputs.size() == 1);
auto
&
src
=
inputs
[
0
].
cast
(
t
.
value_type
());
return
t
.
wrap_outputs
(
imperative
::
apply
(
op
,
t
.
unwrap_inputs
(
inputs
)),
src
.
format
()
.
type
()
);
imperative
::
apply
(
op
,
t
.
unwrap_inputs
(
inputs
)),
src
.
format
());
}
ValueRefList
batchnorm_rule
(
...
...
@@ -467,7 +467,7 @@ ValueRefList FormatTransformation::apply_transformation(
}
}
else
if
(
auto
*
create_tensor
=
op
.
as
<
CreateTensor
>
())
{
auto
format
=
create_tensor
->
format
();
return
{
wrap_output
(
imperative
::
apply
(
op
,
inputs
)[
0
],
format
.
type
()
)};
return
{
wrap_output
(
imperative
::
apply
(
op
,
inputs
)[
0
],
format
)};
}
else
if
(
auto
*
get_attr
=
op
.
as
<
GetAttr
>
())
{
auto
&&
input
=
inputs
.
item
();
if
(
!
input
.
is
(
m_value_type
))
{
...
...
@@ -500,12 +500,16 @@ ValueRefList FormatTransformation::apply_transformation(
op
.
to_string
().
c_str
(),
inputs
[
0
].
to_string
().
c_str
());
return
{
FormatValue
::
make
(
FT
::
DEFAULT
)};
}
}
else
if
(
auto
*
_op
=
op
.
as
<
SetFormat
>
())
{
auto
&&
inp_ref
=
inputs
[
0
].
as_ref
(
m_value_type
);
mgb_assert
(
inp_ref
,
"Cannot set format for non-format Tensor."
);
return
{
m_value_type
.
make
(
inp_ref
->
value
(),
_op
->
format
())};
}
else
if
(
op
.
is
<
Operator
::
IdentityLike
>
())
{
auto
&&
inp_ref
=
inputs
[
0
].
as_ref
(
m_value_type
);
if
(
inp_ref
)
{
auto
&&
format
=
inp_ref
->
format
();
return
wrap_outputs
(
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
)),
format
.
type
()
);
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
)),
format
);
}
else
{
mgb_log_warn
(
"Not FormattedTensorValue input for IdentityLike op: %s, %s"
,
...
...
@@ -521,13 +525,13 @@ ValueRefList FormatTransformation::apply_transformation(
GenericFunction
new_callback
=
[
this
,
callback
,
format
](
Span
<
ValueRef
>
inputs_
)
->
ValueRefList
{
auto
wrapped_inputs
=
SmallVector
<
ValueRef
>
{
this
->
value_type
().
make
(
inputs_
.
item
(),
format
.
type
()
)};
this
->
value_type
().
make
(
inputs_
.
item
(),
format
)};
auto
ret
=
callback
(
wrapped_inputs
);
return
ret
;
};
auto
&&
outputs
=
imperative
::
apply
(
op
,
inp_ref
->
value
(),
FunctionValue
::
make
(
new_callback
));
return
wrap_outputs
(
outputs
,
format
.
type
()
);
return
wrap_outputs
(
outputs
,
format
);
}
else
{
mgb_log_warn
(
"Not FormattedTensorValue input for AttachGrad op: %s, %s"
,
...
...
@@ -549,7 +553,7 @@ ValueRefList FormatTransformation::apply_transformation(
for
(
size_t
i
=
0
;
i
<
nr_outputs
;
++
i
)
{
if
(
auto
output_ref
=
outputs_
[
i
].
as_ref
(
m_value_type
))
{
wrapped_outputs
[
i
]
=
m_value_type
.
make
(
outputs
[
i
],
output_ref
->
format
()
.
type
()
);
m_value_type
.
make
(
outputs
[
i
],
output_ref
->
format
());
}
else
{
mgb_log_warn
(
"Not FormattedTensorValue outputs for SetGrad op: %s, %s"
,
...
...
imperative/src/include/megbrain/imperative/basic_operators.h
浏览文件 @
fd41302c
...
...
@@ -164,7 +164,19 @@ public:
class
GetFormat
final
:
public
OperatorImpl
<
GetFormat
,
Operator
::
GetAttrLike
>
{
public:
std
::
string
to_string
()
const
override
{
return
"GetFormat{}"
;
}
std
::
string
to_string
()
const
override
;
};
class
SetFormat
final
:
public
OperatorImpl
<
SetFormat
,
Operator
::
IdentityLike
>
{
private:
Format
m_format
;
public:
SetFormat
(
std
::
string
format
)
:
m_format
(
format
)
{}
Format
format
()
const
{
return
m_format
;
}
std
::
string
to_string
()
const
override
;
};
class
GetVarVal
final
:
public
OperatorImpl
<
GetVarVal
,
Operator
::
GetAttrLike
>
{
...
...
imperative/src/include/megbrain/imperative/transformations/format.h
浏览文件 @
fd41302c
...
...
@@ -26,6 +26,8 @@ public:
const
Format
&
format
()
const
{
return
m_format
;
}
void
set_format
(
Format
format
)
{
m_format
=
format
;
}
void
clear
()
override
{
m_value
=
{};
m_format
=
{};
...
...
@@ -65,10 +67,10 @@ public:
inline
ValueRef
unwrap_input
(
const
ValueRef
&
input
)
const
;
inline
ValueRefList
unwrap_inputs
(
const
Span
<
ValueRef
>&
inputs
)
const
;
inline
ValueRef
wrap_output
(
const
ValueRef
&
output
,
Format
::
Type
type
=
Format
::
Type
::
DEFAULT
)
const
;
const
ValueRef
&
output
,
Format
format
=
Format
::
Type
::
DEFAULT
)
const
;
inline
ValueRefList
wrap_outputs
(
const
ValueRefList
&
outputs
,
Format
::
Type
type
=
Format
::
Type
::
DEFAULT
)
const
;
Format
format
=
Format
::
Type
::
DEFAULT
)
const
;
TypedValueRef
<
FormattedTensorValue
>
as
(
const
FormattedTensorValue
&
,
const
Format
::
Type
&
target
)
const
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录