From a27ce973ad6388028cc9e20afb6acc6ae8a51f31 Mon Sep 17 00:00:00 2001 From: changzherui Date: Sun, 14 Jun 2020 12:01:01 +0800 Subject: [PATCH] convert subgraph --- mindspore/ccsrc/transform/convert.cc | 195 +++++++++++++++++++- mindspore/ccsrc/transform/convert.h | 7 + mindspore/ccsrc/transform/op_adapter.h | 22 +++ mindspore/ccsrc/transform/op_adapter_base.h | 10 + mindspore/ccsrc/transform/op_declare.cc | 23 +++ mindspore/ccsrc/transform/op_declare.h | 8 + tests/ut/python/automl/case.py | 41 ++++ 7 files changed, 296 insertions(+), 10 deletions(-) create mode 100644 tests/ut/python/automl/case.py diff --git a/mindspore/ccsrc/transform/convert.cc b/mindspore/ccsrc/transform/convert.cc index 3f6b31303..3b4b54602 100644 --- a/mindspore/ccsrc/transform/convert.cc +++ b/mindspore/ccsrc/transform/convert.cc @@ -28,6 +28,7 @@ #include "utils/config_manager.h" #include "utils/convert_utils.h" #include "./common.h" +#include "utils/context/ms_context.h" namespace mindspore { namespace transform { @@ -205,6 +206,7 @@ const char kNameRange[] = "Range"; const char kNameSquareSumAll[] = "SquareSumAll"; const char kNameAscendQuant[] = "AscendQuant"; const char kNameAscendDequant[] = "AscendDequant"; +const char kNameCase[] = "Case"; // -----------------OpAdapter initialization-------------- std::unordered_map &DfGraphConvertor::get_adpt_map() { @@ -411,7 +413,8 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {string(kNameRange), ADPT_DESC(RangeD)}, {string(kNameSquareSumAll), ADPT_DESC(SquareSumAll)}, {string(kNameAscendQuant), ADPT_DESC(AscendQuant)}, - {string(kNameAscendDequant), ADPT_DESC(AscendDequant)}}; + {string(kNameAscendDequant), ADPT_DESC(AscendDequant)}, + {string(kNameCase), ADPT_DESC(Case)}}; #ifdef ENABLE_GE adpt_map[string(kNamePrint)] = ADPT_DESC(Print); adpt_map[string(kNameApplyAdam)] = ADPT_DESC(ApplyAdamD); @@ -433,13 +436,32 @@ PrimType GetCNodeFuncType(const CNodePtr cnode) { return kPrimTypeUnknown; } +bool IsCaseNode(const CNodePtr node) { + if (!node->inputs().empty() && node->input(0)->isa() && + GetCNodeFuncName(node->input(0)->cast()) == "switch_layer") { + return true; + } + return false; +} + +std::string GetCNodeTargetFuncName(const CNodePtr cnode) { + if (IsCaseNode(cnode)) { + return string(kNameCase); + } + auto name = GetCNodeFuncName(cnode); + if (name == "switch_layer") { + name = ""; + } + return name; +} + OpAdapterPtr DfGraphConvertor::FindAdapter(const AnfNodePtr node, bool train) { if (node->isa()) { auto cnode = node->cast(); std::string name = kNameCustomOp; if (!IsCustomCNode(cnode)) { - name = GetCNodeFuncName(cnode); + name = GetCNodeTargetFuncName(cnode); } auto it_adpt = get_adpt_map().find(name); @@ -957,7 +979,7 @@ void DfGraphConvertor::TraceOutput(const AnfNodePtr node) { auto c = anf_out->cast(); std::string name = ""; if (anf_out->isa()) { - name = GetCNodeFuncName(c); + name = GetCNodeTargetFuncName(c); } if (name == "make_tuple") { @@ -1029,6 +1051,99 @@ void SetupDatasetIterGetNextNode(const OperatorPtr &op) { return; } +void DfGraphConvertor::SetSubgraph(AnfNodePtr node) { + if (!node->isa()) { + return; + } + auto cnode = node->cast(); + if (!IsCaseNode(cnode)) { + return; + } + std::vector case_inputs; + for (size_t i = 1; i < cnode->inputs().size(); i++) { + case_inputs.emplace_back(cnode->input(i)); + } + std::shared_ptr> branches = std::make_shared>(); + auto bnode = cnode->input(0)->cast()->input(2)->cast(); + + for (size_t i = 1; i < bnode->inputs().size(); i++) { + auto branch_node = bnode->input(i)->cast(); + for (size_t j = 2; j < branch_node->inputs().size(); j++) { + if (std::find(case_inputs.begin(), case_inputs.end(), branch_node->input(j)) == case_inputs.end()) { + case_inputs.emplace_back(branch_node->input(j)); + } + } + } + + for (size_t i = 1; i < bnode->inputs().size(); i++) { + ProcessSubgraph(bnode->input(i), case_inputs); + } + + for (size_t i = 1; i < bnode->inputs().size(); i++) { + branches->emplace_back(branches_map_[bnode->input(i).get()]); + } + + if (op_cache_.find(node.get()) == op_cache_.end()) { + return; + } + + OpAdapterPtr adpt = FindAdapter(node, training_); + if (nullptr == adpt) { + MS_LOG(DEBUG) << "Not found adapter"; + return; + } + + OperatorPtr op = Convert(node); + adpt->setSubgraph(op, 0, branches); + return; +} + +void DfGraphConvertor::GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node) { + std::vector case_inputs; + for (size_t i = 1; i < node->inputs().size(); i++) { + case_inputs.emplace_back(node->input(i)); + } + std::shared_ptr> branches = std::make_shared>(); + auto bnode = input_node->input(2)->cast(); + + for (size_t i = 1; i < bnode->inputs().size(); i++) { + auto branch_node = bnode->input(i)->cast(); + for (size_t j = 2; j < branch_node->inputs().size(); j++) { + if (std::find(case_inputs.begin(), case_inputs.end(), branch_node->input(j)) == case_inputs.end()) { + case_inputs.emplace_back(branch_node->input(j)); + } + } + } + + const size_t case_index = 1; + const size_t make_tuple_index = 2; + + AnfNodePtr case_index_iter = input_node->input(case_index); + AnfNodePtr make_tuple_iter = input_node->input(make_tuple_index); + auto make_tuple_node = make_tuple_iter->cast(); + std::shared_ptr> tuple_items = std::make_shared>(); + + for (size_t i = 0; i < case_inputs.size(); i++) { + auto item = case_inputs[i]; + auto op = Convert(item); + if (op != nullptr) { + tuple_items->emplace_back(OutHandler(op, "")); + } else if (out_handle_cache_.find(item.get()) != out_handle_cache_.end()) { + tuple_items->push_back(out_handle_cache_[item.get()]); + } else { + MS_LOG(WARNING) << "This anf node is not supported as a case input: " << item->ToString(); + continue; + } + } + + tuple_out_handle_cache_[make_tuple_node.get()] = tuple_items; + + std::shared_ptr> case_input_items = std::make_shared>(); + case_input_items->emplace_back(case_index_iter); + case_input_items->emplace_back(make_tuple_iter); + case_input_handle_cache_[node.get()] = case_input_items; +} + DfGraphConvertor &DfGraphConvertor::BuildGraph() { SetupDatasetIterGetNextNode(dataset_iter_getnext_); @@ -1036,6 +1151,16 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() { return *this; } + // Case node set input. + std::vector nodes = ::mindspore::TopoSort(anf_graph_->get_return()); + for (auto &it : nodes) { + if (it->isa() && IsCaseNode(it->cast())) { + auto node = it->cast(); + auto input_node = node->input(0)->cast(); + GetCaseNodeInput(node, input_node); + } + } + // update tuple_out_handle_cache_ for (auto it : tuple_out_handle_cache_) { std::size_t len = it.second->size(); @@ -1056,10 +1181,11 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() { // set up dependices MS_LOG(DEBUG) << "set up dependices"; - std::vector nodes = ::mindspore::TopoSort(anf_graph_->get_return()); + nodes = ::mindspore::TopoSort(anf_graph_->get_return()); for (auto &it : nodes) { SetNodeInput(it); SetOpControlInput(it); + SetSubgraph(it); UpdateOpDesc(it); } @@ -1075,6 +1201,18 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() { inputs.push_back(*dataset_iter_getnext_); } else { auto params = anf_graph_->parameters(); + if (use_inputs_) { + params = inputs_; + auto anf_params = anf_graph_->parameters(); + for (size_t i = 0; i < params.size(); i++) { + for (size_t j = 0; j < anf_params.size(); j++) { + if (params[i]->ToString() == anf_params[j]->ToString()) { + params[i] = anf_params[j]; + } + } + } + } + int index = 0; for (auto &it : params) { auto name = std::static_pointer_cast(it)->name(); @@ -1185,10 +1323,21 @@ const std::vector trans_var_list = {string(kNameAssign), string(kNa void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) { OperatorPtr src = Convert(node); + int case_flag = 0; auto &inputs = node->inputs(); - for (size_t i = 1; i < inputs.size(); i++) { + size_t input_size = inputs.size(); + if (case_input_handle_cache_.find(node.get()) != case_input_handle_cache_.end()) { + case_flag = 1; + input_size = case_input_handle_cache_[node.get()]->size() + 1; + } + + for (size_t i = 1; i < input_size; i++) { auto pred = inputs[i]; - while (pred->isa() && GetCNodeFuncName(pred->cast()) == "Depend") { + if (case_flag != 0) { + pred = case_input_handle_cache_[node.get()]->at(i - 1); + } + + while (pred->isa() && GetCNodeTargetFuncName(pred->cast()) == "Depend") { pred = pred->cast()->input(1); } // skip the None input @@ -1196,7 +1345,7 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node continue; } // transform "Const" op to "Variable" op when the next node is "Assign" op. - std::string c_name = GetCNodeFuncName(node); + std::string c_name = GetCNodeTargetFuncName(node); auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name); if (!training_ && pos != trans_var_list.end() && pred->isa()) { std::string name = std::static_pointer_cast(pred)->name(); @@ -1220,7 +1369,7 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node if (it != out_handle_cache_.end()) { int ret = adpt->setInput(src, SizeToInt(i), it->second); if (ret == 0) { - if (pred->isa() && GetCNodeFuncName(pred->cast()) == "tuple_getitem") { + if (pred->isa() && GetCNodeTargetFuncName(pred->cast()) == "tuple_getitem") { compute_sout_ << op_draw_name_[pred->cast()->input(1).get()] << " -> " << op_draw_name_[node.get()] << ":" << i << endl; } else if (pred->isa()) { @@ -1278,6 +1427,23 @@ void DfGraphConvertor::SetNodeInput(const AnfNodePtr node) { DfGraphConvertor::SetOpInput(adpt, cnode); } +void DfGraphConvertor::ProcessSubgraph(AnfNodePtr node, const std::vector &inputs) { + if (!node->isa() || GetCNodeFuncName(node->cast()) != "Partial") { + return; + } + auto graph_node = node->cast()->input(1)->cast(); + FuncGraphPtr anf_graph = graph_node->value()->cast(); + DfGraphConvertor convertor(anf_graph); + convertor.use_inputs_ = true; + convertor.inputs_ = inputs; + (void)convertor.ConvertAllNode().BuildGraph(); + std::string name = graph_node->ToString() + "_ge_graph.dot"; + if (MsContext::GetInstance()->save_graphs_flag()) { + convertor.DrawComputeGraph(name); + } + branches_map_[node.get()] = *(convertor.df_graph_); +} + // Update GE op's shape and type info void DfGraphConvertor::UpdateOpDesc(const AnfNodePtr node) { if (nullptr == node || !node->isa()) { @@ -1348,6 +1514,7 @@ void DfGraphConvertor::ConvertMakeTuple(const CNodePtr node) { } } + MS_LOG(WARNING) << "ConvertMakeTuple: " << node.get() << " " << tuple_items->size(); tuple_out_handle_cache_[node.get()] = tuple_items; } @@ -1711,6 +1878,14 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node) return false; } + if (name == "" && GetCNodeFuncName(node) == "switch_layer") { + return false; + } + + if (name == "Partial") { + return false; + } + // make_tuple is used for a dynamic_input, convert it to a vector of OutHandlers if (name == "make_tuple") { ConvertMakeTuple(node); @@ -1732,7 +1907,7 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node) } OperatorPtr DfGraphConvertor::ConvertCNode(const CNodePtr node) { - std::string name = GetCNodeFuncName(node); + std::string name = GetCNodeTargetFuncName(node); if (!CheckCNode(name, node)) { return nullptr; } @@ -1879,7 +2054,7 @@ void DfGraphConvertor::DrawCNode(const CNodePtr node, const OpAdapterPtr adpt) { } compute_sout_ << "\"" << node->ToString() - << ":" << GetCNodeFuncName(node) << "\"" << endl; + << ":" << GetCNodeTargetFuncName(node) << "\"" << endl; // print attrs' values auto atts = adpt->GetAttrsFromDrawGraph(); diff --git a/mindspore/ccsrc/transform/convert.h b/mindspore/ccsrc/transform/convert.h index 2f6c9bb0a..cca0371c2 100644 --- a/mindspore/ccsrc/transform/convert.h +++ b/mindspore/ccsrc/transform/convert.h @@ -201,6 +201,7 @@ class DfGraphConvertor { OperatorPtr ConvertParameter(AnfNodePtr node); Status TryConvertValueNodeToMultiConst(const ValueNodePtr node); OperatorPtr ConvertValueNode(ValueNodePtr node); + void GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node); void ConvertTupleGetItem(const CNodePtr node); void GetDependOnParameterUse(const CNodePtr &node, const AnfNodePtr &src_node, const AnfNodePtr &dest_node, const std::shared_ptr> &src_ops_list, @@ -217,6 +218,8 @@ class DfGraphConvertor { void SetNodeInput(AnfNodePtr node); void SetOpControlInput(const AnfNodePtr node); void UpdateOpDesc(AnfNodePtr node); + void SetSubgraph(AnfNodePtr node); + void ProcessSubgraph(AnfNodePtr node, const std::vector &inputs); void BuildSaveCheckpointGraph(); void DrawCNode(const CNodePtr node, const OpAdapterPtr adpt); void UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const; @@ -228,22 +231,26 @@ class DfGraphConvertor { std::shared_ptr save_ckp_graph_{nullptr}; std::shared_ptr restore_ckp_graph_{nullptr}; std::shared_ptr broadcast_graph_{nullptr}; + std::unordered_map branches_map_; std::unordered_map op_cache_; std::unordered_map> control_depend_cache_; /* record "tuple_getitem"<->"out_handler" mapping */ std::unordered_map out_handle_cache_; /* record "make_tuple"<->"out_handler vector" mapping */ std::unordered_map>> tuple_out_handle_cache_; + std::unordered_map>> case_input_handle_cache_; std::unordered_map params_; std::unordered_map vars_; std::vector> graph_outputs_; std::vector graph_const_inputs_; std::vector init_ops_; std::vector broadcast_ops_; + std::vector inputs_; OperatorPtr dataset_iter_getnext_; Status error_ = SUCCESS; bool training_ = false; bool distribute_ = false; + bool use_inputs_ = false; }; } // namespace transform } // namespace mindspore diff --git a/mindspore/ccsrc/transform/op_adapter.h b/mindspore/ccsrc/transform/op_adapter.h index ae678606a..caac4258d 100644 --- a/mindspore/ccsrc/transform/op_adapter.h +++ b/mindspore/ccsrc/transform/op_adapter.h @@ -164,6 +164,25 @@ class OpAdapter : public BaseOpAdapter { const std::unordered_map &getInputAttrMap() override { return input_attr_map_; } const std::unordered_map &getDynInputMap() override { return dyn_input_map_; } const std::unordered_map &getOutputMap() override { return output_map_; } + const std::unordered_map &getDynSubgraphMap() override { return dyn_subgraph_map_; } + + Status SetOpSubgraphFunc(const OperatorPtr &op, int index, std::shared_ptr> branches) { + MS_EXCEPTION_IF_NULL(op); + auto it = dyn_subgraph_map_.find(index); + if (it != dyn_subgraph_map_.end()) { + auto size = branches->size(); + it->second.create_dyn_subgraph(op, static_cast(size)); + for (size_t i = 0; i < size; i++) { + it->second.set_subgraph(op, static_cast(i), std::make_shared((*branches)[i])); + } + return SUCCESS; + } + return NOT_FOUND; + } + + int setSubgraph(const OperatorPtr &op, int index, std::shared_ptr> branches) override { + return static_cast(SetOpSubgraphFunc(op, index, branches)); + } Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OperatorPtr &input) { MS_EXCEPTION_IF_NULL(op); @@ -855,6 +874,7 @@ class OpAdapter : public BaseOpAdapter { static const std::unordered_map dyn_input_map_; static const std::unordered_map output_map_; static const std::unordered_map dyn_output_map_; + static const std::unordered_map dyn_subgraph_map_; static const std::unordered_map attr_map_; static const std::unordered_map enum_map_; // convert input from anf graph to Attr in Operators @@ -874,6 +894,8 @@ const std::unordered_map OpAdapter::output_map_; template const std::unordered_map OpAdapter::dyn_output_map_; template +const std::unordered_map OpAdapter::dyn_subgraph_map_; +template const std::unordered_map OpAdapter::attr_map_; template const std::unordered_map OpAdapter::enum_map_; diff --git a/mindspore/ccsrc/transform/op_adapter_base.h b/mindspore/ccsrc/transform/op_adapter_base.h index 01f96e251..956b33c42 100644 --- a/mindspore/ccsrc/transform/op_adapter_base.h +++ b/mindspore/ccsrc/transform/op_adapter_base.h @@ -88,6 +88,8 @@ using DynInputOpFunc = std::function; using UpdateOutputDescFunc = std::function; using CreateDynOutputOpFunc = std::function; +using CreateDynSubGraphFunc = std::function; +using DynSubGraphFunc = std::function; struct AttrDesc { std::string name; @@ -108,6 +110,12 @@ struct DynInputDesc { DynInputHandleFunc set_handle; }; +struct DynSubGraphDesc { + std::string name; + CreateDynSubGraphFunc create_dyn_subgraph; + DynSubGraphFunc set_subgraph; +}; + struct OutputDesc { std::string name; UpdateOutputDescFunc update_out_desc; @@ -123,6 +131,7 @@ class BaseOpAdapter { virtual ~BaseOpAdapter() {} virtual OperatorPtr generate(const AnfNodePtr &anf) = 0; virtual OperatorPtr generate(const std::string &type) { return std::make_shared(type); } + virtual int setSubgraph(const OperatorPtr &op, int index, std::shared_ptr> branches) = 0; virtual int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) = 0; virtual int setInput(const OperatorPtr &op, int index, const OutHandler &handle) = 0; virtual int setInput(const OperatorPtr &op, int index, @@ -146,6 +155,7 @@ class BaseOpAdapter { virtual const std::unordered_map &getInputAttrMap() = 0; virtual const std::unordered_map &getDynInputMap() = 0; virtual const std::unordered_map &getOutputMap() = 0; + virtual const std::unordered_map &getDynSubgraphMap() = 0; void AddAttrToDrawGraph(const std::string &attr_str) { attrs_vec_.push_back(attr_str); } const std::vector &GetAttrsFromDrawGraph() const { return attrs_vec_; } void clearAttrVect() { attrs_vec_.clear(); } diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index cac526f1f..0dc9089c6 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -64,6 +64,22 @@ namespace transform { } \ } +#define DYN_SUBGRAPH_MAP(T) \ + template <> \ + const std::unordered_map OpAdapter::dyn_subgraph_map_ +#define DYN_SUBGRAPH_DESC(name) \ + { \ +#name, \ + [](const OperatorPtr op, unsigned int num) { \ + auto p = std::static_pointer_cast(op); \ + (void)p->create_dynamic_subgraph_##name(num); \ + }, \ + [](const OperatorPtr op, unsigned int index, const DfGraphPtr graph) { \ + auto p = std::static_pointer_cast(op); \ + (void)p->set_dynamic_subgraph_builder_##name(index, [graph](){return *graph;}); \ + } \ + } + #define ATTR_MAP(T) \ template <> \ const std::unordered_map OpAdapter::attr_map_ @@ -841,6 +857,13 @@ INPUT_ATTR_MAP(Cast) = {{2, ATTR_DESC(dst_type, AnyTraits())}}; ATTR_MAP(Cast) = EMPTY_ATTR_MAP; OUTPUT_MAP(Cast) = {{0, OUTPUT_DESC(y)}}; +// Case +INPUT_MAP(Case) = {{1, INPUT_DESC(branch_index)}}; +DYN_INPUT_MAP(Case) = {{2, DYN_INPUT_DESC(input)}}; +ATTR_MAP(Case) = EMPTY_ATTR_MAP; +DYN_OUTPUT_MAP(Case) = {{0, DYN_OUTPUT_DESC(output)}}; +DYN_SUBGRAPH_MAP(Case) = {{0, DYN_SUBGRAPH_DESC(branches)}}; + // Reciprocal INPUT_MAP(Reciprocal) = {{1, INPUT_DESC(x)}}; ATTR_MAP(Reciprocal) = EMPTY_ATTR_MAP; diff --git a/mindspore/ccsrc/transform/op_declare.h b/mindspore/ccsrc/transform/op_declare.h index f64dc7b67..ad0371c28 100755 --- a/mindspore/ccsrc/transform/op_declare.h +++ b/mindspore/ccsrc/transform/op_declare.h @@ -46,6 +46,10 @@ namespace transform { template <> \ const std::unordered_map OpAdapter::dyn_input_map_; +#define DECLARE_OP_USE_DYN_SUBGRAPH(T) \ + template <> \ + const std::unordered_map OpAdapter::dyn_subgraph_map_; + #define DECLARE_OP_USE_DYN_OUTPUT(T) \ template <> \ const std::unordered_map OpAdapter::dyn_output_map_; @@ -232,6 +236,10 @@ DECLARE_OP_USE_OUTPUT(RealDiv) DECLARE_OP_ADAPTER(Cast) DECLARE_OP_USE_INPUT_ATTR(Cast) DECLARE_OP_USE_OUTPUT(Cast) +DECLARE_OP_ADAPTER(Case) +DECLARE_OP_USE_DYN_INPUT(Case) +DECLARE_OP_USE_DYN_SUBGRAPH(Case) +DECLARE_OP_USE_DYN_OUTPUT(Case) DECLARE_OP_ADAPTER(Reciprocal) DECLARE_OP_USE_OUTPUT(Reciprocal) DECLARE_OP_ADAPTER(Neg) diff --git a/tests/ut/python/automl/case.py b/tests/ut/python/automl/case.py new file mode 100644 index 000000000..745376277 --- /dev/null +++ b/tests/ut/python/automl/case.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test case.""" +import numpy as np + +import mindspore +import mindspore.nn as nn +from mindspore import Tensor, context + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 3, 3) + self.conv2 = nn.Conv2d(1, 3, 5, has_bias=True) + self.layers = (self.conv1, self.conv2) + + def construct(self, x, index): + x = self.layers[index](x) + y = self.conv1(x) + return x + y + + +def test_case(): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + net = Net() + data = Tensor(np.ones((1, 1, 224, 224)), mindspore.float32) + idx = Tensor(1, mindspore.int32) + net(data, idx) -- GitLab