提交 2ed76b16 编写于 作者: M Megvii Engine Team

feat(mgb/gopt): add graph dumper for graph partition

GitOrigin-RevId: 6dbcb67009678ce9a3c895d2115db8c429531cfb
上级 76b28408
...@@ -11,17 +11,214 @@ ...@@ -11,17 +11,214 @@
*/ */
#include "megbrain/gopt/subgraph_extractor.h" #include "megbrain/gopt/subgraph_extractor.h"
#include <atomic>
#include "megbrain/serialization/opr_shallow_copy.h"
using namespace mgb; using namespace mgb;
using namespace cg; using namespace cg;
using namespace gopt; using namespace gopt;
/* ================== GraphPartition::InputPlaceholder =================*/
// clang-format off
MGB_DEFINE_OPR_CLASS(GraphPartition::InputPlaceholder,
cg::SingleCNOperatorNodeBase) // {
public:
InputPlaceholder(VarNode* src_var, const TensorShape& infer_shp,
std::unique_ptr<HostTensorND> infer_val = nullptr);
static SymbolVar make(VarNode* src_var, const TensorShape& infer_shp,
std::unique_ptr<HostTensorND> infer_val = nullptr);
size_t input_id() const { return m_id; }
private:
void init_output_static_infer_desc() override;
void scn_do_execute() override;
void init_output_comp_node() override;
const size_t m_id;
TensorShape m_infer_shp;
std::unique_ptr<HostTensorND> m_infer_val;
static std::atomic_size_t sm_id;
};
// clang-format on
MGB_DYN_TYPE_OBJ_FINAL_IMPL(GraphPartition::InputPlaceholder);
std::atomic_size_t GraphPartition::InputPlaceholder::sm_id{0};
GraphPartition::InputPlaceholder::InputPlaceholder(
VarNode* src_var, const TensorShape& infer_shp,
std::unique_ptr<HostTensorND> infer_val)
: Super(src_var->owner_graph(), {}, {}, {}),
m_id{sm_id.fetch_add(1, std::memory_order_relaxed)},
m_infer_shp{infer_shp},
m_infer_val{std::move(infer_val)} {
name(ssprintf("InputPlaceholder@%zu", m_id));
add_equivalence_component<ScalarHash<DTypeEnum>>(src_var->dtype().enumv());
add_equivalence_component<ScalarHash<size_t>>(m_id);
add_output(None)->dtype(src_var->dtype());
}
void GraphPartition::InputPlaceholder::init_output_comp_node() {
output(0)->comp_node(CompNode::default_cpu());
}
void GraphPartition::InputPlaceholder::scn_do_execute() {
mgb_throw(InternalError, "InputPlaceholder opr can not be executed");
}
void GraphPartition::InputPlaceholder::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager();
if (m_infer_shp.ndim == 0) {
auto infer_shape = [](TensorShape&, const InpVal&) { return false; };
mgr.register_shape_infer(output(0),
{SourceType::MUTABLE, {}, infer_shape});
} else {
mgr.register_shape_infer(output(0),
ShapeInferDesc::make_const(m_infer_shp));
}
if (m_infer_val == nullptr) {
auto infer_value = [](DeviceTensorND&, const InpVal&) { return false; };
mgr.register_value_infer(output(0),
{SourceType::MUTABLE, {}, infer_value});
} else {
auto infer_value = [this](DeviceTensorND& dest, const InpVal&) {
dest.copy_from(*m_infer_val).sync();
return true;
};
mgr.register_value_infer(output(0),
{SourceType::CONSTANT, {}, infer_value});
}
}
SymbolVar GraphPartition::InputPlaceholder::make(
VarNode* src_var, const TensorShape& infer_shp,
std::unique_ptr<HostTensorND> infer_val) {
return src_var->owner_graph()
->insert_opr(std::make_unique<InputPlaceholder>(
src_var, infer_shp, std::move(infer_val)))
->output(0);
}
/* ================== GraphPartition =================*/
#if MGB_ENABLE_JSON
std::shared_ptr<json::Value> GraphPartition::to_json() const {
auto replaced_outputs = std::get<1>(replace_graph_by_placeholder());
ThinHashSet<VarNode*> all_var_node;
ThinHashSet<OperatorNodeBase*> all_opr_node;
auto comp_seq = json::Array::make();
auto cb = [&](OperatorNodeBase* opr) {
comp_seq->add(json::String::make(opr->id_str()));
for (const auto& i : opr->input()) {
if (all_var_node.count(i) == 0) {
all_var_node.insert(i);
}
}
all_opr_node.insert(opr);
for (const auto& o : opr->output()) {
all_var_node.insert(o);
}
};
cg::DepOprIter iter{cb};
for (const auto& o : replaced_outputs)
iter.add(o->owner_opr());
auto dump_node_coll = [](auto&& collection) {
auto objptr = json::Object::make();
auto&& obj = *objptr;
for (auto&& i : collection)
obj[i->id_str()] = i->to_json();
return objptr;
};
return json::Object::make({{"operator", dump_node_coll(all_opr_node)},
{"var", dump_node_coll(all_var_node)},
{"comp_seq", comp_seq}});
}
#endif
std::pair<VarNodeArray, VarNodeArray>
GraphPartition::replace_graph_by_placeholder() const {
ThinHashMap<VarNode*, VarNode*> old2new;
auto graph_partition_copy_opr_shallow = [](OperatorNodeBase* opr,
const VarNodeArray& inps) {
OperatorNodeConfig config = opr->config();
return serialization::copy_opr_shallow(*opr, inps, config)->output(0);
};
OperatorNodeSet input_opr_set;
for (const auto& i : m_inputs)
input_opr_set.insert(i->owner_opr());
VarNodeArray placeholders;
VarNodeArray replaced_outputs;
VarNodeArray new_i;
auto cb = [&](OperatorNodeBase* opr) {
for (const auto& o : opr->output()) {
if (o->contain_flag(VarNode::Flag::VOLATILE_CONTENT) ||
(input_opr_set.count(opr) && !m_inputs.count(o))) {
continue;
}
VarNode* new_o;
if (m_inputs.count(o)) {
auto&& mgr = opr->owner_graph()->static_infer_manager();
const TensorShape* shp_ptr = nullptr;
if (cg::is_static_var_shape(o)) {
shp_ptr = mgr.infer_shape_fallible(o);
}
TensorShape infer_shp;
if (shp_ptr)
infer_shp = *shp_ptr;
std::unique_ptr<HostTensorND> hval = nullptr;
const DeviceTensorND* dval_ptr = nullptr;
if (cg::is_static_var_value(o)) {
dval_ptr = mgr.infer_value_fallible(o);
}
if (dval_ptr) {
hval.reset(new HostTensorND(CompNode::default_cpu(),
dval_ptr->dtype()));
hval->resize(dval_ptr->shape()).copy_from(*dval_ptr).sync();
}
new_o = InputPlaceholder::make(o, infer_shp, std::move(hval))
.node();
placeholders.push_back(new_o);
} else {
new_i.clear();
for (const auto& i : opr->input()) {
new_i.push_back(old2new.at(i));
}
new_o = graph_partition_copy_opr_shallow(o->owner_opr(), new_i);
}
old2new[o] = new_o;
}
};
cg::DepOprIter iter{cb};
for (auto&& i : m_inputs) {
for (auto&& j : i->owner_opr()->input()) {
if (!input_opr_set.count(j->owner_opr()) &&
!m_opr_set.count(j->owner_opr())) {
iter.set_visited(j->owner_opr());
}
}
}
for (auto&& o : m_outputs)
iter.add(o->owner_opr());
for (auto&& o : m_outputs) {
replaced_outputs.push_back(old2new.at(o));
}
return std::make_pair(placeholders, replaced_outputs);
}
/* ================== SubGraphExtractor =================*/ /* ================== SubGraphExtractor =================*/
std::vector<InternalGraph> SubGraphExtractor::extract( std::vector<GraphPartition> SubGraphExtractor::extract(
const SymbolVarArray& endpoint_vars) const { const SymbolVarArray& endpoint_vars) const {
ThinHashMap<OperatorNodeBase*, std::pair<OperatorNodeBase*, int>> parent; ThinHashMap<OperatorNodeBase*, std::pair<OperatorNodeBase*, int>> parent;
thin_function<OperatorNodeBase*(OperatorNodeBase*)> union_find; thin_function<OperatorNodeBase*(OperatorNodeBase*)> union_find;
auto union_find = [&parent, &union_find](OperatorNodeBase* o) { union_find = [&parent, &union_find](OperatorNodeBase* o) {
if (parent[o].first == o) if (parent[o].first == o)
return o; return o;
else { else {
...@@ -34,7 +231,7 @@ std::vector<InternalGraph> SubGraphExtractor::extract( ...@@ -34,7 +231,7 @@ std::vector<InternalGraph> SubGraphExtractor::extract(
OperatorNodeBase* y) { OperatorNodeBase* y) {
auto root_x = union_find(x), root_y = union_find(y); auto root_x = union_find(x), root_y = union_find(y);
if (root_x != root_y) { if (root_x != root_y) {
OperatorNodeBase *large, small; OperatorNodeBase *large, *small;
if (parent[root_x].second < parent[root_y].second) { if (parent[root_x].second < parent[root_y].second) {
small = root_x, large = root_y; small = root_x, large = root_y;
} else { } else {
...@@ -42,25 +239,23 @@ std::vector<InternalGraph> SubGraphExtractor::extract( ...@@ -42,25 +239,23 @@ std::vector<InternalGraph> SubGraphExtractor::extract(
} }
parent[small].first = large; parent[small].first = large;
if (parent[large].second == parent[small].second) { if (parent[large].second == parent[small].second) {
parend[large].second += 1; parent[large].second += 1;
} }
} }
}; };
std::vector<OperatorNodeBase*> topo; std::vector<OperatorNodeBase*> topo;
auto cb = [&topo](OperatorNodeBase* opr) { auto cb = [this, &parent, &union_merge, &topo](OperatorNodeBase* opr) {
topo.push_back(opr); topo.push_back(opr);
if (opr_list.count(opr->dyn_typeinfo()) == 0) if (m_opr_list.count(opr->dyn_typeinfo()) == 0)
return; return;
auto find = parent.find(opr); auto find = parent.find(opr);
if (find == parent.end()) { if (find == parent.end()) {
auto insert = parent.insert(std::make_pair(opr, std::make_pair(opr, 0)));
parent.insert(std::make_pair(opr, std::make_pair(opr, 0)));
find = insert.first;
} }
for (auto&& i : opr->input()) { for (auto&& i : opr->input()) {
auto&& o = i->owner_opr(); auto&& o = i->owner_opr();
if (opr_list.count(o->dyn_typeinfo()) == 0) if (m_opr_list.count(o->dyn_typeinfo()) == 0)
continue; continue;
union_merge(opr, o); union_merge(opr, o);
} }
...@@ -69,33 +264,51 @@ std::vector<InternalGraph> SubGraphExtractor::extract( ...@@ -69,33 +264,51 @@ std::vector<InternalGraph> SubGraphExtractor::extract(
for (const auto& v : endpoint_vars) for (const auto& v : endpoint_vars)
iter.add(v.node()->owner_opr()); iter.add(v.node()->owner_opr());
std::vector<InternalGraph> partitions; std::vector<GraphPartition> partitions;
ThinHashMap<OperatorNodeBase*, InternalGraph*> roots; partitions.reserve(topo.size());
ThinHashMap<OperatorNodeBase*, GraphPartition*> roots;
for (const auto& opr : reverse_adaptor(topo)) { for (const auto& opr : reverse_adaptor(topo)) {
auto root = union_find(opr); if (m_opr_list.count(opr->dyn_typeinfo()) == 0) {
auto find = roots.find(root); for (const auto& i : opr->input()) {
InternalGraph* internal_graph = nullptr; if (m_opr_list.count(i->owner_opr()->dyn_typeinfo())) {
if (find == roots.end()) { auto root = union_find(i->owner_opr());
partitions.emplace_back(InternalGraph{}); GraphPartition* partition;
auto insert = auto find = roots.find(root);
roots.insert(std::make_pair(root, &partitions.back())); if (find != roots.end()) {
internal_graph = insert.first->second; partition = find->second;
internal_graph->m_outputs.insert(opr->output(0)); partition->output().insert(i);
}
}
}
} else { } else {
internal_graph = find->second; auto root = union_find(opr);
auto erase = internal_graph->m_inputs.erase(opr->output(0)); auto find = roots.find(root);
if (erase > 0) { GraphPartition* partition = nullptr;
internal_graph->m_internals.insert(opr->output(0)); if (find == roots.end()) {
partitions.emplace_back(GraphPartition{});
auto insert =
roots.insert(std::make_pair(root, &partitions.back()));
partition = insert.first->second;
for (auto&& o : opr->output()) {
if (!o->contain_flag(cg::VarNode::Flag::VOLATILE_CONTENT))
partition->output().insert(o);
}
} else { } else {
internal_graph->m_outputs.insert(opr->output(0)); partition = find->second;
for (auto&& o : opr->output()) {
if (!o->contain_flag(cg::VarNode::Flag::VOLATILE_CONTENT)) {
auto erase = partition->input().erase(o);
if (erase == 0)
partition->output().insert(o);
}
}
} }
partition->opr_set().insert(opr);
for (const auto& i : opr->input())
partition->input().insert(i);
} }
for (const auto& i : opr->input())
internal_graph->m_inputs.insert(i);
} }
return partitions; return partitions;
} }
/* ============= SubGraphExtractor =================*/
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -16,17 +16,37 @@ ...@@ -16,17 +16,37 @@
namespace mgb { namespace mgb {
namespace gopt { namespace gopt {
struct InternalGraph { class GraphPartition {
ThinHashSet<VarNode*> m_internals; public:
ThinHashSet<VarNode*> m_inputs; using VarNodeSet = ThinHashSet<VarNode*>;
ThinHashSet<VarNode*> m_outputs; using OperatorNodeSet = ThinHashSet<cg::OperatorNodeBase*>;
class InputPlaceholder;
GraphPartition() = default;
#if MGB_ENABLE_JSON
std::shared_ptr<json::Value> to_json() const;
#endif
const OperatorNodeSet& opr_set() const { return m_opr_set; }
const VarNodeSet& input() const { return m_inputs; }
const VarNodeSet& output() const { return m_outputs; }
OperatorNodeSet& opr_set() { return m_opr_set; }
VarNodeSet& input() { return m_inputs; }
VarNodeSet& output() { return m_outputs; }
private:
OperatorNodeSet m_opr_set;
VarNodeSet m_inputs;
VarNodeSet m_outputs;
std::pair<VarNodeArray, VarNodeArray> replace_graph_by_placeholder() const;
}; };
class SubGraphExtractor { class SubGraphExtractor {
public: public:
using OprList = ThinHashSet<Typeinfo*>; using OprList = ThinHashSet<Typeinfo*>;
SubGraphExtractor(OprList opr_list) : m_opr_list{opr_list} {}; SubGraphExtractor(OprList opr_list) : m_opr_list{opr_list} {};
std::vector<InternalGraph> extract( std::vector<GraphPartition> extract(
const SymbolVarArray& endpoint_vars) const; const SymbolVarArray& endpoint_vars) const;
private: private:
......
/**
* \file src/gopt/test/subgraph_extractor.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "./helper.h"
#include "megbrain/gopt/subgraph_extractor.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/internal/identical_fwd.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/serialization/serializer.h"
using namespace mgb;
using namespace gopt;
using namespace serialization;
namespace {
// clang-format off
MGB_DEFINE_OPR_CLASS(MultipleInputOutput,
cg::SingleCNOperatorNodeBase) // {
public:
MultipleInputOutput(const VarNodeArray& inputs, const OperatorNodeConfig& config);
static SymbolVarArray make(const SymbolVarArray& inputs, const OperatorNodeConfig& config = {});
private:
void scn_do_execute() override { }
void init_output_static_infer_desc() override { }
};
// clang-format on
MGB_DYN_TYPE_OBJ_FINAL_IMPL(MultipleInputOutput);
MultipleInputOutput::MultipleInputOutput(const VarNodeArray& inputs,
const OperatorNodeConfig& config)
: Super(inputs[0]->owner_graph(), config, "multiple_input_output",
inputs) {
for (auto&& i : inputs)
add_input({i});
if (inputs.size() == 1) {
add_output(None);
} else {
for (size_t i = 0; i < inputs.size(); ++i)
add_output(ssprintf("o%zu", i));
}
cg::add_workspace_output(this);
}
SymbolVarArray MultipleInputOutput::make(const SymbolVarArray& inputs,
const OperatorNodeConfig& config) {
auto src = cg::to_var_node_array(inputs);
auto multiple_io = std::make_unique<MultipleInputOutput>(src, config);
auto ret =
cg::to_symbol_var_array(src[0]->owner_graph()
->insert_opr(std::move(multiple_io))
->output());
ret.pop_back();
return ret;
}
}
TEST(TestSubGraphExtractor, MultipleOutputs) {
HostTensorGenerator<> gen;
auto graph = ComputingGraph::make();
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name);
};
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp)).rename(name);
};
graph->options().graph_opt_level = 0;
auto x = mkvar("x", {8, 8, 8, 8}), w1 = mkcvar("w1", {4, 8, 3, 3});
auto y = mkvar("y", {1, 8, 1, 1});
auto add = x + y;
opr::Convolution::Param param;
param.pad_h = param.pad_w = 1;
auto c1 = opr::Convolution::make(add, w1, param);
auto w2 = mkcvar("w2", {8, 4, 3, 3});
auto c2 = opr::ConvolutionBackwardData::make(w2, add, param, {}, {});
auto sym_var_arr = MultipleInputOutput::make({c1, c2});
auto z = sym_var_arr[1];
z = z + (-128);
using OprList = SubGraphExtractor::OprList;
static const OprList opr_list = {
opr::ConvolutionForward::typeinfo(),
opr::Elemwise::typeinfo(),
opr::TypeCvt::typeinfo(),
MultipleInputOutput::typeinfo(),
};
SubGraphExtractor extractor(opr_list);
auto partitions = extractor.extract({z});
ASSERT_EQ(partitions.size(), 1u);
// outputs: sym_var_arr[0], z, add
ASSERT_EQ(partitions[0].output().size(), 3u);
ASSERT_TRUE(partitions[0].output().count(add.node()) > 0);
ASSERT_TRUE(partitions[0].output().count(z.node()) > 0);
ASSERT_TRUE(partitions[0].output().count(sym_var_arr[0].node()) > 0);
ASSERT_TRUE(partitions[0].output().count(sym_var_arr[1].node()) == 0);
// inputs: x, y, w1, c2, (-128)
ASSERT_EQ(partitions[0].input().size(), 5u);
ASSERT_TRUE(partitions[0].input().count(x.node()) > 0);
ASSERT_TRUE(partitions[0].input().count(c2.node()) > 0);
// opr: (x + y) conv1 multi_io, (z - 128)
ASSERT_EQ(partitions[0].opr_set().size(), 4u);
ASSERT_TRUE(partitions[0].opr_set().count(add.node()->owner_opr()) > 0);
ASSERT_TRUE(partitions[0].opr_set().count(c1.node()->owner_opr()) > 0);
ASSERT_TRUE(partitions[0].opr_set().count(
sym_var_arr[0].node()->owner_opr()) > 0);
ASSERT_TRUE(partitions[0].opr_set().count(z.node()->owner_opr()) > 0);
}
TEST(TestSubGraphExtractor, MultipleReaders) {
HostTensorGenerator<> gen;
auto graph = ComputingGraph::make();
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name);
};
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp)).rename(name);
};
graph->options().graph_opt_level = 0;
auto x = mkvar("x", {8, 8, 8, 8}), w1 = mkcvar("w1", {4, 8, 3, 3});
auto y = mkvar("y", {1, 8, 1, 1});
auto add = x + y;
opr::Convolution::Param param;
param.pad_h = param.pad_w = 1;
auto c1 = opr::Convolution::make(add, w1, param);
auto w2 = mkcvar("w2", {8, 4, 3, 3});
auto c2 = opr::ConvolutionBackwardData::make(w2, add, param, {}, {});
auto z = c1 + c2;
using OprList = SubGraphExtractor::OprList;
static const OprList opr_list = {
opr::ConvolutionForward::typeinfo(),
opr::Elemwise::typeinfo(),
opr::TypeCvt::typeinfo(),
};
SubGraphExtractor extractor(opr_list);
auto partitions = extractor.extract({z});
ASSERT_EQ(partitions.size(), 1u);
ASSERT_EQ(partitions[0].output().size(), 2u);
ASSERT_TRUE(partitions[0].output().count(add.node()) > 0);
ASSERT_TRUE(partitions[0].output().count(z.node()) > 0);
ASSERT_EQ(partitions[0].input().size(), 4u);
ASSERT_TRUE(partitions[0].input().count(x.node()) > 0);
partitions[0].to_json()->writeto_fpath(
output_file("TestSubGraphExtractor.MultipleReaders.json"));
}
TEST(TestSubGraphExtractor, Complicated) {
const size_t N = 16, C = 3, H = 768, W = 1280;
HostTensorGenerator<dtype::Uint8> gen;
auto graph = ComputingGraph::make();
/* h2d
|
v
astype(f32)
|
add(-128)
|
v
astype(q8)
|
v
conv1
|
v
astype(u4)
|
/ \
conv2 conv3 -> astype(q32) -> output
\ /
qadd
|
v
astype(q8)
/ \
deconv conv4
\ /
concat -> output */
auto h2d = opr::Host2DeviceCopy::make(*graph, gen({N, C, H, W}));
auto data = opr::TypeCvt::make(h2d, dtype::Float32());
auto sub_128 = data + (-128);
auto x = opr::TypeCvt::make(sub_128, dtype::QuantizedS8(1.f));
auto mkcvar = [&](const char* name, const TensorShape& shp,
const DType& dtype) {
return opr::TypeCvt::make(
opr::SharedDeviceTensor::make(*graph, *gen(shp)).rename(name),
dtype);
};
auto w1 = mkcvar("w1", {16, 3, 3, 3}, dtype::QuantizedS8(1.f));
auto b1 = mkcvar("b1", {1, 16, 1, 1}, dtype::QuantizedS32(1.f));
opr::ConvBias::Param param;
param.stride_h = param.stride_w = 2;
param.pad_h = param.pad_w = 1;
auto conv1 = opr::ConvBias::make(
x, w1, b1, param, {}, OperatorNodeConfig(dtype::QuantizedS8(1.f)));
conv1 = opr::TypeCvt::make(
conv1, dtype::Quantized4Asymm(1.f, static_cast<uint8_t>(8)));
auto w2 = mkcvar("w2", {16, 16, 3, 3}, dtype::QuantizedS4(1.f));
auto b2 = mkcvar("b2", {1, 16, 1, 1}, dtype::QuantizedS32(1.f));
auto conv2 = opr::ConvBias::make(conv1, w2, b2, param, {},
OperatorNodeConfig(dtype::Quantized4Asymm(
1.f, static_cast<uint8_t>(8))));
param.pad_h = param.pad_w = 0;
auto w3 = mkcvar("w3", {16, 16, 1, 1}, dtype::QuantizedS4(1.f));
auto b3 = mkcvar("b3", {1, 16, 1, 1}, dtype::QuantizedS32(1.f));
auto conv3 = opr::ConvBias::make(conv1, w3, b3, param, {},
OperatorNodeConfig(dtype::Quantized4Asymm(
1.f, static_cast<uint8_t>(8))));
auto conv3f = opr::TypeCvt::make(conv3, dtype::Float32());
auto qadd = opr::ElemwiseMultiType::make(
{conv2, conv3}, {opr::ElemwiseMultiType::Mode::QADD},
OperatorNodeConfig(
dtype::Quantized4Asymm(1.f, static_cast<uint8_t>(8))));
auto q8 = opr::TypeCvt::make(qadd, dtype::QuantizedS8(1.f));
auto w4 = mkcvar("w4", {16, 16, 3, 3}, dtype::QuantizedS8(1.f));
param.stride_h = param.stride_w = 1;
param.pad_h = param.pad_w = 1;
auto conv4 = opr::ConvBiasForward::make(
q8, w4, param, {}, OperatorNodeConfig(dtype::QuantizedS8(1.f)));
conv4 = opr::TypeCvt::make(conv4, dtype::Float32());
opr::Convolution::Param conv_param;
conv_param.stride_h = param.stride_w = 1;
conv_param.pad_h = param.pad_w = 0;
auto w5 = mkcvar("w4", {16, 16, 1, 1}, dtype::QuantizedS8(1.f));
auto deconv = opr::ConvolutionBackwardData::make(
w5, q8, conv_param, {},
OperatorNodeConfig(dtype::QuantizedS8(1.f)));
deconv = opr::TypeCvt::make(deconv, dtype::Float32());
auto z = opr::Concat::make({conv4, deconv}, 1);
using OprList = SubGraphExtractor::OprList;
static const OprList opr_list = {
opr::ConvBiasForward::typeinfo(),
opr::ConvolutionForward::typeinfo(),
opr::ConvolutionBackwardData::typeinfo(),
opr::ElemwiseMultiType::typeinfo(),
opr::Elemwise::typeinfo(),
opr::TypeCvt::typeinfo(),
opr::PoolingForward::typeinfo(),
opr::WarpPerspectiveForward::typeinfo(),
};
SubGraphExtractor extractor(opr_list);
auto partitions = extractor.extract({conv3f.node(), z.node()});
ASSERT_EQ(partitions.size(), 1u);
const char* prefix = "TestSubGraphExtractor.Complicated";
partitions[0].to_json()->writeto_fpath(
output_file(ssprintf("%s.json", prefix).c_str()));
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册