提交 e28dc606 编写于 作者: M Megvii Engine Team

refactor(opr/dnn): support new MegDNN conv interface

GitOrigin-RevId: 924aa749bf9b6ffb8b5eeb8f2bdc64e541915ed3
上级 0b320568
......@@ -103,7 +103,9 @@ struct OprArityTrait;
#define cb_ref(x) (&(x))
#define cb_dnn(x) ((x).as_megdnn())
#define INST_ARITY(_Opr, _in, _out) \
#define WS_ARG_true ,nullptr
#define WS_ARG_false
#define INST_ARITY(_Opr, _in, _out, _has_preprocessed_filter) \
template <> \
struct OprArityTrait<_Opr> { \
static constexpr int arity_in = _in; \
......@@ -114,7 +116,8 @@ struct OprArityTrait;
_Opr* opr, typename _Opr::Algorithm* algo, \
const TensorLayoutArray& layouts) { \
opr->execution_policy() = {algo}; \
return opr->get_workspace_in_bytes(LAYOUTS(cb)); \
return opr->get_workspace_in_bytes( \
LAYOUTS(cb) WS_ARG_##_has_preprocessed_filter); \
} \
\
static std::vector<typename _Opr::Algorithm*> get_all_algorithms( \
......@@ -138,8 +141,7 @@ struct OprArityTrait;
#define TENSORS(cb) cb(inp_val[0]), cb(inp_val[1]), cb(out_val[0])
#define LAYOUTS(cb) cb(layouts[0]), cb(layouts[1]), cb(layouts[2])
#define INST_ARITY_2_1(Opr) INST_ARITY(Opr, 2, 1)
INST_ARITY_2_1(megdnn::Convolution);
#define INST_ARITY_2_1(Opr) INST_ARITY(Opr, 2, 1, false)
INST_ARITY_2_1(megdnn::ConvolutionBackwardData);
INST_ARITY_2_1(megdnn::ConvolutionBackwardFilter);
INST_ARITY_2_1(megdnn::Convolution3DForward);
......@@ -149,6 +151,9 @@ INST_ARITY_2_1(megdnn::LocalShareForward);
INST_ARITY_2_1(megdnn::LocalShareBackwardData);
INST_ARITY_2_1(megdnn::LocalShareBackwardFilter);
#undef TENSORS
#define TENSORS(cb) cb(inp_val[0]), cb(inp_val[1]), cb(out_val[0]), nullptr
INST_ARITY(megdnn::Convolution, 2, 1, true);
#undef TENSORS
#undef LAYOUTS
#undef INST_ARITY_2_1
......@@ -158,12 +163,16 @@ INST_ARITY_2_1(megdnn::LocalShareBackwardFilter);
#define LAYOUTS(cb) \
cb(layouts[0]), cb(layouts[1]), cb(layouts[2]), cb(layouts[3]), \
cb(layouts[4])
#define INST_ARITY_4_1(Opr) INST_ARITY(Opr, 4, 1)
INST_ARITY_4_1(megdnn::ConvBias);
#define INST_ARITY_4_1(Opr) INST_ARITY(Opr, 4, 1, false)
INST_ARITY_4_1(megdnn::DeformableConvForward);
INST_ARITY_4_1(megdnn::DeformableConvBackwardFilter);
INST_ARITY_4_1(megdnn::BatchConvBiasForward);
#undef TENSORS
#define TENSORS(cb) \
cb(inp_val[0]), cb(inp_val[1]), cb(inp_val[2]), cb(inp_val[3]), \
cb(out_val[0]), nullptr
INST_ARITY(megdnn::ConvBias, 4, 1, true);
#undef TENSORS
#undef LAYOUTS
#undef INST_ARITY_4_1
......@@ -174,7 +183,7 @@ INST_ARITY_4_1(megdnn::BatchConvBiasForward);
cb(layouts[3]), cb(layouts[4]), cb(layouts[5]), \
cb(layouts[6]), cb(layouts[7])
#define INST_ARITY_5_3(Opr) INST_ARITY(Opr, 5, 3)
#define INST_ARITY_5_3(Opr) INST_ARITY(Opr, 5, 3, false)
INST_ARITY_5_3(megdnn::DeformableConvBackwardData);
#undef TENSORS
#undef LAYOUTS
......@@ -183,6 +192,8 @@ INST_ARITY_5_3(megdnn::DeformableConvBackwardData);
#undef cb_ref
#undef cb_dnn
#undef INST_ARITY
#undef WS_ARG_true
#undef WS_ARG_false
// timeout delta to be added with fastest known algorithm for new algos
constexpr double TIMEOUT_TOLERANCE = 2;
......@@ -924,6 +935,41 @@ void ConvolutionForward::init_output_format() {
output(0)->format(input(0)->format());
}
void ConvolutionForward::scn_do_execute() {
megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(),
input(1)->dev_tensor().as_megdnn(),
output(0)->dev_tensor().as_megdnn(), nullptr,
intl::get_megdnn_workspace_from_var(output().back()));
}
void ConvolutionForward::add_input_layout_constraint() {
mixin::megdnn_utils::add_input_layout_constraint_contig(*this);
}
void ConvolutionForward::init_output_static_infer_desc() {
Super::set_nr_managed_outputs(this->output().size() - 1);
Super::init_output_static_infer_desc();
init_output_static_infer_desc_workspace(
intl::AutoAddWorkspaceNeedLimitGetter<
megdnn::ConvolutionForward>::val);
}
void ConvolutionForward::get_output_var_shape(
const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const {
TensorLayout input_layout{inp_shape[0], input(0)->dtype(),
input(0)->format()};
TensorLayout filter_layout{inp_shape[1], input(1)->dtype(),
input(1)->format()};
TensorLayout dst_layout{output(0)->dtype(), output(0)->format()};
megdnn_opr()->deduce_layout(input_layout, filter_layout, dst_layout);
out_shape[0] = dst_layout;
}
void ConvolutionForward::record_execute_deps(
cg::GraphExecutable::ExecDependencyArray& deps) {
record_megdnn_opr(deps);
}
/* ==================== ConvolutionBackwardData ==================== */
IMPL_CONV(ConvolutionBackwardData, "conv_bwd_data");
......@@ -1429,6 +1475,7 @@ void ConvBiasForward::scn_do_execute() {
mo->exec(inp[0]->dev_tensor().as_megdnn(),
inp[1]->dev_tensor().as_megdnn(), bias_tensor, z_tensor,
output(0)->dev_tensor().as_megdnn(),
nullptr,
intl::get_megdnn_workspace_from_var(output().back()));
} else if (inp.size() == 3) {
......@@ -1441,6 +1488,7 @@ void ConvBiasForward::scn_do_execute() {
inp[1]->dev_tensor().as_megdnn(),
inp[2]->dev_tensor().as_megdnn(), z_tensor,
output(0)->dev_tensor().as_megdnn(),
nullptr,
intl::get_megdnn_workspace_from_var(output().back()));
} else {
mgb_assert(inp.size() == 4);
......@@ -1449,6 +1497,7 @@ void ConvBiasForward::scn_do_execute() {
inp[2]->dev_tensor().as_megdnn(),
inp[3]->dev_tensor().as_megdnn(),
output(0)->dev_tensor().as_megdnn(),
nullptr,
intl::get_megdnn_workspace_from_var(output().back()));
}
}
......
......@@ -89,18 +89,26 @@ namespace intl {
cg::OutshapePureByInshapeOpr<>,
mixin::MegDNNOprHolderImpl<megdnn::BatchConvBiasForward>>;
using BatchConvBiasForwardBase = WorkspaceSizeInfer<BatchConvBiasBase>;
using ConvolutionForwardBase = WorkspaceSizeInfer<
typename MegDNNOprWrapperFwdBase<megdnn::ConvolutionForward>::Base>;
} // namespace intl
MGB_DEFINE_OPR_CLASS(ConvolutionForward,
intl::MegDNNOprWrapperFwd<megdnn::ConvolutionForward>,
public mixin::Convolution) // {
intl::ConvolutionForwardBase, public mixin::Convolution) // {
void init_profile_cache() override;
void init_output_dtype() override;
size_t get_workspace_size_bytes(
const TensorShapeArray &input_shapes,
const TensorShapeArray &output_shapes) const override final;
void init_output_format() override;
void scn_do_execute() override;
void add_input_layout_constraint() override;
void init_output_static_infer_desc() override;
void get_output_var_shape(const TensorShapeArray& inp_shape,
TensorShapeArray& out_shape) const override final;
void record_execute_deps(
cg::GraphExecutable::ExecDependencyArray& deps) override;
public:
ConvolutionForward(VarNode *src, VarNode *filter,
......
......@@ -532,11 +532,11 @@ TEST(TestOprDNN, DilatedConvolution) {
TensorLayout dest_layout;
opr->deduce_layout(inp[0]->layout(), inp[1]->layout(), dest_layout);
std::vector<dt_byte> workspace(opr->get_workspace_in_bytes(
inp[0]->layout(), inp[1]->layout(), dest_layout));
inp[0]->layout(), inp[1]->layout(), dest_layout, nullptr));
dest[0].dtype(dtype::Float32()).
comp_node(inp[0]->comp_node()).resize(dest_layout);
opr->exec(inp[0]->as_megdnn(), inp[1]->as_megdnn(),
dest[0].as_megdnn(), {workspace.data(), workspace.size()});
opr->exec(inp[0]->as_megdnn(), inp[1]->as_megdnn(), dest[0].as_megdnn(),
nullptr, {workspace.data(), workspace.size()});
};
Checker::RunOptions option;
option.numdiff_eps = 0.1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册