diff --git a/mace/python/tools/memory_optimizer.py b/mace/python/tools/memory_optimizer.py index 2e5716976b5a8cafdd22dceee0785b88a199bc11..8753b2700157416f9f078a8ee6a14b64c8fec718 100644 --- a/mace/python/tools/memory_optimizer.py +++ b/mace/python/tools/memory_optimizer.py @@ -43,6 +43,9 @@ class MemoryOptimizer(object): mem_size[1] = output_shape[0] * output_shape[1] return mem_size + def mem_area(self, memory_size): + return memory_size[0] * memory_size[1] + def optimize(self): for op in self.net_def.op: if self.is_buffer_image_op(op): @@ -54,22 +57,34 @@ class MemoryOptimizer(object): print('WARNING: the number of output shape is not equal to the number of output.') return for i in range(len(op.output)): - if len(self.idle_mem) == 0: - # allocate new mem + op_mem_size = self.get_mem_size(op.type, op.output_shape[i].dims) + mem_id = -1 + if len(self.idle_mem) > 0: + best_mem_candidate_id = -1 + best_mem_candidate_delta_area = sys.maxint + best_mem_candidate_shape = [] + for mid in self.idle_mem: + reuse_mem_size = self.mem_block[mid] + resize_mem_size = [max(reuse_mem_size[0], op_mem_size[0]), max(reuse_mem_size[1], op_mem_size[1])] + delta_mem_area = self.mem_area(resize_mem_size) - self.mem_area(reuse_mem_size) + if delta_mem_area < best_mem_candidate_delta_area: + best_mem_candidate_id = mid + best_mem_candidate_delta_area = delta_mem_area + best_mem_candidate_shape = resize_mem_size + + if best_mem_candidate_delta_area <= self.mem_area(op_mem_size): + # reuse + self.mem_block[best_mem_candidate_id] = best_mem_candidate_shape + mem_id = best_mem_candidate_id + self.idle_mem.remove(mem_id) + + if mem_id == -1: mem_id = self.total_mem_count self.total_mem_count += 1 - else: - # reuse mem - mem_id = self.idle_mem.pop() + self.mem_block[mem_id] = op_mem_size op.mem_id.extend([mem_id]) self.op_mem[op.output[i]] = mem_id - if mem_id not in self.mem_block: - self.mem_block[mem_id] = [0, 0] - mem_size = self.mem_block[mem_id] - op_mem_size = self.get_mem_size(op.type, op.output_shape[i].dims) - mem_size[0] = max(mem_size[0], op_mem_size[0]) - mem_size[1] = max(mem_size[1], op_mem_size[1]) # de-ref input tensor mem for ipt in op.input: