diff --git a/mace/python/tools/memory_optimizer.py b/mace/python/tools/memory_optimizer.py index e12d2a6cf7e2745f3e8f5b07a3da526bdc9c4bd5..63af39ab9ed24af0ad102d3e02c02d70c86819aa 100644 --- a/mace/python/tools/memory_optimizer.py +++ b/mace/python/tools/memory_optimizer.py @@ -154,6 +154,9 @@ class MemoryOptimizer(object): mem_id = self.op_mem.get(op.input[0], -1) else: output_type = mace_pb2.DT_FLOAT + for arg in op.arg: + if arg.name == 'T': + output_type = arg.i if len(op.output_type) > i: output_type = op.output_type[i] op_mem_block = self.get_op_mem_block(