diff --git a/mace/python/tools/memory_optimizer.py b/mace/python/tools/memory_optimizer.py index 10a0ab5029722e75c4226556ea7b9620e93f469c..a4688aada47f82532bf0483ea6c9a39e99d26c58 100644 --- a/mace/python/tools/memory_optimizer.py +++ b/mace/python/tools/memory_optimizer.py @@ -148,6 +148,9 @@ class MemoryOptimizer(object): mem_id = self.op_mem.get(op.input[0], -1) else: output_type = mace_pb2.DT_FLOAT + for arg in op.arg: + if arg.name == 'T': + output_type = arg.i if len(op.output_type) > i: output_type = op.output_type[i] op_mem_block = self.get_op_mem_block(