提交 3c124889 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!901 Refactor vm module for multigraph sink.

Merge pull request !901 from rick_sanchez/master
......@@ -564,42 +564,67 @@ AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodeP
return create_parameter_from_cnode(output_item_with_index.first, output_item_with_index.second);
}
void AscendSession::SetFinalGraphOutput(const BaseRef &output) {
auto final_graph = GetGraph(final_graph_id_);
MS_EXCEPTION_IF_NULL(final_graph);
if (!utils::isa<AnfNodePtr>(output)) {
if (!utils::isa<ValuePtr>(output)) {
MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString();
}
auto value_ptr = utils::cast<ValuePtr>(output);
auto value_node = NewValueNode(value_ptr);
MS_EXCEPTION_IF_NULL(value_node);
auto kernel_info = std::make_shared<device::KernelInfo>();
value_node->set_kernel_info(kernel_info);
value_node->set_abstract(abstract::FromValue(value_ptr));
final_graph->set_output(final_graph->NewCNode({NewValueNode(prim::kPrimMakeTuple), value_node}));
final_graph->set_executable(false);
MS_LOG(INFO) << "Not anf output[" << output.ToString() << "]";
return;
}
void AscendSession::SetFinalGraphOutput(const AnfNodePtr &node) {
// get the backend anf node related to the output node of front
auto output_anf_node = utils::cast<AnfNodePtr>(output);
auto output_from_graph_id = GetGraphIdByNode(output_anf_node);
auto output_from_graph_id = GetGraphIdByNode(node);
auto output_from_graph = GetGraph(output_from_graph_id);
MS_EXCEPTION_IF_NULL(output_anf_node);
MS_LOG(INFO) << "Set the output[" << output_anf_node->DebugString() << "] of graph[" << output_from_graph_id
MS_EXCEPTION_IF_NULL(node);
MS_LOG(INFO) << "Set the output[" << node->DebugString() << "] of graph[" << output_from_graph_id
<< "] to final graph";
MS_EXCEPTION_IF_NULL(output_from_graph);
auto final_graph = GetGraph(final_graph_id_);
MS_EXCEPTION_IF_NULL(final_graph);
// if output is from final graph,it remarks no child graph exist
if (final_graph_id_ == output_from_graph_id) {
MS_LOG(INFO) << "No child graph,output is " << output_anf_node->DebugString();
final_graph->set_output(ConstructOutput({output_anf_node}, final_graph));
MS_LOG(INFO) << "No child graph,output is " << node->DebugString();
final_graph->set_output(ConstructOutput({node}, final_graph));
final_graph->set_executable(false);
return;
}
final_graph->set_output(output_from_graph->output());
}
void AscendSession::SetFinalGraphOutput(const ValuePtr &value) {
auto value_node = NewValueNode(value);
auto kernel_info = std::make_shared<device::KernelInfo>();
value_node->set_kernel_info(kernel_info);
value_node->set_abstract(abstract::FromValue(value));
auto final_graph = GetGraph(final_graph_id_);
MS_EXCEPTION_IF_NULL(final_graph);
final_graph->set_output(final_graph->NewCNode({NewValueNode(prim::kPrimMakeTuple), value_node}));
final_graph->set_executable(false);
MS_LOG(INFO) << "Not anf output[" << value->ToString() << "]";
}
void AscendSession::SetFinalGraphOutput(const VectorRef &vec_output) {
for (auto &output : vec_output) {
if (utils::isa<AnfNodePtr>(output)) {
auto output_anf_node = utils::cast<AnfNodePtr>(output);
SetFinalGraphOutput(output_anf_node);
} else if (utils::isa<ValuePtr>(output)) {
auto value = utils::cast<ValuePtr>(output);
SetFinalGraphOutput(value);
} else {
MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString();
}
}
}
void AscendSession::SetFinalGraphOutput(const BaseRef &output) {
if (utils::isa<AnfNodePtr>(output)) {
auto output_anf_node = utils::cast<AnfNodePtr>(output);
SetFinalGraphOutput(output_anf_node);
} else if (utils::isa<ValuePtr>(output)) {
auto value = utils::cast<ValuePtr>(output);
SetFinalGraphOutput(value);
} else if (utils::isa<VectorRef>(output)) {
auto vec_output = utils::cast<VectorRef>(output);
SetFinalGraphOutput(vec_output);
} else {
MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString();
}
}
KernelGraphPtr AscendSession::GetGraph(mindspore::GraphId graph_id) {
auto it = graphs_.find(graph_id);
if (it == graphs_.end()) {
......
......@@ -88,6 +88,10 @@ class AscendSession : public SessionBasic {
size_t SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index);
size_t SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index);
void SetFinalGraphOutput(const AnfNodePtr &node);
void SetFinalGraphOutput(const ValuePtr &value);
void SetFinalGraphOutput(const VectorRef &vec_output);
// merge execution order list of child graphs
void MergeGraphExecOrder();
// insert assion op to sync data bettween different graphs
......
......@@ -243,7 +243,7 @@ void CompileGraph::AddSinkSwitch(const CNodePtr &node) {
AddInst(Instruction::kCall, args);
args.clear();
args.emplace_back(true);
args.emplace_back(node->input(1));
AddInst(Instruction::kSwitchReturn, args);
args.clear();
......
......@@ -141,17 +141,31 @@ void FinalVM::Popsp() {
}
}
void FinalVM::PushStatus(bool is_switch_call) { ret_status_.push(is_switch_call); }
bool FinalVM::PopStatus() {
if (ret_status_.empty()) {
return false;
}
bool status = ret_status_.top();
ret_status_.pop();
return status;
}
void FinalVM::DoJmp(const BaseRef &jmp_orig) {
MS_LOG(DEBUG) << "Start";
BaseRef jmp = jmp_orig;
if (backend_->simu_flag()) {
bool is_switch_call = false;
if (utils::isa<StructSimuSwitch>(jmp)) { // need to inherit from Base
MS_LOG(DEBUG) << "Start jump StructSwitch";
auto simu_value = utils::cast<std::shared_ptr<StructSimuSwitch>>(jmp);
jmp = simu_value->fn_;
backend_->set_curr_switch(simu_value->value_);
is_switch_call = true;
}
PushStatus(is_switch_call);
}
if (utils::isa<StructPartial>(jmp)) { // need to inherit from Base
......@@ -255,6 +269,13 @@ void FinalVM::InstSwitchReturn(const VectorRef &args) {
MS_LOG(ERROR) << __FUNCTION__ << " requires one parameter, while the input size is " << args.size() << ".";
return;
}
auto rv = Ref(-1);
if (utils::isa<AnfNodePtr>(rv) || utils::isa<VectorRef>(rv)) {
auto &c = args[0];
cond_out_[c] = rv;
}
Pop(1);
Popsp();
}
......@@ -272,8 +293,20 @@ void FinalVM::InstReturn(const VectorRef &args) {
int height = utils::cast<int>(args[1]);
auto rv = Ref(rpos);
if (backend_->simu_flag() && backend_->is_switch_call()) {
backend_->SetSwitchGraph();
if (backend_->simu_flag()) {
auto c = backend_->curr_switch();
auto status = PopStatus();
if (status) {
auto iter = cond_out_.find(c);
if (iter != cond_out_.end()) {
rv = MergeArgs(rv, iter->second);
cond_out_.erase(iter);
}
}
if (backend_->is_switch_call()) {
backend_->SetSwitchGraph();
}
}
Pop(height);
......@@ -383,21 +416,30 @@ void FinalVM::MergeJmpArgs(const BaseRef &jmp, const BaseRef &c) {
for (size_t i = 0; i < new_args.size(); ++i) {
auto &old_arg = old_args[i];
auto &new_arg = new_args[i];
if (utils::isa<VectorRef>(old_arg)) {
auto old_vec_ref = utils::cast<VectorRef>(old_arg);
if (utils::isa<VectorRef>(new_arg)) {
auto new_vec_ref = utils::cast<VectorRef>(new_arg);
std::copy(new_vec_ref.begin(), new_vec_ref.end(), std::back_inserter(old_vec_ref));
}
new_arg = old_vec_ref;
} else if (utils::isa<VectorRef>(new_arg)) {
auto new_vec_ref = utils::cast<VectorRef>(new_arg);
new_vec_ref.push_back(old_arg);
new_arg = new_vec_ref;
new_arg = MergeArgs(old_arg, new_arg);
}
}
BaseRef FinalVM::MergeArgs(const BaseRef &first, const BaseRef &second) {
MS_LOG(DEBUG) << __FUNCTION__ << ": " << first.ToString() << ", " << second.ToString();
if (utils::isa<VectorRef>(first)) {
auto old_vec_ref = utils::cast<VectorRef>(first);
if (utils::isa<VectorRef>(second)) {
auto new_vec_ref = utils::cast<VectorRef>(second);
std::copy(new_vec_ref.begin(), new_vec_ref.end(), std::back_inserter(old_vec_ref));
} else {
new_arg = VectorRef({new_arg, old_arg});
old_vec_ref.push_back(second);
}
return old_vec_ref;
}
if (utils::isa<VectorRef>(second)) {
auto new_vec_ref = utils::cast<VectorRef>(second);
new_vec_ref.push_back(first);
return new_vec_ref;
}
return VectorRef({first, second});
}
void FinalVM::InstRealSwitch(const VectorRef &args) {
......
......@@ -125,17 +125,22 @@ class FinalVM {
void Popp();
void Pushsp();
void Popsp();
void PushStatus(bool is_switch_call);
bool PopStatus();
void DoJmp(const BaseRef &jmp);
void MergeJmpArgs(const BaseRef &jmp, const BaseRef &c);
BaseRef MergeArgs(const BaseRef &first, const BaseRef &second);
private:
InstSet insts_;
std::deque<BaseRef> insts_stack_;
std::stack<int> retp_;
std::stack<int> retsp_;
std::stack<bool> ret_status_;
int pc_;
int sp_;
std::unordered_map<BaseRef, BaseRef, BaseRefHash> cond_jmp_;
std::unordered_map<BaseRef, BaseRef, BaseRefHash> cond_out_;
BackendPtr backend_;
const InstFunctionMap inst_function_map = {
{Instruction::kCall, [this](const VectorRef &args) { InstCall(args); }},
......
......@@ -26,6 +26,7 @@ from mindspore.ops import operations as P
def setup_module(module):
context.set_context(mode = context.PYNATIVE_MODE, device_target = "Ascend")
c1 = Tensor([2], mstype.int32)
c2 = Tensor([14], mstype.int32)
c3 = Tensor([1], mstype.int32)
......@@ -149,6 +150,10 @@ def test_if_by_if():
assert output == expect
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_if_in_if():
output = if_in_if(c1, c2, c3)
expect = Tensor([7], mstype.int32)
......@@ -194,6 +199,7 @@ def test_while_by_while_in_while():
expect = Tensor([350], mstype.int32)
assert output == expect
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册