提交 777f3ea9 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

refactor(gopt): format code

GitOrigin-RevId: 9d5c87000fdfa291d91306365f7401c8af443dc1
上级 b44e0549
...@@ -10,23 +10,23 @@ ...@@ -10,23 +10,23 @@
* implied. * implied.
*/ */
#include "megbrain/gopt/inference.h"
#include "megbrain/gopt/gtrans.h"
#include "megbrain/gopt/basic_arith.h" #include "megbrain/gopt/basic_arith.h"
#include "megbrain/gopt/gtrans.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/graph/event.h" #include "megbrain/graph/event.h"
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/utils/shared_set.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/opr/basic_arith.h" #include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/blas.h" #include "megbrain/opr/blas.h"
#include "megbrain/opr/misc.h" #include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/utility.h" #include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/dnn/pooling.h" #include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/imgproc.h" #include "megbrain/opr/imgproc.h"
#include "megbrain/opr/misc.h"
#include "megbrain/opr/nn_int.h" #include "megbrain/opr/nn_int.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/utils/shared_set.h"
#include "megdnn/opr_param_defs.h" #include "megdnn/opr_param_defs.h"
#include "megdnn/tensor_format.h" #include "megdnn/tensor_format.h"
...@@ -66,71 +66,74 @@ using namespace gopt; ...@@ -66,71 +66,74 @@ using namespace gopt;
* oprs should not get involved in any actual computing. * oprs should not get involved in any actual computing.
*/ */
MGB_DEFINE_OPR_CLASS(TensorReformatPass::RelayoutPlaceholder, MGB_DEFINE_OPR_CLASS(TensorReformatPass::RelayoutPlaceholder,
cg::SingleCNOperatorNodeBase) // { cg::SingleCNOperatorNodeBase) // {
public: public:
//! relayout type of this opr //! relayout type of this opr
enum class LayoutType { enum class LayoutType {
NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout
NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout
NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout
CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout
NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout
NCHW_TO_NCHW4_IC_SMALL_CONV, ///< from nchw layout to nchw4 whose NCHW_TO_NCHW4_IC_SMALL_CONV, ///< from nchw layout to nchw4 whose
///< channel size less than 4 ///< channel size less than 4
NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout
NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout
NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout
WEIGHT_NCHW_TO_NCHW4_DENSE, //!< weight from nchw layout to nchw4 WEIGHT_NCHW_TO_NCHW4_DENSE, //!< weight from nchw layout to nchw4
//!< layout //!< layout
WEIGHT_NCHW_TO_NCHW4_GROUP, //!< group weight from nchw layout to WEIGHT_NCHW_TO_NCHW4_GROUP, //!< group weight from nchw layout to
//!< nchw4 layout //!< nchw4 layout
WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV, //!< weight from nchw layout WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV, //!< weight from nchw layout
//!< to nchw4 layout whose //!< to nchw4 layout whose
//! channel size less than 4 //! channel size less than 4
WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88 WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88
//!< layout //!< layout
WEIGHT_NCHW_TO_NCHW88_GROUP, //!< group weight from nchw layout to WEIGHT_NCHW_TO_NCHW88_GROUP, //!< group weight from nchw layout to
//!< nchw88 layout //!< nchw88 layout
WEIGHT_NCHW_TO_NCHW88_CHAN, //!< channel wise weight from nchw layout WEIGHT_NCHW_TO_NCHW88_CHAN, //!< channel wise weight from nchw layout
//!< to nchw88 layout //!< to nchw88 layout
//!< the weight layout of input is nchw output is nchw88, special for //!< the weight layout of input is nchw output is nchw88, special for
//!< shape weight in nchw like {64, 2, 3, 3} to {8, 3, 3, 2, 8} //!< shape weight in nchw like {64, 2, 3, 3} to {8, 3, 3, 2, 8}
WEIGHT_HYBIRD_NCHW_NCHW88, WEIGHT_HYBIRD_NCHW_NCHW88,
WEIGHT_NCHW_TO_NCHW44_DENSE, //!< weight from nchw layout to nchw44 WEIGHT_NCHW_TO_NCHW44_DENSE, //!< weight from nchw layout to nchw44
//!< layout //!< layout
WEIGHT_NCHW_TO_NCHW44_GROUP, //!< group weight from nchw layout to WEIGHT_NCHW_TO_NCHW44_GROUP, //!< group weight from nchw layout to
//!< nchw44 layout //!< nchw44 layout
WEIGHT_NCHW_TO_NCHW44_CHAN, //!< channel wise weight from nchw layout WEIGHT_NCHW_TO_NCHW44_CHAN, //!< channel wise weight from nchw layout
//!< to nchw44 layout //!< to nchw44 layout
//!< the weight layout of input is nchw output is nchw44, special for //!< the weight layout of input is nchw output is nchw44, special for
//!< shape weight in nchw like {64, 2, 3, 3} to {16, 3, 3, 2, 4} //!< shape weight in nchw like {64, 2, 3, 3} to {16, 3, 3, 2, 4}
WEIGHT_HYBIRD_NCHW_NCHW44, WEIGHT_HYBIRD_NCHW_NCHW44,
WEIGHT_NCHW_TO_NCHW44_DOT_DENSE, //!< weight from NCHW44 layout to WEIGHT_NCHW_TO_NCHW44_DOT_DENSE, //!< weight from NCHW44 layout to
//!< NCHW44_DOT layout dense //!< NCHW44_DOT layout dense
WEIGHT_NCHW_TO_NCHW44_DOT_GROUP, //!< weight from NCHW44 layout to WEIGHT_NCHW_TO_NCHW44_DOT_GROUP, //!< weight from NCHW44 layout to
//!< NCHW44_DOT layout group //!< NCHW44_DOT layout group
}; };
RelayoutPlaceholder(VarNode* src_var, LayoutType layout_type); RelayoutPlaceholder(VarNode* src_var, LayoutType layout_type);
/*! /*!
* \param src_var the input var * \param src_var the input var
* \param layout_type tensor layout transform type of this relayout * \param layout_type tensor layout transform type of this relayout
* placeholder as described in LayoutType * placeholder as described in LayoutType
*/ */
static SymbolVar make(VarNode* src_var, LayoutType layout_type); static SymbolVar make(VarNode* src_var, LayoutType layout_type);
LayoutType layout_type() const { return m_layout_type; } LayoutType layout_type() const {
return m_layout_type;
}
private: private:
void init_output_static_infer_desc() override; void init_output_static_infer_desc() override;
void scn_do_execute() override; void scn_do_execute() override;
void init_output_comp_node() override; void init_output_comp_node() override;
const LayoutType m_layout_type; const LayoutType m_layout_type;
}; }
;
MGB_DYN_TYPE_OBJ_FINAL_IMPL(TensorReformatPass::RelayoutPlaceholder); MGB_DYN_TYPE_OBJ_FINAL_IMPL(TensorReformatPass::RelayoutPlaceholder);
...@@ -211,7 +214,7 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { ...@@ -211,7 +214,7 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
dst[3] = inp_shape[3]; dst[3] = inp_shape[3];
dst[4] = 4; dst[4] = 4;
} else if (layout_type() == } else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW){ RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4); mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
dst.ndim = 4; dst.ndim = 4;
dst[0] = inp_shape[0]; dst[0] = inp_shape[0];
...@@ -249,7 +252,7 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { ...@@ -249,7 +252,7 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
dst[3] = inp_shape[3]; dst[3] = inp_shape[3];
dst[4] = inp_shape[4]; dst[4] = inp_shape[4];
dst[5] = 4; dst[5] = 4;
}else if (layout_type() == } else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW88) { RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW88) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 8 == 0); mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 8 == 0);
dst.ndim = 5; dst.ndim = 5;
...@@ -489,7 +492,7 @@ void TensorReformatPass::translate_pass(OptState& opt) const { ...@@ -489,7 +492,7 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); return opr::IndexAt::make(xshp, {{0, cv(idx)}});
}; };
auto tshp0 = opr::Concat::make( auto tshp0 = opr::Concat::make(
{sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0); {sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0);
auto y0 = opr::Reshape::make(x, tshp0); auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2}); auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2});
return y1.node(); return y1.node();
...@@ -1033,7 +1036,6 @@ EnableTensorCorePass::make_tensorcore_converter() { ...@@ -1033,7 +1036,6 @@ EnableTensorCorePass::make_tensorcore_converter() {
"can not be changed in this opt " "can not be changed in this opt "
"pass"); "pass");
return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
}; };
auto replace_warp_affine_opr = auto replace_warp_affine_opr =
[replace_inps_to_nchw4, replace_non_nchw4_opr]( [replace_inps_to_nchw4, replace_non_nchw4_opr](
...@@ -1247,7 +1249,8 @@ std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() { ...@@ -1247,7 +1249,8 @@ std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() {
} }
if (nr_shape_changed) { if (nr_shape_changed) {
auto inps = new_inp; auto inps = new_inp;
if (nr_shape_changed >= nr_inps / 2) { // CHWN4 > NCHW4 -> use CHWN4 if (nr_shape_changed >=
nr_inps / 2) { // CHWN4 > NCHW4 -> use CHWN4
for (size_t i = 0; i < nr_inps; ++i) { for (size_t i = 0; i < nr_inps; ++i) {
if (varshape_changed.count(new_inp[i]) == 0) { if (varshape_changed.count(new_inp[i]) == 0) {
auto symvar = RelayoutPlaceholder::make( auto symvar = RelayoutPlaceholder::make(
...@@ -1309,7 +1312,6 @@ std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() { ...@@ -1309,7 +1312,6 @@ std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() {
"can not be changed in this opt " "can not be changed in this opt "
"pass"); "pass");
return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
}; };
// capture by copy to avoid use after return // capture by copy to avoid use after return
auto replace_warp_affine_opr = auto replace_warp_affine_opr =
...@@ -1410,7 +1412,7 @@ VarNode* EnableNCHW4Pass::on_graph_endpoint_var(VarNode* new_var, ...@@ -1410,7 +1412,7 @@ VarNode* EnableNCHW4Pass::on_graph_endpoint_var(VarNode* new_var,
return new_var; return new_var;
} }
std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
MIDOUT_B("EnableNCHW4Pass::make") MIDOUT_B("EnableNCHW4Pass::make")
auto ret = std::make_unique<EnableNCHW4Pass>(); auto ret = std::make_unique<EnableNCHW4Pass>();
ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
...@@ -1469,14 +1471,12 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ ...@@ -1469,14 +1471,12 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
return serialization::copy_opr_shallow(*opr, new_inp, return serialization::copy_opr_shallow(*opr, new_inp,
opr->config()); opr->config());
} }
auto conv_mode = auto conv_mode = trans_nchw4(conv_opr.param().sparse, new_inp[1]);
trans_nchw4(conv_opr.param().sparse, new_inp[1]);
VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1]; VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1];
// src: NCHW --> NCWH4 // src: NCHW --> NCWH4
if (new_inp[0]->shape().ndim != 5) { if (new_inp[0]->shape().ndim != 5) {
mgb_assert(new_inp[0]->shape().ndim == 4); mgb_assert(new_inp[0]->shape().ndim == 4);
auto new_src = auto new_src = RelayoutPlaceholder::make(new_inp[0], conv_mode.src);
RelayoutPlaceholder::make(new_inp[0], conv_mode.src);
conv_src = new_src.node(); conv_src = new_src.node();
} }
// weight: NCHW --> NCHW4 // weight: NCHW --> NCHW4
...@@ -1488,21 +1488,21 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ ...@@ -1488,21 +1488,21 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
new_param.format = conv_format; new_param.format = conv_format;
// dst // dst
auto new_conv_opr = opr::Convolution::make( auto new_conv_opr = opr::Convolution::make(
conv_src, conv_filter, new_param, conv_src, conv_filter, new_param, conv_opr.execution_policy(),
conv_opr.execution_policy(), conv_opr.config()); conv_opr.config());
OperatorNodeBase* new_opr = new_conv_opr.node()->owner_opr(); OperatorNodeBase* new_opr = new_conv_opr.node()->owner_opr();
mgb_assert(new_conv_opr.shape().ndim == 5, mgb_assert(new_conv_opr.shape().ndim == 5,
"The conv dst dim is not trans to nchw4"); "The conv dst dim is not trans to nchw4");
return new_opr; return new_opr;
}; };
auto replace_batch_conv_bias_opr = [batch_conv_bias_format, auto replace_batch_conv_bias_opr = [batch_conv_bias_format,
src_to_nchw4_mode]( src_to_nchw4_mode](
OperatorNodeBase* opr, OperatorNodeBase* opr,
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
auto& batch_conv_bias_opr = auto& batch_conv_bias_opr =
opr->cast_final_safe<opr::BatchConvBiasForward>(); opr->cast_final_safe<opr::BatchConvBiasForward>();
if (batch_conv_bias_opr.param().format != if (batch_conv_bias_opr.param().format !=
megdnn::param::BatchConvBias::Format::NCHW) { megdnn::param::BatchConvBias::Format::NCHW) {
return serialization::copy_opr_shallow(*opr, new_inp, return serialization::copy_opr_shallow(*opr, new_inp,
...@@ -1515,10 +1515,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ ...@@ -1515,10 +1515,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
// what should be converted: src, weight // what should be converted: src, weight
VarNode *src = new_inp[0], *filter = new_inp[1]; VarNode *src = new_inp[0], *filter = new_inp[1];
// src: NCHW --> NCHW4 // src: NCHW --> NCHW4
if (new_inp[0]->shape().ndim !=5) { if (new_inp[0]->shape().ndim != 5) {
mgb_assert(new_inp[0]->shape().ndim == 4); mgb_assert(new_inp[0]->shape().ndim == 4);
auto new_src = RelayoutPlaceholder::make(new_inp[0], auto new_src =
src_to_nchw4_mode); RelayoutPlaceholder::make(new_inp[0], src_to_nchw4_mode);
src = new_src.node(); src = new_src.node();
} }
// weight: BNCHW --> BNCHW4 // weight: BNCHW --> BNCHW4
...@@ -1531,9 +1531,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ ...@@ -1531,9 +1531,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
auto new_param = batch_conv_bias_opr.param(); auto new_param = batch_conv_bias_opr.param();
new_param.format = batch_conv_bias_format; new_param.format = batch_conv_bias_format;
if (new_inp.size() == 2) { if (new_inp.size() == 2) {
auto dst = opr::BatchConvBias::make(src, filter, new_param, auto dst = opr::BatchConvBias::make(
batch_conv_bias_opr.execution_policy(), src, filter, new_param,
batch_conv_bias_opr.config()); batch_conv_bias_opr.execution_policy(),
batch_conv_bias_opr.config());
OperatorNodeBase* new_opr = dst.node()->owner_opr(); OperatorNodeBase* new_opr = dst.node()->owner_opr();
mgb_assert(dst.shape().ndim == 5, mgb_assert(dst.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4"); "The conv_bias dst dim is not trans to nchw4");
...@@ -1542,14 +1543,15 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ ...@@ -1542,14 +1543,15 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
// bias: NCHW --> NCHW4 // bias: NCHW --> NCHW4
VarNode* bias = new_inp[2]; VarNode* bias = new_inp[2];
if (new_inp[2]->shape().ndim == 4) { if (new_inp[2]->shape().ndim == 4) {
auto new_bias = RelayoutPlaceholder::make(new_inp[2], auto new_bias =
src_to_nchw4_mode); RelayoutPlaceholder::make(new_inp[2], src_to_nchw4_mode);
bias = new_bias.node(); bias = new_bias.node();
} }
if (new_inp.size() == 3) { if (new_inp.size() == 3) {
auto dst = opr::BatchConvBias::make(src, filter, bias, new_param, auto dst = opr::BatchConvBias::make(
batch_conv_bias_opr.execution_policy(), src, filter, bias, new_param,
batch_conv_bias_opr.config()); batch_conv_bias_opr.execution_policy(),
batch_conv_bias_opr.config());
OperatorNodeBase* new_opr = dst.node()->owner_opr(); OperatorNodeBase* new_opr = dst.node()->owner_opr();
mgb_assert(dst.shape().ndim == 5, mgb_assert(dst.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4"); "The conv_bias dst dim is not trans to nchw4");
...@@ -1558,13 +1560,14 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ ...@@ -1558,13 +1560,14 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
// z_inp: NCHW --> NCHW4 // z_inp: NCHW --> NCHW4
VarNode* z_inp = new_inp[3]; VarNode* z_inp = new_inp[3];
if (new_inp[3]->shape().ndim == 4) { if (new_inp[3]->shape().ndim == 4) {
auto new_z = RelayoutPlaceholder::make(new_inp[3], auto new_z =
src_to_nchw4_mode); RelayoutPlaceholder::make(new_inp[3], src_to_nchw4_mode);
z_inp = new_z.node(); z_inp = new_z.node();
} }
auto dst = opr::BatchConvBias::make(src, filter, bias, z_inp, auto dst =
new_param,batch_conv_bias_opr.execution_policy(), opr::BatchConvBias::make(src, filter, bias, z_inp, new_param,
batch_conv_bias_opr.config()); batch_conv_bias_opr.execution_policy(),
batch_conv_bias_opr.config());
OperatorNodeBase* new_opr = dst.node()->owner_opr(); OperatorNodeBase* new_opr = dst.node()->owner_opr();
mgb_assert(dst.shape().ndim == 5, mgb_assert(dst.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4"); "The conv_bias dst dim is not trans to nchw4");
...@@ -1584,13 +1587,11 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ ...@@ -1584,13 +1587,11 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
// what should be converted: src, weight // what should be converted: src, weight
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1]; VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1];
auto conv_mode = auto conv_mode = trans_nchw4(conv_bias_opr.param().sparse, new_inp[1]);
trans_nchw4(conv_bias_opr.param().sparse, new_inp[1]);
// src: NCHW --> NCHW4 // src: NCHW --> NCHW4
if (new_inp[0]->shape().ndim != 5) { if (new_inp[0]->shape().ndim != 5) {
mgb_assert(new_inp[0]->shape().ndim == 4); mgb_assert(new_inp[0]->shape().ndim == 4);
auto new_src = auto new_src = RelayoutPlaceholder::make(new_inp[0], conv_mode.src);
RelayoutPlaceholder::make(new_inp[0], conv_mode.src);
conv_bias_src = new_src.node(); conv_bias_src = new_src.node();
} }
// weight: NCHW --> NCHW4 or GNCHW --> GNCHW4 // weight: NCHW --> NCHW4 or GNCHW --> GNCHW4
...@@ -1602,8 +1603,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ ...@@ -1602,8 +1603,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
new_param.format = conv_bias_format; new_param.format = conv_bias_format;
if (new_inp.size() == 2) { if (new_inp.size() == 2) {
auto new_conv_bias_opr = opr::ConvBias::make( auto new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, new_param, conv_bias_src, conv_bias_filter, new_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config()); conv_bias_opr.execution_policy(), conv_bias_opr.config());
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
mgb_assert(new_conv_bias_opr.shape().ndim == 5, mgb_assert(new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4"); "The conv_bias dst dim is not trans to nchw4");
...@@ -1618,8 +1619,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ ...@@ -1618,8 +1619,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
} }
if (new_inp.size() == 3) { if (new_inp.size() == 3) {
auto new_conv_bias_opr = opr::ConvBias::make( auto new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, conv_bias_src, conv_bias_filter, conv_bias_bias, new_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config()); conv_bias_opr.execution_policy(), conv_bias_opr.config());
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
mgb_assert(new_conv_bias_opr.shape().ndim == 5, mgb_assert(new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4"); "The conv_bias dst dim is not trans to nchw4");
...@@ -1632,9 +1633,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ ...@@ -1632,9 +1633,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
RelayoutPlaceholder::make(new_inp[3], src_to_nchw4_mode); RelayoutPlaceholder::make(new_inp[3], src_to_nchw4_mode);
z_inp = new_z.node(); z_inp = new_z.node();
} }
auto new_conv_bias_opr = opr::ConvBias::make(conv_bias_src, auto new_conv_bias_opr = opr::ConvBias::make(
conv_bias_filter, conv_bias_bias, z_inp, new_param, conv_bias_src, conv_bias_filter, conv_bias_bias, z_inp,
conv_bias_opr.execution_policy(), conv_bias_opr.config()); new_param, conv_bias_opr.execution_policy(),
conv_bias_opr.config());
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
mgb_assert(new_conv_bias_opr.shape().ndim == 5, mgb_assert(new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4"); "The conv_bias dst dim is not trans to nchw4");
...@@ -1654,8 +1656,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ ...@@ -1654,8 +1656,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
auto temp_inp = new_inp; auto temp_inp = new_inp;
for (size_t i = 0; i < opr->input().size(); i++) { for (size_t i = 0; i < opr->input().size(); i++) {
if (new_inp[i]->shape().ndim == 4) { if (new_inp[i]->shape().ndim == 4) {
auto new_var = RelayoutPlaceholder::make( auto new_var = RelayoutPlaceholder::make(new_inp[i],
new_inp[i], src_to_nchw4_mode); src_to_nchw4_mode);
temp_inp[i] = new_var.node(); temp_inp[i] = new_var.node();
} else { } else {
mgb_assert((new_inp[i]->shape().ndim == 5) || mgb_assert((new_inp[i]->shape().ndim == 5) ||
...@@ -1670,7 +1672,7 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ ...@@ -1670,7 +1672,7 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
} }
}; };
auto relayout_inp_to_nchw = [=](OperatorNodeBase* opr, auto relayout_inp_to_nchw = [=](OperatorNodeBase* opr,
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
VarNodeArray temp_inp = new_inp; VarNodeArray temp_inp = new_inp;
for (size_t i = 0; i < opr->input().size(); i++) { for (size_t i = 0; i < opr->input().size(); i++) {
...@@ -1697,8 +1699,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ ...@@ -1697,8 +1699,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8); mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8);
auto new_param = pooling.param(); auto new_param = pooling.param();
new_param.format = Format::NCHW4; new_param.format = Format::NCHW4;
auto new_pooling = auto new_pooling = opr::PoolingForward::make(new_inp[0], new_param,
opr::PoolingForward::make(new_inp[0], new_param, opr->config()); opr->config());
mgb_assert(new_pooling.shape().ndim == 5, mgb_assert(new_pooling.shape().ndim == 5,
"out var of Pooling opr after transform must be 5 (got: " "out var of Pooling opr after transform must be 5 (got: "
"%zu).", "%zu).",
...@@ -1767,8 +1769,7 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ ...@@ -1767,8 +1769,7 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
//! supportted nchw4 //! supportted nchw4
replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr;
replace_func[opr::BatchConvBias::typeinfo()] = replace_func[opr::BatchConvBias::typeinfo()] = replace_batch_conv_bias_opr;
replace_batch_conv_bias_opr;
replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr;
replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr; replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr;
replace_func[opr::WarpPerspectiveForward::typeinfo()] = replace_func[opr::WarpPerspectiveForward::typeinfo()] =
...@@ -1811,7 +1812,7 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var, ...@@ -1811,7 +1812,7 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var,
return new_var; return new_var;
} }
void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
using RelayoutMode = RelayoutPlaceholder::LayoutType; using RelayoutMode = RelayoutPlaceholder::LayoutType;
using TestFilterResult = std::pair<TransType, RelayoutMode>; using TestFilterResult = std::pair<TransType, RelayoutMode>;
RelayoutMode weight_to_nchwxx_mode_dense = RelayoutMode weight_to_nchwxx_mode_dense =
...@@ -2045,7 +2046,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){ ...@@ -2045,7 +2046,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){
}; };
auto replace_pooling_opr = [=](OperatorNodeBase* opr, auto replace_pooling_opr = [=](OperatorNodeBase* opr,
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
auto& pooling_opr = opr->cast_final_safe<opr::PoolingForward>(); auto& pooling_opr = opr->cast_final_safe<opr::PoolingForward>();
mgb_assert(pooling_opr.param().format == mgb_assert(pooling_opr.param().format ==
...@@ -2115,7 +2116,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){ ...@@ -2115,7 +2116,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){
}; };
auto relayout_inp_to_nchw = [=](OperatorNodeBase* opr, auto relayout_inp_to_nchw = [=](OperatorNodeBase* opr,
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
VarNodeArray temp_inp = new_inp; VarNodeArray temp_inp = new_inp;
for (size_t i = 0; i < opr->input().size(); i++) { for (size_t i = 0; i < opr->input().size(); i++) {
...@@ -2205,7 +2206,8 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2205,7 +2206,8 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
size_t OC = filter->shape()[0]; size_t OC = filter->shape()[0];
if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) { if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) {
ret.trans_type = TransType::TRANS_PURE_NCHWXX; ret.trans_type = TransType::TRANS_PURE_NCHWXX;
ret.relayout_mod = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE; ret.relayout_mod =
RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE;
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT;
} else if (IC < pack_c_size && OC % pack_c_size == 0) { } else if (IC < pack_c_size && OC % pack_c_size == 0) {
ret.trans_type = TransType::TRANS_HYBIRD_NCHWXX; ret.trans_type = TransType::TRANS_HYBIRD_NCHWXX;
...@@ -2223,7 +2225,8 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2223,7 +2225,8 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44; ret.conv_format = megdnn::param::ConvBias::Format::NCHW44;
} else if ((icpg % pack_c_size == 0) && (ocpg % pack_c_size == 0)) { } else if ((icpg % pack_c_size == 0) && (ocpg % pack_c_size == 0)) {
ret.trans_type = TransType::TRANS_PURE_NCHWXX; ret.trans_type = TransType::TRANS_PURE_NCHWXX;
ret.relayout_mod = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP; ret.relayout_mod =
RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP;
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT;
} }
} }
...@@ -2538,7 +2541,6 @@ public: ...@@ -2538,7 +2541,6 @@ public:
param.mode = megdnn::param::RelayoutFormat::Mode::NCHW4_CHWN4; param.mode = megdnn::param::RelayoutFormat::Mode::NCHW4_CHWN4;
auto reformat = opr::RelayoutFormat::make(inp, param); auto reformat = opr::RelayoutFormat::make(inp, param);
return reformat.node(); return reformat.node();
}; };
m_reformat[std::make_pair(TensorFormat::CHWN4, TensorFormat::NCHW4)] = m_reformat[std::make_pair(TensorFormat::CHWN4, TensorFormat::NCHW4)] =
...@@ -2563,7 +2565,6 @@ public: ...@@ -2563,7 +2565,6 @@ public:
auto y0 = opr::Reshape::make(x, tshp); auto y0 = opr::Reshape::make(x, tshp);
auto y1 = opr::Dimshuffle::make(y0, {1, 3, 4, 0, 2}); auto y1 = opr::Dimshuffle::make(y0, {1, 3, 4, 0, 2});
return y1.node(); return y1.node();
}; };
m_reformat[std::make_pair(TensorFormat::CHWN4, TensorFormat::NCHW)] = m_reformat[std::make_pair(TensorFormat::CHWN4, TensorFormat::NCHW)] =
...@@ -2591,24 +2592,29 @@ public: ...@@ -2591,24 +2592,29 @@ public:
* \brief abstract operator representation of shuffle operation * \brief abstract operator representation of shuffle operation
*/ */
MGB_DEFINE_OPR_CLASS(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr, MGB_DEFINE_OPR_CLASS(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr,
cg::SingleCNOperatorNodeBase) // { cg::SingleCNOperatorNodeBase) // {
public: public:
AbstractShuffleOpr(VarNode* inpvar, TensorFormat inp_format, AbstractShuffleOpr(VarNode* inpvar, TensorFormat inp_format,
TensorFormat out_format); TensorFormat out_format);
static SymbolVar make(VarNode* inpvar, TensorFormat inp_format, static SymbolVar make(VarNode* inpvar, TensorFormat inp_format,
TensorFormat out_format); TensorFormat out_format);
TensorFormat inp_format() const { return m_inp_format; } TensorFormat inp_format() const {
return m_inp_format;
}
TensorFormat out_format() const { return m_out_format; } TensorFormat out_format() const {
return m_out_format;
}
private: private:
void init_output_static_infer_desc() override; void init_output_static_infer_desc() override;
void scn_do_execute() override; void scn_do_execute() override;
const TensorFormat m_inp_format; const TensorFormat m_inp_format;
const TensorFormat m_out_format; const TensorFormat m_out_format;
}; }
;
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr); MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr);
...@@ -2914,7 +2920,8 @@ void ShuffleShuffleRemovePass::Impl::do_replace() { ...@@ -2914,7 +2920,8 @@ void ShuffleShuffleRemovePass::Impl::do_replace() {
bool force_folding_typecvt = false; bool force_folding_typecvt = false;
bool first_shuffle = false; bool first_shuffle = false;
// initialize inp_format and out_format // initialize inp_format and out_format
TensorFormat out_format = TensorFormat::NCHW, inp_format = out_format; TensorFormat out_format = TensorFormat::NCHW,
inp_format = out_format;
megdnn::DType inp_dtype = cur->input(0)->dtype(), megdnn::DType inp_dtype = cur->input(0)->dtype(),
out_dtype = cur->output(0)->dtype(); out_dtype = cur->output(0)->dtype();
SmallVector<megdnn::DType> out_dtype_vec; SmallVector<megdnn::DType> out_dtype_vec;
......
...@@ -6,32 +6,31 @@ ...@@ -6,32 +6,31 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "megbrain/opr/dnn/local.h" #include "megbrain/opr/dnn/local.h"
#include "megbrain/test/helper.h" #include "megbrain/test/helper.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/gopt/basic_arith.h" #include "megbrain/gopt/basic_arith.h"
#include "megbrain/gopt/gtrans.h" #include "megbrain/gopt/gtrans.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/basic_arith_wrapper.h" #include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/blas.h"
#include "megbrain/opr/dnn/batch_norm.h" #include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/convolution.h" #include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/utility.h" #include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/imgproc.h" #include "megbrain/opr/imgproc.h"
#include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/io.h"
#include "megbrain/opr/nn_int.h" #include "megbrain/opr/nn_int.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/tensor_gen.h" #include "megbrain/opr/tensor_gen.h"
#include "megbrain/opr/blas.h" #include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#include "megbrain/comp_node_env.h"
#include "./helper.h" #include "./helper.h"
#include "megbrain/comp_node_env.h"
#include "megdnn/tensor_format.h" #include "megdnn/tensor_format.h"
...@@ -121,20 +120,17 @@ TEST(TestGoptInference, ParamFuseConstEndPoint) { ...@@ -121,20 +120,17 @@ TEST(TestGoptInference, ParamFuseConstEndPoint) {
graph->options().graph_opt_level = 0; graph->options().graph_opt_level = 0;
auto x = opr::SharedDeviceTensor::make(*graph, *host_x), auto x = opr::SharedDeviceTensor::make(*graph, *host_x),
y = opr::SharedDeviceTensor::make(*graph, *host_y), y = opr::SharedDeviceTensor::make(*graph, *host_y),
p = opr::Host2DeviceCopy::make(*graph, host_p), p = opr::Host2DeviceCopy::make(*graph, host_p), q = p + x, a = y + 3,
q = p + x, z0 = a + q, z1 = a + 4;
a = y + 3,
z0 = a + q,
z1 = a + 4;
HostTensorND host_z0, host_z1; HostTensorND host_z0, host_z1;
SymbolVar z0_1, z1_1; SymbolVar z0_1, z1_1;
unpack_vector( unpack_vector(gopt::GraphOptimizer{}
gopt::GraphOptimizer{}. .add_pass<gopt::ParamFusePass>()
add_pass<gopt::ParamFusePass>(). .apply({{z1, z0}})
apply({{z1, z0}}).endpoint_vars(), .endpoint_vars(),
z1_1, z0_1); z1_1, z0_1);
auto func = graph->compile({make_callback_copy(z0_1, host_z0), auto func = graph->compile({make_callback_copy(z0_1, host_z0),
make_callback_copy(z1_1, host_z1)}); make_callback_copy(z1_1, host_z1)});
...@@ -143,7 +139,10 @@ TEST(TestGoptInference, ParamFuseConstEndPoint) { ...@@ -143,7 +139,10 @@ TEST(TestGoptInference, ParamFuseConstEndPoint) {
func->execute(); func->execute();
int nr_opr = 0; int nr_opr = 0;
func->iter_opr_seq([&](cg::OperatorNodeBase*) {++ nr_opr; return true; }); func->iter_opr_seq([&](cg::OperatorNodeBase*) {
++nr_opr;
return true;
});
ASSERT_EQ(8, nr_opr); ASSERT_EQ(8, nr_opr);
auto px = host_x->ptr<float>(), pz0 = host_z0.ptr<float>(); auto px = host_x->ptr<float>(), pz0 = host_z0.ptr<float>();
...@@ -151,13 +150,12 @@ TEST(TestGoptInference, ParamFuseConstEndPoint) { ...@@ -151,13 +150,12 @@ TEST(TestGoptInference, ParamFuseConstEndPoint) {
auto yv = host_y->ptr<float>()[0], pv = host_p->ptr<float>()[0], auto yv = host_y->ptr<float>()[0], pv = host_p->ptr<float>()[0],
pz1 = host_z1.ptr<float>()[0]; pz1 = host_z1.ptr<float>()[0];
for (size_t i = 0; i < SIZE; ++ i) { for (size_t i = 0; i < SIZE; ++i) {
MGB_ASSERT_FLOAT_EQ(px[i] + yv + 3 + pv, pz0[i]); MGB_ASSERT_FLOAT_EQ(px[i] + yv + 3 + pv, pz0[i]);
} }
MGB_ASSERT_FLOAT_EQ(yv + 7, pz1); MGB_ASSERT_FLOAT_EQ(yv + 7, pz1);
} }
TEST(TestGoptInference, ParamFuse) { TEST(TestGoptInference, ParamFuse) {
constexpr size_t SIZE = 23; constexpr size_t SIZE = 23;
HostTensorGenerator<> gen; HostTensorGenerator<> gen;
...@@ -168,35 +166,37 @@ TEST(TestGoptInference, ParamFuse) { ...@@ -168,35 +166,37 @@ TEST(TestGoptInference, ParamFuse) {
auto x = opr::SharedDeviceTensor::make(*graph, *host_x), auto x = opr::SharedDeviceTensor::make(*graph, *host_x),
y = opr::SharedDeviceTensor::make(*graph, *host_y), y = opr::SharedDeviceTensor::make(*graph, *host_y),
p = opr::Host2DeviceCopy::make(*graph, host_p), p = opr::Host2DeviceCopy::make(*graph, host_p),
z = x + y, // endpoint z = x + y, // endpoint
q = x * y + p; // middle point q = x * y + p; // middle point
SymbolVar z1, q1; SymbolVar z1, q1;
unpack_vector( unpack_vector(gopt::GraphOptimizer{}
gopt::GraphOptimizer{}. .add_pass<gopt::ParamFusePass>()
add_pass<gopt::ParamFusePass>(). .apply({{z, q}})
apply({{z, q}}).endpoint_vars(), .endpoint_vars(),
z1, q1); z1, q1);
ASSERT_TRUE(z1.node()->owner_opr()->same_type<opr::SharedDeviceTensor>()); ASSERT_TRUE(z1.node()->owner_opr()->same_type<opr::SharedDeviceTensor>());
ASSERT_NE(q1.node()->owner_opr(), q.node()->owner_opr()); ASSERT_NE(q1.node()->owner_opr(), q.node()->owner_opr());
ASSERT_EQ(q1.node()->owner_opr()->dyn_typeinfo(), ASSERT_EQ(q1.node()->owner_opr()->dyn_typeinfo(),
q.node()->owner_opr()->dyn_typeinfo()); q.node()->owner_opr()->dyn_typeinfo());
HostTensorND host_z, host_q; HostTensorND host_z, host_q;
auto func = graph->compile( auto func = graph->compile(
{make_callback_copy(z1, host_z), {make_callback_copy(z1, host_z), make_callback_copy(q1, host_q)});
make_callback_copy(q1, host_q)});
func->execute(); func->execute();
int nr_opr = 0; int nr_opr = 0;
func->iter_opr_seq([&](cg::OperatorNodeBase*) {++ nr_opr; return true; }); func->iter_opr_seq([&](cg::OperatorNodeBase*) {
++nr_opr;
return true;
});
ASSERT_EQ(6, nr_opr); ASSERT_EQ(6, nr_opr);
auto px = host_x->ptr<float>(), pz = host_z.ptr<float>(), auto px = host_x->ptr<float>(), pz = host_z.ptr<float>(),
pq = host_q.ptr<float>(); pq = host_q.ptr<float>();
auto yv = host_y->ptr<float>()[0], pv = host_p->ptr<float>()[0]; auto yv = host_y->ptr<float>()[0], pv = host_p->ptr<float>()[0];
for (size_t i = 0; i < SIZE; ++ i) { for (size_t i = 0; i < SIZE; ++i) {
MGB_ASSERT_FLOAT_EQ(px[i] + yv, pz[i]); MGB_ASSERT_FLOAT_EQ(px[i] + yv, pz[i]);
MGB_ASSERT_FLOAT_EQ(px[i] * yv + pv, pq[i]); MGB_ASSERT_FLOAT_EQ(px[i] * yv + pv, pq[i]);
} }
...@@ -212,8 +212,8 @@ TEST(TestGoptInference, ParamFuseMultiDeviceTensorHolder) { ...@@ -212,8 +212,8 @@ TEST(TestGoptInference, ParamFuseMultiDeviceTensorHolder) {
auto x = opr::SharedDeviceTensor::make(*graph, *host_x), auto x = opr::SharedDeviceTensor::make(*graph, *host_x),
y = opr::SharedDeviceTensor::make(*graph, *host_y), y = opr::SharedDeviceTensor::make(*graph, *host_y),
p = opr::Host2DeviceCopy::make(*graph, host_p), p = opr::Host2DeviceCopy::make(*graph, host_p),
z = x + y, // endpoint z = x + y, //! endpoint
q = x * y + p; // middle point q = x * y + p; //! middle point
SymbolVar z1, q1; SymbolVar z1, q1;
unpack_vector(gopt::GraphOptimizer{} unpack_vector(gopt::GraphOptimizer{}
...@@ -223,34 +223,38 @@ TEST(TestGoptInference, ParamFuseMultiDeviceTensorHolder) { ...@@ -223,34 +223,38 @@ TEST(TestGoptInference, ParamFuseMultiDeviceTensorHolder) {
z1); z1);
ASSERT_TRUE(z1.node() ASSERT_TRUE(z1.node()
->owner_opr()->input(0)->owner_opr() ->owner_opr()
->input(0)
->owner_opr()
->same_type<opr::MultipleDeviceTensorHolder>()); ->same_type<opr::MultipleDeviceTensorHolder>());
unpack_vector( unpack_vector(gopt::GraphOptimizer{}
gopt::GraphOptimizer{}. .add_pass<gopt::ParamMergePass>()
add_pass<gopt::ParamMergePass>(). .add_pass<gopt::ParamFusePass>()
add_pass<gopt::ParamFusePass>(). .apply({{z, q}})
apply({{z, q}}).endpoint_vars(), .endpoint_vars(),
z1, q1); z1, q1);
ASSERT_TRUE(z1.node()->owner_opr()->same_type<opr::SharedDeviceTensor>()); ASSERT_TRUE(z1.node()->owner_opr()->same_type<opr::SharedDeviceTensor>());
ASSERT_NE(q1.node()->owner_opr(), q.node()->owner_opr()); ASSERT_NE(q1.node()->owner_opr(), q.node()->owner_opr());
ASSERT_EQ(q1.node()->owner_opr()->dyn_typeinfo(), ASSERT_EQ(q1.node()->owner_opr()->dyn_typeinfo(),
q.node()->owner_opr()->dyn_typeinfo()); q.node()->owner_opr()->dyn_typeinfo());
HostTensorND host_z, host_q; HostTensorND host_z, host_q;
auto func = graph->compile( auto func = graph->compile(
{make_callback_copy(z1, host_z), {make_callback_copy(z1, host_z), make_callback_copy(q1, host_q)});
make_callback_copy(q1, host_q)});
func->execute(); func->execute();
int nr_opr = 0; int nr_opr = 0;
func->iter_opr_seq([&](cg::OperatorNodeBase*op) {++ nr_opr; return true; }); func->iter_opr_seq([&](cg::OperatorNodeBase* op) {
++nr_opr;
return true;
});
ASSERT_EQ(6, nr_opr); ASSERT_EQ(6, nr_opr);
auto px = host_x->ptr<float>(), pz = host_z.ptr<float>(), auto px = host_x->ptr<float>(), pz = host_z.ptr<float>(),
pq = host_q.ptr<float>(); pq = host_q.ptr<float>();
auto yv = host_y->ptr<float>()[0], pv = host_p->ptr<float>()[0]; auto yv = host_y->ptr<float>()[0], pv = host_p->ptr<float>()[0];
for (size_t i = 0; i < SIZE; ++ i) { for (size_t i = 0; i < SIZE; ++i) {
MGB_ASSERT_FLOAT_EQ(px[i] + yv, pz[i]); MGB_ASSERT_FLOAT_EQ(px[i] + yv, pz[i]);
MGB_ASSERT_FLOAT_EQ(px[i] * yv + pv, pq[i]); MGB_ASSERT_FLOAT_EQ(px[i] * yv + pv, pq[i]);
} }
...@@ -262,33 +266,42 @@ TEST(TestGoptInference, ParamFuseMultiRead) { ...@@ -262,33 +266,42 @@ TEST(TestGoptInference, ParamFuseMultiRead) {
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0; graph->options().graph_opt_level = 0;
auto mkvar = [&](const char *name, const TensorShape &shp) { auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name);
}; };
auto mkcvar = [&](const char *name, const TensorShape &shp) { auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp)).rename(name); return opr::SharedDeviceTensor::make(*graph, *gen(shp)).rename(name);
}; };
auto x = mkvar("x", {23}), auto x = mkvar("x", {23}), p0 = mkcvar("p0", {1}), p1 = mkcvar("p1", {1}),
p0 = mkcvar("p0", {1}),
p1 = mkcvar("p1", {1}),
z0 = x * (p0 + p1) + x / (p0 + p1); z0 = x * (p0 + p1) + x / (p0 + p1);
SymbolVar z1; SymbolVar z1;
unpack_vector( unpack_vector(gopt::GraphOptimizer{}
gopt::GraphOptimizer{}. .add_pass<gopt::ParamFusePass>()
add_pass<gopt::ParamFusePass>(). .apply({{z0}})
apply({{z0}}).endpoint_vars(), .endpoint_vars(),
z1); z1);
ASSERT_NE(z0.node(), z1.node()); ASSERT_NE(z0.node(), z1.node());
ASSERT_TRUE(z1.node()->owner_opr()->input(0)->owner_opr() ASSERT_TRUE(z1.node()
->input(1)->owner_opr()->same_type<opr::SharedDeviceTensor>()); ->owner_opr()
ASSERT_TRUE(z1.node()->owner_opr()->input(1)->owner_opr() ->input(0)
->input(1)->owner_opr()->same_type<opr::SharedDeviceTensor>()); ->owner_opr()
->input(1)
->owner_opr()
->same_type<opr::SharedDeviceTensor>());
ASSERT_TRUE(z1.node()
->owner_opr()
->input(1)
->owner_opr()
->input(1)
->owner_opr()
->same_type<opr::SharedDeviceTensor>());
HostTensorND host_z0, host_z1; HostTensorND host_z0, host_z1;
graph->compile({make_callback_copy(z0, host_z0), graph->compile({make_callback_copy(z0, host_z0),
make_callback_copy(z1, host_z1)})->execute(); make_callback_copy(z1, host_z1)})
->execute();
MGB_ASSERT_TENSOR_EQ(host_z0, host_z1); MGB_ASSERT_TENSOR_EQ(host_z0, host_z1);
} }
...@@ -297,10 +310,10 @@ TEST(TestGoptInference, ParamFuseStaticInfer) { ...@@ -297,10 +310,10 @@ TEST(TestGoptInference, ParamFuseStaticInfer) {
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
auto mkvar = [&](const char *name, const TensorShape &shp) { auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name);
}; };
auto mkcvar = [&](const char *name, const TensorShape &shp) { auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp)).rename(name); return opr::SharedDeviceTensor::make(*graph, *gen(shp)).rename(name);
}; };
...@@ -308,11 +321,11 @@ TEST(TestGoptInference, ParamFuseStaticInfer) { ...@@ -308,11 +321,11 @@ TEST(TestGoptInference, ParamFuseStaticInfer) {
b = a.reshape(opr::GetVarShape::make(mkcvar("tshp", {2, 2}))); b = a.reshape(opr::GetVarShape::make(mkcvar("tshp", {2, 2})));
SymbolVar b1; SymbolVar b1;
unpack_vector( unpack_vector(gopt::GraphOptimizer{}
gopt::GraphOptimizer{}. .add_pass<gopt::ParamFusePass>()
add_pass<gopt::ParamFusePass>(). .apply({{b}})
apply({{b}}).endpoint_vars(), .endpoint_vars(),
b1); b1);
ASSERT_EQ(b1, a.reshape({2, 2})); ASSERT_EQ(b1, a.reshape({2, 2}));
} }
...@@ -333,11 +346,11 @@ TEST(TestGoptInference, ParamRedistributeConvMul) { ...@@ -333,11 +346,11 @@ TEST(TestGoptInference, ParamRedistributeConvMul) {
y0 = opr::Convolution::make(x * k, w); y0 = opr::Convolution::make(x * k, w);
SymbolVar y1; SymbolVar y1;
unpack_vector( unpack_vector(gopt::GraphOptimizer{}
gopt::GraphOptimizer{}. .add_pass<gopt::ParamRedistributePass>()
add_pass<gopt::ParamRedistributePass>(). .apply({{y0}})
apply({{y0}}).endpoint_vars(), .endpoint_vars(),
y1); y1);
ASSERT_NE(y0.node(), y1.node()); ASSERT_NE(y0.node(), y1.node());
...@@ -364,18 +377,18 @@ TEST(TestGoptInference, ParamRedistributeConvMulUniqReader) { ...@@ -364,18 +377,18 @@ TEST(TestGoptInference, ParamRedistributeConvMulUniqReader) {
{-1, 0, -1, -1}), {-1, 0, -1, -1}),
w = opr::SharedDeviceTensor::make(*graph, *host_w), w = opr::SharedDeviceTensor::make(*graph, *host_w),
// y0 should be replaced // y0 should be replaced
y0 = opr::powf(opr::Convolution::make(x * k, w).rename("y0") + 2, 2), y0 = opr::powf(opr::Convolution::make(x * k, w).rename("y0") + 2,
2),
y0k = (y0 * k).rename("y0k"), y0k = (y0 * k).rename("y0k"),
// y0k is accessed twice, so it should not be replaced // y0k is accessed twice, so it should not be replaced
y1 = opr::Convolution::make(y0k, w).rename("y1"), y1 = opr::Convolution::make(y0k, w).rename("y1"), z0 = y1 / y0k;
z0 = y1 / y0k;
SymbolVar z1; SymbolVar z1;
unpack_vector( unpack_vector(gopt::GraphOptimizer{}
gopt::GraphOptimizer{}. .add_pass<gopt::ParamRedistributePass>()
add_pass<gopt::ParamRedistributePass>(). .apply({{z0}})
apply({{z0}}).endpoint_vars(), .endpoint_vars(),
z1); z1);
ASSERT_NE(z0.node(), z1.node()); ASSERT_NE(z0.node(), z1.node());
auto y1_repl = z1.node()->owner_opr()->input(0)->owner_opr(); auto y1_repl = z1.node()->owner_opr()->input(0)->owner_opr();
...@@ -394,10 +407,8 @@ TEST(TestGoptInference, ParamRedistributeMulConvMul) { ...@@ -394,10 +407,8 @@ TEST(TestGoptInference, ParamRedistributeMulConvMul) {
constexpr size_t N = 4, IC = 3, IH = 5, IW = 4, OC = 4, KH = 3, KW = 2; constexpr size_t N = 4, IC = 3, IH = 5, IW = 4, OC = 4, KH = 3, KW = 2;
HostTensorGenerator<> gen; HostTensorGenerator<> gen;
auto host_x = gen({N, IC, IH, IW}), auto host_x = gen({N, IC, IH, IW}), host_k1 = gen({IC}),
host_k1 = gen({IC}), host_k2 = gen({1, OC, 1, 1}), host_w = gen({OC, IC, KH, KW});
host_k2 = gen({1, OC, 1, 1}),
host_w = gen({OC, IC, KH, KW});
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, host_x), auto x = opr::Host2DeviceCopy::make(*graph, host_x),
...@@ -409,12 +420,12 @@ TEST(TestGoptInference, ParamRedistributeMulConvMul) { ...@@ -409,12 +420,12 @@ TEST(TestGoptInference, ParamRedistributeMulConvMul) {
y0 = opr::Convolution::make(x * k1, w) * k2; y0 = opr::Convolution::make(x * k1, w) * k2;
SymbolVar y1; SymbolVar y1;
unpack_vector( unpack_vector(gopt::GraphOptimizer{}
gopt::GraphOptimizer{}. .add_pass<gopt::ParamRedistributePass>()
add_pass<gopt::ParamRedistributePass>(). .add_pass<gopt::ParamFusePass>()
add_pass<gopt::ParamFusePass>(). .apply({{y0}})
apply({{y0}}).endpoint_vars(), .endpoint_vars(),
y1); y1);
auto y1opr = y1.node()->owner_opr(); auto y1opr = y1.node()->owner_opr();
ASSERT_TRUE(y1opr->same_type<opr::Convolution>()); ASSERT_TRUE(y1opr->same_type<opr::Convolution>());
...@@ -444,12 +455,12 @@ TEST(TestGoptInference, ParamRedistributeConvAdd) { ...@@ -444,12 +455,12 @@ TEST(TestGoptInference, ParamRedistributeConvAdd) {
y0 = opr::Convolution::make(x + b, w); y0 = opr::Convolution::make(x + b, w);
SymbolVar y1; SymbolVar y1;
unpack_vector( unpack_vector(gopt::GraphOptimizer{}
gopt::GraphOptimizer{}. .add_pass<gopt::ParamRedistributePass>()
add_pass<gopt::ParamRedistributePass>(). .add_pass<gopt::ParamFusePass>()
add_pass<gopt::ParamFusePass>(). .apply({{y0}})
apply({{y0}}).endpoint_vars(), .endpoint_vars(),
y1); y1);
ASSERT_NE(y0.node(), y1.node()); ASSERT_NE(y0.node(), y1.node());
...@@ -462,41 +473,37 @@ TEST(TestGoptInference, ParamRedistributeConvAdd) { ...@@ -462,41 +473,37 @@ TEST(TestGoptInference, ParamRedistributeConvAdd) {
} }
TEST(TestGoptInference, ParamRedistributeDistThenReasso) { TEST(TestGoptInference, ParamRedistributeDistThenReasso) {
constexpr size_t N = 4, IC0 = 3, IC1 = 6, IH = 5, constexpr size_t N = 4, IC0 = 3, IC1 = 6, IH = 5, IW = 4, OC = 4, KH = 3,
IW = 4, OC = 4, KH = 3, KW = 2; KW = 2;
HostTensorGenerator<> gen; HostTensorGenerator<> gen;
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
auto mkvar = [&](const char *name, const TensorShape &shp) { auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name);
}; };
auto mkcvar = [&](const char *name, const TensorShape &shp) { auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp)).rename(name); return opr::SharedDeviceTensor::make(*graph, *gen(shp)).rename(name);
}; };
auto x0 = mkvar("x0", {N, IC0, IH, IW}), auto x0 = mkvar("x0", {N, IC0, IH, IW}), x1 = mkvar("x1", {N, IC1, IH, IW}),
x1 = mkvar("x1", {N, IC1, IH, IW}), k0 = opr::Dimshuffle::make(mkcvar("x1_", {IC0}), {-1, 0, -1, -1})
k0 = opr::Dimshuffle::make( .rename("x1"),
mkcvar("x1_", {IC0}), {-1, 0, -1, -1}).rename("x1"),
w0 = mkcvar("w0", {OC, IC0, KH, KW}), w0 = mkcvar("w0", {OC, IC0, KH, KW}),
k1 = mkcvar("k1", {1, IC1, 1, 1}), k1 = mkcvar("k1", {1, IC1, 1, 1}),
w1 = mkcvar("w1", {OC, IC1, KH, KW}), w1 = mkcvar("w1", {OC, IC1, KH, KW}), b0 = mkvar("b0", {1, OC, 1, 1}),
b0 = mkvar("b0", {1, OC, 1, 1}), b1 = mkcvar("b1", {1}), k2 = mkcvar("k2", {1}),
b1 = mkcvar("b1", {1}), y0 = (opr::Convolution::make(x0 * k0, w0) +
k2 = mkcvar("k2", {1}), opr::Convolution::make(x1 + k1, w1) + b0 + b1) *
y0 = ( k2;
opr::Convolution::make(x0 * k0, w0) +
opr::Convolution::make(x1 + k1, w1) +
b0 + b1) * k2;
SymbolVar y1; SymbolVar y1;
unpack_vector( unpack_vector(gopt::GraphOptimizer{}
gopt::GraphOptimizer{}. .add_pass<gopt::ParamRedistributePass>()
add_pass<gopt::ParamRedistributePass>(). .add_pass<gopt::ReorderArithChainPass>(
add_pass<gopt::ReorderArithChainPass>( gopt::ConstVarType::IMMUTABLE_AND_PARAM)
gopt::ConstVarType::IMMUTABLE_AND_PARAM). .add_pass<gopt::ParamFusePass>()
add_pass<gopt::ParamFusePass>(). .apply({{y0}})
apply({{y0}}).endpoint_vars(), .endpoint_vars(),
y1); y1);
ASSERT_NE(y0.node(), y1.node()); ASSERT_NE(y0.node(), y1.node());
HostTensorND host_y0, host_y1; HostTensorND host_y0, host_y1;
...@@ -506,19 +513,21 @@ TEST(TestGoptInference, ParamRedistributeDistThenReasso) { ...@@ -506,19 +513,21 @@ TEST(TestGoptInference, ParamRedistributeDistThenReasso) {
MGB_ASSERT_TENSOR_NEAR(host_y0, host_y1, 1e-5); MGB_ASSERT_TENSOR_NEAR(host_y0, host_y1, 1e-5);
auto chain = gopt::extract_opr_leaves(y1.node(), auto chain =
[](cg::OperatorNodeBase*opr){ gopt::extract_opr_leaves(y1.node(), [](cg::OperatorNodeBase* opr) {
return gopt::as_elem_opr(opr, opr::Elemwise::Mode::ADD); return gopt::as_elem_opr(opr, opr::Elemwise::Mode::ADD);
}); });
size_t nr_conv = 0; size_t nr_conv = 0;
for (auto i: chain) { for (auto i : chain) {
auto opr = i->owner_opr(); auto opr = i->owner_opr();
if (opr->same_type<opr::Convolution>()) { if (opr->same_type<opr::Convolution>()) {
++ nr_conv; ++nr_conv;
ASSERT_TRUE(opr->input(0)->owner_opr() ASSERT_TRUE(opr->input(0)
->same_type<opr::Host2DeviceCopy>()); ->owner_opr()
ASSERT_TRUE(opr->input(1)->owner_opr() ->same_type<opr::Host2DeviceCopy>());
->same_type<opr::SharedDeviceTensor>()); ASSERT_TRUE(opr->input(1)
->owner_opr()
->same_type<opr::SharedDeviceTensor>());
} }
} }
ASSERT_EQ(2u, nr_conv); ASSERT_EQ(2u, nr_conv);
...@@ -531,27 +540,24 @@ TEST(TestGoptInference, ParamRedistributeMultiChange) { ...@@ -531,27 +540,24 @@ TEST(TestGoptInference, ParamRedistributeMultiChange) {
HostTensorGenerator<> gen; HostTensorGenerator<> gen;
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0; graph->options().graph_opt_level = 0;
auto mkvar = [&](const char *name, const TensorShape &shp) { auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name);
}; };
auto mkcvar = [&](const char *name, const TensorShape &shp) { auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp)).rename(name); return opr::SharedDeviceTensor::make(*graph, *gen(shp)).rename(name);
}; };
auto x = mkvar("x", {N, IC, IH, IW}), auto x = mkvar("x", {N, IC, IH, IW}), k0 = mkcvar("k0", {1, IC, 1, 1}),
k0 = mkcvar("k0", {1, IC, 1, 1}), b0 = mkcvar("b0", {1, IC, 1, 1}), k1 = mkcvar("k0", {1}),
b0 = mkcvar("b0", {1, IC, 1, 1}), b1 = mkcvar("b0", {1}), w = mkcvar("w", {OC, IC, KH, KW}),
k1 = mkcvar("k0", {1}),
b1 = mkcvar("b0", {1}),
w = mkcvar("w", {OC, IC, KH, KW}),
y0 = (opr::Convolution::make(x * k0 + b0, w) + b1) * k1; y0 = (opr::Convolution::make(x * k0 + b0, w) + b1) * k1;
SymbolVar y1; SymbolVar y1;
unpack_vector( unpack_vector(gopt::GraphOptimizer{}
gopt::GraphOptimizer{}. .add_pass<gopt::ParamRedistributePass>()
add_pass<gopt::ParamRedistributePass>(). .add_pass<gopt::ParamFusePass>()
add_pass<gopt::ParamFusePass>(). .apply({{y0}})
apply({{y0}}).endpoint_vars(), .endpoint_vars(),
y1); y1);
ASSERT_NE(y0.node(), y1.node()); ASSERT_NE(y0.node(), y1.node());
HostTensorND host_y0, host_y1; HostTensorND host_y0, host_y1;
...@@ -577,16 +583,15 @@ TEST(TestGoptInference, ParamRedistributeMultiReader) { ...@@ -577,16 +583,15 @@ TEST(TestGoptInference, ParamRedistributeMultiReader) {
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0; graph->options().graph_opt_level = 0;
auto mkvar = [&](const char *name, const TensorShape &shp) { auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name);
}; };
auto mkcvar = [&](const char *name, const TensorShape &shp) { auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp)).rename(name); return opr::SharedDeviceTensor::make(*graph, *gen(shp)).rename(name);
}; };
auto x = mkvar("x", {N, IC, IH, IW}), auto x = mkvar("x", {N, IC, IH, IW}), k = mkcvar("k", {1, OC, 1, 1}),
k = mkcvar("k", {1, OC, 1, 1}),
w = mkcvar("w", {OC, IC, KH, KW}); w = mkcvar("w", {OC, IC, KH, KW});
auto conv = opr::Convolution::make(x, w); auto conv = opr::Convolution::make(x, w);
...@@ -594,12 +599,12 @@ TEST(TestGoptInference, ParamRedistributeMultiReader) { ...@@ -594,12 +599,12 @@ TEST(TestGoptInference, ParamRedistributeMultiReader) {
auto y0 = t * 4.2f + t * 2.4f; auto y0 = t * 4.2f + t * 2.4f;
SymbolVar y1; SymbolVar y1;
unpack_vector( unpack_vector(gopt::GraphOptimizer{}
gopt::GraphOptimizer{}. .add_pass<gopt::ParamRedistributePass>()
add_pass<gopt::ParamRedistributePass>(). .add_pass<gopt::ParamFusePass>()
add_pass<gopt::ParamFusePass>(). .apply({{y0}})
apply({{y0}}).endpoint_vars(), .endpoint_vars(),
y1); y1);
ASSERT_NE(y0.node(), y1.node()); ASSERT_NE(y0.node(), y1.node());
HostTensorND host_y0, host_y1; HostTensorND host_y0, host_y1;
...@@ -616,13 +621,11 @@ TEST(TestGoptInference, ParamRedistributeMultiReader) { ...@@ -616,13 +621,11 @@ TEST(TestGoptInference, ParamRedistributeMultiReader) {
ASSERT_TRUE(ymul0); ASSERT_TRUE(ymul0);
ASSERT_TRUE(ymul1); ASSERT_TRUE(ymul1);
auto yconv = ymul0->input(0)->owner_opr(); auto yconv = ymul0->input(0)->owner_opr();
if (!yconv->same_type<opr::Convolution>()) if (!yconv->same_type<opr::Convolution>()) {
{
yconv = ymul0->input(1)->owner_opr(); yconv = ymul0->input(1)->owner_opr();
} }
ASSERT_TRUE(yconv->same_type<opr::Convolution>()); ASSERT_TRUE(yconv->same_type<opr::Convolution>());
if (ymul1->input(0) != yconv->output(0)) if (ymul1->input(0) != yconv->output(0)) {
{
ASSERT_EQ(yconv->output(0), ymul1->input(1)); ASSERT_EQ(yconv->output(0), ymul1->input(1));
} }
ASSERT_EQ(x.node(), yconv->input(0)); ASSERT_EQ(x.node(), yconv->input(0));
...@@ -751,9 +754,9 @@ TEST(TestGoptInference, Float16IOFloat32ComputeRemap) { ...@@ -751,9 +754,9 @@ TEST(TestGoptInference, Float16IOFloat32ComputeRemap) {
auto a = mkvar("a", {N, 4, INP_H, INP_W}); auto a = mkvar("a", {N, 4, INP_H, INP_W});
auto gen_map = [&](HostTensorND& mat) { auto gen_map = [&](HostTensorND& mat) {
auto ptr = mat.ptr<float>(); auto ptr = mat.ptr<float>();
for(size_t n = 0; n < N; ++n){ for (size_t n = 0; n < N; ++n) {
for(int h = 0; h < 5; ++h){ for (int h = 0; h < 5; ++h) {
for(int w = 0; w < 5; ++w){ for (int w = 0; w < 5; ++w) {
*ptr++ = (h * 5 * 2) + 5 * 2 + 0; *ptr++ = (h * 5 * 2) + 5 * 2 + 0;
*ptr++ = (h * 5 * 2) + 5 * 2 + 1; *ptr++ = (h * 5 * 2) + 5 * 2 + 1;
} }
...@@ -905,21 +908,26 @@ TEST(TestGoptInference, Float32TOFloat16C32) { ...@@ -905,21 +908,26 @@ TEST(TestGoptInference, Float32TOFloat16C32) {
}; };
auto make_f16_graph = [&]() { auto make_f16_graph = [&]() {
auto d0 = opr::TypeCvt::make(opr::TypeCvt::make( auto d0 = opr::TypeCvt::make(
opr::Host2DeviceCopy::make(*graph, host_x0), opr::TypeCvt::make(
dtype::Float16{}), dtype::Float32{}), opr::Host2DeviceCopy::make(*graph, host_x0),
d1 = opr::TypeCvt::make(opr::TypeCvt::make( dtype::Float16{}),
opr::Host2DeviceCopy::make(*graph, host_x1), dtype::Float32{}),
dtype::Float16{}), dtype::Float32{}), d1 = opr::TypeCvt::make(
d2 = opr::TypeCvt::make(opr::TypeCvt::make( opr::TypeCvt::make(
opr::SharedDeviceTensor::make(*graph, *host_x2), opr::Host2DeviceCopy::make(*graph, host_x1),
dtype::Float16{}), dtype::Float32{}); dtype::Float16{}),
dtype::Float32{}),
d2 = opr::TypeCvt::make(
opr::TypeCvt::make(
opr::SharedDeviceTensor::make(*graph, *host_x2),
dtype::Float16{}),
dtype::Float32{});
auto y = opr::ConvBias::make(d1, d2, d0); auto y = opr::ConvBias::make(d1, d2, d0);
y = opr::Reduce::make(y, {}, y.make_scalar(1)); y = opr::Reduce::make(y, {}, y.make_scalar(1));
y = opr::TypeCvt::make( y = opr::TypeCvt::make(opr::TypeCvt::make(y, dtype::Float16{}),
opr::TypeCvt::make(y, dtype::Float16{}), dtype::Float32{});
dtype::Float32{});
return y; return y;
}; };
...@@ -927,7 +935,7 @@ TEST(TestGoptInference, Float32TOFloat16C32) { ...@@ -927,7 +935,7 @@ TEST(TestGoptInference, Float32TOFloat16C32) {
auto y_opt = make_f32_to_f16_graph(); auto y_opt = make_f32_to_f16_graph();
auto y = make_f16_graph(); auto y = make_f16_graph();
ASSERT_EQ(find_opr<opr::ConvBias>(y_opt).param().compute_mode, ASSERT_EQ(find_opr<opr::ConvBias>(y_opt).param().compute_mode,
opr::ConvBias::Param::ConvBias::ComputeMode::FLOAT32); opr::ConvBias::Param::ConvBias::ComputeMode::FLOAT32);
ASSERT_EQ(y_opt.dtype(), dtype::Float32{}); ASSERT_EQ(y_opt.dtype(), dtype::Float32{});
ASSERT_EQ(y.dtype(), dtype::Float32{}); ASSERT_EQ(y.dtype(), dtype::Float32{});
...@@ -1061,16 +1069,14 @@ TEST(TestGoptInference, Float32TOFloat16Endpoints) { ...@@ -1061,16 +1069,14 @@ TEST(TestGoptInference, Float32TOFloat16Endpoints) {
}; };
auto mkcvar = [&](const char* name, const TensorShape& shp) { auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp)) return opr::SharedDeviceTensor::make(*graph, *gen(shp)).rename(name);
.rename(name);
}; };
graph->options().graph_opt_level = 0; graph->options().graph_opt_level = 0;
opr::Convolution::Param param; opr::Convolution::Param param;
param.pad_h = param.pad_w = 0; param.pad_h = param.pad_w = 0;
auto x = mkvar("x", {8, 8, 8, 8}), auto x = mkvar("x", {8, 8, 8, 8}), y = mkvar("y", {8, 8, 8, 8}),
y = mkvar("y", {8, 8, 8, 8}),
w = mkcvar("w", {4, 8, 3, 3}), w = mkcvar("w", {4, 8, 3, 3}),
z = opr::Convolution::make(x + y, w, param); z = opr::Convolution::make(x + y, w, param);
...@@ -1277,9 +1283,8 @@ TEST(TestGoptInference, ConvertFormatNHWCD4Qint8) { ...@@ -1277,9 +1283,8 @@ TEST(TestGoptInference, ConvertFormatNHWCD4Qint8) {
param.pad_h = param.pad_w = 0; param.pad_h = param.pad_w = 0;
auto w = mkcvar("w", {4, 8, 3, 3}, dtype::QuantizedS8(0.1f)), auto w = mkcvar("w", {4, 8, 3, 3}, dtype::QuantizedS8(0.1f)),
b = mkcvar("b", {1, 4, 1, 1}, dtype::QuantizedS32(0.02f)), b = mkcvar("b", {1, 4, 1, 1}, dtype::QuantizedS32(0.02f)),
y = opr::ConvBias::make( y = opr::ConvBias::make(x, w, b, param, {},
x, w, b, param, {}, OperatorNodeConfig{dtype::QuantizedS8(0.2f)});
OperatorNodeConfig{dtype::QuantizedS8(0.2f)});
SymbolVar y_opt; SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{}; auto options = gopt::OptimizeForInferenceOptions{};
...@@ -1542,7 +1547,6 @@ TEST(TestGoptInference, AlgoWorkspaceLimit) { ...@@ -1542,7 +1547,6 @@ TEST(TestGoptInference, AlgoWorkspaceLimit) {
ASSERT_EQ(10000u, conv.execution_policy().workspace_limit); ASSERT_EQ(10000u, conv.execution_policy().workspace_limit);
} }
TEST_PASS(FuseConvBiasNonlinPass, Basic) { TEST_PASS(FuseConvBiasNonlinPass, Basic) {
auto cn = CompNode::load("xpux"); auto cn = CompNode::load("xpux");
...@@ -1563,11 +1567,9 @@ TEST_PASS(FuseConvBiasNonlinPass, Basic) { ...@@ -1563,11 +1567,9 @@ TEST_PASS(FuseConvBiasNonlinPass, Basic) {
dtype); dtype);
}; };
for (auto format : { for (auto format : {opr::Convolution::Param::Format::NCHW,
opr::Convolution::Param::Format::NCHW,
opr::Convolution::Param::Format::NHWC, opr::Convolution::Param::Format::NHWC,
opr::Convolution::Param::Format::NCHW4 opr::Convolution::Param::Format::NCHW4}) {
}) {
opr::Convolution::Param param; opr::Convolution::Param param;
param.format = format; param.format = format;
SymbolVar x, w, b; SymbolVar x, w, b;
...@@ -1670,7 +1672,6 @@ TEST(TestEnableTensorCore, SmallInputShape) { ...@@ -1670,7 +1672,6 @@ TEST(TestEnableTensorCore, SmallInputShape) {
MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt); MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt);
} }
TEST(TestEnableTensorCore, Nchw4Nchw) { TEST(TestEnableTensorCore, Nchw4Nchw) {
REQUIRE_GPU(1); REQUIRE_GPU(1);
auto cn = CompNode::load("gpu0"); auto cn = CompNode::load("gpu0");
...@@ -2085,12 +2086,11 @@ TEST(TestEnableTensorCore, ShuffleMerge) { ...@@ -2085,12 +2086,11 @@ TEST(TestEnableTensorCore, ShuffleMerge) {
return y1; return y1;
}; };
auto x = mkvar("x", {32, 64, 16, 16}, dtype::QuantizedS8(2.5f)), auto x = mkvar("x", {32, 64, 16, 16}, dtype::QuantizedS8(2.5f)),
w = mkcvar("w1", {64, 64, 3, 3}, dtype::QuantizedS8(2.5f)), w = mkcvar("w1", {64, 64, 3, 3}, dtype::QuantizedS8(2.5f)),
b = mkcvar("b", {1, 64, 1, 1}, dtype::QuantizedS32(6.25f)), b = mkcvar("b", {1, 64, 1, 1}, dtype::QuantizedS32(6.25f)),
z = mkvar("b1", {32, 64, 16, 16}, dtype::QuantizedS8(2.5f)); z = mkvar("b1", {32, 64, 16, 16}, dtype::QuantizedS8(2.5f));
x = nchw2nchw4(x), w = nchw2nchw4(w), b = nchw2nchw4(b), z= nchw2nchw4(z); x = nchw2nchw4(x), w = nchw2nchw4(w), b = nchw2nchw4(b), z = nchw2nchw4(z);
opr::ConvBias::Param param; opr::ConvBias::Param param;
param.format = opr::ConvBias::Param::Format::NCHW4; param.format = opr::ConvBias::Param::Format::NCHW4;
param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU;
...@@ -2350,7 +2350,8 @@ TEST(TestGoptInference, EnableCHWN4WarpPespective) { ...@@ -2350,7 +2350,8 @@ TEST(TestGoptInference, EnableCHWN4WarpPespective) {
opr::WarpPerspective::Param warp_param; opr::WarpPerspective::Param warp_param;
warp_param.format = opr::WarpPerspective::Param::Format::NCHW4; warp_param.format = opr::WarpPerspective::Param::Format::NCHW4;
auto y1 = opr::WarpPerspective::make(y, mat_var, TensorShape{16, 16}, warp_param); auto y1 = opr::WarpPerspective::make(y, mat_var, TensorShape{16, 16},
warp_param);
y1 = opr::TypeCvt::make(y1, dtype::Float32()); y1 = opr::TypeCvt::make(y1, dtype::Float32());
auto nchw42nchw = [](SymbolVar x) { auto nchw42nchw = [](SymbolVar x) {
auto xshp = opr::GetVarShape::make(x); auto xshp = opr::GetVarShape::make(x);
...@@ -2366,7 +2367,8 @@ TEST(TestGoptInference, EnableCHWN4WarpPespective) { ...@@ -2366,7 +2367,8 @@ TEST(TestGoptInference, EnableCHWN4WarpPespective) {
}; };
y1 = nchw42nchw(y1); y1 = nchw42nchw(y1);
warp_param.format = opr::WarpPerspective::Param::Format::NCHW; warp_param.format = opr::WarpPerspective::Param::Format::NCHW;
auto y2 = opr::WarpPerspective::make(y1, mat_var, TensorShape{16, 16}, warp_param); auto y2 = opr::WarpPerspective::make(y1, mat_var, TensorShape{16, 16},
warp_param);
SymbolVar y_opt; SymbolVar y_opt;
SymbolVar y_cudnn; SymbolVar y_cudnn;
{ {
...@@ -2833,8 +2835,7 @@ TEST(TestGoptInference, ConvertFormatNCHW4Ic3) { ...@@ -2833,8 +2835,7 @@ TEST(TestGoptInference, ConvertFormatNCHW4Ic3) {
auto mkcvar = [&](const char* name, const TensorShape& shp, auto mkcvar = [&](const char* name, const TensorShape& shp,
const DType& dtype) { const DType& dtype) {
return opr::TypeCvt::make( return opr::TypeCvt::make(
opr::SharedDeviceTensor::make(*graph, *gen(shp)) opr::SharedDeviceTensor::make(*graph, *gen(shp)).rename(name),
.rename(name),
dtype); dtype);
}; };
...@@ -2878,7 +2879,6 @@ TEST(TestGoptInference, ConvertFormatNCHW4Ic3) { ...@@ -2878,7 +2879,6 @@ TEST(TestGoptInference, ConvertFormatNCHW4Ic3) {
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
} }
TEST(TestGoptInference, ConvertFormatNCHW88) { TEST(TestGoptInference, ConvertFormatNCHW88) {
HostTensorGenerator<> gen; HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0"); auto cn = CompNode::load("cpu0");
...@@ -2894,12 +2894,12 @@ TEST(TestGoptInference, ConvertFormatNCHW88) { ...@@ -2894,12 +2894,12 @@ TEST(TestGoptInference, ConvertFormatNCHW88) {
auto host_x = gen({2, 3, 16, 16}, cn); auto host_x = gen({2, 3, 16, 16}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x); auto x = opr::Host2DeviceCopy::make(*graph, host_x);
//!Hybrid nchw88 mode //! Hybrid nchw88 mode
opr::Convolution::Param param_conv; opr::Convolution::Param param_conv;
param_conv.pad_h = param_conv.pad_w = 1; param_conv.pad_h = param_conv.pad_w = 1;
auto w1 = mkcvar("w1", {8, 3, 3, 3}), auto w1 = mkcvar("w1", {8, 3, 3, 3}),
conv1 = opr::Convolution::make(x, w1, param_conv); conv1 = opr::Convolution::make(x, w1, param_conv);
//!channel wise //! channel wise
opr::ConvBias::Param param_conv_bias; opr::ConvBias::Param param_conv_bias;
param_conv_bias.pad_h = param_conv_bias.pad_w = 1; param_conv_bias.pad_h = param_conv_bias.pad_w = 1;
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP;
...@@ -2976,12 +2976,12 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { ...@@ -2976,12 +2976,12 @@ TEST(TestGoptInference, ConvertFormatNCHW44) {
auto host_x = gen({2, 3, 16, 16}, cn); auto host_x = gen({2, 3, 16, 16}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x); auto x = opr::Host2DeviceCopy::make(*graph, host_x);
//!Hybrid nchw88 mode //! Hybrid nchw88 mode
opr::Convolution::Param param_conv; opr::Convolution::Param param_conv;
param_conv.pad_h = param_conv.pad_w = 1; param_conv.pad_h = param_conv.pad_w = 1;
auto w1 = mkcvar("w1", {8, 3, 3, 3}), auto w1 = mkcvar("w1", {8, 3, 3, 3}),
conv1 = opr::Convolution::make(x, w1, param_conv); conv1 = opr::Convolution::make(x, w1, param_conv);
//!channel wise //! channel wise
opr::ConvBias::Param param_conv_bias; opr::ConvBias::Param param_conv_bias;
param_conv_bias.pad_h = param_conv_bias.pad_w = 1; param_conv_bias.pad_h = param_conv_bias.pad_w = 1;
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP;
...@@ -3140,12 +3140,12 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { ...@@ -3140,12 +3140,12 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) {
auto host_x = gen({2, 3, 16, 16}, cn); auto host_x = gen({2, 3, 16, 16}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x); auto x = opr::Host2DeviceCopy::make(*graph, host_x);
//!Hybrid nchw88 mode //! Hybrid nchw88 mode
opr::Convolution::Param param_conv; opr::Convolution::Param param_conv;
param_conv.pad_h = param_conv.pad_w = 1; param_conv.pad_h = param_conv.pad_w = 1;
auto w1 = mkcvar("w1", {8, 3, 3, 3}), auto w1 = mkcvar("w1", {8, 3, 3, 3}),
conv1 = opr::Convolution::make(x, w1, param_conv); conv1 = opr::Convolution::make(x, w1, param_conv);
//!channel wise //! channel wise
opr::ConvBias::Param param_conv_bias; opr::ConvBias::Param param_conv_bias;
param_conv_bias.pad_h = param_conv_bias.pad_w = 1; param_conv_bias.pad_h = param_conv_bias.pad_w = 1;
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册