提交 97ba539e 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1269 fix issue of loadding control input tensors failed in control sink mode

Merge pull request !1269 from wenchunjiang/fix_task_sink_bug
......@@ -375,18 +375,16 @@ CNodePtr KernelAdjust::CreateStreamAssignAddnOP(
return assign_add_one;
}
bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::Context> &context,
const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
if (!NeedInsertSwitch()) {
return true;
}
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
auto input_nodes = kernel_graph_ptr->inputs();
std::vector<tensor::TensorPtr> inputs;
LoadSwitchInputs(&inputs);
std::shared_ptr<std::vector<tensor::TensorPtr>> inputsPtr = std::make_shared<std::vector<tensor::TensorPtr>>(inputs);
context->SetResult(session::kInputCtrlTensors, inputsPtr);
kernel_graph_ptr->set_input_ctrl_tensors(inputsPtr);
size_t input_ctrl_size = inputs.size();
// inputs_node:include four ctrl nodes in the back. such as:conv,loop_cnt, ites_loop, zero, one.
// deal four ctrl nodes.
......
......@@ -53,8 +53,7 @@ class KernelAdjust {
}
void Reorder(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
void InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
bool StepLoadCtrlInputs(const std::shared_ptr<session::Context> &context,
const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
bool StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
void Profiling(NotNull<session::KernelGraph *> kernel_graph_ptr);
static bool NeedInsertSwitch();
CNodePtr CreateStreamActiveOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
......
......@@ -517,7 +517,7 @@ void AscendSession::RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input
void AscendSession::GenerateTaskInfo(const std::shared_ptr<KernelGraph> &kernel_graph) const {
MS_LOG(INFO) << "Start!";
(void)device::KernelAdjust::GetInstance().StepLoadCtrlInputs(context_, kernel_graph);
(void)device::KernelAdjust::GetInstance().StepLoadCtrlInputs(kernel_graph);
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
bool ret_ok = runtime_instance->GenTask(kernel_graph.get());
......
......@@ -107,6 +107,12 @@ class KernelGraph : public FuncGraph {
std::vector<std::shared_ptr<KernelGraph>> child_graph_order() const { return child_graph_order_; }
// checkout whether current graph is leaf graph
bool IsLeafGraph() const;
// set input_tensors pointer of control parameter
void set_input_ctrl_tensors(const std::shared_ptr<std::vector<tensor::TensorPtr>> &input_tensors_ptr) {
input_ctrl_tensors_ = input_tensors_ptr;
}
// get input_tensors pointer of control parameter
std::shared_ptr<std::vector<tensor::TensorPtr>> input_ctrl_tensors() const { return input_ctrl_tensors_; }
private:
// remove value node form graph
......@@ -150,6 +156,8 @@ class KernelGraph : public FuncGraph {
std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> node_to_child_graphs_;
// child graph execute order in root graph
std::vector<std::shared_ptr<KernelGraph>> child_graph_order_;
// input_tensors of control parameter
std::shared_ptr<std::vector<tensor::TensorPtr>> input_ctrl_tensors_;
};
} // namespace session
using KernelGraphPtr = std::shared_ptr<session::KernelGraph>;
......
......@@ -268,23 +268,12 @@ AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input,
return make_tuple;
}
bool NeedInsertSwitch() {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
return (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() &&
ConfigManager::GetInstance().iter_num() > 1);
}
size_t LoadCtrlInputTensor(const std::shared_ptr<Context> &context, std::vector<tensor::TensorPtr> *inputs) {
MS_EXCEPTION_IF_NULL(context);
if (!NeedInsertSwitch()) {
(void)context->results_.erase(kInputCtrlTensors);
size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vector<tensor::TensorPtr> *inputs) {
MS_LOG(INFO) << "Load kInputCtrlTensors";
auto inputs_params = graph->input_ctrl_tensors();
if (inputs_params == nullptr) {
return 0;
}
MS_LOG(INFO) << "Load kInputCtrlTensors";
auto inputs_params =
context->GetResult(kInputCtrlTensors).cast<const std::shared_ptr<std::vector<tensor::TensorPtr>>>();
MS_EXCEPTION_IF_NULL(inputs_params);
if (inputs_params->empty()) {
MS_LOG(EXCEPTION) << "Illegal empty inputs_params";
}
......@@ -689,11 +678,10 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
const std::vector<tensor::TensorPtr> &inputs_const) const {
std::vector<tensor::TensorPtr> inputs(inputs_const);
size_t input_ctrl_size = 1;
MS_EXCEPTION_IF_NULL(context_);
if (context_->HasResult(kInputCtrlTensors)) {
input_ctrl_size = LoadCtrlInputTensor(context_, &inputs);
}
MS_EXCEPTION_IF_NULL(kernel_graph);
if (kernel_graph->input_ctrl_tensors()) {
input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs);
}
auto input_nodes = kernel_graph->inputs();
if ((inputs.size() + input_ctrl_size) - 1 != input_nodes.size()) {
MS_LOG(EXCEPTION) << "tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
......
......@@ -39,8 +39,7 @@ bool TaskGenerator::GenTasks(const std::vector<CNodePtr> &anf_node_list, std::ve
} // namespace ascend
void KernelAdjust::Reorder(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { return; }
void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { return; }
bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::Context> &context,
const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
return true;
}
bool KernelAdjust::NeedInsertSwitch() { return true; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册