diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc index 99df43ca237f155c77fb06f4becace0a67f0260a..823b1dca08771aafa051f033c52150b2c8e83a25 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc @@ -220,11 +220,18 @@ std::shared_ptr EliminateGraph(const std::shared_ptr graph, } } std::shared_ptr new_graph(new Graph); - for (size_t i = 0; i < (size_t)graph->nodes.size(); i++) { + for (size_t i = 0; i < graph->nodes.size(); i++) { if (index_list->at(i) > SIZE_MAX / 2) { continue; } + new_graph->nodes.push_back(graph->nodes[i]); + for (size_t j = 0; j < new_graph->nodes[index_list->at(i)].node_in.size(); j++) { + new_graph->nodes[index_list->at(i)].node_in[j] = index_list->at(new_graph->nodes[index_list->at(i)].node_in[j]); + } + for (size_t j = 0; j < new_graph->nodes[index_list->at(i)].node_out.size(); j++) { + new_graph->nodes[index_list->at(i)].node_out[j] = index_list->at(new_graph->nodes[index_list->at(i)].node_out[j]); + } } return new_graph; } diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc index af9b2669cb5cb1128656d9468162e7c836480869..ac8e52eed67c035f85b1813fd2c74a231f6d98df 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc @@ -232,7 +232,7 @@ Status PartitionForAllDevices(const size_t num_device, const double device_memor } } - if (DevicesMemoryControl(device_memory, graph) != SUCCESS) { + if (DevicesMemoryControl(num_device, device_memory, graph) != SUCCESS) { return FAILED; } else { return SUCCESS; @@ -257,16 +257,15 @@ Graph::NodeType ApplyStrToTensor(Graph::NodeType Node) { return Node; } -Status DevicesMemoryControl(const double device_memory, std::shared_ptr graph) { +Status DevicesMemoryControl(const size_t num_device, const double device_memory, std::shared_ptr graph) { MS_EXCEPTION_IF_NULL(graph); uint64_t iter_nodes = graph->nodes.size(); + double used_memory = 0.0; for (uint64_t i_node = 0; i_node < iter_nodes; i_node++) { if (graph->nodes[i_node].info == 0) { Graph::NodeType &Node = graph->nodes[i_node]; - double used_memory = 0.0; - for (int index = 0; index < 2; index++) { used_memory += Node.apply.arguments[index].tensor_str.str_n * Node.apply.arguments[index].tensor_shape.shape_n * Node.apply.arguments[index].tensor_str.str_c * Node.apply.arguments[index].tensor_shape.shape_c * @@ -274,21 +273,15 @@ Status DevicesMemoryControl(const double device_memory, std::shared_ptr g Node.apply.arguments[index].tensor_str.str_w * Node.apply.arguments[index].tensor_shape.shape_w * GetDataTypeSize(Node.apply.arguments[index].tensor_type); } - - used_memory += Node.tensor_parm.tensor_str.str_n * Node.tensor_parm.tensor_shape.shape_n * - Node.tensor_parm.tensor_str.str_c * Node.tensor_parm.tensor_shape.shape_c * - Node.tensor_parm.tensor_str.str_h * Node.tensor_parm.tensor_shape.shape_h * - Node.tensor_parm.tensor_str.str_w * Node.tensor_parm.tensor_shape.shape_w * - GetDataTypeSize(Node.tensor_parm.tensor_type); - - if (device_memory < used_memory) { - MS_LOG(EXCEPTION) << "Failure: Out of memory!"; - return FAILED; - } } } - return SUCCESS; + if (device_memory < (used_memory / num_device)) { + MS_LOG(EXCEPTION) << "Failure: Out of memory!"; + return FAILED; + } else { + return SUCCESS; + } } size_t GetDataTypeSize(const TensorType &type) { diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.h index fc504b3cb2c0d562bbd804a3f4727d1cfc6f208d..b2fbeddebd857116b91238e3d4e56a58ba3616e3 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.h @@ -44,7 +44,7 @@ Status PartitionForAllDevices(const size_t num_device, const double device_memor Graph::NodeType ApplyStrToTensor(Graph::NodeType Node); -Status DevicesMemoryControl(const double device_memory, std::shared_ptr graph); +Status DevicesMemoryControl(const size_t num_device, const double device_memory, std::shared_ptr graph); size_t GetDataTypeSize(const TensorType &type); } // namespace parallel