Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
72a70dd6
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
72a70dd6
编写于
3月 14, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(imperative): specialize convolution implementation
GitOrigin-RevId: 33634c550f958db7d971c0b0ebfc15b89d19781b
上级
12a3ef8d
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
404 addition
and
29 deletion
+404
-29
dnn/include/megdnn/heuristic_cache.h
dnn/include/megdnn/heuristic_cache.h
+3
-3
imperative/src/impl/algo_chooser.h
imperative/src/impl/algo_chooser.h
+54
-0
imperative/src/impl/ops/convolution.cpp
imperative/src/impl/ops/convolution.cpp
+347
-2
imperative/src/impl/ops/specializations.cpp
imperative/src/impl/ops/specializations.cpp
+0
-24
未找到文件。
dnn/include/megdnn/heuristic_cache.h
浏览文件 @
72a70dd6
...
...
@@ -26,7 +26,7 @@ private:
HeuristicCache
()
=
default
;
public:
static
HeuristicCache
&
instance
();
MGE_WIN_DECLSPEC_FUC
static
HeuristicCache
&
instance
();
struct
KeyStorage
{
std
::
string
category
;
...
...
@@ -67,9 +67,9 @@ public:
size_t
workspace
;
};
void
put
(
const
Key
&
key
,
Result
&
result
);
MGE_WIN_DECLSPEC_FUC
void
put
(
const
Key
&
key
,
Result
&
result
);
Result
get
(
const
Key
&
key
);
MGE_WIN_DECLSPEC_FUC
Result
get
(
const
Key
&
key
);
void
clear
();
...
...
imperative/src/impl/algo_chooser.h
0 → 100644
浏览文件 @
72a70dd6
#include "megbrain/rdnn/algo_chooser.h"
#include "megdnn/heuristic_cache.h"
namespace
mgb
{
namespace
imperative
{
template
<
typename
Opr
>
MGE_WIN_DECLSPEC_FUC
size_t
setup_algo
(
const
typename
mgb
::
rdnn
::
AlgoChooser
<
Opr
>::
FixedTensorLayouts
&
layouts
,
Opr
*
megdnn_opr
,
uint32_t
shared_batch_size
,
bool
binary_equal_between_batch
,
bool
no_profiling_on_shape_change
,
CompNode
comp_node
,
megdnn
::
param
::
ExecutionPolicy
execution_policy
,
bool
allow_weight_preprocess
)
{
megdnn
::
HeuristicCache
::
Key
cache_key
(
megdnn_opr
->
handle
(),
megdnn_opr
->
get_opr_type
(),
layouts
.
data
(),
layouts
.
size
(),
&
megdnn_opr
->
param
(),
sizeof
(
megdnn_opr
->
param
()));
auto
rst
=
megdnn
::
HeuristicCache
::
instance
().
get
(
cache_key
);
if
(
rst
.
policy
.
algo
.
valid
())
{
megdnn_opr
->
execution_policy
()
=
rst
.
policy
;
return
rst
.
workspace
;
}
std
::
string
param_str
;
megdnn
::
Algorithm
::
serialize_write_pod
(
megdnn_opr
->
param
(),
param_str
);
rdnn
::
AlgoChooserDesc
desc
;
desc
.
shared_batch_size
=
shared_batch_size
;
desc
.
binary_equal_between_batch
=
binary_equal_between_batch
;
desc
.
no_profiling_on_shape_change
=
no_profiling_on_shape_change
;
desc
.
get_workspace_limit
=
[
&
](
CompNode
cn
,
size_t
old_limit
)
{
size_t
free
=
cn
.
get_free_mem
();
size_t
lmt
=
cn
.
get_max_block_size_available
();
return
std
::
max
(
lmt
,
free
);
};
using
AlgoChooserHelper
=
typename
mgb
::
rdnn
::
AlgoChooser
<
Opr
>::
AlgoChooserHelper
;
AlgoChooserHelper
helper
(
layouts
,
megdnn_opr
,
param_str
,
comp_node
,
execution_policy
,
allow_weight_preprocess
,
desc
);
megdnn
::
ExecutionPolicy
policy
;
policy
=
mgb
::
rdnn
::
AlgoChooser
<
Opr
>::
get_policy
(
helper
);
size_t
workspace
=
helper
.
get_workspace_size_bytes
(
policy
,
layouts
);
megdnn_opr
->
execution_policy
()
=
policy
;
if
(
execution_policy
.
strategy
&
rdnn
::
ExecutionStrategy
::
HEURISTIC
)
{
megdnn
::
HeuristicCache
::
Result
cache_result
{
policy
,
workspace
};
megdnn
::
HeuristicCache
::
instance
().
put
(
cache_key
,
cache_result
);
}
return
workspace
;
}
}
// namespace imperative
}
// namespace mgb
imperative/src/impl/ops/convolution.cpp
浏览文件 @
72a70dd6
...
...
@@ -10,14 +10,23 @@
*/
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/imperative/ops/autogen.h"
#include "../algo_chooser.h"
#include "../blob_manager_impl.h"
#include "../dnn_op_helper.h"
#include "../op_trait.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
namespace
mgb
{
namespace
imperative
{
namespace
{
size_t
infer_conv_shape
(
size_t
inp
,
size_t
flt
,
size_t
stride
,
size_t
pad
)
{
mgb_assert
(
inp
+
2
*
pad
>=
flt
,
"input=%zu padding=%zu filter=%zu"
,
inp
,
pad
,
flt
);
return
(
inp
+
2
*
pad
-
flt
)
/
stride
+
1
;
}
namespace
convolution
{
std
::
shared_ptr
<
OpDef
>
make_from_op_node
(
cg
::
OperatorNodeBase
*
node_
)
{
auto
*
node
=
&
node_
->
cast_final_safe
<
opr
::
Convolution
>
();
...
...
@@ -31,13 +40,199 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
inputs
[
0
],
inputs
[
1
],
conv
.
param
(),
conv
.
policy
(),
config
);
}
TensorLayout
do_shape_infer
(
const
OpDef
&
def
,
size_t
src_ndim
,
TensorLayout
src
,
TensorLayout
filter
)
{
auto
&&
conv
=
static_cast
<
const
Convolution
&>
(
def
);
using
Param
=
::
megdnn
::
param
::
Convolution
;
auto
img_ndim
=
src_ndim
-
2
;
mgb_assert
(
img_ndim
==
2
,
"only 2D convolution is supported, and input should be 4-dim; "
"got input dim = %zu"
,
src_ndim
);
size_t
group
=
1
;
size_t
flt_start
,
flt_spatial_start
,
ocpg_pos
,
icpg_pos
;
if
(
conv
.
sparse
==
Param
::
Sparse
::
DENSE
)
{
mgb_assert
(
filter
.
ndim
==
img_ndim
+
2
||
filter
.
ndim
==
img_ndim
+
4
,
"bad filter ndim for dense convolution: "
"spatial_ndim=%zu filter_ndim=%zu"
,
img_ndim
,
filter
.
ndim
);
group
=
1
;
flt_start
=
0
;
}
else
{
// Param::Sparse::GROUP
mgb_assert
(
filter
.
ndim
==
img_ndim
+
3
||
filter
.
ndim
==
img_ndim
+
5
,
"bad filter ndim for group convolution: "
"spatial_ndim=%zu filter_ndim=%zu"
,
img_ndim
,
filter
.
ndim
);
// grp, oc, ic, dims[]
group
=
filter
[
0
];
flt_start
=
1
;
}
uint32_t
ic_block_size
=
1
,
oc_block_size
=
1
;
size_t
src_or_dst_c_pos
=
0
;
size_t
src_or_dst_spatial_start
=
0
;
if
(
conv
.
format
==
Param
::
Format
::
NCHW
)
{
// filter should be (oc, ic, fh, fw)
flt_spatial_start
=
2
;
ocpg_pos
=
0
;
icpg_pos
=
1
;
src_or_dst_c_pos
=
1
;
src_or_dst_spatial_start
=
2
;
}
else
{
// Param::Format::NHWC
// filter should be (oc, fh, fw, ic)
flt_spatial_start
=
1
;
ocpg_pos
=
0
;
icpg_pos
=
3
;
src_or_dst_c_pos
=
3
;
src_or_dst_spatial_start
=
1
;
}
size_t
ocpg
=
filter
[
flt_start
+
ocpg_pos
]
*
oc_block_size
;
size_t
icpg
=
filter
[
flt_start
+
icpg_pos
]
*
ic_block_size
;
uint32_t
dilation
[
2
],
dilated_spatial
[
2
],
stride
[
2
],
padding
[
2
];
dilation
[
0
]
=
conv
.
dilate_h
;
dilation
[
1
]
=
conv
.
dilate_w
;
stride
[
0
]
=
conv
.
stride_h
;
stride
[
1
]
=
conv
.
stride_w
;
padding
[
0
]
=
conv
.
pad_h
;
padding
[
1
]
=
conv
.
pad_w
;
for
(
size_t
i
=
0
;
i
<
img_ndim
;
++
i
)
{
mgb_assert
(
dilation
[
i
]
>
0
,
"invalid dilation on spatial dim %zu: %u"
,
i
,
dilation
[
i
]);
dilated_spatial
[
i
]
=
(
filter
[
i
+
flt_start
+
flt_spatial_start
]
-
1
)
*
dilation
[
i
]
+
1
;
}
mgb_assert
(
icpg
*
group
==
src
[
src_or_dst_c_pos
],
"group conv invalid"
);
TensorLayout
dst
{
src
.
dtype
};
dst
.
ndim
=
src_ndim
;
dst
[
0
]
=
src
[
0
];
dst
[
src_or_dst_c_pos
]
=
ocpg
*
group
;
for
(
size_t
i
=
0
;
i
<
img_ndim
;
++
i
)
{
dst
[
i
+
src_or_dst_spatial_start
]
=
infer_conv_shape
(
src
[
i
+
src_or_dst_spatial_start
],
dilated_spatial
[
i
],
stride
[
i
],
padding
[
i
]);
}
dst
.
init_contiguous_stride
();
return
dst
;
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
auto
&&
conv
=
static_cast
<
const
Convolution
&>
(
def
);
using
Param
=
::
megdnn
::
param
::
Convolution
;
SmallVector
<
LogicalTensorDesc
>
dests
(
1
);
auto
&&
desc
=
dests
[
0
];
desc
.
comp_node
=
inputs
[
0
].
comp_node
;
TensorLayout
src
=
inputs
[
0
].
layout
;
size_t
src_ndim
=
src
.
ndim
;
if
(
src_ndim
==
0
)
{
desc
.
layout
=
src
;
return
{
dests
,
false
};
}
TensorLayout
filter
=
inputs
[
1
].
layout
;
desc
.
layout
=
do_shape_infer
(
def
,
src_ndim
,
src
,
filter
);
return
{
dests
,
true
};
}
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
,
SmallVector
<
LogicalTensorDesc
>&
output_descs
,
const
bool
&
validated
)
{
// create megdnn opr
auto
&&
conv
=
static_cast
<
const
Convolution
&>
(
def
);
CompNode
cn
=
inputs
[
0
]
->
comp_node
();
TensorLayout
out_layout
=
output_descs
[
0
].
layout
;
if
(
!
validated
)
out_layout
=
do_shape_infer
(
def
,
inputs
[
0
]
->
layout
().
ndim
,
inputs
[
0
]
->
layout
(),
inputs
[
1
]
->
layout
());
DeviceTensorND
out
=
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
cn
,
out_layout
);
using
TensorND
=
megdnn
::
TensorND
;
SmallVector
<
TensorND
>
inp_tensornds
(
inputs
.
size
());
TensorLayoutArray
inp_shapes
(
inputs
.
size
()),
oup_shapes
(
output_descs
.
size
());
for
(
unsigned
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
inp_tensornds
[
i
]
=
inputs
[
i
]
->
dnn_tensor
();
inp_shapes
[
i
]
=
inputs
[
i
]
->
layout
();
}
oup_shapes
[
0
]
=
out_layout
;
auto
&&
dnn_opr
=
opr
::
intl
::
create_megdnn_opr
<
megdnn
::
ConvBiasForward
>
(
cn
);
dnn_opr
->
param
().
pad_h
=
conv
.
pad_h
;
dnn_opr
->
param
().
pad_w
=
conv
.
pad_w
;
dnn_opr
->
param
().
stride_h
=
conv
.
stride_h
;
dnn_opr
->
param
().
stride_w
=
conv
.
stride_w
;
dnn_opr
->
param
().
dilate_h
=
conv
.
dilate_h
;
dnn_opr
->
param
().
dilate_w
=
conv
.
dilate_w
;
dnn_opr
->
param
().
sparse
=
conv
.
sparse
;
dnn_opr
->
param
().
compute_mode
=
conv
.
compute_mode
;
dnn_opr
->
param
().
format
=
conv
.
format
;
// shape infer
TensorLayout
shp
({
0
},
inputs
[
0
]
->
dtype
());
shp
.
ndim
=
0
;
size_t
sz
=
setup_algo
<
megdnn
::
ConvBiasForward
>
(
{
inp_shapes
[
0
],
inp_shapes
[
1
],
shp
,
shp
,
oup_shapes
[
0
]},
dnn_opr
.
get
(),
0
,
false
,
false
,
cn
,
conv
.
policy
(),
false
);
// alloc memory
DeviceTensorND
bias
=
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
cn
,
shp
);
auto
wk
=
Blob
::
make
(
cn
,
sz
);
auto
ptr
=
wk
->
storage
().
get
();
megdnn
::
Workspace
dnn_wk
(
ptr
,
sz
);
// exeucte
dnn_opr
->
exec
(
inp_tensornds
[
0
],
inp_tensornds
[
1
],
bias
.
as_megdnn
(),
bias
.
as_megdnn
(),
out
.
as_megdnn
(),
nullptr
,
dnn_wk
);
return
{
Tensor
::
make
(
out
)};
}
OP_TRAIT_REG
(
Convolution
,
Convolution
,
opr
::
Convolution
)
.
make_from_op_node
(
make_from_op_node
)
.
apply_on_var_node
(
apply_on_var_node
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible
)
.
apply_on_physical_tensor
(
apply_on_physical_tensor
)
.
fallback
();
}
// namespace convolution
}
// namespace
namespace
{
namespace
conv_bias
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
conv
=
static_cast
<
const
ConvBias
&>
(
def
);
cg
::
OperatorNodeConfig
config
{
conv
.
dtype
};
config
.
name
(
conv
.
make_name
());
if
(
inputs
.
size
()
==
2
)
{
return
opr
::
ConvBias
::
make
(
inputs
[
0
],
inputs
[
1
],
conv
.
param
(),
conv
.
policy
(),
config
);
}
else
if
(
inputs
.
size
()
==
3
)
{
return
opr
::
ConvBias
::
make
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
conv
.
param
(),
conv
.
policy
(),
config
);
}
else
if
(
inputs
.
size
()
==
4
)
{
return
opr
::
ConvBias
::
make
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
inputs
[
3
],
conv
.
param
(),
conv
.
policy
(),
config
);
}
mgb_assert
(
0
);
}
OP_TRAIT_REG
(
ConvBias
,
ConvBias
).
apply_on_var_node
(
apply_on_var_node
).
fallback
();
}
// namespace conv_bias
}
// namespace
namespace
{
namespace
convolution_backward_data
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
...
...
@@ -76,9 +271,159 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
return
opr
::
Convolution3D
::
make
(
inputs
[
0
],
inputs
[
1
],
conv
.
param
(),
conv
.
policy
());
}
TensorLayout
do_shape_infer
(
const
OpDef
&
def
,
size_t
src_ndim
,
TensorLayout
src
,
TensorLayout
filter
)
{
auto
&&
conv
=
static_cast
<
const
Convolution3D
&>
(
def
);
using
Param
=
::
megdnn
::
param
::
Convolution3D
;
auto
img_ndim
=
src_ndim
-
2
;
mgb_assert
(
img_ndim
==
3
,
"only 3D convolution is supported, and input should be 5-dim; "
"got input dim = %zu"
,
src_ndim
);
size_t
group
=
1
;
size_t
flt_start
,
flt_spatial_start
,
ocpg_pos
,
icpg_pos
;
if
(
conv
.
sparse
==
Param
::
Sparse
::
DENSE
)
{
mgb_assert
(
filter
.
ndim
==
img_ndim
+
2
||
filter
.
ndim
==
img_ndim
+
4
,
"bad filter ndim for dense convolution: "
"spatial_ndim=%zu filter_ndim=%zu"
,
img_ndim
,
filter
.
ndim
);
group
=
1
;
flt_start
=
0
;
}
else
{
// Param::Sparse::GROUP
mgb_assert
(
filter
.
ndim
==
img_ndim
+
3
||
filter
.
ndim
==
img_ndim
+
5
,
"bad filter ndim for group convolution: "
"spatial_ndim=%zu filter_ndim=%zu"
,
img_ndim
,
filter
.
ndim
);
// grp, oc, ic, dims[]
group
=
filter
[
0
];
flt_start
=
1
;
}
uint32_t
ic_block_size
=
1
,
oc_block_size
=
1
;
size_t
src_or_dst_c_pos
=
0
;
size_t
src_or_dst_spatial_start
=
0
;
if
(
conv
.
format
==
Param
::
Format
::
NCDHW
)
{
// filter should be (oc, ic, fd, fh, fw)
flt_spatial_start
=
2
;
ocpg_pos
=
0
;
icpg_pos
=
1
;
src_or_dst_c_pos
=
1
;
src_or_dst_spatial_start
=
2
;
}
else
{
// Param::Format::NDHWC
// filter should be (oc, fd, fh, fw, ic)
flt_spatial_start
=
1
;
ocpg_pos
=
0
;
icpg_pos
=
4
;
src_or_dst_c_pos
=
4
;
src_or_dst_spatial_start
=
1
;
}
size_t
ocpg
=
filter
[
flt_start
+
ocpg_pos
]
*
oc_block_size
;
size_t
icpg
=
filter
[
flt_start
+
icpg_pos
]
*
ic_block_size
;
uint32_t
dilation
[
3
],
dilated_spatial
[
3
],
stride
[
3
],
padding
[
3
];
dilation
[
0
]
=
conv
.
dilate_d
;
dilation
[
1
]
=
conv
.
dilate_h
;
dilation
[
2
]
=
conv
.
dilate_w
;
stride
[
0
]
=
conv
.
stride_d
;
stride
[
1
]
=
conv
.
stride_h
;
stride
[
2
]
=
conv
.
stride_w
;
padding
[
0
]
=
conv
.
pad_d
;
padding
[
1
]
=
conv
.
pad_h
;
padding
[
2
]
=
conv
.
pad_w
;
for
(
size_t
i
=
0
;
i
<
img_ndim
;
++
i
)
{
mgb_assert
(
dilation
[
i
]
>
0
,
"invalid dilation on spatial dim %zu: %u"
,
i
,
dilation
[
i
]);
dilated_spatial
[
i
]
=
(
filter
[
i
+
flt_start
+
flt_spatial_start
]
-
1
)
*
dilation
[
i
]
+
1
;
}
mgb_assert
(
icpg
*
group
==
src
[
src_or_dst_c_pos
],
"group conv invalid"
);
TensorLayout
dst
{
src
.
dtype
};
dst
.
ndim
=
src_ndim
;
dst
[
0
]
=
src
[
0
];
dst
[
src_or_dst_c_pos
]
=
ocpg
*
group
;
for
(
size_t
i
=
0
;
i
<
img_ndim
;
++
i
)
{
dst
[
i
+
src_or_dst_spatial_start
]
=
infer_conv_shape
(
src
[
i
+
src_or_dst_spatial_start
],
dilated_spatial
[
i
],
stride
[
i
],
padding
[
i
]);
}
dst
.
init_contiguous_stride
();
return
dst
;
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
auto
&&
conv
=
static_cast
<
const
Convolution3D
&>
(
def
);
using
Param
=
::
megdnn
::
param
::
Convolution3D
;
SmallVector
<
LogicalTensorDesc
>
dests
(
1
);
auto
&&
desc
=
dests
[
0
];
desc
.
comp_node
=
inputs
[
0
].
comp_node
;
TensorLayout
src
=
inputs
[
0
].
layout
;
size_t
src_ndim
=
src
.
ndim
;
if
(
src_ndim
==
0
)
{
return
{
dests
,
false
};
}
TensorLayout
filter
=
inputs
[
1
].
layout
;
desc
.
layout
=
do_shape_infer
(
def
,
src_ndim
,
src
,
filter
);
return
{
dests
,
true
};
}
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
,
SmallVector
<
LogicalTensorDesc
>&
output_descs
,
const
bool
&
validated
)
{
// create megdnn opr
auto
&&
conv
=
static_cast
<
const
Convolution3D
&>
(
def
);
TensorLayout
out_layout
=
output_descs
[
0
].
layout
;
if
(
!
validated
)
out_layout
=
do_shape_infer
(
def
,
inputs
[
0
]
->
layout
().
ndim
,
inputs
[
0
]
->
layout
(),
inputs
[
1
]
->
layout
());
using
TensorND
=
megdnn
::
TensorND
;
CompNode
cn
=
inputs
[
0
]
->
comp_node
();
SmallVector
<
TensorND
>
inp_tensornds
(
inputs
.
size
());
TensorLayoutArray
inp_shapes
(
inputs
.
size
()),
oup_shapes
(
output_descs
.
size
());
for
(
unsigned
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
inp_tensornds
[
i
]
=
inputs
[
i
]
->
dnn_tensor
();
inp_shapes
[
i
]
=
inputs
[
i
]
->
layout
();
}
oup_shapes
[
0
]
=
out_layout
;
auto
&&
dnn_opr
=
opr
::
intl
::
create_megdnn_opr
<
megdnn
::
Convolution3D
>
(
cn
);
dnn_opr
->
param
()
=
conv
.
param
();
// shape infer
size_t
sz
=
setup_algo
<
megdnn
::
Convolution3D
>
(
{
inp_shapes
[
0
],
inp_shapes
[
1
],
oup_shapes
[
0
]},
dnn_opr
.
get
(),
0
,
false
,
false
,
cn
,
conv
.
policy
(),
false
);
// alloc memory
DeviceTensorND
out
=
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
cn
,
out_layout
);
auto
wk
=
Blob
::
make
(
cn
,
sz
);
auto
ptr
=
wk
->
storage
().
get
();
megdnn
::
Workspace
dnn_wk
(
ptr
,
sz
);
// exeucte
dnn_opr
->
exec
(
inp_tensornds
[
0
],
inp_tensornds
[
1
],
out
.
as_megdnn
(),
dnn_wk
);
return
{
Tensor
::
make
(
out
)};
}
OP_TRAIT_REG
(
Convolution3D
,
Convolution3D
,
opr
::
Convolution3D
)
.
make_from_op_node
(
make_from_op_node
)
.
apply_on_var_node
(
apply_on_var_node
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible
)
.
apply_on_physical_tensor
(
apply_on_physical_tensor
)
.
fallback
();
}
// namespace convolution3d
}
// namespace
...
...
imperative/src/impl/ops/specializations.cpp
浏览文件 @
72a70dd6
...
...
@@ -223,30 +223,6 @@ OP_TRAIT_REG(AdaptivePooling, AdaptivePooling)
}
// namespace adaptive_pooling
}
// namespace
namespace
{
namespace
conv_bias
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
conv
=
static_cast
<
const
ConvBias
&>
(
def
);
cg
::
OperatorNodeConfig
config
{
conv
.
dtype
};
config
.
name
(
conv
.
make_name
());
if
(
inputs
.
size
()
==
2
)
{
return
opr
::
ConvBias
::
make
(
inputs
[
0
],
inputs
[
1
],
conv
.
param
(),
conv
.
policy
(),
config
);
}
else
if
(
inputs
.
size
()
==
3
)
{
return
opr
::
ConvBias
::
make
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
conv
.
param
(),
conv
.
policy
(),
config
);
}
else
if
(
inputs
.
size
()
==
4
)
{
return
opr
::
ConvBias
::
make
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
inputs
[
3
],
conv
.
param
(),
conv
.
policy
(),
config
);
}
mgb_assert
(
0
);
}
OP_TRAIT_REG
(
ConvBias
,
ConvBias
).
apply_on_var_node
(
apply_on_var_node
).
fallback
();
}
// namespace conv_bias
}
// namespace
namespace
{
namespace
batch_conv_bias
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录