Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
0ad5eeae
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
0ad5eeae
编写于
12月 23, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb/gopt): global layout transform support opencl
GitOrigin-RevId: 132605c7d946d403dc2164a71cd3769b29ccfb31
上级
26146e5a
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
236 addition
and
11 deletion
+236
-11
dnn/src/common/named_tensor.cpp
dnn/src/common/named_tensor.cpp
+2
-0
src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp
...mpl/global_layout_transform/opr_tensor_formats_config.cpp
+130
-0
src/gopt/impl/global_layout_transform/profiler_impl.cpp
src/gopt/impl/global_layout_transform/profiler_impl.cpp
+60
-7
src/gopt/impl/global_layout_transform/reformat_manager.cpp
src/gopt/impl/global_layout_transform/reformat_manager.cpp
+29
-4
src/gopt/impl/global_layout_transform/utils.h
src/gopt/impl/global_layout_transform/utils.h
+4
-0
src/gopt/include/megbrain/gopt/profiler.h
src/gopt/include/megbrain/gopt/profiler.h
+5
-0
src/gopt/test/cache_data.h
src/gopt/test/cache_data.h
+0
-0
src/opr/impl/io.cpp
src/opr/impl/io.cpp
+4
-0
src/opr/include/megbrain/opr/io.h
src/opr/include/megbrain/opr/io.h
+2
-0
未找到文件。
dnn/src/common/named_tensor.cpp
浏览文件 @
0ad5eeae
...
...
@@ -246,6 +246,8 @@ NamedTensorShape NamedTensorShape::make_named_tensor_shape(Format format) {
return
{{
"N//8"
},
{
"C//8"
},
{
"H"
},
{
"W"
},
{
"C%8"
},
{
"N%8"
}};
case
Format
::
NCHW44_DOT
:
return
{{
"N//4"
},
{
"C//4"
},
{
"H"
},
{
"W"
},
{
"N%4"
},
{
"C%4"
}};
case
Format
::
NHWCD4
:
return
{{
"N"
},
{
"H"
},
{
"C//4"
},
{
"W"
},
{
"C%4"
}};
default:
megdnn_throw
(
ssprintf
(
"Format unimplement(%d)"
,
static_cast
<
int
>
(
format
))
.
c_str
());
...
...
src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp
浏览文件 @
0ad5eeae
...
...
@@ -229,6 +229,30 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW64> {
}
};
template
<
>
struct
OprSingleInOutTensorFormatsDispatcherImpl
<
OprFormatConfigID
::
NHWCD4
>
{
static
Maybe
<
OprTensorFormatsConfiguration
>
dispatch
(
const
OperatorNodeBase
*
opr
)
{
OprTensorFormatsConfiguration
config
;
config
.
typeinfo
=
opr
->
dyn_typeinfo
();
config
.
opr_format
=
OprFormat
::
NHWCD4
;
config
.
config_id
=
OprFormatConfigID
::
NHWCD4
;
bool
available
=
opr
->
input
(
0
)
->
dtype
().
enumv
()
==
DTypeEnum
::
Float32
||
DNN_FLOAT16_SELECT
(
(
opr
->
input
(
0
)
->
dtype
().
enumv
()
==
DTypeEnum
::
Float16
),
true
)
||
opr
->
input
(
0
)
->
dtype
().
enumv
()
==
DTypeEnum
::
Int8
||
opr
->
input
(
0
)
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS8
;
config
.
input_dtypes
=
{
opr
->
input
(
0
)
->
dtype
().
enumv
()};
config
.
input_tensor_types
=
{
TensorType
::
FEATURE
};
config
.
output_dtypes
=
{
opr
->
output
(
0
)
->
dtype
().
enumv
()};
config
.
input_tensor_formats
=
{
TensorFormats
::
NHCWc4
};
config
.
output_tensor_formats
=
{
TensorFormats
::
NHCWc4
};
if
(
available
)
return
config
;
return
None
;
}
};
template
<
typename
Opr
,
OprFormatConfigID
config_id
>
struct
ConvTensorFormatsDispatcherImpl
;
...
...
@@ -814,6 +838,55 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44_DOT_HYBRID
}
};
template
<
typename
Opr
>
struct
ConvTensorFormatsDispatcherImpl
<
Opr
,
OprFormatConfigID
::
NHWCD4
>
{
static
Maybe
<
OprTensorFormatsConfiguration
>
dispatch
(
const
OperatorNodeBase
*
opr
)
{
const
auto
&
conv
=
opr
->
cast_final_safe
<
Opr
>
();
OprTensorFormatsConfiguration
config
;
config
.
typeinfo
=
opr
->
dyn_typeinfo
();
config
.
opr_format
=
OprFormat
::
NHWCD4
;
config
.
config_id
=
OprFormatConfigID
::
NHWCD4
;
for
(
size_t
i
=
0
;
i
<
opr
->
input
().
size
();
++
i
)
{
config
.
input_dtypes
.
emplace_back
(
opr
->
input
(
i
)
->
dtype
().
enumv
());
TensorType
tensor_type
=
i
==
1
?
TensorType
::
WEIGHT
:
TensorType
::
FEATURE
;
config
.
input_tensor_types
.
emplace_back
(
tensor_type
);
}
config
.
output_dtypes
.
emplace_back
(
opr
->
output
(
0
)
->
dtype
().
enumv
());
if
(
conv
.
param
().
sparse
==
Opr
::
Param
::
Sparse
::
DENSE
)
{
if
(
opr
->
input
(
1
)
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS8
||
opr
->
input
(
1
)
->
dtype
().
enumv
()
==
DTypeEnum
::
Quantized8Asymm
)
{
config
.
input_tensor_formats
=
{
TensorFormats
::
NHCWc4
,
TensorFormats
::
KRSCk4c4
,
TensorFormats
::
NHCWc4
,
TensorFormats
::
NHCWc4
};
}
else
{
config
.
input_tensor_formats
=
{
TensorFormats
::
NHCWc4
,
TensorFormats
::
KRSCk4
,
TensorFormats
::
NHCWc4
,
TensorFormats
::
NHCWc4
};
}
}
else
{
mgb_assert
(
conv
.
param
().
sparse
==
Opr
::
Param
::
Sparse
::
GROUP
);
if
(
is_channel_wise_conv
<
Opr
>
(
opr
))
{
config
.
input_tensor_formats
=
{
TensorFormats
::
NHCWc4
,
TensorFormats
::
C1RSc4
,
TensorFormats
::
NHCWc4
,
TensorFormats
::
NHCWc4
};
}
else
{
if
(
opr
->
input
(
1
)
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS8
||
opr
->
input
(
1
)
->
dtype
().
enumv
()
==
DTypeEnum
::
Quantized8Asymm
)
{
config
.
input_tensor_formats
=
{
TensorFormats
::
NHCWc4
,
TensorFormats
::
GKRSCk4c4
,
TensorFormats
::
NHCWc4
,
TensorFormats
::
NHCWc4
};
}
else
{
config
.
input_tensor_formats
=
{
TensorFormats
::
NHCWc4
,
TensorFormats
::
GKRSCk4
,
TensorFormats
::
NHCWc4
,
TensorFormats
::
NHCWc4
};
}
}
}
config
.
output_tensor_formats
=
{
TensorFormats
::
NHCWc4
};
return
config
;
}
};
template
<
>
struct
ConvTensorFormatsDispatcherImpl
<
opr
::
ConvolutionBackwardData
,
OprFormatConfigID
::
NCHW
>
{
...
...
@@ -919,6 +992,57 @@ struct ConvTensorFormatsDispatcherImpl<
}
};
template
<
>
struct
ConvTensorFormatsDispatcherImpl
<
opr
::
ConvolutionBackwardData
,
OprFormatConfigID
::
NHWCD4
>
{
using
Opr
=
opr
::
ConvolutionBackwardData
;
static
Maybe
<
OprTensorFormatsConfiguration
>
dispatch
(
const
OperatorNodeBase
*
opr
)
{
const
auto
&
conv
=
opr
->
cast_final_safe
<
Opr
>
();
OprTensorFormatsConfiguration
config
;
config
.
typeinfo
=
opr
->
dyn_typeinfo
();
config
.
opr_format
=
OprFormat
::
NHWCD4
;
config
.
config_id
=
OprFormatConfigID
::
NHWCD4
;
for
(
size_t
i
=
0
;
i
<
opr
->
input
().
size
();
++
i
)
{
config
.
input_dtypes
.
emplace_back
(
opr
->
input
(
i
)
->
dtype
().
enumv
());
TensorType
tensor_type
=
i
==
0
?
TensorType
::
WEIGHT
:
TensorType
::
FEATURE
;
config
.
input_tensor_types
.
emplace_back
(
tensor_type
);
}
config
.
output_dtypes
.
emplace_back
(
opr
->
output
(
0
)
->
dtype
().
enumv
());
if
(
conv
.
param
().
sparse
==
Opr
::
Param
::
Sparse
::
DENSE
)
{
if
(
opr
->
input
(
0
)
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS8
||
opr
->
input
(
0
)
->
dtype
().
enumv
()
==
DTypeEnum
::
Quantized8Asymm
)
{
config
.
input_tensor_formats
=
{
TensorFormats
::
KRSCk4c4
,
TensorFormats
::
NHCWc4
,
TensorFormats
::
NHCWc4
,
TensorFormats
::
NHCWc4
};
}
else
{
config
.
input_tensor_formats
=
{
TensorFormats
::
KRSCk4
,
TensorFormats
::
NHCWc4
,
TensorFormats
::
NHCWc4
,
TensorFormats
::
NHCWc4
};
}
}
else
{
mgb_assert
(
conv
.
param
().
sparse
==
Opr
::
Param
::
Sparse
::
GROUP
);
if
(
is_channel_wise_conv
<
Opr
>
(
opr
))
{
config
.
input_tensor_formats
=
{
TensorFormats
::
C1RSc4
,
TensorFormats
::
NHCWc4
,
TensorFormats
::
NHCWc4
,
TensorFormats
::
NHCWc4
};
}
else
{
if
(
opr
->
input
(
0
)
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS8
||
opr
->
input
(
0
)
->
dtype
().
enumv
()
==
DTypeEnum
::
Quantized8Asymm
)
{
config
.
input_tensor_formats
=
{
TensorFormats
::
GKRSCk4c4
,
TensorFormats
::
NHCWc4
,
TensorFormats
::
NHCWc4
,
TensorFormats
::
NHCWc4
};
}
else
{
config
.
input_tensor_formats
=
{
TensorFormats
::
GKRSCk4
,
TensorFormats
::
NHCWc4
,
TensorFormats
::
NHCWc4
,
TensorFormats
::
NHCWc4
};
}
}
}
config
.
output_tensor_formats
=
{
TensorFormats
::
NHCWc4
};
return
config
;
}
};
struct
StaticData
{
struct
KeyHash
{
size_t
operator
()(
const
std
::
pair
<
Typeinfo
*
,
OprFormatConfigID
>&
val
)
const
{
...
...
@@ -969,6 +1093,7 @@ StaticData::StaticData() {
OPR_TENSOR_FORMATS_CONFIG_REG
(
ConvBias
,
NCHW44_DOT
);
OPR_TENSOR_FORMATS_CONFIG_REG
(
ConvBias
,
NCHW44_HYBRID
);
OPR_TENSOR_FORMATS_CONFIG_REG
(
ConvBias
,
NCHW44_DOT_HYBRID
);
OPR_TENSOR_FORMATS_CONFIG_REG
(
ConvBias
,
NHWCD4
);
OPR_TENSOR_FORMATS_CONFIG_REG
(
ConvolutionForward
,
NCHW
);
OPR_TENSOR_FORMATS_CONFIG_REG
(
ConvolutionForward
,
NHWC
);
...
...
@@ -979,15 +1104,18 @@ StaticData::StaticData() {
OPR_TENSOR_FORMATS_CONFIG_REG
(
ConvolutionForward
,
NCHW44_DOT
);
OPR_TENSOR_FORMATS_CONFIG_REG
(
ConvolutionForward
,
NCHW44_HYBRID
);
OPR_TENSOR_FORMATS_CONFIG_REG
(
ConvolutionForward
,
NCHW44_DOT_HYBRID
);
OPR_TENSOR_FORMATS_CONFIG_REG
(
ConvolutionForward
,
NHWCD4
);
OPR_TENSOR_FORMATS_CONFIG_REG
(
ConvolutionBackwardData
,
NCHW
);
OPR_TENSOR_FORMATS_CONFIG_REG
(
ConvolutionBackwardData
,
NHWC
);
OPR_TENSOR_FORMATS_CONFIG_REG
(
ConvolutionBackwardData
,
NCHW4
);
OPR_TENSOR_FORMATS_CONFIG_REG
(
ConvolutionBackwardData
,
NHWCD4
);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG
(
WarpPerspectiveForward
,
NCHW
);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG
(
WarpPerspectiveForward
,
NHWC
);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG
(
WarpPerspectiveForward
,
NCHW4
);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG
(
WarpPerspectiveForward
,
NCHW64
);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG
(
WarpPerspectiveForward
,
NHWCD4
);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG
(
PoolingForward
,
NCHW
);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG
(
PoolingForward
,
NHWC
);
...
...
@@ -997,10 +1125,12 @@ StaticData::StaticData() {
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG
(
PoolingForward
,
NCHW64
);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG
(
PoolingForward
,
NCHW44
);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG
(
PoolingForward
,
NCHW88
);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG
(
PoolingForward
,
NHWCD4
);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG
(
ResizeForward
,
NCHW
);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG
(
ResizeForward
,
NCHW44
);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG
(
ResizeForward
,
NCHW88
);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG
(
ResizeForward
,
NHWCD4
);
#undef OPR_TENSOR_FORMATS_CONFIG_REG
#undef OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG
...
...
src/gopt/impl/global_layout_transform/profiler_impl.cpp
浏览文件 @
0ad5eeae
...
...
@@ -22,6 +22,7 @@
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/plugin/base.h"
#include "megbrain/serialization/sereg.h"
#include "megdnn/tensor_format.h"
using
namespace
mgb
;
using
namespace
cg
;
...
...
@@ -281,9 +282,6 @@ float ProfilerImpl::profile_operator(
std
::
min
(
config
.
input_tensor_formats
.
size
(),
opr
->
input
().
size
());
for
(;
i
<
nr_input_tensor
;
++
i
)
{
auto
&&
var
=
opr
->
input
(
i
);
auto
&&
cn
=
var
->
comp_node
();
auto
&&
dtype
=
var
->
dtype
();
auto
dval
=
std
::
make_shared
<
DeviceTensorND
>
(
cn
,
dtype
);
TensorShape
aligned_shape
;
if
(
config
.
input_tensor_types
[
i
]
==
TensorType
::
WEIGHT
)
{
mgb_assert
(
base_config
.
input_tensor_types
[
i
]
==
TensorType
::
WEIGHT
);
...
...
@@ -299,9 +297,12 @@ float ProfilerImpl::profile_operator(
var
,
base_config
.
input_tensor_formats
[
i
],
config
.
input_tensor_formats
[
i
],
extra_attribute
);
}
dval
->
resize
(
aligned_shape
);
std
::
shared_ptr
<
DeviceTensorND
>
dval
=
create_device_tensor_helper
(
config
,
i
,
var
,
aligned_shape
,
extra_attribute
);
if
(
config
.
input_tensor_types
[
i
]
==
TensorType
::
WEIGHT
)
{
new_inps
[
i
]
=
opr
::
SharedDeviceTensor
::
make_const
(
*
graph
,
dval
).
node
();
new_inps
[
i
]
=
opr
::
SharedDeviceTensorWithFormat
::
make_const
(
*
graph
,
dval
).
node
();
}
else
{
new_inps
[
i
]
=
opr
::
VolatileSharedDeviceTensor
::
make
(
*
graph
,
dval
).
node
();
}
...
...
@@ -368,10 +369,27 @@ float ProfilerImpl::profile_var_node(
const
VarNode
*
var
,
TensorFormats
base_format
,
const
ReformatKey
&
key
)
const
{
auto
&&
cn
=
var
->
comp_node
();
auto
&&
dtype
=
var
->
dtype
();
auto
dval
=
std
::
make_shared
<
DeviceTensorND
>
(
cn
,
dtype
);
auto
aligned_tensor_shape
=
ReformatManager
::
make_aligned_tensor_shape
(
var
,
base_format
,
key
.
input_format
,
key
.
attribute
);
dval
->
resize
(
aligned_tensor_shape
);
std
::
shared_ptr
<
DeviceTensorND
>
dval
;
if
(
key
.
input_format
==
TensorFormats
::
NHCWc4
&&
key
.
attribute
&
ReformatAttribute
::
IMAGE2D
)
{
size_t
align_axis
=
2
;
auto
named_tensor
=
tensor_formats_to_named_tensor_shape
(
key
.
input_format
);
for
(
size_t
n
=
0
;
n
<
named_tensor
.
ndim
;
n
++
)
{
if
(
named_tensor
[
n
].
name
()
==
megdnn
::
Dimension
::
Name
::
C
)
{
align_axis
=
n
;
break
;
}
}
dval
=
std
::
make_shared
<
DeviceTensorND
>
(
cn
,
aligned_tensor_shape
,
dtype
,
megdnn
::
Image2DPack4TensorFormat
::
make
(
align_axis
,
opr
::
intl
::
get_megdnn_handle
(
cn
)));
}
else
dval
=
std
::
make_shared
<
DeviceTensorND
>
(
cn
,
aligned_tensor_shape
,
dtype
);
auto
graph
=
ComputingGraph
::
make
();
graph
->
options
().
graph_opt_level
=
0
;
graph
->
options
().
var_sanity_check_first_run
=
false
;
...
...
@@ -516,6 +534,8 @@ ProfilerImpl::OprFormatConfigID ProfilerImpl::tensor_formats_to_config_id(
return
OprFormatConfigID
::
NHWC
;
case
TensorFormats
::
CHWNc4
:
return
OprFormatConfigID
::
CHWN4
;
case
TensorFormats
::
NHCWc4
:
return
OprFormatConfigID
::
NHWCD4
;
default:
mgb_throw
(
MegBrainError
,
"tensor format(%u) is not supported"
,
...
...
@@ -523,6 +543,39 @@ ProfilerImpl::OprFormatConfigID ProfilerImpl::tensor_formats_to_config_id(
}
}
std
::
shared_ptr
<
DeviceTensorND
>
ProfilerImpl
::
create_device_tensor_helper
(
const
OprTensorFormatsConfiguration
&
config
,
const
size_t
inp_idx
,
const
VarNode
*
var
,
const
TensorShape
aligned_shape
,
ReformatAttribute
extra_attribute
)
const
{
auto
&&
cn
=
var
->
comp_node
();
auto
&&
dtype
=
var
->
dtype
();
std
::
shared_ptr
<
DeviceTensorND
>
dval
;
if
(
config
.
config_id
==
OprFormatConfigID
::
NHWCD4
&&
extra_attribute
&
ReformatAttribute
::
IMAGE2D
)
{
size_t
align_axis
=
2
;
auto
named_tensor
=
tensor_formats_to_named_tensor_shape
(
config
.
input_tensor_formats
[
inp_idx
]);
for
(
size_t
n
=
0
;
n
<
named_tensor
.
ndim
;
n
++
)
{
if
(
named_tensor
[
n
].
name
()
==
megdnn
::
Dimension
::
Name
::
C
)
{
align_axis
=
n
;
break
;
}
}
// channel wise weight
bool
is_channel_wise
=
config
.
input_tensor_formats
[
inp_idx
]
==
TensorFormats
::
C1RSc4
;
if
(
is_channel_wise
)
align_axis
=
1
;
dval
=
std
::
make_shared
<
DeviceTensorND
>
(
cn
,
aligned_shape
,
dtype
,
megdnn
::
Image2DPack4TensorFormat
::
make
(
align_axis
,
opr
::
intl
::
get_megdnn_handle
(
cn
)));
}
else
{
dval
=
std
::
make_shared
<
DeviceTensorND
>
(
cn
,
aligned_shape
,
dtype
);
}
return
dval
;
}
/* ================== ProfilerBase =================*/
std
::
string
ProfilerBase
::
OperatorNodeRecord
::
to_string
()
const
{
auto
str
=
ssprintf
(
...
...
src/gopt/impl/global_layout_transform/reformat_manager.cpp
浏览文件 @
0ad5eeae
...
...
@@ -249,7 +249,7 @@ ReformatManager::ReformatManager() {
m_cache
.
emplace
(
ReformatKey
{
i
,
o
,
Attribute
::
IMAGE2D
},
impl
);
}
{
auto
i
=
TensorFormats
::
KCRS
,
o
=
TensorFormats
::
GKRSCk4
;
auto
i
=
TensorFormats
::
G
KCRS
,
o
=
TensorFormats
::
GKRSCk4
;
auto
&&
impl
=
[](
const
VarNodeArray
&
vars
)
{
return
opr
::
RelayoutFormat
::
make
(
vars
[
0
],
...
...
@@ -259,7 +259,7 @@ ReformatManager::ReformatManager() {
m_cache
.
emplace
(
ReformatKey
{
i
,
o
,
Attribute
::
IMAGE2D
},
impl
);
}
{
auto
i
=
TensorFormats
::
KC
RS
,
o
=
TensorFormats
::
C1RSc4
;
auto
i
=
TensorFormats
::
C11
RS
,
o
=
TensorFormats
::
C1RSc4
;
auto
&&
impl
=
[](
const
VarNodeArray
&
vars
)
{
return
opr
::
RelayoutFormat
::
make
(
vars
[
0
],
...
...
@@ -268,6 +268,21 @@ ReformatManager::ReformatManager() {
};
m_cache
.
emplace
(
ReformatKey
{
i
,
o
,
Attribute
::
IMAGE2D
},
impl
);
}
{
auto
i
=
TensorFormats
::
NCHW
,
o
=
TensorFormats
::
NHCWc4
;
auto
&&
impl1
=
[](
const
VarNodeArray
&
vars
)
{
return
opr
::
RelayoutFormat
::
make
(
vars
[
0
],
megdnn
::
param
::
RelayoutFormat
::
Mode
::
NCHW_NHWCD4
)
.
node
();
};
m_cache
.
emplace
(
ReformatKey
{
i
,
o
},
impl1
);
auto
&&
impl2
=
[](
const
VarNodeArray
&
vars
)
{
return
opr
::
RelayoutFormat
::
make
(
vars
[
0
],
megdnn
::
param
::
RelayoutFormat
::
Mode
::
NHWCD4_NCHW
)
.
node
();
};
m_cache
.
emplace
(
ReformatKey
{
o
,
i
},
impl2
);
}
{
auto
i
=
TensorFormats
::
NCHW
,
o
=
TensorFormats
::
NHCWc4
;
auto
&&
impl
=
[](
const
VarNodeArray
&
vars
)
{
...
...
@@ -281,7 +296,7 @@ ReformatManager::ReformatManager() {
auto
i
=
TensorFormats
::
NHCWc4
,
o
=
TensorFormats
::
NCHW
;
auto
&&
impl
=
[](
const
VarNodeArray
&
vars
)
{
return
opr
::
RelayoutFormat
::
make
(
vars
[
0
],
megdnn
::
param
::
RelayoutFormat
::
Mode
::
N
CHW_NHWCD4I
)
vars
[
0
],
megdnn
::
param
::
RelayoutFormat
::
Mode
::
N
HWCD4I_NCHW
)
.
node
();
};
m_cache
.
emplace
(
ReformatKey
{
i
,
o
,
Attribute
::
IMAGE2D
},
impl
);
...
...
@@ -346,6 +361,15 @@ ReformatManager::ReformatImpl ReformatManager::get(const ReformatKey& key) const
return
rst
;
}
}
if
(
key
.
attribute
==
Attribute
::
IMAGE2D
)
{
auto
key_
=
key
;
key_
.
input_dtype
=
DTypeEnum
::
Float32
;
key_
.
output_dtype
=
DTypeEnum
::
Float32
;
auto
find
=
m_cache
.
find
(
key_
);
if
(
find
!=
m_cache
.
end
())
{
return
find
->
second
;
}
}
mgb_assert
(
!
(
key
.
attribute
&
Attribute
::
IMAGE2D
)
&&
!
(
key
.
attribute
&
Attribute
::
IC_SMALL
));
...
...
@@ -682,7 +706,8 @@ TensorShape ReformatManager::make_aligned_weight_shape(
auto
target_shape
=
tensor_formats_to_named_tensor_shape
(
target_formats
);
for
(
size_t
i
=
0
;
i
<
target_shape
.
ndim
;
++
i
)
{
auto
name
=
target_shape
[
i
].
name
();
if
((
name
==
Dimension
::
Name
::
K
||
name
==
Dimension
::
Name
::
N
)
&&
if
((
name
==
Dimension
::
Name
::
K
||
name
==
Dimension
::
Name
::
N
||
(
extra_formats
==
TensorFormats
::
NHCWc4
&&
name
==
Dimension
::
Name
::
C
))
&&
target_shape
[
i
].
extent
()
==
UNDETERMINED_EXTENT
)
{
size_t
out_channels
=
tshp
[
i
]
*
target_shape
[
i
].
stride
();
tshp
[
i
]
=
divup
(
out_channels
,
out_channel_alignment
)
*
...
...
src/gopt/impl/global_layout_transform/utils.h
浏览文件 @
0ad5eeae
...
...
@@ -32,6 +32,7 @@ static inline const char* opr_format_to_string(
cb
(
NCHW44
);
cb
(
NCHW88
);
cb
(
NCHW44_DOT
);
cb
(
NHWCD4
);
default:
mgb_assert
(
false
,
"Invalid opr format(got:%u)"
,
...
...
@@ -63,6 +64,7 @@ static inline const char* config_id_to_string(
cb
(
NCHW88_HYBRID
);
cb
(
NCHW44_DOT
);
cb
(
NCHW44_DOT_HYBRID
);
cb
(
NHWCD4
);
default:
mgb_assert
(
false
,
"Invalid config id(got:%u)"
,
...
...
@@ -95,6 +97,8 @@ static inline TensorFormats opr_format_to_tensor_formats(
return
TensorFormats
::
NCHWc8
;
case
OprFormat
::
NCHW44_DOT
:
return
TensorFormats
::
NCHWc4
;
case
OprFormat
::
NHWCD4
:
return
TensorFormats
::
NHCWc4
;
default:
mgb_throw
(
AssertionError
,
"format(%s) is not supported"
,
...
...
src/gopt/include/megbrain/gopt/profiler.h
浏览文件 @
0ad5eeae
...
...
@@ -202,6 +202,11 @@ protected:
const
ReformatKey
&
key
)
const
;
OprFormatConfigID
tensor_formats_to_config_id
(
TensorFormats
tensor_format
)
const
;
std
::
shared_ptr
<
DeviceTensorND
>
create_device_tensor_helper
(
const
OprTensorFormatsConfiguration
&
config
,
const
size_t
inp_idx
,
const
VarNode
*
var
,
const
TensorShape
aligned_shape
,
ReformatAttribute
extra_attribute
)
const
;
OprFootprint
m_opr_footprint
;
float
m_opr_threshold
;
/// a threshold, when the computation of the newly
/// created operator that is built in some opr
...
...
src/gopt/test/cache_data.h
浏览文件 @
0ad5eeae
此差异由.gitattributes 抑制。
src/opr/impl/io.cpp
浏览文件 @
0ad5eeae
...
...
@@ -336,6 +336,10 @@ cg::OperatorNodeBase::NodeProp* VolatileSharedDeviceTensor::do_make_node_prop()
return
ret
;
}
void
VolatileSharedDeviceTensor
::
init_output_format
()
{
output
(
0
)
->
format
(
get_dev_tensor
().
format
());
}
SymbolVar
VolatileSharedDeviceTensor
::
make
(
ComputingGraph
&
graph
,
const
std
::
shared_ptr
<
DeviceTensorND
>&
dev_data
,
const
OperatorNodeConfig
&
config
)
{
...
...
src/opr/include/megbrain/opr/io.h
浏览文件 @
0ad5eeae
...
...
@@ -337,6 +337,8 @@ MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
public
:
using
Super
::
Super
;
void
init_output_format
()
override
;
MGE_WIN_DECLSPEC_FUC
static
SymbolVar
make
(
ComputingGraph
&
graph
,
const
std
::
shared_ptr
<
DeviceTensorND
>&
dev_data
,
const
OperatorNodeConfig
&
config
=
{});
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录