提交 5f15f759 编写于 作者: M Megvii Engine Team

test(mgb/gopt): add a testcase for SubGraphExtractor with multiple outputs

GitOrigin-RevId: 7785bdc8c090467cf75c864cb4056a0cb2059199
上级 a6230ba9
...@@ -94,8 +94,8 @@ std::unique_ptr<LayoutTransformContext> make_arm_ctx( ...@@ -94,8 +94,8 @@ std::unique_ptr<LayoutTransformContext> make_arm_ctx(
opr::TypeCvt::typeinfo(), opr::TypeCvt::typeinfo(),
opr::PoolingForward::typeinfo(), opr::PoolingForward::typeinfo(),
opr::Resize::typeinfo(), opr::Resize::typeinfo(),
opr::PowC::typeinfo(), opr::PowC::typeinfo(),
opr::Concat::typeinfo(), opr::Concat::typeinfo(),
}; };
SmallVector<TensorFormats> available_tensor_formats = { SmallVector<TensorFormats> available_tensor_formats = {
...@@ -103,22 +103,23 @@ std::unique_ptr<LayoutTransformContext> make_arm_ctx( ...@@ -103,22 +103,23 @@ std::unique_ptr<LayoutTransformContext> make_arm_ctx(
DNN_INC_FLOAT16(TensorFormats::NCHWc8)}; DNN_INC_FLOAT16(TensorFormats::NCHWc8)};
Attribute attribute = {base_opr_format, base_tensor_format, Target::ARM}; Attribute attribute = {base_opr_format, base_tensor_format, Target::ARM};
auto ctx = std::make_unique<LayoutTransformContext>( auto ctx = std::make_unique<LayoutTransformContext>(
std::move(opr_list), std::move(available_tensor_formats), std::move(opr_list), std::move(available_tensor_formats), attribute);
attribute);
ctx->add_opr_config( ctx->add_opr_config(
opr::ConvBiasForward::typeinfo(), opr::ConvBiasForward::typeinfo(),
{OprFormat::NCHW, OprFormat::NCHW44, {OprFormat::NCHW, OprFormat::NCHW44, DNN_INC_FLOAT16(OprFormat::NCHW88),
DNN_INC_FLOAT16(OprFormat::NCHW88), OprFormat::NCHW44_DOT}) OprFormat::NCHW44_DOT})
.add_opr_config( .add_opr_config(
opr::ConvolutionForward::typeinfo(), opr::ConvolutionForward::typeinfo(),
{OprFormat::NCHW, OprFormat::NCHW44, {OprFormat::NCHW, OprFormat::NCHW44,
DNN_INC_FLOAT16(OprFormat::NCHW88), OprFormat::NCHW44_DOT}) DNN_INC_FLOAT16(OprFormat::NCHW88), OprFormat::NCHW44_DOT})
.add_opr_config(opr::PoolingForward::typeinfo(), .add_opr_config(
{OprFormat::NCHW, OprFormat::NCHW44, opr::PoolingForward::typeinfo(),
DNN_INC_FLOAT16(OprFormat::NCHW88)}) {OprFormat::NCHW, OprFormat::NCHW44,
.add_opr_config(opr::ResizeForward::typeinfo(), DNN_INC_FLOAT16(OprFormat::NCHW88)})
{OprFormat::NCHW, OprFormat::NCHW44, .add_opr_config(
DNN_INC_FLOAT16(OprFormat::NCHW88)}); opr::ResizeForward::typeinfo(),
{OprFormat::NCHW, OprFormat::NCHW44,
DNN_INC_FLOAT16(OprFormat::NCHW88)});
return ctx; return ctx;
} }
} // namespace } // namespace
......
...@@ -80,8 +80,7 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW> { ...@@ -80,8 +80,7 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW> {
template <> template <>
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW44> { struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW44> {
static Maybe<OprTensorFormatsConfiguration> dispatch( static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
const OperatorNodeBase* opr) {
OprTensorFormatsConfiguration config; OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo(); config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW44; config.opr_format = OprFormat::NCHW44;
...@@ -101,8 +100,7 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW44> { ...@@ -101,8 +100,7 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW44> {
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
template <> template <>
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW88> { struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW88> {
static Maybe<OprTensorFormatsConfiguration> dispatch( static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
const OperatorNodeBase* opr) {
OprTensorFormatsConfiguration config; OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo(); config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW88; config.opr_format = OprFormat::NCHW88;
...@@ -440,8 +438,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::CHWN4> { ...@@ -440,8 +438,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::CHWN4> {
template <typename Opr> template <typename Opr>
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44> { struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44> {
static Maybe<OprTensorFormatsConfiguration> dispatch( static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
const OperatorNodeBase* opr) {
const auto& conv = opr->cast_final_safe<Opr>(); const auto& conv = opr->cast_final_safe<Opr>();
OprTensorFormatsConfiguration config; OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo(); config.typeinfo = opr->dyn_typeinfo();
...@@ -451,8 +448,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44> { ...@@ -451,8 +448,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44> {
for (size_t i = 0; i < opr->input().size(); ++i) { for (size_t i = 0; i < opr->input().size(); ++i) {
available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float32; available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float32;
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv());
TensorType tensor_type = TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE;
i == 1 ? TensorType::WEIGHT : TensorType::FEATURE;
config.input_tensor_types.emplace_back(tensor_type); config.input_tensor_types.emplace_back(tensor_type);
} }
available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float32; available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float32;
...@@ -484,8 +480,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44> { ...@@ -484,8 +480,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44> {
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
template <typename Opr> template <typename Opr>
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW88> { struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW88> {
static Maybe<OprTensorFormatsConfiguration> dispatch( static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
const OperatorNodeBase* opr) {
const auto& conv = opr->cast_final_safe<Opr>(); const auto& conv = opr->cast_final_safe<Opr>();
OprTensorFormatsConfiguration config; OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo(); config.typeinfo = opr->dyn_typeinfo();
...@@ -495,8 +490,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW88> { ...@@ -495,8 +490,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW88> {
for (size_t i = 0; i < opr->input().size(); ++i) { for (size_t i = 0; i < opr->input().size(); ++i) {
available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float16; available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float16;
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv());
TensorType tensor_type = TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE;
i == 1 ? TensorType::WEIGHT : TensorType::FEATURE;
config.input_tensor_types.emplace_back(tensor_type); config.input_tensor_types.emplace_back(tensor_type);
} }
available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float16; available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float16;
...@@ -528,8 +522,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW88> { ...@@ -528,8 +522,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW88> {
template <typename Opr> template <typename Opr>
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44_DOT> { struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44_DOT> {
static Maybe<OprTensorFormatsConfiguration> dispatch( static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
const OperatorNodeBase* opr) {
const auto& conv = opr->cast_final_safe<Opr>(); const auto& conv = opr->cast_final_safe<Opr>();
OprTensorFormatsConfiguration config; OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo(); config.typeinfo = opr->dyn_typeinfo();
...@@ -538,22 +531,18 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44_DOT> { ...@@ -538,22 +531,18 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44_DOT> {
// setup dtypes // setup dtypes
for (size_t i = 0; i < opr->input().size(); ++i) { for (size_t i = 0; i < opr->input().size(); ++i) {
if (i == 2) { if (i == 2) {
available &= opr->input(i)->dtype().enumv() == available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS32;
DTypeEnum::QuantizedS32;
} else { } else {
available &= opr->input(i)->dtype().enumv() == available &=
DTypeEnum::QuantizedS8 || opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8 ||
opr->input(i)->dtype().enumv() == opr->input(i)->dtype().enumv() == DTypeEnum::Quantized8Asymm;
DTypeEnum::Quantized8Asymm;
} }
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv());
TensorType tensor_type = TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE;
i == 1 ? TensorType::WEIGHT : TensorType::FEATURE;
config.input_tensor_types.emplace_back(tensor_type); config.input_tensor_types.emplace_back(tensor_type);
} }
available &= available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8 ||
opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8 || opr->output(0)->dtype().enumv() == DTypeEnum::Quantized8Asymm;
opr->output(0)->dtype().enumv() == DTypeEnum::Quantized8Asymm;
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv());
// setup tensor formats // setup tensor formats
if (conv.param().sparse == Opr::Param::Sparse::DENSE) { if (conv.param().sparse == Opr::Param::Sparse::DENSE) {
...@@ -747,7 +736,7 @@ StaticData::StaticData() { ...@@ -747,7 +736,7 @@ StaticData::StaticData() {
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW44); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW44);
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW88); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW88);
#endif #endif
......
...@@ -264,4 +264,42 @@ TEST(TestSubGraphExtractor, Complicated) { ...@@ -264,4 +264,42 @@ TEST(TestSubGraphExtractor, Complicated) {
output_file(ssprintf("%s.json", prefix).c_str())); output_file(ssprintf("%s.json", prefix).c_str()));
} }
TEST(TestSubGraphExtractor, SubGraphWithMultipleOutputs) {
HostTensorGenerator<> gen;
auto graph = ComputingGraph::make();
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name);
};
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp)).rename(name);
};
graph->options().graph_opt_level = 0;
auto x = mkvar("x", {8, 8, 8, 8}), w = mkcvar("w", {4, 8, 3, 3});
opr::Convolution::Param param;
param.pad_h = param.pad_w = 1;
auto c = opr::Convolution::make(x, w, param);
auto neg_c = -c;
auto z = opr::Concat::make({c, neg_c}, 1);
using OprList = SubGraphExtractor::OprList;
static const OprList opr_list = {
opr::ConvolutionForward::typeinfo(),
opr::Elemwise::typeinfo(),
};
SubGraphExtractor extractor(opr_list);
auto partitions = extractor.extract({z});
ASSERT_EQ(partitions.size(), 1u);
ASSERT_EQ(partitions[0].output().size(), 2u);
ASSERT_TRUE(partitions[0].output().count(c.node()) > 0);
ASSERT_TRUE(partitions[0].output().count(neg_c.node()) > 0);
ASSERT_EQ(partitions[0].input().size(), 2u);
ASSERT_TRUE(partitions[0].input().count(x.node()) > 0);
partitions[0].to_json()->writeto_fpath(
output_file("TestSubGraphExtractor.SubGraphMultipleOuputs.json"));
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册