提交 2ee4fdad 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1165 new control sink entry

Merge pull request !1165 from zhoufeng/new-control-sink-entry
......@@ -78,6 +78,10 @@ const PrimitivePtr kPrimEmbed = std::make_shared<Primitive>("embed");
const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEmbed");
const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_instance");
const PrimitivePtr kPrimLabelGoto = std::make_shared<Primitive>("LabelGoto");
const PrimitivePtr kPrimLabelSwitch = std::make_shared<Primitive>("LabelSwitch");
const PrimitivePtr kPrimLabelSet = std::make_shared<Primitive>("LabelSet");
// Structure
const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal");
const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("string_concat");
......
......@@ -84,6 +84,10 @@ extern const PrimitivePtr kPrimEmbed;
extern const PrimitivePtr kPrimRefToEmbed;
extern const PrimitivePtr kPrimCreateInstance;
extern const PrimitivePtr kPrimLabelGoto;
extern const PrimitivePtr kPrimLabelSwitch;
extern const PrimitivePtr kPrimLabelSet;
// Structure
extern const PrimitivePtr kPrimStringEqual;
extern const PrimitivePtr kPrimStringConcat;
......
......@@ -269,13 +269,41 @@ bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePa
bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); }
static bool IsCtrlSink() {
auto ms_ctx = MsContext::GetInstance();
std::string device_target = ms_ctx->device_target();
if (device_target != kAscendDevice) {
return false;
}
if (!ms_ctx->enable_task_sink()) {
return false;
}
char *enable_ctrl_sink = std::getenv("ENABLE_CTRL_SINK");
if (enable_ctrl_sink == nullptr) {
return false;
}
std::string enable_ctrl_sink_str(enable_ctrl_sink);
if (enable_ctrl_sink_str == "0") {
return false;
}
return true;
}
bool TaskEmitAction(const ResourcePtr &res) {
if (res->func_graph() == nullptr) {
MS_LOG(EXCEPTION) << "TaskEmit args error";
}
FuncGraphPtr func_graph = res->func_graph();
auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>();
if (IsCtrlSink()) {
res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph));
return true;
}
std::vector<PrimitivePtr> cut_list = compile::nonlinear_ops;
if (bc_ptr->name() == kMsConvert) {
cut_list = compile::GetMsNonlinearOps();
......@@ -286,10 +314,31 @@ bool TaskEmitAction(const ResourcePtr &res) {
}
bool ExecuteAction(const ResourcePtr &res) {
if (res->results().count(kOutput) == 0 || !res->results()[kOutput].is<compile::FinalVMPtr>()) {
if (res->results().count(kOutput) == 0) {
MS_LOG(EXCEPTION) << "Execute args error";
}
if (IsCtrlSink()) {
if (!res->results()[kOutput].is<GraphId>()) {
MS_LOG(EXCEPTION) << "Execute args error";
}
auto graph_id = res->results()[kOutput].cast<GraphId>();
auto bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::MsBackend>>();
compile::VmEvalFuncPtr run =
std::make_shared<compile::VmEvalFunc>([&bc_ptr, graph_id](const VectorRef &args) -> BaseRef {
MS_LOG(INFO) << "Execute args size" << args.size();
auto outs = bc_ptr->RunGraph(graph_id, args);
MS_LOG(DEBUG) << "out size" << outs.size();
return outs[0];
});
res->results()[kOutput] = run;
return true;
}
if (!res->results()[kOutput].is<compile::FinalVMPtr>()) {
MS_LOG(EXCEPTION) << "Execute args error";
}
compile::FinalVMPtr vm = res->results()[kOutput].cast<compile::FinalVMPtr>();
if (vm == nullptr) {
MS_LOG(INFO) << "Call GE to Run the func_graph instead of VM";
......
......@@ -138,7 +138,7 @@ GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrL
return graph_id;
}
GraphId AscendSession::CompileGraph(const FuncGraphPtr &func_graph) {
GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
MS_LOG(INFO) << "start";
auto graph = ConstructKernelGraph(func_graph);
// split switch
......
......@@ -42,7 +42,7 @@ class AscendSession : public SessionBasic {
context_ = std::make_shared<Context>(kAscendDevice, device_id);
}
GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
GraphId CompileGraph(const FuncGraphPtr &func_graph) override;
GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) override;
void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override;
void BuildGraph(GraphId) override;
void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
......
......@@ -28,6 +28,7 @@
#include "ir/meta_tensor.h"
#include "utils/any.h"
#include "utils/base_ref.h"
#include "utils/contract.h"
#include "pynative/pynative_execute.h"
#include "device/kernel_info.h"
......@@ -57,7 +58,7 @@ class SessionBasic {
virtual ~SessionBasic() { summary_callback_ = nullptr; }
virtual GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0;
virtual GraphId CompileGraph(const FuncGraphPtr &) { return kInvalidGraphId; }
virtual GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; }
// build graph, used to handle multiple child graphs
virtual void BuildGraph(GraphId) {}
......
......@@ -327,5 +327,9 @@ MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_
sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
}
GraphId MsBackend::CompileGraph(NotNull<FuncGraphPtr> fg) { return sess_->CompileGraph(fg); }
VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); }
} // namespace compile
} // namespace mindspore
......@@ -22,6 +22,7 @@
#include <unordered_map>
#include <utility>
#include "utils/contract.h"
#include "ir/anf.h"
#include "vm/segment_runner.h"
#include "vm/vm.h"
......@@ -49,7 +50,7 @@ class Backend {
virtual void SetSwitchActive(const BaseRef &, bool) {}
virtual void RecallGraphInput(const FuncGraphPtr &, const VectorRef &, const BaseRef &) {}
virtual void SetGraphUserInputs(const FuncGraphPtr &, const FuncGraphPtr &, const AnfNodePtrList &) {}
virtual GraphId CompileGraph(NotNull<FuncGraphPtr> fg) { return kInvalidGraphId; }
void set_curr_switch(const BaseRef &value) {
curr_switch_ = value;
is_switch_call_ = true;
......@@ -104,6 +105,8 @@ class MsBackend : public Backend {
void Link(GraphId) override;
AnfNodePtr ConvertGraphInput(const FuncGraphPtr &, const AnfNodePtr &);
LinConvertResult GetMultiGraphRun(const FuncGraphPtr &g) override;
GraphId CompileGraph(NotNull<FuncGraphPtr> fg) override;
VectorRef RunGraph(GraphId graph_id, const VectorRef &args);
private:
session::SessionPtr sess_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册