Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
533fb5bf
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
533fb5bf
编写于
1月 07, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(imperative): support formatted tensor and add special op rules
GitOrigin-RevId: 77ff909f2371f768442fb103ec7038832f9310f6
上级
4aa79c45
变更
18
显示空白变更内容
内联
并排
Showing
18 changed file
with
985 addition
and
33 deletion
+985
-33
imperative/python/megengine/__init__.py
imperative/python/megengine/__init__.py
+1
-0
imperative/python/megengine/core/_config.py
imperative/python/megengine/core/_config.py
+52
-16
imperative/python/megengine/functional/vision.py
imperative/python/megengine/functional/vision.py
+0
-1
imperative/python/megengine/tensor.py
imperative/python/megengine/tensor.py
+8
-0
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+35
-3
imperative/python/src/tensor.h
imperative/python/src/tensor.h
+5
-2
imperative/python/src/transformation.h
imperative/python/src/transformation.h
+1
-0
imperative/python/test/unit/amp/test_autocast.py
imperative/python/test/unit/amp/test_autocast.py
+1
-1
imperative/python/test/unit/core/test_formatted_tensor.py
imperative/python/test/unit/core/test_formatted_tensor.py
+307
-0
imperative/src/impl/basic_operators.cpp
imperative/src/impl/basic_operators.cpp
+12
-5
imperative/src/impl/transformations/format.cpp
imperative/src/impl/transformations/format.cpp
+406
-0
imperative/src/impl/value.cpp
imperative/src/impl/value.cpp
+4
-0
imperative/src/include/megbrain/imperative/basic_operators.h
imperative/src/include/megbrain/imperative/basic_operators.h
+11
-1
imperative/src/include/megbrain/imperative/basic_values.h
imperative/src/include/megbrain/imperative/basic_values.h
+8
-0
imperative/src/include/megbrain/imperative/transformations/format.h
.../src/include/megbrain/imperative/transformations/format.h
+70
-0
imperative/src/include/megbrain/imperative/utils/data_format.h
...ative/src/include/megbrain/imperative/utils/data_format.h
+56
-0
imperative/src/include/megbrain/imperative/value.h
imperative/src/include/megbrain/imperative/value.h
+7
-3
src/opr/impl/dnn/batch_norm.cpp
src/opr/impl/dnn/batch_norm.cpp
+1
-1
未找到文件。
imperative/python/megengine/__init__.py
浏览文件 @
533fb5bf
...
@@ -156,6 +156,7 @@ _atexit(_persistent_cache.flush)
...
@@ -156,6 +156,7 @@ _atexit(_persistent_cache.flush)
# subpackages
# subpackages
import
megengine.amp
import
megengine.amp
import
megengine.autodiff
import
megengine.autodiff
import
megengine.config
import
megengine.data
import
megengine.data
import
megengine.distributed
import
megengine.distributed
import
megengine.dtr
import
megengine.dtr
...
...
imperative/python/megengine/core/_config.py
浏览文件 @
533fb5bf
...
@@ -2,7 +2,13 @@
...
@@ -2,7 +2,13 @@
import
os
import
os
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
._imperative_rt.core2
import
_clear_algorithm_cache
,
get_option
,
set_option
from
._imperative_rt.core2
import
(
_clear_algorithm_cache
,
get_auto_format_convert
,
get_option
,
set_auto_format_convert
,
set_option
,
)
__compute_mode
=
"default"
__compute_mode
=
"default"
__conv_format
=
"default"
__conv_format
=
"default"
...
@@ -153,20 +159,41 @@ def _conv_format(mod, format: str):
...
@@ -153,20 +159,41 @@ def _conv_format(mod, format: str):
__conv_format
=
format
__conv_format
=
format
@
property
def
_auto_format_convert
(
mod
):
r
"""Automatically convert indexing params' order for NCHW Tensor to NHWC order.
The default value is False, which means no convert.
Examples:
.. code-block::
import megengine as mge
mge.config._auto_format_convert = True
"""
return
get_auto_format_convert
()
@
_auto_format_convert
.
setter
def
_auto_format_convert
(
mod
,
option
:
bool
):
set_auto_format_convert
(
option
)
def
_reset_execution_config
(
def
_reset_execution_config
(
benchmark_kernel
=
None
,
benchmark_kernel
=
None
,
deterministic_kernel
=
None
,
deterministic_kernel
=
None
,
async_level
=
None
,
async_level
=
None
,
compute_mode
=
None
,
compute_mode
=
None
,
conv_format
=
None
,
conv_format
=
None
,
auto_format_convert
=
None
,
):
):
global
_benchmark_kernel
,
_deterministic_kernel
,
_
async_level
,
_
_compute_mode
,
__conv_format
global
_benchmark_kernel
,
_deterministic_kernel
,
__compute_mode
,
__conv_format
orig_flags
=
(
orig_flags
=
(
_benchmark_kernel
,
_benchmark_kernel
,
_deterministic_kernel
,
_deterministic_kernel
,
get_option
(
"async_level"
),
get_option
(
"async_level"
),
__compute_mode
,
__compute_mode
,
__conv_format
,
__conv_format
,
get_auto_format_convert
(),
)
)
if
benchmark_kernel
is
not
None
:
if
benchmark_kernel
is
not
None
:
_benchmark_kernel
=
benchmark_kernel
_benchmark_kernel
=
benchmark_kernel
...
@@ -178,6 +205,8 @@ def _reset_execution_config(
...
@@ -178,6 +205,8 @@ def _reset_execution_config(
__compute_mode
=
compute_mode
__compute_mode
=
compute_mode
if
conv_format
is
not
None
:
if
conv_format
is
not
None
:
__conv_format
=
conv_format
__conv_format
=
conv_format
if
auto_format_convert
is
not
None
:
set_auto_format_convert
(
auto_format_convert
)
return
orig_flags
return
orig_flags
...
@@ -189,6 +218,7 @@ def _override(
...
@@ -189,6 +218,7 @@ def _override(
async_level
=
None
,
async_level
=
None
,
compute_mode
=
None
,
compute_mode
=
None
,
conv_format
=
None
,
conv_format
=
None
,
auto_format_convert
=
None
,
):
):
r
"""A context manager that users can opt in by attaching the decorator to set
r
"""A context manager that users can opt in by attaching the decorator to set
the config of the global variable.
the config of the global variable.
...
@@ -204,11 +234,17 @@ def _override(
...
@@ -204,11 +234,17 @@ def _override(
async_level=2,
async_level=2,
compute_mode="float32",
compute_mode="float32",
conv_format="NHWC",
conv_format="NHWC",
auto_format_convert=True,
)
)
def train():
def train():
"""
"""
orig_flags
=
_reset_execution_config
(
orig_flags
=
_reset_execution_config
(
benchmark_kernel
,
deterministic_kernel
,
async_level
,
compute_mode
,
conv_format
,
benchmark_kernel
,
deterministic_kernel
,
async_level
,
compute_mode
,
conv_format
,
auto_format_convert
,
)
)
try
:
try
:
yield
yield
...
...
imperative/python/megengine/functional/vision.py
浏览文件 @
533fb5bf
...
@@ -564,7 +564,6 @@ def interpolate(
...
@@ -564,7 +564,6 @@ def interpolate(
if
inp
.
dtype
==
np
.
float16
:
if
inp
.
dtype
==
np
.
float16
:
inp
=
inp
.
astype
(
"float32"
)
inp
=
inp
.
astype
(
"float32"
)
conv_format
=
_config
.
_get_actual_op_param
(
"NCHW"
,
_config
.
__conv_format
)
conv_format
=
_config
.
_get_actual_op_param
(
"NCHW"
,
_config
.
__conv_format
)
assert
conv_format
==
"NCHW"
,
"Currently resize only support NCHW mode"
op
=
builtin
.
Resize
(
imode
=
mode_map
[
mode
],
format
=
conv_format
)
op
=
builtin
.
Resize
(
imode
=
mode_map
[
mode
],
format
=
conv_format
)
shape
=
astensor1d
(
dsize
,
inp
,
dtype
=
"int32"
,
device
=
inp
.
device
)
shape
=
astensor1d
(
dsize
,
inp
,
dtype
=
"int32"
,
device
=
inp
.
device
)
(
ret
,)
=
apply
(
op
,
inp
,
shape
)
(
ret
,)
=
apply
(
op
,
inp
,
shape
)
...
...
imperative/python/megengine/tensor.py
浏览文件 @
533fb5bf
...
@@ -4,6 +4,7 @@ from typing import Union
...
@@ -4,6 +4,7 @@ from typing import Union
import
numpy
as
np
import
numpy
as
np
from
.core._imperative_rt
import
CompNode
from
.core._imperative_rt
import
CompNode
from
.core._imperative_rt.core2
import
FormatType
from
.core._imperative_rt.core2
import
Tensor
as
_Tensor
from
.core._imperative_rt.core2
import
Tensor
as
_Tensor
from
.core._imperative_rt.core2
import
apply
,
set_py_tensor_type
from
.core._imperative_rt.core2
import
apply
,
set_py_tensor_type
from
.core._trace_option
import
use_symbolic_shape
from
.core._trace_option
import
use_symbolic_shape
...
@@ -45,6 +46,8 @@ class Tensor(_Tensor, ArrayMethodMixin):
...
@@ -45,6 +46,8 @@ class Tensor(_Tensor, ArrayMethodMixin):
is_const: Whether make it a ``ImutableTensor`` in tracing mode, refer to :class:`.jit.trace`.
is_const: Whether make it a ``ImutableTensor`` in tracing mode, refer to :class:`.jit.trace`.
no_cache: Whether cache it for memory sharing.
no_cache: Whether cache it for memory sharing.
name: Used to improve convenience in graph operation on dumped model.
name: Used to improve convenience in graph operation on dumped model.
format: Used to indicate which memory format Tensor uses. It will not affect actual memory order or stride,
but may affect some operators related to indexing and dimension. Only support "default", "nchw" and "nhwc".
.. note::
.. note::
...
@@ -73,6 +76,7 @@ class Tensor(_Tensor, ArrayMethodMixin):
...
@@ -73,6 +76,7 @@ class Tensor(_Tensor, ArrayMethodMixin):
is_const
:
bool
=
False
,
is_const
:
bool
=
False
,
no_cache
:
bool
=
False
,
no_cache
:
bool
=
False
,
name
:
str
=
None
,
name
:
str
=
None
,
format
:
str
=
"default"
,
):
):
if
name
is
None
:
if
name
is
None
:
name
=
""
name
=
""
...
@@ -116,6 +120,10 @@ class Tensor(_Tensor, ArrayMethodMixin):
...
@@ -116,6 +120,10 @@ class Tensor(_Tensor, ArrayMethodMixin):
r
"""Returns a :class:`numpy.dtype` object represents the data type of a :class:`~.Tensor`."""
r
"""Returns a :class:`numpy.dtype` object represents the data type of a :class:`~.Tensor`."""
return
super
().
dtype
return
super
().
dtype
@
property
def
format
(
self
)
->
str
:
return
super
().
format
@
property
@
property
def
qparams
(
self
):
def
qparams
(
self
):
r
"""Returns a :class:`~.QParams` object containing quantization params of a :class:`~.Tensor`."""
r
"""Returns a :class:`~.QParams` object containing quantization params of a :class:`~.Tensor`."""
...
...
imperative/python/src/tensor.cpp
浏览文件 @
533fb5bf
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include "megbrain/imperative/transformations/dim_expansion.h"
#include "megbrain/imperative/transformations/dim_expansion.h"
#include "megbrain/imperative/transformations/dtype_promote.h"
#include "megbrain/imperative/transformations/dtype_promote.h"
#include "megbrain/imperative/transformations/eval.h"
#include "megbrain/imperative/transformations/eval.h"
#include "megbrain/imperative/transformations/format.h"
#include "megbrain/imperative/transformations/lazy.h"
#include "megbrain/imperative/transformations/lazy.h"
#include "megbrain/imperative/transformations/scalar.h"
#include "megbrain/imperative/transformations/scalar.h"
#include "megbrain/imperative/transformations/symbol.h"
#include "megbrain/imperative/transformations/symbol.h"
...
@@ -492,6 +493,9 @@ ssize_t name2idx(const char* name) {
...
@@ -492,6 +493,9 @@ ssize_t name2idx(const char* name) {
// name
// name
case
'a'
:
return
compare_cstr
<
'm'
,
'e'
>
(
ch
)
?
5
:
-
1
;
case
'a'
:
return
compare_cstr
<
'm'
,
'e'
>
(
ch
)
?
5
:
-
1
;
}
}
case
'f'
:
// format
return
compare_cstr
<
'o'
,
'r'
,
'm'
,
'a'
,
't'
>
(
ch
)
?
6
:
-
1
;
}
}
// clang-format on
// clang-format on
return
-
1
;
return
-
1
;
...
@@ -508,6 +512,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
...
@@ -508,6 +512,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
{
"is_const"
,
[]()
->
py
::
object
{
return
py
::
bool_
(
false
);
}},
{
"is_const"
,
[]()
->
py
::
object
{
return
py
::
bool_
(
false
);
}},
{
"no_cache"
,
[]()
->
py
::
object
{
return
py
::
bool_
(
false
);
}},
{
"no_cache"
,
[]()
->
py
::
object
{
return
py
::
bool_
(
false
);
}},
{
"name"
,
[]()
->
py
::
object
{
return
py
::
none
();
}},
{
"name"
,
[]()
->
py
::
object
{
return
py
::
none
();
}},
{
"format"
,
[]()
->
py
::
object
{
return
py
::
none
();
}},
},
},
name2idx
};
name2idx
};
py
::
detail
::
loader_life_support
life_sup
;
// FIXME!!!required to cast DType
py
::
detail
::
loader_life_support
life_sup
;
// FIXME!!!required to cast DType
...
@@ -518,19 +523,23 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
...
@@ -518,19 +523,23 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
}
else
{
}
else
{
tup
=
parse_args
(
tup
,
descs
);
tup
=
parse_args
(
tup
,
descs
);
}
}
mgb_assert
(
tup
.
size
()
==
6
);
mgb_assert
(
tup
.
size
()
==
7
);
if
(
auto
*
t
=
try_cast
(
tup
[
0
].
ptr
()))
{
if
(
auto
*
t
=
try_cast
(
tup
[
0
].
ptr
()))
{
m_tensor
=
t
->
m_tensor
->
copy
();
m_tensor
=
t
->
m_tensor
->
copy
();
}
else
{
}
else
{
auto
data
=
tup
[
0
];
auto
data
=
tup
[
0
];
DType
dtype
=
tup
[
1
].
cast
<
DType
>
();
DType
dtype
=
tup
[
1
].
cast
<
DType
>
();
CompNode
cn
=
as_comp_node
(
tup
[
2
]);
bool
is_const
=
tup
[
3
].
cast
<
bool
>
();
bool
is_const
=
tup
[
3
].
cast
<
bool
>
();
bool
no_cache
=
tup
[
4
].
cast
<
bool
>
();
bool
no_cache
=
tup
[
4
].
cast
<
bool
>
();
std
::
string
name
;
std
::
string
name
;
if
(
!
tup
[
5
].
is_none
())
{
if
(
!
tup
[
5
].
is_none
())
{
name
=
tup
[
5
].
cast
<
std
::
string
>
();
name
=
tup
[
5
].
cast
<
std
::
string
>
();
}
}
CompNode
cn
=
as_comp_node
(
tup
[
2
]);
Format
format
;
if
(
!
tup
[
6
].
is_none
())
{
format
=
tup
[
6
].
cast
<
std
::
string
>
();
}
{
{
CreateTensor
::
Kind
kind
=
is_const
?
CreateTensor
::
Const
CreateTensor
::
Kind
kind
=
is_const
?
CreateTensor
::
Const
...
@@ -544,7 +553,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
...
@@ -544,7 +553,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
}
else
{
}
else
{
auto
&&
hval
=
pyobj2hval
(
data
,
cn
,
dtype
);
auto
&&
hval
=
pyobj2hval
(
data
,
cn
,
dtype
);
val
=
imperative
::
apply
(
val
=
imperative
::
apply
(
CreateTensor
(
kind
,
cn
,
hval
.
dtype
,
hval
.
shape
),
CreateTensor
(
kind
,
cn
,
hval
.
dtype
,
hval
.
shape
,
format
),
hval
.
storage
)[
0
];
hval
.
storage
)[
0
];
}
}
m_tensor
.
emplace
(
val
);
m_tensor
.
emplace
(
val
);
...
@@ -610,6 +619,10 @@ PyObject* TensorWrapper::device() {
...
@@ -610,6 +619,10 @@ PyObject* TensorWrapper::device() {
return
py
::
cast
(
m_tensor
->
comp_node
()).
release
().
ptr
();
return
py
::
cast
(
m_tensor
->
comp_node
()).
release
().
ptr
();
}
}
PyObject
*
TensorWrapper
::
format
()
{
return
py
::
cast
(
m_tensor
->
format
().
to_string
()).
release
().
ptr
();
}
PyObject
*
TensorWrapper
::
numpy
()
{
PyObject
*
TensorWrapper
::
numpy
()
{
auto
hv
=
m_tensor
->
numpy
();
auto
hv
=
m_tensor
->
numpy
();
if
(
!
hv
)
{
if
(
!
hv
)
{
...
@@ -722,6 +735,7 @@ WRAP_FUNC_PY35(pixel_shuffle_cpp);
...
@@ -722,6 +735,7 @@ WRAP_FUNC_PY35(pixel_shuffle_cpp);
void
init_tensor
(
py
::
module
m
)
{
void
init_tensor
(
py
::
module
m
)
{
imperative
::
Tensor
::
static_initialize
();
imperative
::
Tensor
::
static_initialize
();
// Transformations
static
auto
&
transformations
=
TransformationManager
::
get_instance
();
static
auto
&
transformations
=
TransformationManager
::
get_instance
();
using
Segment
=
TransformationManager
::
Segment
;
using
Segment
=
TransformationManager
::
Segment
;
...
@@ -755,6 +769,9 @@ void init_tensor(py::module m) {
...
@@ -755,6 +769,9 @@ void init_tensor(py::module m) {
.
register_at
<
Segment
::
DimExpansion
>
(
.
register_at
<
Segment
::
DimExpansion
>
(
std
::
make_shared
<
DimExpansionTransformation
>
())
std
::
make_shared
<
DimExpansionTransformation
>
())
.
release
());
.
release
());
auto
format_trans
=
std
::
make_shared
<
FormatTransformation
>
();
MGB_MARK_USED_VAR
(
transformations
.
register_at
<
Segment
::
Format
>
(
format_trans
).
release
());
static
py
::
exception
<
interpreter
::
AsyncError
>
py_async_error
(
static
py
::
exception
<
interpreter
::
AsyncError
>
py_async_error
(
m
,
"AsyncError"
,
PyExc_RuntimeError
);
m
,
"AsyncError"
,
PyExc_RuntimeError
);
...
@@ -788,12 +805,14 @@ void init_tensor(py::module m) {
...
@@ -788,12 +805,14 @@ void init_tensor(py::module m) {
}
}
});
});
// Tensor
auto
*
tensor_type
=
auto
*
tensor_type
=
TensorWrapper
::
wrap_t
::
type
()
TensorWrapper
::
wrap_t
::
type
()
.
def
<&
TensorWrapper
::
numpy
>
(
"numpy"
)
.
def
<&
TensorWrapper
::
numpy
>
(
"numpy"
)
.
def_getset
<&
TensorWrapper
::
shape
>
(
"shape"
)
.
def_getset
<&
TensorWrapper
::
shape
>
(
"shape"
)
.
def_getset
<&
TensorWrapper
::
dtype
>
(
"dtype"
)
.
def_getset
<&
TensorWrapper
::
dtype
>
(
"dtype"
)
.
def_getset
<&
TensorWrapper
::
device
>
(
"device"
)
.
def_getset
<&
TensorWrapper
::
device
>
(
"device"
)
.
def_getset
<&
TensorWrapper
::
format
>
(
"format"
)
.
def
<&
TensorWrapper
::
reset
>
(
"_reset"
)
.
def
<&
TensorWrapper
::
reset
>
(
"_reset"
)
.
def
<&
TensorWrapper
::
isscalar
>
(
"_isscalar"
)
.
def
<&
TensorWrapper
::
isscalar
>
(
"_isscalar"
)
.
def
<&
TensorWrapper
::
detach
>
(
"detach"
)
.
def
<&
TensorWrapper
::
detach
>
(
"detach"
)
...
@@ -812,6 +831,11 @@ void init_tensor(py::module m) {
...
@@ -812,6 +831,11 @@ void init_tensor(py::module m) {
if
(
!
tensor_type
)
if
(
!
tensor_type
)
throw
py
::
error_already_set
();
throw
py
::
error_already_set
();
py
::
setattr
(
m
,
"Tensor"
,
tensor_type
);
py
::
setattr
(
m
,
"Tensor"
,
tensor_type
);
py
::
enum_
<
Format
::
Type
>
(
m
,
"FormatType"
)
.
value
(
"DEFAULT"
,
Format
::
Type
::
DEFAULT
)
.
value
(
"NCHW"
,
Format
::
Type
::
NCHW
)
.
value
(
"NHWC"
,
Format
::
Type
::
NHWC
)
.
export_values
();
py
::
class_
<
TensorWeakRef
>
(
m
,
"TensorWeakRef"
)
py
::
class_
<
TensorWeakRef
>
(
m
,
"TensorWeakRef"
)
.
def
(
py
::
init
<
const
TensorWrapper
&>
())
.
def
(
py
::
init
<
const
TensorWrapper
&>
())
...
@@ -911,6 +935,7 @@ void init_tensor(py::module m) {
...
@@ -911,6 +935,7 @@ void init_tensor(py::module m) {
sync_py_task_q
();
sync_py_task_q
();
});
});
// GradTransformation
py
::
handle
grad_key_type
=
py
::
handle
grad_key_type
=
GradKeyWrapper
::
wrap_t
::
type
()
GradKeyWrapper
::
wrap_t
::
type
()
.
def
<&
GradKeyWrapper
::
attach
>
(
"attach"
)
.
def
<&
GradKeyWrapper
::
attach
>
(
"attach"
)
...
@@ -1203,6 +1228,7 @@ void init_tensor(py::module m) {
...
@@ -1203,6 +1228,7 @@ void init_tensor(py::module m) {
return
wrapped_outputs
;
return
wrapped_outputs
;
});
});
// ModuleTraceTransformation
static
py
::
function
module_trace_hook
;
static
py
::
function
module_trace_hook
;
static
auto
get_module_trace
=
[]
{
static
auto
get_module_trace
=
[]
{
...
@@ -1309,6 +1335,12 @@ void init_tensor(py::module m) {
...
@@ -1309,6 +1335,12 @@ void init_tensor(py::module m) {
m
.
def
(
"_clear_algorithm_cache"
,
[]
{
megdnn
::
AlgorithmCache
::
instance
().
clear
();
});
m
.
def
(
"_clear_algorithm_cache"
,
[]
{
megdnn
::
AlgorithmCache
::
instance
().
clear
();
});
// FormatTransformation
m
.
def
(
"set_auto_format_convert"
,
[
format_trans
](
bool
enabled
)
{
format_trans
->
set_auto_convert
(
enabled
);
});
m
.
def
(
"get_auto_format_convert"
,
[
format_trans
]()
{
return
format_trans
->
get_auto_convert
();
});
py
::
register_exception
<
TraceError
>
(
m
,
"TraceError"
);
py
::
register_exception
<
TraceError
>
(
m
,
"TraceError"
);
}
}
...
...
imperative/python/src/tensor.h
浏览文件 @
533fb5bf
#pragma once
#pragma once
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
#include <variant>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <variant>
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/interpreter.h"
#include "megbrain/imperative/interpreter.h"
#include "pybind11/pybind11.h"
#include "pybind11/pybind11.h"
...
@@ -57,6 +58,7 @@ public:
...
@@ -57,6 +58,7 @@ public:
}
}
return
*
shape
;
return
*
shape
;
}
}
inline
Format
format
()
{
return
*
data
().
format
();
}
inline
HostValue
::
ref_t
numpy
()
{
return
data
().
numpy
();
}
inline
HostValue
::
ref_t
numpy
()
{
return
data
().
numpy
();
}
inline
void
reset
(
ValueRef
value
)
{
inline
void
reset
(
ValueRef
value
)
{
m_data
=
value
;
m_data
=
value
;
...
@@ -116,6 +118,7 @@ public:
...
@@ -116,6 +118,7 @@ public:
PyObject
*
shape
();
PyObject
*
shape
();
PyObject
*
dtype
();
PyObject
*
dtype
();
PyObject
*
device
();
PyObject
*
device
();
PyObject
*
format
();
PyObject
*
numpy
();
PyObject
*
numpy
();
void
reset
(
PyObject
*
);
void
reset
(
PyObject
*
);
PyObject
*
detach
();
PyObject
*
detach
();
...
...
imperative/python/src/transformation.h
浏览文件 @
533fb5bf
...
@@ -19,6 +19,7 @@ public:
...
@@ -19,6 +19,7 @@ public:
DTypePromote
,
DTypePromote
,
DimExpansion
,
DimExpansion
,
Grad
,
Grad
,
Format
,
Scalar
,
Scalar
,
Symbol
,
Symbol
,
Trace
,
Trace
,
...
...
imperative/python/test/unit/amp/test_autocast.py
浏览文件 @
533fb5bf
...
@@ -2,7 +2,7 @@ from megengine import amp
...
@@ -2,7 +2,7 @@ from megengine import amp
from
megengine.core.tensor
import
amp
as
origin_amp
from
megengine.core.tensor
import
amp
as
origin_amp
def
test_
grad_scaler
():
def
test_
autocast
():
def
check
(
enabled
,
low
,
high
):
def
check
(
enabled
,
low
,
high
):
assert
amp
.
enabled
==
enabled
assert
amp
.
enabled
==
enabled
assert
origin_amp
.
_enabled
==
enabled
assert
origin_amp
.
_enabled
==
enabled
...
...
imperative/python/test/unit/core/test_formatted_tensor.py
0 → 100644
浏览文件 @
533fb5bf
import
numpy
as
np
import
pytest
import
megengine
as
mge
import
megengine.functional
as
F
from
megengine
import
tensor
from
megengine.autodiff
import
GradManager
def
test_basic
():
a
=
tensor
(
np
.
arange
(
0
,
24
).
reshape
((
1
,
2
,
3
,
4
)),
dtype
=
"float32"
,
format
=
"nhwc"
)
assert
a
.
format
==
"nhwc"
b
=
tensor
(
a
)
assert
b
.
format
==
"nhwc"
# TODO: fix Tensor init bug for another Tensor
# c = tensor(a, format="nchw")
# assert c.format == "nchw"
def
_compare_nchw_nhwc
(
data
,
func
):
x1
=
tensor
(
data
,
format
=
"nchw"
)
x2
=
tensor
(
data
.
transpose
(
0
,
2
,
3
,
1
),
format
=
"nhwc"
)
out1
=
func
(
x1
)
with
mge
.
config
.
_override
(
auto_format_convert
=
True
):
out2
=
func
(
x2
)
np
.
testing
.
assert_equal
(
out1
,
out2
)
def
test_dimshuffle
():
def
func
(
x
):
out
=
F
.
transpose
(
x
,
[
2
,
3
,
0
,
1
])
assert
out
.
format
==
"default"
return
out
.
numpy
()
data
=
np
.
arange
(
0
,
24
).
reshape
((
1
,
2
,
3
,
4
))
_compare_nchw_nhwc
(
data
,
func
)
def
test_reshape
():
# maintain NHWC format
def
func
(
x
):
out
=
F
.
reshape
(
x
,
(
1
,
2
,
6
,
2
))
if
x
.
format
==
"nhwc"
:
assert
out
.
format
==
"nhwc"
return
out
.
numpy
()
data
=
np
.
arange
(
0
,
24
).
reshape
((
1
,
2
,
3
,
4
))
_compare_nchw_nhwc
(
data
,
func
)
# not maintain NHWC format
def
func2
(
x
):
out
=
F
.
reshape
(
x
,
(
1
,
24
))
assert
out
.
format
==
"default"
return
out
.
numpy
()
_compare_nchw_nhwc
(
data
,
func2
)
def
test_flatten
():
def
func
(
x
):
return
F
.
flatten
(
x
).
numpy
()
data
=
np
.
arange
(
0
,
24
).
reshape
((
1
,
2
,
3
,
4
))
_compare_nchw_nhwc
(
data
,
func
)
def
test_broadcast
():
# maintain NHWC format
def
func
(
x
):
out
=
F
.
broadcast_to
(
x
,
(
4
,
3
,
2
,
3
))
if
x
.
format
==
"nhwc"
:
assert
out
.
format
==
"nhwc"
return
out
.
numpy
()
data
=
np
.
arange
(
0
,
24
).
reshape
((
4
,
3
,
2
,
1
))
_compare_nchw_nhwc
(
data
,
func
)
# not maintain NHWC format
def
func2
(
x
):
out
=
F
.
broadcast_to
(
x
,
(
3
,
4
,
3
,
2
,
1
))
assert
out
.
format
==
"default"
return
out
.
numpy
()
_compare_nchw_nhwc
(
data
,
func2
)
@
pytest
.
mark
.
skip
(
"repeat cannot maintain format yet"
)
def
test_repeat
():
def
func
(
x
):
rst
=
F
.
repeat
(
x
,
3
,
axis
=
1
)
assert
rst
.
format
==
x
.
format
return
rst
.
numpy
()
data
=
np
.
arange
(
0
,
24
).
reshape
((
1
,
2
,
3
,
4
))
_compare_nchw_nhwc
(
data
,
func
)
def
test_getshape
():
def
func
(
x
):
return
x
.
shape
data
=
np
.
arange
(
0
,
24
).
reshape
((
1
,
2
,
3
,
4
))
_compare_nchw_nhwc
(
data
,
func
)
@
pytest
.
mark
.
skip
(
"symbolic shape is not supported yet"
)
def
test_get_symbolic_shape
():
from
megengine.core._trace_option
import
set_symbolic_shape
origin_opt
=
set_symbolic_shape
(
True
)
def
func
(
x
):
return
x
.
shape
.
numpy
()
data
=
np
.
arange
(
0
,
24
).
reshape
((
1
,
2
,
3
,
4
))
_compare_nchw_nhwc
(
data
,
func
)
set_symbolic_shape
(
origin_opt
)
def
test_getvalue
():
def
func
(
x
):
return
x
.
numpy
()
data
=
np
.
arange
(
0
,
24
).
reshape
((
1
,
2
,
3
,
4
))
_compare_nchw_nhwc
(
data
,
func
)
def
test_get_set_subtensor
():
def
get_subtensor
(
x
):
return
x
[:,
:
1
,
:
2
,
:
3
].
numpy
()
data
=
np
.
arange
(
0
,
24
).
reshape
((
1
,
2
,
3
,
4
))
_compare_nchw_nhwc
(
data
,
get_subtensor
)
def
set_subtensor
(
x
):
x
[:,
:
1
,
:
2
,
:
3
]
=
0
return
x
.
numpy
()
_compare_nchw_nhwc
(
data
,
set_subtensor
)
def
test_get_set_advanced_indexing
():
def
get_advanced_indexing
(
x
):
x
=
x
[:,
:
mge
.
tensor
(
2
),
:
mge
.
tensor
(
2
),
[
1
,
2
]].
numpy
()
return
x
data
=
np
.
arange
(
0
,
24
).
reshape
((
1
,
2
,
3
,
4
))
_compare_nchw_nhwc
(
data
,
get_advanced_indexing
)
def
set_advanced_indexing
(
x
):
x
[:,
:
mge
.
tensor
(
2
),
:
mge
.
tensor
([
2
]),
[
1
,]]
=
0
return
x
.
numpy
()
_compare_nchw_nhwc
(
data
,
set_advanced_indexing
)
def
test_typecvt
():
def
typecvt
(
x
):
return
x
.
astype
(
"float16"
).
numpy
()
data
=
np
.
arange
(
0
,
24
).
reshape
((
1
,
2
,
3
,
4
))
_compare_nchw_nhwc
(
data
,
typecvt
)
def
test_elemwise
():
def
elemwise
(
x
):
return
(
x
*
2
+
x
/
2
).
numpy
()
data
=
np
.
arange
(
0
,
24
).
reshape
((
1
,
2
,
3
,
4
))
_compare_nchw_nhwc
(
data
,
elemwise
)
def
test_concat
():
def
func
(
x
):
rst
=
F
.
concat
([
x
/
2
,
x
*
2
],
axis
=
1
)
assert
rst
.
format
==
x
.
format
return
rst
.
numpy
()
data
=
np
.
arange
(
0
,
24
).
reshape
((
1
,
2
,
3
,
4
))
_compare_nchw_nhwc
(
data
,
func
)
@
pytest
.
mark
.
parametrize
(
"mode"
,
[
"bilinear"
,
"nearest"
],
)
def
test_interpolate
(
mode
):
def
func
(
x
):
if
x
.
format
==
"nhwc"
:
with
mge
.
config
.
_override
(
conv_format
=
"NHWC"
):
rst
=
F
.
vision
.
interpolate
(
x
,
scale_factor
=
3
,
mode
=
mode
)
assert
rst
.
format
==
"nhwc"
return
rst
.
numpy
()
else
:
return
F
.
vision
.
interpolate
(
x
,
scale_factor
=
3
,
mode
=
mode
).
numpy
()
# NHWC interpolate only suppoted channel is 1 or 3
data
=
np
.
arange
(
0
,
48
).
reshape
((
1
,
3
,
4
,
4
)).
astype
(
"float32"
)
_compare_nchw_nhwc
(
data
,
func
)
def
test_conv2d
():
def
conv2d
(
x
):
if
x
.
format
==
"nhwc"
:
with
mge
.
config
.
_override
(
conv_format
=
"NHWC"
):
x
=
F
.
conv2d
(
x
,
weight
=
mge
.
tensor
(
np
.
ones
((
3
,
1
,
1
,
2
)),
format
=
"nhwc"
),
bias
=
mge
.
tensor
(
np
.
ones
((
1
,
1
,
1
,
3
)),
format
=
"nhwc"
),
)
assert
x
.
format
==
"nhwc"
return
x
.
numpy
()
else
:
return
F
.
conv2d
(
x
,
F
.
ones
((
3
,
2
,
1
,
1
)),
F
.
ones
((
1
,
3
,
1
,
1
))).
numpy
()
data
=
np
.
arange
(
0
,
24
).
reshape
((
1
,
2
,
3
,
4
))
_compare_nchw_nhwc
(
data
,
conv2d
)
def
test_group_conv2d
():
def
conv2d
(
x
):
if
x
.
format
==
"nhwc"
:
with
mge
.
config
.
_override
(
conv_format
=
"NHWC"
):
x
=
F
.
conv2d
(
x
,
weight
=
mge
.
tensor
(
np
.
ones
((
2
,
2
,
1
,
1
,
2
)),
format
=
"nhwc"
),
bias
=
mge
.
tensor
(
np
.
ones
((
1
,
1
,
1
,
4
)),
format
=
"nhwc"
),
groups
=
2
,
)
assert
x
.
format
==
"nhwc"
return
x
.
numpy
()
else
:
return
F
.
conv2d
(
x
,
F
.
ones
((
2
,
2
,
2
,
1
,
1
)),
F
.
ones
((
1
,
4
,
1
,
1
)),
groups
=
2
).
numpy
()
data
=
np
.
arange
(
0
,
48
).
reshape
((
1
,
4
,
3
,
4
))
_compare_nchw_nhwc
(
data
,
conv2d
)
def
test_bn
():
def
func
(
x
):
if
x
.
format
==
"nhwc"
:
with
mge
.
config
.
_override
(
bn_format
=
"dim_111c"
):
oups
=
F
.
batch_norm
(
x
.
astype
(
"float32"
),
running_mean
=
mge
.
tensor
(
np
.
ones
((
1
,
1
,
1
,
2
)),
format
=
"nhwc"
),
running_var
=
mge
.
tensor
(
np
.
ones
((
1
,
1
,
1
,
2
)),
format
=
"nhwc"
),
weight
=
mge
.
tensor
(
np
.
ones
((
1
,
1
,
1
,
2
)),
format
=
"nhwc"
),
bias
=
mge
.
tensor
(
np
.
ones
((
1
,
1
,
1
,
2
)),
format
=
"nhwc"
),
training
=
True
,
inplace
=
False
,
)
assert
oups
[
0
].
format
==
"nhwc"
,
"y's format is wrong"
assert
oups
[
1
].
format
==
"nhwc"
,
"running_mean's format is wrong"
assert
oups
[
2
].
format
==
"nhwc"
,
"running_var's format is wrong"
return
oups
[
0
].
numpy
()
else
:
return
F
.
batch_norm
(
x
.
astype
(
"float32"
),
running_mean
=
mge
.
tensor
(
np
.
ones
((
1
,
2
,
1
,
1
))),
running_var
=
mge
.
tensor
(
np
.
ones
((
1
,
2
,
1
,
1
))),
weight
=
mge
.
tensor
(
np
.
ones
((
1
,
2
,
1
,
1
))),
bias
=
mge
.
tensor
(
np
.
ones
((
1
,
2
,
1
,
1
))),
training
=
True
,
inplace
=
False
,
)[
0
].
numpy
()
data
=
np
.
arange
(
0
,
24
).
reshape
((
1
,
2
,
3
,
4
))
_compare_nchw_nhwc
(
data
,
func
)
@
pytest
.
mark
.
parametrize
(
"pooling"
,
[
F
.
max_pool2d
,
F
.
avg_pool2d
,
F
.
adaptive_avg_pool2d
,
F
.
adaptive_max_pool2d
],
)
def
test_pooling2d
(
pooling
):
def
func
(
x
):
if
x
.
format
==
"nhwc"
:
with
mge
.
config
.
_override
(
conv_format
=
"NHWC"
):
x
=
pooling
(
x
.
astype
(
"float32"
),
2
)
assert
x
.
format
==
"nhwc"
return
x
.
numpy
()
else
:
return
pooling
(
x
.
astype
(
"float32"
),
2
).
numpy
()
data
=
np
.
arange
(
0
,
24
).
reshape
((
1
,
2
,
3
,
4
))
_compare_nchw_nhwc
(
data
,
func
)
def
test_backward
():
data
=
np
.
arange
(
0
,
24
).
reshape
((
1
,
2
,
3
,
4
))
x
=
tensor
(
data
.
transpose
(
0
,
2
,
3
,
1
),
format
=
"nhwc"
)
w
=
mge
.
tensor
(
np
.
ones
((
3
,
1
,
1
,
2
)),
format
=
"nhwc"
)
b
=
mge
.
tensor
(
np
.
ones
((
1
,
1
,
1
,
3
)),
format
=
"nhwc"
)
gm
=
GradManager
().
attach
([
w
,
b
])
with
gm
:
with
mge
.
config
.
_override
(
auto_format_convert
=
True
,
conv_format
=
"NHWC"
):
x
=
F
.
conv2d
(
x
,
w
,
b
)
gm
.
backward
(
x
)
# TODO: backward grad has no format yet
np
.
testing
.
assert_equal
(
w
.
grad
.
numpy
(),
np
.
array
([
66
,
210
,
66
,
210
,
66
,
210
]).
reshape
((
3
,
1
,
1
,
2
)),
)
np
.
testing
.
assert_equal
(
b
.
grad
.
numpy
(),
np
.
array
([
12
,
12
,
12
]).
reshape
((
1
,
1
,
1
,
3
))
)
imperative/src/impl/basic_operators.cpp
浏览文件 @
533fb5bf
...
@@ -33,14 +33,20 @@ std::string GetAttr::to_string() const {
...
@@ -33,14 +33,20 @@ std::string GetAttr::to_string() const {
return
ssprintf
(
"GetAttr{attr=%s}"
,
attr_name
);
return
ssprintf
(
"GetAttr{attr=%s}"
,
attr_name
);
}
}
CreateTensor
::
CreateTensor
(
Kind
kind
,
CompNode
device
,
DType
dtype
,
ValueShape
shape
)
CreateTensor
::
CreateTensor
(
:
m_kind
(
kind
),
m_device
(
device
),
m_dtype
(
dtype
),
m_shape
(
shape
)
{}
Kind
kind
,
CompNode
device
,
DType
dtype
,
ValueShape
shape
,
Format
format
)
:
m_kind
(
kind
),
m_device
(
device
),
m_dtype
(
dtype
),
m_shape
(
shape
),
m_format
(
format
)
{}
CreateTensor
::
CreateTensor
(
Kind
kind
,
CompNode
device
,
TensorLayout
layout
)
CreateTensor
::
CreateTensor
(
Kind
kind
,
CompNode
device
,
TensorLayout
layout
)
:
m_kind
(
kind
),
:
m_kind
(
kind
),
m_device
(
device
),
m_device
(
device
),
m_dtype
(
layout
.
dtype
),
m_dtype
(
layout
.
dtype
),
m_shape
(
ValueShape
::
from
(
layout
))
{
m_shape
(
ValueShape
::
from
(
layout
)),
m_format
(
Format
::
Type
::
DEFAULT
)
{
mgb_assert
(
mgb_assert
(
layout
.
is_contiguous
()
||
layout
.
is_empty
(),
"layout should be contiguous"
);
layout
.
is_contiguous
()
||
layout
.
is_empty
(),
"layout should be contiguous"
);
}
}
...
@@ -74,8 +80,9 @@ auto CreateTensor::parse(Span<ValueRef> inputs) const -> Args {
...
@@ -74,8 +80,9 @@ auto CreateTensor::parse(Span<ValueRef> inputs) const -> Args {
std
::
string
CreateTensor
::
to_string
()
const
{
std
::
string
CreateTensor
::
to_string
()
const
{
return
ssprintf
(
return
ssprintf
(
"CreateTensor{kind=%d, device=%s, dtype=%s, shape=%s}"
,
(
int
)
m_kind
,
"CreateTensor{kind=%d, device=%s, dtype=%s, shape=%s, format=%s}"
,
m_device
.
to_string
().
c_str
(),
m_dtype
.
name
(),
m_shape
.
to_string
().
c_str
());
(
int
)
m_kind
,
m_device
.
to_string
().
c_str
(),
m_dtype
.
name
(),
m_shape
.
to_string
().
c_str
(),
m_format
.
to_string
().
c_str
());
}
}
std
::
string
DTRCommand
::
to_string
()
const
{
std
::
string
DTRCommand
::
to_string
()
const
{
...
...
imperative/src/impl/transformations/format.cpp
0 → 100644
浏览文件 @
533fb5bf
#include "megbrain/imperative/transformations/format.h"
#include "megbrain/imperative/ops/autogen.h"
namespace
mgb
{
namespace
imperative
{
using
FT
=
Format
::
Type
;
TypedValueRef
<
FormattedTensorValue
>
FormattedTensorValue
::
as
(
const
FT
&
target
)
const
{
return
FormattedTensorValue
::
make
(
m_value
,
target
);
}
TypedValueRef
<
FormattedTensorValue
>
FormattedTensorValue
::
to
(
const
FT
&
target
,
const
std
::
string
&
scope
)
const
{
std
::
vector
<
int32_t
>
pattern
;
if
(
m_format
==
FT
::
NHWC
&&
target
==
FT
::
NCHW
)
{
pattern
=
{
0
,
3
,
1
,
2
};
}
else
if
(
m_format
==
FT
::
NCHW
&&
target
==
FT
::
NHWC
)
{
pattern
=
{
0
,
2
,
3
,
1
};
}
else
{
mgb_throw
(
MegBrainError
,
"Unsupport format conversion from %s to %s"
,
m_format
.
to_string
().
c_str
(),
Format
(
target
).
to_string
().
c_str
());
}
auto
output
=
imperative
::
apply
(
*
Dimshuffle
::
make
(
pattern
,
scope
),
std
::
vector
<
ValueRef
>
{
m_value
})[
0
];
return
FormattedTensorValue
::
make
(
output
,
target
);
}
namespace
{
ValueRef
unwrap_input
(
const
ValueRef
&
input
)
{
if
(
auto
format_input
=
input
.
as_ref
<
FormattedTensorValue
>
())
{
return
format_input
->
value
();
}
else
{
return
input
;
}
}
std
::
vector
<
ValueRef
>
unwrap_inputs
(
const
Span
<
ValueRef
>&
inputs
)
{
std
::
vector
<
ValueRef
>
unwrapped_inputs
;
for
(
auto
&&
input
:
inputs
)
{
unwrapped_inputs
.
push_back
(
unwrap_input
(
input
));
}
return
unwrapped_inputs
;
}
std
::
vector
<
ValueRef
>
wrap_outputs
(
const
std
::
vector
<
ValueRef
>&
outputs
,
FT
type
=
FT
::
DEFAULT
)
{
std
::
vector
<
ValueRef
>
wrapped_outputs
;
for
(
auto
&&
output
:
outputs
)
{
wrapped_outputs
.
push_back
(
FormattedTensorValue
::
make
(
output
,
type
));
}
return
wrapped_outputs
;
}
ValueShape
convert_nhwc2nchw_shape
(
const
ValueShape
&
shape
)
{
mgb_assert
(
shape
.
ndim
==
4
);
auto
out
=
ValueShape
(
shape
);
out
[
3
]
=
shape
[
2
];
out
[
2
]
=
shape
[
1
];
out
[
1
]
=
shape
[
3
];
return
out
;
}
using
FormatRule
=
std
::
function
<
std
::
vector
<
ValueRef
>
(
const
OpDef
&
,
Span
<
ValueRef
>&
,
const
bool
&
)
>
;
static
std
::
unordered_map
<
Typeinfo
*
,
FormatRule
>
format_rules
;
template
<
typename
T
>
void
register_format_rule
(
std
::
vector
<
ValueRef
>
(
*
rule
)(
const
T
&
,
Span
<
ValueRef
>&
,
const
bool
&
))
{
format_rules
[
T
::
typeinfo
()]
=
[
rule
](
const
OpDef
&
def
,
Span
<
ValueRef
>&
inputs
,
const
bool
&
auto_convert
)
{
return
(
*
rule
)(
def
.
cast_final_safe
<
T
>
(),
inputs
,
auto_convert
);
};
}
auto
convert_nchw2nhwc_pattern
(
const
std
::
vector
<
int32_t
>&
pattern
)
{
mgb_assert
(
pattern
.
size
()
==
4
);
auto
nhwc_pattern
=
pattern
;
for
(
size_t
idx
=
0
;
idx
<
4
;
++
idx
)
{
auto
dim
=
pattern
[
idx
];
if
(
dim
==
1
)
{
nhwc_pattern
[
idx
]
=
3
;
}
else
if
(
dim
==
2
)
{
nhwc_pattern
[
idx
]
=
1
;
}
else
if
(
dim
==
3
)
{
nhwc_pattern
[
idx
]
=
2
;
}
}
return
nhwc_pattern
;
}
std
::
vector
<
ValueRef
>
dimshuffle_rule
(
const
Dimshuffle
&
op
,
Span
<
ValueRef
>&
inputs
,
const
bool
&
auto_convert
)
{
mgb_assert
(
inputs
.
size
()
==
1
);
auto
&
src
=
inputs
[
0
].
cast
<
FormattedTensorValue
>
();
// Only support converting pattern from NCHW to NHWC currently.
if
(
auto_convert
&&
src
.
format
()
==
FT
::
NHWC
)
{
auto
pattern
=
convert_nchw2nhwc_pattern
(
op
.
pattern
);
// dimshuffle will not maintain NHWC Format
return
wrap_outputs
(
imperative
::
apply
(
*
Dimshuffle
::
make
(
std
::
move
(
pattern
),
op
.
scope
()),
unwrap_inputs
(
inputs
)));
}
return
wrap_outputs
(
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
)));
}
ValueRef
convert_nchw2nhwc_tensornd
(
const
HostTensorND
&
shape
)
{
mgb_assert
(
shape
.
layout
().
total_nr_elems
()
==
4
);
auto
*
old_ptr
=
shape
.
ptr
<
dt_int32
>
();
auto
cn
=
shape
.
comp_node
();
auto
layout
=
shape
.
layout
();
auto
nhwc_shape
=
HostTensorND
(
cn
,
layout
);
auto
*
new_ptr
=
nhwc_shape
.
ptr
<
dt_int32
>
();
new_ptr
[
0
]
=
old_ptr
[
0
];
new_ptr
[
1
]
=
old_ptr
[
2
];
new_ptr
[
2
]
=
old_ptr
[
3
];
new_ptr
[
3
]
=
old_ptr
[
1
];
auto
hv
=
HostStorage
::
make
(
nhwc_shape
.
storage
());
auto
nhwc_shape_input
=
imperative
::
apply
(
CreateTensor
(
CreateTensor
::
Const
,
cn
,
layout
),
hv
)[
0
];
return
nhwc_shape_input
;
}
std
::
vector
<
ValueRef
>
reshape_rule
(
const
Reshape
&
op
,
Span
<
ValueRef
>&
inputs
,
const
bool
&
auto_convert
)
{
mgb_assert
(
inputs
.
size
()
==
2
);
auto
&
src
=
inputs
[
0
].
cast
<
FormattedTensorValue
>
();
if
(
auto_convert
&&
src
.
format
()
==
FT
::
NHWC
)
{
auto
shape
=
unwrap_input
(
inputs
[
1
]).
numpy
().
cast
<
HostValue
>
().
as_nd
();
if
(
shape
.
layout
().
total_nr_elems
()
==
4
)
{
// output is still NHWC format
auto
nhwc_shape
=
convert_nchw2nhwc_tensornd
(
shape
);
auto
outputs
=
imperative
::
apply
(
op
,
std
::
vector
<
ValueRef
>
{
unwrap_input
(
inputs
[
0
]),
nhwc_shape
});
return
wrap_outputs
(
outputs
,
FT
::
NHWC
);
}
else
{
// will not maintain src's format
auto
nchw_src
=
src
.
to
(
FT
::
NCHW
,
op
.
scope
())
->
value
();
auto
outputs
=
imperative
::
apply
(
op
,
std
::
vector
<
ValueRef
>
{
nchw_src
,
unwrap_input
(
inputs
[
1
])});
return
wrap_outputs
(
outputs
);
}
}
return
wrap_outputs
(
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
)));
}
std
::
vector
<
ValueRef
>
broadcast_rule
(
const
Broadcast
&
op
,
Span
<
ValueRef
>&
inputs
,
const
bool
&
auto_convert
)
{
mgb_assert
(
inputs
.
size
()
==
2
);
auto
&
src
=
inputs
[
0
].
cast
<
FormattedTensorValue
>
();
if
(
auto_convert
&&
src
.
format
()
==
FT
::
NHWC
)
{
auto
shape
=
unwrap_input
(
inputs
[
1
]).
numpy
().
cast
<
HostValue
>
().
as_nd
();
if
(
shape
.
layout
().
total_nr_elems
()
==
4
)
{
// output is still NHWC format
auto
nhwc_shape
=
convert_nchw2nhwc_tensornd
(
shape
);
auto
outputs
=
imperative
::
apply
(
op
,
std
::
vector
<
ValueRef
>
{
unwrap_input
(
inputs
[
0
]),
nhwc_shape
});
return
wrap_outputs
(
outputs
,
FT
::
NHWC
);
}
else
{
// will not maintain src's format
auto
nchw_src
=
src
.
to
(
FT
::
NCHW
,
op
.
scope
())
->
value
();
auto
outputs
=
imperative
::
apply
(
op
,
std
::
vector
<
ValueRef
>
{
nchw_src
,
unwrap_input
(
inputs
[
1
])});
return
wrap_outputs
(
outputs
);
}
}
return
wrap_outputs
(
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
)));
}
bool
is_reduce_ndim_idx_items
(
const
std
::
vector
<
std
::
tuple
<
int8_t
,
bool
,
bool
,
bool
,
bool
>>&
items
,
const
Span
<
ValueRef
>&
inputs
)
{
for
(
auto
i
=
0
;
i
<
items
.
size
();
++
i
)
{
auto
&&
[
axis
,
begin
,
end
,
step
,
idx
]
=
items
[
i
];
if
(
idx
)
{
// if inputs[i] contains more than one value, ndim will not be reduced.
return
inputs
[
i
].
is_scalar
();
}
}
return
false
;
}
auto
convert_nchw2nhwc_idx_items
(
const
std
::
vector
<
std
::
tuple
<
int8_t
,
bool
,
bool
,
bool
,
bool
>>&
items
)
{
auto
nhwc_items
=
items
;
for
(
auto
i
=
0
;
i
<
nhwc_items
.
size
();
++
i
)
{
auto
&&
[
axis
,
begin
,
end
,
step
,
idx
]
=
nhwc_items
[
i
];
if
(
axis
==
2
||
axis
==
3
)
{
nhwc_items
[
i
]
=
{
axis
-
1
,
begin
,
end
,
step
,
idx
};
}
else
if
(
axis
==
1
)
{
nhwc_items
[
i
]
=
{
3
,
begin
,
end
,
step
,
idx
};
}
}
return
nhwc_items
;
}
template
<
typename
T
>
std
::
vector
<
ValueRef
>
subtensor_rule
(
const
T
&
op
,
Span
<
ValueRef
>&
inputs
,
const
bool
&
auto_convert
)
{
mgb_assert
(
inputs
.
size
()
>=
1
);
auto
&
src
=
inputs
[
0
].
cast
<
FormattedTensorValue
>
();
bool
is_reduce_ndim
=
is_reduce_ndim_idx_items
(
op
.
items
,
{
&
inputs
[
1
],
&
inputs
[
inputs
.
size
()
-
1
]});
if
(
!
is_reduce_ndim
)
{
// only support NHWC2NCHW convert, otherwise maintain src's format
if
(
!
(
auto_convert
&&
src
.
format
()
==
FT
::
NHWC
))
{
return
{
FormattedTensorValue
::
make
(
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
))[
0
],
src
.
format
())};
}
auto
nhwc_items
=
convert_nchw2nhwc_idx_items
(
op
.
items
);
auto
outputs
=
imperative
::
apply
(
*
T
::
make
(
std
::
move
(
nhwc_items
),
op
.
scope
()),
unwrap_inputs
(
inputs
));
return
wrap_outputs
(
outputs
,
FT
::
NHWC
);
}
return
wrap_outputs
(
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
)));
}
template
<
typename
T
>
std
::
vector
<
ValueRef
>
setsubtensor_rule
(
const
T
&
op
,
Span
<
ValueRef
>&
inputs
,
const
bool
&
auto_convert
)
{
mgb_assert
(
inputs
.
size
()
>=
2
);
auto
&
src
=
inputs
[
0
].
cast
<
FormattedTensorValue
>
();
bool
is_reduce_ndim
=
is_reduce_ndim_idx_items
(
op
.
items
,
{
&
inputs
[
2
],
&
inputs
[
inputs
.
size
()
-
1
]});
if
(
!
is_reduce_ndim
)
{
// only support NHWC2NCHW convert, otherwise maintain src's format
if
(
!
(
auto_convert
&&
src
.
format
()
==
FT
::
NHWC
))
{
return
{
FormattedTensorValue
::
make
(
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
))[
0
],
src
.
format
())};
}
// value has been broadcasted to src's fake NCHW shape.
auto
&
value
=
inputs
[
1
].
cast
<
FormattedTensorValue
>
();
auto
&
format
=
value
.
format
();
auto
nhwc_inputs
=
std
::
vector
<
ValueRef
>
(
inputs
.
size
());
if
(
format
==
FT
::
DEFAULT
||
format
==
FT
::
NCHW
)
{
// value for setsubtensor should transpose to match shape.
auto
nhwc_value
=
value
.
as
(
FT
::
NCHW
)
->
to
(
FT
::
NHWC
);
// make new inputs for setsubtensor
nhwc_inputs
[
0
]
=
src
.
value
();
nhwc_inputs
[
1
]
=
nhwc_value
->
value
();
for
(
auto
i
=
2
;
i
<
inputs
.
size
();
++
i
)
{
nhwc_inputs
[
i
]
=
inputs
[
i
].
as_ref
<
FormattedTensorValue
>
()
->
value
();
}
}
else
if
(
format
!=
FT
::
NHWC
)
{
mgb_throw
(
MegBrainError
,
"Unsupported format(%s) of value for setsubtensor."
,
format
.
to_string
().
c_str
());
}
auto
nhwc_items
=
convert_nchw2nhwc_idx_items
(
op
.
items
);
auto
outputs
=
imperative
::
apply
(
*
T
::
make
(
std
::
move
(
nhwc_items
),
op
.
scope
()),
nhwc_inputs
);
return
wrap_outputs
(
outputs
,
FT
::
NHWC
);
}
return
wrap_outputs
(
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
)));
}
FT
get_inputs_format
(
Span
<
ValueRef
>&
inputs
)
{
FT
format
(
FT
::
DEFAULT
);
for
(
auto
&
inp
:
inputs
)
{
auto
&
inp_format
=
inp
.
cast
<
FormattedTensorValue
>
().
format
();
if
(
inp_format
!=
FT
::
DEFAULT
)
{
mgb_assert
(
format
==
FT
::
DEFAULT
||
inp_format
==
format
);
format
=
inp_format
.
type
();
}
}
return
format
;
}
std
::
vector
<
ValueRef
>
concat_rule
(
const
Concat
&
op
,
Span
<
ValueRef
>&
inputs
,
const
bool
&
auto_convert
)
{
FT
format
=
get_inputs_format
(
inputs
);
if
(
!
(
format
==
FT
::
NHWC
&&
auto_convert
))
{
return
wrap_outputs
(
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
)),
format
);
}
// TODO: handle 5D NHWC Tensor from group conv
auto
axis
=
op
.
axis
;
if
(
axis
==
2
||
axis
==
3
)
{
axis
=
axis
-
1
;
}
else
if
(
axis
==
1
)
{
axis
=
3
;
}
return
wrap_outputs
(
imperative
::
apply
(
*
Concat
::
make
(
axis
,
op
.
comp_node
,
op
.
scope
()),
unwrap_inputs
(
inputs
)),
format
);
}
std
::
vector
<
ValueRef
>
elemwise_rule
(
const
Elemwise
&
op
,
Span
<
ValueRef
>&
inputs
,
const
bool
&
auto_convert
)
{
FT
format
=
get_inputs_format
(
inputs
);
return
wrap_outputs
(
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
)),
format
);
}
std
::
vector
<
ValueRef
>
identity_rule_helper
(
const
OpDef
&
op
,
const
Span
<
ValueRef
>&
inputs
)
{
// mgb_assert(inputs.size() == 1);
auto
&
src
=
inputs
[
0
].
cast
<
FormattedTensorValue
>
();
return
wrap_outputs
(
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
)),
src
.
format
().
type
());
}
// clang-format off
#define FOREACH_IDENTITY_OP(cb) \
cb(Copy) \
cb(FastpathCopy) \
cb(TypeCvt) \
cb(Pooling) \
cb(AdaptivePooling) \
cb(Dropout) \
cb(Convolution) \
cb(BatchNorm) \
cb(Resize) \
cb(Identity)
// clang-format on
#define CREATE_IDENTITY_OP_RULE(op) \
std::vector<ValueRef> op##_rule( \
const op& _op, Span<ValueRef>& inputs, const bool& auto_convert) { \
return identity_rule_helper(_op, inputs); \
}
FOREACH_IDENTITY_OP
(
CREATE_IDENTITY_OP_RULE
)
#undef CREATE_IDENTITY_OP_RULE
#define REGISTER_IDENTITY_OP_RULE(op) register_format_rule(op##_rule);
struct
FormatRuleRegistry
{
FormatRuleRegistry
()
{
register_format_rule
(
dimshuffle_rule
);
register_format_rule
(
reshape_rule
);
register_format_rule
(
broadcast_rule
);
register_format_rule
(
subtensor_rule
<
Subtensor
>
);
register_format_rule
(
subtensor_rule
<
IndexingMultiAxisVec
>
);
register_format_rule
(
setsubtensor_rule
<
SetSubtensor
>
);
register_format_rule
(
setsubtensor_rule
<
IndexingSetMultiAxisVec
>
);
register_format_rule
(
concat_rule
);
register_format_rule
(
elemwise_rule
);
FOREACH_IDENTITY_OP
(
REGISTER_IDENTITY_OP_RULE
)
}
}
_
;
#undef REGISTER_IDENTITY_OP_RULE
}
// namespace
std
::
vector
<
ValueRef
>
FormatTransformation
::
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
{
if
(
auto
*
apply_op
=
op
.
as
<
ApplyOp
>
())
{
// all inputs should be FormattedTensorValue
auto
iter
=
format_rules
.
find
(
apply_op
->
op
().
dyn_typeinfo
());
if
(
iter
!=
format_rules
.
end
())
{
return
iter
->
second
(
apply_op
->
op
(),
inputs
,
m_auto_convert
);
}
else
{
return
wrap_outputs
(
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
)));
}
}
else
if
(
auto
*
create_tensor
=
op
.
as
<
CreateTensor
>
())
{
auto
format
=
create_tensor
->
format
();
return
{
FormattedTensorValue
::
make
(
imperative
::
apply
(
op
,
inputs
)[
0
],
format
)};
}
else
if
(
auto
*
get_attr
=
op
.
as
<
GetAttr
>
())
{
auto
*
src
=
inputs
.
as_array
<
1
>
()[
0
].
as
<
FormattedTensorValue
>
();
if
(
!
m_auto_convert
||
!
src
||
src
->
format
()
!=
FT
::
NHWC
)
{
return
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
));
}
switch
(
get_attr
->
attr
())
{
case
GetAttr
::
Shape
:
{
auto
output
=
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
))[
0
];
auto
shape
=
convert_nhwc2nchw_shape
(
output
.
cast
<
ShapeValue
>
());
return
{
ShapeValue
::
make
(
shape
)};
}
case
GetAttr
::
Value
:
{
auto
nchw_src
=
unwrap_input
(
src
->
to
(
FT
::
NCHW
,
""
));
return
imperative
::
apply
(
op
,
std
::
vector
<
ValueRef
>
{
nchw_src
});
}
default:
return
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
));
}
}
else
if
(
op
.
is
<
GetFormat
>
())
{
bool
is_formatted_tensor
=
inputs
.
as_array
<
1
>
()[
0
].
is
<
FormattedTensorValue
>
();
if
(
is_formatted_tensor
)
{
return
{
FormatValue
::
make
(
inputs
[
0
].
cast
<
FormattedTensorValue
>
().
format
())};
}
else
{
mgb_log_warn
(
"Not FormattedTensorValue input for GetFormat op: %s"
,
inputs
[
0
].
to_string
().
c_str
());
return
{
FormatValue
::
make
(
FT
::
DEFAULT
)};
}
}
else
if
(
op
.
is
<
Operator
::
IdentityLike
>
())
{
bool
is_formatted_tensor
=
inputs
.
as_array
<
1
>
()[
0
].
is
<
FormattedTensorValue
>
();
if
(
is_formatted_tensor
)
{
auto
&
format
=
inputs
[
0
].
cast
<
FormattedTensorValue
>
().
format
();
return
wrap_outputs
(
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
)),
format
.
type
());
}
else
{
mgb_log_warn
(
"Not FormattedTensorValue input for IdentityLike op: %s"
,
inputs
[
0
].
to_string
().
c_str
());
return
imperative
::
apply
(
op
,
inputs
);
}
}
else
{
return
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
));
}
};
}
// namespace imperative
}
// namespace mgb
imperative/src/impl/value.cpp
浏览文件 @
533fb5bf
...
@@ -58,6 +58,10 @@ TypedValueRef<DTypeValue> ValueRef::dtype() const {
...
@@ -58,6 +58,10 @@ TypedValueRef<DTypeValue> ValueRef::dtype() const {
return
imperative
::
apply
(
GetAttr
(
GetAttr
::
DType
),
*
this
)[
0
].
cast_ref
<
DTypeValue
>
();
return
imperative
::
apply
(
GetAttr
(
GetAttr
::
DType
),
*
this
)[
0
].
cast_ref
<
DTypeValue
>
();
}
}
TypedValueRef
<
FormatValue
>
ValueRef
::
format
()
const
{
return
imperative
::
apply
(
GetFormat
(),
*
this
)[
0
].
as_ref
<
FormatValue
>
();
}
TypedValueRef
<
StringValue
>
ValueRef
::
name
()
const
{
TypedValueRef
<
StringValue
>
ValueRef
::
name
()
const
{
return
imperative
::
apply
(
GetName
(),
*
this
)[
0
].
cast_ref
<
StringValue
>
();
return
imperative
::
apply
(
GetName
(),
*
this
)[
0
].
cast_ref
<
StringValue
>
();
}
}
...
...
imperative/src/include/megbrain/imperative/basic_operators.h
浏览文件 @
533fb5bf
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/operator.h"
#include "megbrain/imperative/operator.h"
#include "megbrain/imperative/utils/data_format.h"
#include "megbrain/imperative/utils/helper.h"
#include "megbrain/imperative/utils/helper.h"
#include "megbrain/imperative/utils/value_shape.h"
#include "megbrain/imperative/utils/value_shape.h"
...
@@ -82,9 +83,12 @@ private:
...
@@ -82,9 +83,12 @@ private:
CompNode
m_device
;
CompNode
m_device
;
DType
m_dtype
;
DType
m_dtype
;
ValueShape
m_shape
;
ValueShape
m_shape
;
Format
m_format
;
public:
public:
CreateTensor
(
Kind
kind
,
CompNode
device
,
DType
dtype
,
ValueShape
shape
);
CreateTensor
(
Kind
kind
,
CompNode
device
,
DType
dtype
,
ValueShape
shape
,
Format
format
=
Format
::
Type
::
DEFAULT
);
CreateTensor
(
Kind
kind
,
CompNode
device
,
TensorLayout
layout
);
CreateTensor
(
Kind
kind
,
CompNode
device
,
TensorLayout
layout
);
/**
/**
...
@@ -99,6 +103,7 @@ public:
...
@@ -99,6 +103,7 @@ public:
CompNode
device
()
const
{
return
m_device
;
}
CompNode
device
()
const
{
return
m_device
;
}
DType
dtype
()
const
{
return
m_dtype
;
}
DType
dtype
()
const
{
return
m_dtype
;
}
ValueShape
shape
()
const
{
return
m_shape
;
}
ValueShape
shape
()
const
{
return
m_shape
;
}
Format
format
()
const
{
return
m_format
;
}
std
::
string
to_string
()
const
override
;
std
::
string
to_string
()
const
override
;
};
};
...
@@ -157,6 +162,11 @@ public:
...
@@ -157,6 +162,11 @@ public:
std
::
string
to_string
()
const
override
;
std
::
string
to_string
()
const
override
;
};
};
class
GetFormat
final
:
public
OperatorImpl
<
GetFormat
,
Operator
::
GetAttrLike
>
{
public:
std
::
string
to_string
()
const
override
{
return
"GetFormat{}"
;
}
};
class
GetVarVal
final
:
public
OperatorImpl
<
GetVarVal
,
Operator
::
GetAttrLike
>
{
class
GetVarVal
final
:
public
OperatorImpl
<
GetVarVal
,
Operator
::
GetAttrLike
>
{
public:
public:
std
::
string
to_string
()
const
override
;
std
::
string
to_string
()
const
override
;
...
...
imperative/src/include/megbrain/imperative/basic_values.h
浏览文件 @
533fb5bf
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include <future>
#include <future>
#include <iomanip>
#include <iomanip>
#include "megbrain/imperative/utils/data_format.h"
#include "megbrain/imperative/utils/helper.h"
#include "megbrain/imperative/utils/helper.h"
#include "megbrain/imperative/utils/value_shape.h"
#include "megbrain/imperative/utils/value_shape.h"
#include "megbrain/imperative/value.h"
#include "megbrain/imperative/value.h"
...
@@ -148,6 +149,13 @@ public:
...
@@ -148,6 +149,13 @@ public:
std
::
string
to_string
()
const
override
;
std
::
string
to_string
()
const
override
;
};
};
class
FormatValue
final
:
public
PrimitiveValue
<
FormatValue
,
Format
>
{
public:
using
PrimitiveValue
::
PrimitiveValue
;
std
::
string
to_string
()
const
override
{
return
Format
::
to_string
();
}
};
class
StringValue
final
:
public
PrimitiveValue
<
StringValue
,
std
::
string
>
{
class
StringValue
final
:
public
PrimitiveValue
<
StringValue
,
std
::
string
>
{
public:
public:
using
PrimitiveValue
::
PrimitiveValue
;
using
PrimitiveValue
::
PrimitiveValue
;
...
...
imperative/src/include/megbrain/imperative/transformations/format.h
0 → 100644
浏览文件 @
533fb5bf
#pragma once
#include "megbrain/imperative/basic_values.h"
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/utils/data_format.h"
namespace
mgb
::
imperative
{
class
FormattedTensorValue
final
:
public
ValueImpl
<
FormattedTensorValue
>
{
private:
ValueRef
m_value
;
Format
m_format
;
public:
FormattedTensorValue
(
ValueRef
value
,
Format
format
)
:
m_value
(
value
),
m_format
(
format
)
{}
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"FormattedTensorValue{value=%s, format=%s}"
,
m_value
.
to_string
().
c_str
(),
m_format
.
to_string
().
c_str
());
}
ValueRef
value
()
const
{
return
m_value
;
}
const
Format
&
format
()
const
{
return
m_format
;
}
TypedValueRef
<
FormattedTensorValue
>
as
(
const
Format
::
Type
&
target
)
const
;
TypedValueRef
<
FormattedTensorValue
>
to
(
const
Format
::
Type
&
target
,
const
std
::
string
&
scope
=
""
)
const
;
void
clear
()
override
{
m_value
=
{};
m_format
=
{};
}
void
on_watch
()
override
{
m_value
.
watch
();
}
void
on_unwatch
()
override
{
m_value
.
unwatch
();
}
};
/**
* \brief simulates scalar because megbrain graph system don't support scalar
*
* Assume that we has 'a = ScalarValue(b)', thus 'a.shape == []', 'b.shape == [1]'.
* This transformation simulates scalars with a flag. If a value is ScalarValue, it is
* scalar, vice versa. So there is not scalar down this layer.
*/
class
FormatTransformation
final
:
public
Transformation
{
private:
bool
m_auto_convert
=
false
;
public:
std
::
vector
<
ValueRef
>
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
;
ValueRef
unwrap
(
ValueRef
value
)
override
{
mgb_assert
(
!
value
.
is
<
FormattedTensorValue
>
());
return
value
;
}
std
::
string
name
()
const
override
{
return
ssprintf
(
"FormatTransformation{auto_convert=%d}"
,
m_auto_convert
);
}
void
set_auto_convert
(
bool
enabled
)
{
m_auto_convert
=
enabled
;
}
bool
get_auto_convert
()
const
{
return
m_auto_convert
;
}
};
}
// namespace mgb::imperative
imperative/src/include/megbrain/imperative/utils/data_format.h
0 → 100644
浏览文件 @
533fb5bf
#pragma once
#include "megbrain/tensor.h"
namespace
mgb
::
imperative
{
/**
* \brief like TensorFormats, but only including common formats and DEFAULT.
*
*/
class
Format
{
public:
enum
class
Type
{
DEFAULT
=
0
,
NCHW
=
1
,
///< [N, C, H, W]
NHWC
=
2
,
///< [N, H, W, C]
};
std
::
string
to_string
()
const
{
switch
(
m_type
)
{
case
Type
::
DEFAULT
:
return
"default"
;
case
Type
::
NCHW
:
return
"nchw"
;
case
Type
::
NHWC
:
return
"nhwc"
;
default:
mgb_throw
(
MegBrainError
,
"bad format type"
);
}
}
Format
()
:
m_type
(
Type
::
DEFAULT
)
{}
Format
(
std
::
string
str
)
{
if
(
str
==
"default"
)
{
m_type
=
Type
::
DEFAULT
;
}
else
if
(
str
==
"nchw"
)
{
m_type
=
Type
::
NCHW
;
}
else
if
(
str
==
"nhwc"
)
{
m_type
=
Type
::
NHWC
;
}
else
{
mgb_throw
(
MegBrainError
,
"Invalid format type."
" Only support
\"
default
\"
,
\"
nchw
\"
and
\"
nhwc
\"
"
);
}
}
Format
(
Type
type
)
:
m_type
(
type
)
{}
Type
type
()
const
{
return
m_type
;
}
bool
operator
==
(
const
Format
&
b
)
const
{
return
m_type
==
b
.
type
();
}
bool
operator
==
(
const
Format
::
Type
&
b
)
const
{
return
m_type
==
b
;
}
bool
operator
!=
(
const
Format
&
b
)
const
{
return
m_type
!=
b
.
type
();
}
bool
operator
!=
(
const
Format
::
Type
&
b
)
const
{
return
m_type
!=
b
;
}
private:
Type
m_type
=
Type
::
DEFAULT
;
};
}
// namespace mgb::imperative
imperative/src/include/megbrain/imperative/value.h
浏览文件 @
533fb5bf
...
@@ -31,6 +31,7 @@ class HostValue;
...
@@ -31,6 +31,7 @@ class HostValue;
class
DeviceValue
;
class
DeviceValue
;
class
ShapeValue
;
class
ShapeValue
;
class
DTypeValue
;
class
DTypeValue
;
class
FormatValue
;
class
CompNodeValue
;
class
CompNodeValue
;
class
StringValue
;
class
StringValue
;
class
NodeValue
;
class
NodeValue
;
...
@@ -219,6 +220,7 @@ public:
...
@@ -219,6 +220,7 @@ public:
TypedValueRef
<
CompNodeValue
>
device
()
const
;
TypedValueRef
<
CompNodeValue
>
device
()
const
;
TypedValueRef
<
ShapeValue
>
shape
()
const
;
TypedValueRef
<
ShapeValue
>
shape
()
const
;
TypedValueRef
<
DTypeValue
>
dtype
()
const
;
TypedValueRef
<
DTypeValue
>
dtype
()
const
;
TypedValueRef
<
FormatValue
>
format
()
const
;
TypedValueRef
<
StringValue
>
name
()
const
;
TypedValueRef
<
StringValue
>
name
()
const
;
bool
is_scalar
()
const
;
bool
is_scalar
()
const
;
...
@@ -431,9 +433,11 @@ inline const TypedValueRef<TValue>& ValueRef::cast_ref(const Type<TValue>& type)
...
@@ -431,9 +433,11 @@ inline const TypedValueRef<TValue>& ValueRef::cast_ref(const Type<TValue>& type)
inline
void
ValueRef
::
on_cast_failure
(
const
IType
&
type
)
const
{
inline
void
ValueRef
::
on_cast_failure
(
const
IType
&
type
)
const
{
// if this is ErrorValue, rethrow directly
// if this is ErrorValue, rethrow directly
storage
()
->
try_rethrow
();
storage
()
->
try_rethrow
();
mgb_assert
(
if
(
storage
()
->
type
()
!=
type
)
{
storage
()
->
type
()
!=
type
,
"expect type %s, got %s"
,
type
.
name
().
c_str
(),
mgb_throw
(
to_string
().
c_str
());
MegBrainError
,
"Unable to cast ValueRef: expect type %s, got %s"
,
type
.
name
().
c_str
(),
to_string
().
c_str
());
}
}
}
/**
/**
...
...
src/opr/impl/dnn/batch_norm.cpp
浏览文件 @
533fb5bf
...
@@ -200,7 +200,7 @@ void BatchNormForward::get_output_var_shape(
...
@@ -200,7 +200,7 @@ void BatchNormForward::get_output_var_shape(
bias_c
=
inp_shape
[
2
][
channel_idx
];
bias_c
=
inp_shape
[
2
][
channel_idx
];
mgb_assert
(
mgb_assert
(
inp_c
==
scale_c
&&
inp_c
==
bias_c
,
inp_c
==
scale_c
&&
inp_c
==
bias_c
,
"inconsistent channel size, input ch
e
nnel: %zu, scale channel: %zu, bias "
"inconsistent channel size, input ch
a
nnel: %zu, scale channel: %zu, bias "
"channel: %zu"
,
"channel: %zu"
,
inp_c
,
scale_c
,
bias_c
);
inp_c
,
scale_c
,
bias_c
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录