提交 70abe362 编写于 作者: L leilei_snow

add case process

上级 b8292613
...@@ -230,6 +230,20 @@ bool ValueToBool(const ValuePtr &v, bool *value) { ...@@ -230,6 +230,20 @@ bool ValueToBool(const ValuePtr &v, bool *value) {
return true; return true;
} }
bool BaseRefToInt(const ValuePtr &v, int *value) {
MS_EXCEPTION_IF_NULL(v);
if (v->isa<tensor::Tensor>()) {
auto tensor = v->cast<tensor::TensorPtr>();
(void)tensor->data_sync();
int *tensor_data = static_cast<int *>(tensor->data_c());
auto vb = tensor_data[0];
*value = vb;
return true;
}
MS_LOG(ERROR) << "Index must be tensor type.";
return false;
}
bool BaseRefToBool(const BaseRef &v, bool *value) { bool BaseRefToBool(const BaseRef &v, bool *value) {
if (utils::isa<ValuePtr>(v)) { if (utils::isa<ValuePtr>(v)) {
return ValueToBool(utils::cast<ValuePtr>(v), value); return ValueToBool(utils::cast<ValuePtr>(v), value);
......
...@@ -42,6 +42,7 @@ using TensorPtr = std::shared_ptr<Tensor>; ...@@ -42,6 +42,7 @@ using TensorPtr = std::shared_ptr<Tensor>;
py::object AnyToPyData(const Any &value); py::object AnyToPyData(const Any &value);
py::object BaseRefToPyData(const BaseRef &value); py::object BaseRefToPyData(const BaseRef &value);
bool BaseRefToBool(const BaseRef &in, bool *out); bool BaseRefToBool(const BaseRef &in, bool *out);
bool BaseRefToInt(const ValuePtr &v, int *value);
bool ValueToBool(const ValuePtr &in, bool *out); bool ValueToBool(const ValuePtr &in, bool *out);
py::object ValuePtrToPyData(const ValuePtr &value); py::object ValuePtrToPyData(const ValuePtr &value);
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
namespace mindspore { namespace mindspore {
namespace compile { namespace compile {
bool Backend::GetCond(const BaseRef &c, bool *const value) { return BaseRefToBool(c, value); } bool Backend::GetCond(const BaseRef &c, bool *const value) { return BaseRefToBool(c, value); }
bool Backend::GetIndex(const BaseRef &c, int *const value) { return BaseRefToInt(utils::cast<ValuePtr>(c), value); }
LinConvertResult MsBackend::GetMultiGraphRun(const FuncGraphPtr &g) { LinConvertResult MsBackend::GetMultiGraphRun(const FuncGraphPtr &g) {
// multi_graph merge to one, big graph have paramters in begin and only have one output // multi_graph merge to one, big graph have paramters in begin and only have one output
......
...@@ -46,6 +46,7 @@ class Backend { ...@@ -46,6 +46,7 @@ class Backend {
virtual void SimulateRun(FinalVMPtr, FuncGraphPtr) {} virtual void SimulateRun(FinalVMPtr, FuncGraphPtr) {}
virtual SwitchCondStatus SetSimuCond(const BaseRef &, bool) { return kCondOk; } virtual SwitchCondStatus SetSimuCond(const BaseRef &, bool) { return kCondOk; }
virtual bool GetCond(const BaseRef &c, bool *value); virtual bool GetCond(const BaseRef &c, bool *value);
virtual bool GetIndex(const BaseRef &c, int *value);
virtual void SetSwitchGraph() {} virtual void SetSwitchGraph() {}
virtual void SetSwitchActive(const BaseRef &, bool) {} virtual void SetSwitchActive(const BaseRef &, bool) {}
virtual void RecallGraphInput(const FuncGraphPtr &, const VectorRef &, const BaseRef &) {} virtual void RecallGraphInput(const FuncGraphPtr &, const VectorRef &, const BaseRef &) {}
......
...@@ -46,8 +46,9 @@ using TypedPrimitiveAbstractClosurePtr = std::shared_ptr<abstract::TypedPrimitiv ...@@ -46,8 +46,9 @@ using TypedPrimitiveAbstractClosurePtr = std::shared_ptr<abstract::TypedPrimitiv
std::vector<PrimitivePtr> nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch, std::vector<PrimitivePtr> nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch,
prim::kPrimMakeTuple, prim::kPrimBpropCut}; prim::kPrimMakeTuple, prim::kPrimBpropCut};
const std::vector<PrimitivePtr> &GetMsNonlinearOps() { const std::vector<PrimitivePtr> &GetMsNonlinearOps() {
static const std::vector<PrimitivePtr> ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch, static const std::vector<PrimitivePtr> ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial,
prim::kPrimBpropCut}; prim::kPrimSwitch, prim::kPrimMakeTuple,
prim::kPrimBpropCut, prim::kPrimSwitchLayer};
return ms_nonlinear_ops; return ms_nonlinear_ops;
} }
...@@ -187,6 +188,30 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string & ...@@ -187,6 +188,30 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &
std::reverse(result.begin(), result.end()); std::reverse(result.begin(), result.end());
return result; return result;
} }
bool IsSubGraph(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
auto &inputs = cnode->inputs();
if (inputs.empty()) {
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
}
AnfNodePtr fn = inputs[0];
MS_EXCEPTION_IF_NULL(fn);
if (!IsValueNode<Primitive>(fn)) {
return false;
}
auto node_prim = GetValueNode<PrimitivePtr>(fn);
if (node_prim->name() == prim::kPrimPartial->name()) {
return true;
}
} else if (IsValueNode<FuncGraph>(node)) {
return true;
}
return false;
}
} // namespace } // namespace
CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list) CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list)
...@@ -235,6 +260,15 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) { ...@@ -235,6 +260,15 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_enable_pynative_hook(true); ms_context->set_enable_pynative_hook(true);
} }
if (backend_->name() == kMsConvert && prim->name() == prim::kPrimMakeTuple->name()) {
if (inputs.size() < 2) {
return false;
}
auto ret = IsSubGraph(inputs[1]);
return ret;
}
return true; return true;
} }
} }
...@@ -466,6 +500,8 @@ int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node) ...@@ -466,6 +500,8 @@ int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node)
} else if (IsPrimitive(fn, prim::kPrimSwitch)) { } else if (IsPrimitive(fn, prim::kPrimSwitch)) {
AddSwitch(node); AddSwitch(node);
AddSinkSwitch(node); AddSinkSwitch(node);
} else if (IsPrimitive(fn, prim::kPrimSwitchLayer)) {
AddSwitchLayer(node);
} else if (IsPrimitive(fn, prim::kPrimMakeTuple)) { } else if (IsPrimitive(fn, prim::kPrimMakeTuple)) {
AddMakeTuple(node); AddMakeTuple(node);
} else { } else {
...@@ -622,6 +658,17 @@ void CompileGraph::AddSwitch(const CNodePtr &node) { ...@@ -622,6 +658,17 @@ void CompileGraph::AddSwitch(const CNodePtr &node) {
AddInst(Instruction::kSwitch, args); AddInst(Instruction::kSwitch, args);
} }
void CompileGraph::AddSwitchLayer(const CNodePtr &node) {
auto inputs = node->inputs();
if (inputs.size() != 3) {
MS_LOG(EXCEPTION) << "Switch layer must have index and branches.";
}
VectorRef args;
args.emplace_back(Ref(inputs[1]));
args.emplace_back(Ref(inputs[2]));
AddInst(Instruction::kSwitchLayer, args);
}
void CompileGraph::AddReturn(const CNodePtr &node) { void CompileGraph::AddReturn(const CNodePtr &node) {
VectorRef args; VectorRef args;
if (backend_->simu_flag()) { if (backend_->simu_flag()) {
......
...@@ -90,6 +90,7 @@ class CompileGraph { ...@@ -90,6 +90,7 @@ class CompileGraph {
void AddPartial(const CNodePtr &node); void AddPartial(const CNodePtr &node);
void AddMakeTuple(const CNodePtr &node); void AddMakeTuple(const CNodePtr &node);
void AddSwitch(const CNodePtr &node); void AddSwitch(const CNodePtr &node);
void AddSwitchLayer(const CNodePtr &node);
void AddReturn(const CNodePtr &node); void AddReturn(const CNodePtr &node);
void AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim); void AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim);
void AddInput(const AnfNodePtr &node); void AddInput(const AnfNodePtr &node);
......
...@@ -480,6 +480,35 @@ void FinalVM::InstSwitch(const VectorRef &args) { ...@@ -480,6 +480,35 @@ void FinalVM::InstSwitch(const VectorRef &args) {
MS_LOG(DEBUG) << "End"; MS_LOG(DEBUG) << "End";
} }
void FinalVM::InstSwitchLayer(const VectorRef &args) {
MS_LOG(DEBUG) << "Start";
const size_t args_size = 2;
if (args.size() != args_size) {
MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size()
<< ".";
return;
}
int idx = utils::cast<int>(args[0]);
VectorRef branches = utils::cast<VectorRef>(Ref(utils::cast<int>(args[1])));
int size = static_cast<int>(branches.size());
BaseRef index = Ref(idx);
int idx_value = 0;
if (!backend_->GetIndex(index, &idx_value)) {
MS_LOG(EXCEPTION) << "Not supported type to be casted to int.";
}
if (idx_value < 0) {
// Add support negative index range [-size, -1].
idx_value += size;
}
if (idx_value < 0 || idx_value >= size) {
MS_LOG(EXCEPTION) << __FUNCTION__ << " given index " << idx_value << " out of range.";
}
Push(branches[idx_value]);
MS_LOG(DEBUG) << "End";
}
void FinalVM::InstTuple(const VectorRef &args) { void FinalVM::InstTuple(const VectorRef &args) {
MS_LOG(DEBUG) << "Start"; MS_LOG(DEBUG) << "Start";
VectorRef tuple; VectorRef tuple;
......
...@@ -51,15 +51,17 @@ enum Instruction { ...@@ -51,15 +51,17 @@ enum Instruction {
kPush, kPush,
kPrim, kPrim,
kGraph, kGraph,
kPadStack kPadStack,
kSwitchLayer
}; };
using InstType = std::pair<Instruction, VectorRef>; using InstType = std::pair<Instruction, VectorRef>;
using InstSet = std::vector<InstType>; using InstSet = std::vector<InstType>;
using InstFunctionMap = std::map<Instruction, std::function<void(const VectorRef &)>>; using InstFunctionMap = std::map<Instruction, std::function<void(const VectorRef &)>>;
const std::vector<std::string> inst_str{"call", "tail_call", "return", "partial", "switch", "switch_return", "tuple", const std::vector<std::string> inst_str{"call", "tail_call", "return", "partial", "switch",
"input", "external", "push", "primitive", "graph", "pad_stack"}; "switch_return", "tuple", "input", "external", "push",
"primitive", "graph", "pad_stack", "switch_layer"};
class StructPartial : public Base { class StructPartial : public Base {
public: public:
// Initialize StructPartial. // Initialize StructPartial.
...@@ -114,6 +116,7 @@ class FinalVM { ...@@ -114,6 +116,7 @@ class FinalVM {
void InstExternal(const VectorRef &args); void InstExternal(const VectorRef &args);
void InstPushPrim(const VectorRef &args); void InstPushPrim(const VectorRef &args);
void InstSwitchReturn(const VectorRef &args); void InstSwitchReturn(const VectorRef &args);
void InstSwitchLayer(const VectorRef &args);
void set_insts(const InstSet &value) { insts_ = value; } void set_insts(const InstSet &value) { insts_ = value; }
BaseRef RunHook(const PrimitivePtr &prim, const VectorRef &arg); BaseRef RunHook(const PrimitivePtr &prim, const VectorRef &arg);
...@@ -157,7 +160,7 @@ class FinalVM { ...@@ -157,7 +160,7 @@ class FinalVM {
{Instruction::kExternal, [this](const VectorRef &args) { InstExternal(args); }}, {Instruction::kExternal, [this](const VectorRef &args) { InstExternal(args); }},
{Instruction::kPrim, [this](const VectorRef &args) { InstPushPrim(args); }}, {Instruction::kPrim, [this](const VectorRef &args) { InstPushPrim(args); }},
{Instruction::kSwitchReturn, [this](const VectorRef &args) { InstSwitchReturn(args); }}, {Instruction::kSwitchReturn, [this](const VectorRef &args) { InstSwitchReturn(args); }},
}; {Instruction::kSwitchLayer, [this](const VectorRef &args) { InstSwitchLayer(args); }}};
std::map<std::string, py::object> _hook_grad; std::map<std::string, py::object> _hook_grad;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册