提交 2f7c5f7a 编写于 作者: H hongxing 提交者: Sheng

fix edge bug

上级 9c8e750c
......@@ -220,11 +220,18 @@ std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> graph,
}
}
std::shared_ptr<Graph> new_graph(new Graph);
for (size_t i = 0; i < (size_t)graph->nodes.size(); i++) {
for (size_t i = 0; i < graph->nodes.size(); i++) {
if (index_list->at(i) > SIZE_MAX / 2) {
continue;
}
new_graph->nodes.push_back(graph->nodes[i]);
for (size_t j = 0; j < new_graph->nodes[index_list->at(i)].node_in.size(); j++) {
new_graph->nodes[index_list->at(i)].node_in[j] = index_list->at(new_graph->nodes[index_list->at(i)].node_in[j]);
}
for (size_t j = 0; j < new_graph->nodes[index_list->at(i)].node_out.size(); j++) {
new_graph->nodes[index_list->at(i)].node_out[j] = index_list->at(new_graph->nodes[index_list->at(i)].node_out[j]);
}
}
return new_graph;
}
......
......@@ -232,7 +232,7 @@ Status PartitionForAllDevices(const size_t num_device, const double device_memor
}
}
if (DevicesMemoryControl(device_memory, graph) != SUCCESS) {
if (DevicesMemoryControl(num_device, device_memory, graph) != SUCCESS) {
return FAILED;
} else {
return SUCCESS;
......@@ -257,16 +257,15 @@ Graph::NodeType ApplyStrToTensor(Graph::NodeType Node) {
return Node;
}
Status DevicesMemoryControl(const double device_memory, std::shared_ptr<Graph> graph) {
Status DevicesMemoryControl(const size_t num_device, const double device_memory, std::shared_ptr<Graph> graph) {
MS_EXCEPTION_IF_NULL(graph);
uint64_t iter_nodes = graph->nodes.size();
double used_memory = 0.0;
for (uint64_t i_node = 0; i_node < iter_nodes; i_node++) {
if (graph->nodes[i_node].info == 0) {
Graph::NodeType &Node = graph->nodes[i_node];
double used_memory = 0.0;
for (int index = 0; index < 2; index++) {
used_memory += Node.apply.arguments[index].tensor_str.str_n * Node.apply.arguments[index].tensor_shape.shape_n *
Node.apply.arguments[index].tensor_str.str_c * Node.apply.arguments[index].tensor_shape.shape_c *
......@@ -274,21 +273,15 @@ Status DevicesMemoryControl(const double device_memory, std::shared_ptr<Graph> g
Node.apply.arguments[index].tensor_str.str_w * Node.apply.arguments[index].tensor_shape.shape_w *
GetDataTypeSize(Node.apply.arguments[index].tensor_type);
}
used_memory += Node.tensor_parm.tensor_str.str_n * Node.tensor_parm.tensor_shape.shape_n *
Node.tensor_parm.tensor_str.str_c * Node.tensor_parm.tensor_shape.shape_c *
Node.tensor_parm.tensor_str.str_h * Node.tensor_parm.tensor_shape.shape_h *
Node.tensor_parm.tensor_str.str_w * Node.tensor_parm.tensor_shape.shape_w *
GetDataTypeSize(Node.tensor_parm.tensor_type);
if (device_memory < used_memory) {
MS_LOG(EXCEPTION) << "Failure: Out of memory!";
return FAILED;
}
}
}
return SUCCESS;
if (device_memory < (used_memory / num_device)) {
MS_LOG(EXCEPTION) << "Failure: Out of memory!";
return FAILED;
} else {
return SUCCESS;
}
}
size_t GetDataTypeSize(const TensorType &type) {
......
......@@ -44,7 +44,7 @@ Status PartitionForAllDevices(const size_t num_device, const double device_memor
Graph::NodeType ApplyStrToTensor(Graph::NodeType Node);
Status DevicesMemoryControl(const double device_memory, std::shared_ptr<Graph> graph);
Status DevicesMemoryControl(const size_t num_device, const double device_memory, std::shared_ptr<Graph> graph);
size_t GetDataTypeSize(const TensorType &type);
} // namespace parallel
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册