Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
d2278f02
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
d2278f02
编写于
3月 23, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(imperative): speed up conv_transpose3d
GitOrigin-RevId: e741305446e926086c36affcb54d77f739133bbe
上级
3a5347ed
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
137 addition
and
48 deletion
+137
-48
dnn/include/megdnn/oprs/nn.h
dnn/include/megdnn/oprs/nn.h
+10
-0
dnn/src/common/convolution3d.cpp
dnn/src/common/convolution3d.cpp
+35
-23
dnn/src/common/pooling.cpp
dnn/src/common/pooling.cpp
+20
-17
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+5
-0
imperative/src/impl/algo_chooser.h
imperative/src/impl/algo_chooser.h
+2
-0
imperative/src/impl/dnn_op_helper.h
imperative/src/impl/dnn_op_helper.h
+1
-0
imperative/src/impl/ops/convolution.cpp
imperative/src/impl/ops/convolution.cpp
+59
-0
imperative/src/impl/ops/pooling.cpp
imperative/src/impl/ops/pooling.cpp
+5
-8
未找到文件。
dnn/include/megdnn/oprs/nn.h
浏览文件 @
d2278f02
...
@@ -784,6 +784,10 @@ public:
...
@@ -784,6 +784,10 @@ public:
protected:
protected:
void
deduce_layout_fwd
(
const
TensorLayout
&
src
,
TensorLayout
&
dst
);
void
deduce_layout_fwd
(
const
TensorLayout
&
src
,
TensorLayout
&
dst
);
void
check_layout_fwd
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
);
void
check_layout_fwd
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
);
public:
MGE_WIN_DECLSPEC_FUC
static
void
deduce_layout_impl
(
const
TensorLayout
&
src
,
const
Param
&
param
,
TensorLayout
&
dst
);
};
};
class
PoolingForward
:
public
PoolingBase
,
class
PoolingForward
:
public
PoolingBase
,
...
@@ -1241,6 +1245,8 @@ protected:
...
@@ -1241,6 +1245,8 @@ protected:
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
)
const
;
const
TensorLayout
&
dst
)
const
;
static
CanonizedFilterMeta
make_canonized_filter_meta_impl
(
size_t
src_ndim
,
const
TensorLayout
&
filter
,
const
Param
&
param
);
CanonizedFilterMeta
make_canonized_filter_meta
(
CanonizedFilterMeta
make_canonized_filter_meta
(
size_t
src_ndim
,
const
TensorLayout
&
filter
)
const
;
size_t
src_ndim
,
const
TensorLayout
&
filter
)
const
;
};
};
...
@@ -1286,6 +1292,10 @@ public:
...
@@ -1286,6 +1292,10 @@ public:
* \param[in] diff (n, oc, od, oh, ow)
* \param[in] diff (n, oc, od, oh, ow)
* \param[out] grad (n, ic, id, ih, iw)
* \param[out] grad (n, ic, id, ih, iw)
*/
*/
MGE_WIN_DECLSPEC_FUC
static
void
deduce_layout_impl
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
Param
&
param
,
TensorLayout
&
grad
);
virtual
void
exec
(
virtual
void
exec
(
_megdnn_tensor_in
filter
,
_megdnn_tensor_in
diff
,
_megdnn_tensor_out
grad
,
_megdnn_tensor_in
filter
,
_megdnn_tensor_in
diff
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
=
0
;
_megdnn_workspace
workspace
)
=
0
;
...
...
dnn/src/common/convolution3d.cpp
浏览文件 @
d2278f02
...
@@ -38,17 +38,18 @@ std::string get_errmsg(
...
@@ -38,17 +38,18 @@ std::string get_errmsg(
}
}
}
// namespace
}
// namespace
Convolution3DBase
::
CanonizedFilterMeta
Convolution3DBase
::
make_canonized_filter_meta
(
Convolution3DBase
::
CanonizedFilterMeta
Convolution3DBase
::
size_t
src_ndim
,
const
TensorLayout
&
filter
)
const
{
make_canonized_filter_meta_impl
(
size_t
src_ndim
,
const
TensorLayout
&
filter
,
const
Param
&
param
)
{
megdnn_assert_contiguous
(
filter
);
megdnn_assert_contiguous
(
filter
);
auto
img_ndim
=
src_ndim
-
2
;
auto
img_ndim
=
src_ndim
-
2
;
CanonizedFilterMeta
ret
;
CanonizedFilterMeta
ret
;
ret
.
dtype_enum
=
filter
.
dtype
.
enumv
();
ret
.
dtype_enum
=
filter
.
dtype
.
enumv
();
ret
.
format
=
param
()
.
format
;
ret
.
format
=
param
.
format
;
if
(
param
()
.
mode
==
Mode
::
CONVOLUTION
)
{
if
(
param
.
mode
==
Mode
::
CONVOLUTION
)
{
ret
.
should_flip
=
true
;
ret
.
should_flip
=
true
;
}
else
{
}
else
{
megdnn_assert
(
param
()
.
mode
==
Mode
::
CROSS_CORRELATION
,
"invalid conv mode"
);
megdnn_assert
(
param
.
mode
==
Mode
::
CROSS_CORRELATION
,
"invalid conv mode"
);
ret
.
should_flip
=
false
;
ret
.
should_flip
=
false
;
}
}
size_t
flt_start
,
flt_spatial_start
,
ocpg_pos
,
icpg_pos
;
size_t
flt_start
,
flt_spatial_start
,
ocpg_pos
,
icpg_pos
;
...
@@ -56,7 +57,7 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_
...
@@ -56,7 +57,7 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_
MEGDNN_MARK_USED_VAR
(
ocpg_pos
);
MEGDNN_MARK_USED_VAR
(
ocpg_pos
);
MEGDNN_MARK_USED_VAR
(
icpg_pos
);
MEGDNN_MARK_USED_VAR
(
icpg_pos
);
if
(
param
()
.
sparse
==
Param
::
Sparse
::
DENSE
)
{
if
(
param
.
sparse
==
Param
::
Sparse
::
DENSE
)
{
megdnn_assert
(
megdnn_assert
(
filter
.
ndim
==
img_ndim
+
2
,
filter
.
ndim
==
img_ndim
+
2
,
"bad filter ndim for dense convolution: "
"bad filter ndim for dense convolution: "
...
@@ -66,7 +67,7 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_
...
@@ -66,7 +67,7 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_
flt_start
=
0
;
flt_start
=
0
;
}
else
{
}
else
{
megdnn_assert
(
megdnn_assert
(
param
()
.
sparse
==
Param
::
Sparse
::
GROUP
,
param
.
sparse
==
Param
::
Sparse
::
GROUP
,
"invalid convolution sparse type"
);
"invalid convolution sparse type"
);
megdnn_assert
(
megdnn_assert
(
filter
.
ndim
==
img_ndim
+
3
,
filter
.
ndim
==
img_ndim
+
3
,
...
@@ -77,14 +78,14 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_
...
@@ -77,14 +78,14 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_
flt_start
=
1
;
flt_start
=
1
;
}
}
if
(
param
()
.
format
==
Param
::
Format
::
NCDHW
)
{
if
(
param
.
format
==
Param
::
Format
::
NCDHW
)
{
// filter should be (oc, ic, fd, fh, fw)
// filter should be (oc, ic, fd, fh, fw)
flt_spatial_start
=
2
;
flt_spatial_start
=
2
;
ocpg_pos
=
0
;
ocpg_pos
=
0
;
icpg_pos
=
1
;
icpg_pos
=
1
;
}
else
{
}
else
{
megdnn_assert
(
megdnn_assert
(
param
()
.
format
==
Param
::
Format
::
NDHWC
,
"invalid conv tensor format"
);
param
.
format
==
Param
::
Format
::
NDHWC
,
"invalid conv tensor format"
);
// filter should be (oc, fd, fh, fw, ic)
// filter should be (oc, fd, fh, fw, ic)
flt_spatial_start
=
1
;
flt_spatial_start
=
1
;
ocpg_pos
=
0
;
ocpg_pos
=
0
;
...
@@ -96,15 +97,15 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_
...
@@ -96,15 +97,15 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_
"only 3D convolution is supported, and input should be 5-dim; "
"only 3D convolution is supported, and input should be 5-dim; "
"got input dim = %zu"
,
"got input dim = %zu"
,
src_ndim
);
src_ndim
);
ret
.
stride
[
0
]
=
this
->
param
()
.
stride_d
;
ret
.
stride
[
0
]
=
param
.
stride_d
;
ret
.
stride
[
1
]
=
this
->
param
()
.
stride_h
;
ret
.
stride
[
1
]
=
param
.
stride_h
;
ret
.
stride
[
2
]
=
this
->
param
()
.
stride_w
;
ret
.
stride
[
2
]
=
param
.
stride_w
;
ret
.
padding
[
0
]
=
this
->
param
()
.
pad_d
;
ret
.
padding
[
0
]
=
param
.
pad_d
;
ret
.
padding
[
1
]
=
this
->
param
()
.
pad_h
;
ret
.
padding
[
1
]
=
param
.
pad_h
;
ret
.
padding
[
2
]
=
this
->
param
()
.
pad_w
;
ret
.
padding
[
2
]
=
param
.
pad_w
;
ret
.
dilation
[
0
]
=
param
()
.
dilate_d
;
ret
.
dilation
[
0
]
=
param
.
dilate_d
;
ret
.
dilation
[
1
]
=
param
()
.
dilate_h
;
ret
.
dilation
[
1
]
=
param
.
dilate_h
;
ret
.
dilation
[
2
]
=
param
()
.
dilate_w
;
ret
.
dilation
[
2
]
=
param
.
dilate_w
;
ret
.
ocpg
=
filter
[
flt_start
+
ocpg_pos
];
ret
.
ocpg
=
filter
[
flt_start
+
ocpg_pos
];
ret
.
icpg
=
filter
[
flt_start
+
icpg_pos
];
ret
.
icpg
=
filter
[
flt_start
+
icpg_pos
];
for
(
size_t
i
=
0
;
i
<
ret
.
spatial_ndim
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
ret
.
spatial_ndim
;
++
i
)
{
...
@@ -117,6 +118,11 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_
...
@@ -117,6 +118,11 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_
return
ret
;
return
ret
;
}
}
Convolution3DBase
::
CanonizedFilterMeta
Convolution3DBase
::
make_canonized_filter_meta
(
size_t
src_ndim
,
const
TensorLayout
&
filter
)
const
{
return
make_canonized_filter_meta_impl
(
src_ndim
,
filter
,
param
());
}
Convolution3DBase
::
CanonizedFilterMeta
Convolution3DBase
::
deduce_layout_fwd
(
Convolution3DBase
::
CanonizedFilterMeta
Convolution3DBase
::
deduce_layout_fwd
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
TensorLayout
&
dst
)
const
{
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
TensorLayout
&
dst
)
const
{
auto
errmsg
=
[
&
]()
{
return
get_errmsg
(
src
,
filter
,
dst
,
param
());
};
auto
errmsg
=
[
&
]()
{
return
get_errmsg
(
src
,
filter
,
dst
,
param
());
};
...
@@ -213,12 +219,13 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBackwardData::check_exec(
...
@@ -213,12 +219,13 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBackwardData::check_exec(
return
ret
;
return
ret
;
}
}
void
Convolution3DBackwardData
::
deduce_layout
(
void
Convolution3DBackwardData
::
deduce_layout_impl
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
TensorLayout
&
grad
)
{
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
Param
&
param
,
TensorLayout
&
grad
)
{
megdnn_assert
(
megdnn_assert
(
param
()
.
data_type
==
Param
::
DataType
::
FLOAT
,
param
.
data_type
==
Param
::
DataType
::
FLOAT
,
"only float type is supported for conv backward"
);
"only float type is supported for conv backward"
);
auto
errmsg
=
[
&
]()
{
return
get_errmsg
(
filter
,
diff
,
grad
,
param
()
);
};
auto
errmsg
=
[
&
]()
{
return
get_errmsg
(
filter
,
diff
,
grad
,
param
);
};
MEGDNN_MARK_USED_VAR
(
errmsg
);
MEGDNN_MARK_USED_VAR
(
errmsg
);
megdnn_assert_contiguous
(
filter
);
megdnn_assert_contiguous
(
filter
);
megdnn_assert_contiguous
(
diff
);
megdnn_assert_contiguous
(
diff
);
...
@@ -226,7 +233,7 @@ void Convolution3DBackwardData::deduce_layout(
...
@@ -226,7 +233,7 @@ void Convolution3DBackwardData::deduce_layout(
megdnn_assert
(
diff
.
ndim
==
5
_z
,
"%s"
,
errmsg
().
c_str
());
megdnn_assert
(
diff
.
ndim
==
5
_z
,
"%s"
,
errmsg
().
c_str
());
megdnn_assert
(
filter
.
dtype
==
diff
.
dtype
,
"%s"
,
errmsg
().
c_str
());
megdnn_assert
(
filter
.
dtype
==
diff
.
dtype
,
"%s"
,
errmsg
().
c_str
());
auto
cflt
=
make_canonized_filter_meta
(
diff
.
ndim
,
filter
);
auto
cflt
=
make_canonized_filter_meta
_impl
(
diff
.
ndim
,
filter
,
param
);
megdnn_assert
(
cflt
.
ocpg
*
cflt
.
group
==
diff
[
1
],
"%s"
,
errmsg
().
c_str
());
megdnn_assert
(
cflt
.
ocpg
*
cflt
.
group
==
diff
[
1
],
"%s"
,
errmsg
().
c_str
());
auto
deduce
=
[
&
errmsg
](
size_t
out
,
size_t
filter
,
size_t
stride
,
size_t
pad
)
{
auto
deduce
=
[
&
errmsg
](
size_t
out
,
size_t
filter
,
size_t
stride
,
size_t
pad
)
{
...
@@ -247,6 +254,11 @@ void Convolution3DBackwardData::deduce_layout(
...
@@ -247,6 +254,11 @@ void Convolution3DBackwardData::deduce_layout(
grad
.
init_contiguous_stride
();
grad
.
init_contiguous_stride
();
}
}
void
Convolution3DBackwardData
::
deduce_layout
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
TensorLayout
&
grad
)
{
deduce_layout_impl
(
filter
,
diff
,
param
(),
grad
);
}
Convolution3DBase
::
CanonizedFilterMeta
Convolution3DBackwardFilter
::
check_exec
(
Convolution3DBase
::
CanonizedFilterMeta
Convolution3DBackwardFilter
::
check_exec
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
,
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
,
size_t
workspace_in_bytes
)
{
size_t
workspace_in_bytes
)
{
...
...
dnn/src/common/pooling.cpp
浏览文件 @
d2278f02
...
@@ -15,22 +15,22 @@
...
@@ -15,22 +15,22 @@
namespace
megdnn
{
namespace
megdnn
{
void
PoolingBase
::
deduce_layout_
fwd
(
const
TensorLayout
&
src
,
TensorLayout
&
dst
)
{
void
PoolingBase
::
deduce_layout_
impl
(
auto
&
p
=
param
();
const
TensorLayout
&
src
,
const
Param
&
param
,
TensorLayout
&
dst
)
{
auto
pformat
=
p
.
format
;
auto
pformat
=
p
aram
.
format
;
// the overhead of generating error message is about 18x of the other part of this
// the overhead of generating error message is about 18x of the other part of this
// function so we use a function to wrap the error message and get it only when need.
// function so we use a function to wrap the error message and get it only when need.
auto
get_errmsg
=
[
&
](
void
)
->
std
::
string
{
auto
get_errmsg
=
[
&
](
void
)
->
std
::
string
{
std
::
string
errmsg
=
std
::
string
errmsg
=
megdnn_layout_msg
(
src
)
+
", "
+
megdnn_layout_msg
(
dst
)
+
", "
+
megdnn_layout_msg
(
src
)
+
", "
+
megdnn_layout_msg
(
dst
)
+
", "
+
"pad_h="
+
std
::
to_string
(
param
()
.
pad_h
)
+
", "
+
"pad_h="
+
std
::
to_string
(
param
.
pad_h
)
+
", "
+
"pad_w="
+
std
::
to_string
(
param
()
.
pad_w
)
+
", "
+
"pad_w="
+
std
::
to_string
(
param
.
pad_w
)
+
", "
+
"stride_h="
+
std
::
to_string
(
param
()
.
stride_h
)
+
", "
+
"stride_h="
+
std
::
to_string
(
param
.
stride_h
)
+
", "
+
"stride_w="
+
std
::
to_string
(
param
()
.
stride_w
)
+
", "
+
"stride_w="
+
std
::
to_string
(
param
.
stride_w
)
+
", "
+
"window_h="
+
std
::
to_string
(
param
()
.
window_h
)
+
", "
+
"window_h="
+
std
::
to_string
(
param
.
window_h
)
+
", "
+
"window_w="
+
std
::
to_string
(
param
()
.
window_w
)
+
", "
+
"window_w="
+
std
::
to_string
(
param
.
window_w
)
+
", "
+
"is_max="
+
std
::
to_string
(
param
()
.
mode
==
Mode
::
MAX
)
+
", "
+
"is_max="
+
std
::
to_string
(
param
.
mode
==
Mode
::
MAX
)
+
", "
+
"is_nhwc="
+
std
::
to_string
(
pformat
==
Param
::
Format
::
NHWC
)
+
", "
+
"is_nhwc="
+
std
::
to_string
(
pformat
==
Param
::
Format
::
NHWC
)
+
", "
+
"is_nhwcd4="
+
std
::
to_string
(
pformat
==
Param
::
Format
::
NHWCD4
);
"is_nhwcd4="
+
std
::
to_string
(
pformat
==
Param
::
Format
::
NHWCD4
);
return
errmsg
;
return
errmsg
;
...
@@ -90,12 +90,12 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst)
...
@@ -90,12 +90,12 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst)
c
*=
64
;
c
*=
64
;
}
}
size_t
oh
,
ow
;
size_t
oh
,
ow
;
size_t
fh
=
p
.
window_h
;
size_t
fh
=
p
aram
.
window_h
;
size_t
fw
=
p
.
window_w
;
size_t
fw
=
p
aram
.
window_w
;
size_t
sh
=
p
.
stride_h
;
size_t
sh
=
p
aram
.
stride_h
;
size_t
sw
=
p
.
stride_w
;
size_t
sw
=
p
aram
.
stride_w
;
size_t
ph
=
p
.
pad_h
;
size_t
ph
=
p
aram
.
pad_h
;
size_t
pw
=
p
.
pad_w
;
size_t
pw
=
p
aram
.
pad_w
;
// moving some python assert to here
// moving some python assert to here
// megdnn_assert()
// megdnn_assert()
...
@@ -128,12 +128,15 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst)
...
@@ -128,12 +128,15 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst)
}
}
}
}
void
PoolingBase
::
deduce_layout_fwd
(
const
TensorLayout
&
src
,
TensorLayout
&
dst
)
{
deduce_layout_impl
(
src
,
param
(),
dst
);
}
void
PoolingBase
::
check_layout_fwd
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
{
void
PoolingBase
::
check_layout_fwd
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
{
TensorLayout
dst_expected
;
TensorLayout
dst_expected
;
megdnn_assert_eq_dtype
(
src
,
dst
);
megdnn_assert_eq_dtype
(
src
,
dst
);
deduce_layout_fwd
(
src
,
dst_expected
);
deduce_layout_fwd
(
src
,
dst_expected
);
megdnn_assert_eq_layout
(
dst_expected
,
dst
);
megdnn_assert_eq_layout
(
dst_expected
,
dst
);
megdnn_assert
(
src
.
dtype
==
dst
.
dtype
);
megdnn_assert
(
megdnn_assert
(
src
.
dtype
.
category
()
==
DTypeCategory
::
FLOAT
||
src
.
dtype
.
category
()
==
DTypeCategory
::
FLOAT
||
src
.
dtype
==
dtype
::
Int8
()
||
src
.
dtype
==
dtype
::
Int8
()
||
...
...
imperative/python/megengine/functional/nn.py
浏览文件 @
d2278f02
...
@@ -93,12 +93,17 @@ __all__ = [
...
@@ -93,12 +93,17 @@ __all__ = [
def
expand_hw
(
x
):
def
expand_hw
(
x
):
# judge int is 5 times faster than judge Sequence
if
isinstance
(
x
,
int
):
return
x
,
x
if
isinstance
(
x
,
Sequence
):
if
isinstance
(
x
,
Sequence
):
return
int
(
x
[
0
]),
int
(
x
[
1
])
return
int
(
x
[
0
]),
int
(
x
[
1
])
return
int
(
x
),
int
(
x
)
return
int
(
x
),
int
(
x
)
def
expand_dhw
(
x
):
def
expand_dhw
(
x
):
if
isinstance
(
x
,
int
):
return
x
,
x
,
x
if
isinstance
(
x
,
Sequence
):
if
isinstance
(
x
,
Sequence
):
return
int
(
x
[
0
]),
int
(
x
[
1
]),
int
(
x
[
2
])
return
int
(
x
[
0
]),
int
(
x
[
1
]),
int
(
x
[
2
])
return
int
(
x
),
int
(
x
),
int
(
x
)
return
int
(
x
),
int
(
x
),
int
(
x
)
...
...
imperative/src/impl/algo_chooser.h
浏览文件 @
d2278f02
#pragma once
#include "megbrain/rdnn/algo_chooser.h"
#include "megbrain/rdnn/algo_chooser.h"
#include "megdnn/heuristic_cache.h"
#include "megdnn/heuristic_cache.h"
...
...
imperative/src/impl/dnn_op_helper.h
浏览文件 @
d2278f02
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
*/
#pragma once
#include "megbrain/comp_node.h"
#include "megbrain/comp_node.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/comp_node_env.h"
...
...
imperative/src/impl/ops/convolution.cpp
浏览文件 @
d2278f02
...
@@ -579,6 +579,63 @@ OP_TRAIT_REG(Convolution3D, Convolution3D, opr::Convolution3D)
...
@@ -579,6 +579,63 @@ OP_TRAIT_REG(Convolution3D, Convolution3D, opr::Convolution3D)
namespace
{
namespace
{
namespace
convolution3d_backward_data
{
namespace
convolution3d_backward_data
{
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
mgb_assert
(
inputs
.
size
()
==
2
,
"inputs num of conv_transpose3d should be 2 but you give %zu"
,
inputs
.
size
());
auto
&&
op_def
=
def
.
cast_final_safe
<
Convolution3DBackwardData
>
();
auto
&&
weight
=
inputs
[
0
];
auto
&&
diff
=
inputs
[
1
];
auto
&
cn
=
weight
.
comp_node
;
if
(
weight
.
layout
.
ndim
==
0
)
{
return
{{{
TensorLayout
{
weight
.
layout
.
dtype
},
cn
,
{}}},
false
};
}
TensorLayout
oup_layout
;
megdnn
::
Convolution3DBackwardData
::
deduce_layout_impl
(
weight
.
layout
,
diff
.
layout
,
op_def
.
param
(),
oup_layout
);
return
{{{
oup_layout
,
cn
,
{}}},
true
};
}
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
,
SmallVector
<
LogicalTensorDesc
>&
output_descs
,
const
bool
&
validated
)
{
auto
&&
op_def
=
def
.
cast_final_safe
<
Convolution3DBackwardData
>
();
auto
cn
=
inputs
[
0
]
->
comp_node
();
megdnn
::
TensorND
weight
=
inputs
[
0
]
->
dnn_tensor
();
megdnn
::
TensorND
diff
=
inputs
[
1
]
->
dnn_tensor
();
DnnOprCaller
<
megdnn
::
Convolution3DBackwardData
>
caller
(
cn
);
auto
&&
dnn_opr
=
caller
.
op
;
dnn_opr
->
param
()
=
op_def
.
param
();
TensorLayout
&
oup_layout
=
output_descs
[
0
].
layout
;
if
(
!
validated
)
{
megdnn
::
Convolution3DBackwardData
::
deduce_layout_impl
(
weight
.
layout
,
diff
.
layout
,
op_def
.
param
(),
oup_layout
);
}
DeviceTensorND
oup
=
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
cn
,
oup_layout
);
size_t
wk_size
=
setup_algo
<
megdnn
::
Convolution3DBackwardData
>
(
{
weight
.
layout
,
diff
.
layout
,
oup_layout
},
dnn_opr
.
get
(),
0
,
false
,
false
,
cn
,
op_def
.
policy
(),
false
);
megdnn
::
Workspace
dnn_wk
;
if
(
wk_size
!=
0
)
{
auto
wk
=
Blob
::
make
(
cn
,
wk_size
);
dnn_wk
.
raw_ptr
=
wk
->
storage
().
get
();
dnn_wk
.
size
=
wk_size
;
}
dnn_opr
->
exec
(
weight
,
diff
,
oup
.
as_megdnn
(),
dnn_wk
);
return
{
Tensor
::
make
(
oup
)};
}
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
conv
=
static_cast
<
const
Convolution3DBackwardData
&>
(
def
);
auto
&&
conv
=
static_cast
<
const
Convolution3DBackwardData
&>
(
def
);
OperatorNodeConfig
config
{
conv
.
make_name
()};
OperatorNodeConfig
config
{
conv
.
make_name
()};
...
@@ -589,6 +646,8 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
...
@@ -589,6 +646,8 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
OP_TRAIT_REG
(
Convolution3DBackwardData
,
Convolution3DBackwardData
)
OP_TRAIT_REG
(
Convolution3DBackwardData
,
Convolution3DBackwardData
)
.
apply_on_var_node
(
apply_on_var_node
)
.
apply_on_var_node
(
apply_on_var_node
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible
)
.
apply_on_physical_tensor
(
apply_on_physical_tensor
)
.
fallback
();
.
fallback
();
}
// namespace convolution3d_backward_data
}
// namespace convolution3d_backward_data
}
// namespace
}
// namespace
...
...
imperative/src/impl/ops/pooling.cpp
浏览文件 @
d2278f02
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/utils/stats.h"
#include "megbrain/opr/utility.h"
#include "megbrain/opr/utility.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
...
@@ -25,9 +26,6 @@ namespace mgb::imperative {
...
@@ -25,9 +26,6 @@ namespace mgb::imperative {
namespace
{
namespace
{
namespace
pooling
{
namespace
pooling
{
// using OprHandle = opr::intl::UniqPtrWithCN<megdnn::Pooling>;
// static ThinHashMap<CompNode, OprHandle> dnn_oprs;
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
pool
=
static_cast
<
const
Pooling
&>
(
def
);
auto
&&
pool
=
static_cast
<
const
Pooling
&>
(
def
);
OperatorNodeConfig
config
{
pool
.
make_name
()};
OperatorNodeConfig
config
{
pool
.
make_name
()};
...
@@ -48,11 +46,9 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
...
@@ -48,11 +46,9 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return
{{{
TensorLayout
{
inp
.
layout
.
dtype
},
inp_cn
,
{}}},
false
};
return
{{{
TensorLayout
{
inp
.
layout
.
dtype
},
inp_cn
,
{}}},
false
};
}
}
DnnOprCaller
<
megdnn
::
Pooling
>
caller
(
inp_cn
);
auto
&&
dnn_opr
=
caller
.
op
;
dnn_opr
->
param
()
=
op_def
.
param
();
TensorLayout
oup_layout
;
TensorLayout
oup_layout
;
dnn_opr
->
deduce_layout
(
inp
.
layout
,
oup_layout
);
megdnn
::
Pooling
::
deduce_layout_impl
(
inp
.
layout
,
op_def
.
param
(),
oup_layout
);
return
{{{
oup_layout
,
inp_cn
,
{}}},
true
};
return
{{{
oup_layout
,
inp_cn
,
{}}},
true
};
}
}
...
@@ -73,7 +69,8 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
...
@@ -73,7 +69,8 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
TensorLayout
&
oup_layout
=
output_descs
[
0
].
layout
;
TensorLayout
&
oup_layout
=
output_descs
[
0
].
layout
;
if
(
!
validated
)
{
if
(
!
validated
)
{
dnn_opr
->
deduce_layout
(
inp_tensornd
.
layout
,
oup_layout
);
megdnn
::
Pooling
::
deduce_layout_impl
(
inp_tensornd
.
layout
,
op_def
.
param
(),
oup_layout
);
}
}
DeviceTensorND
out_devtensor
=
DeviceTensorND
out_devtensor
=
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
cn
,
oup_layout
);
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
cn
,
oup_layout
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录