diff --git a/mace/python/tools/memory_optimizer.py b/mace/python/tools/memory_optimizer.py index 8753b2700157416f9f078a8ee6a14b64c8fec718..e632a22a070c6ae57774ea36704bdda8dc6a7ee7 100644 --- a/mace/python/tools/memory_optimizer.py +++ b/mace/python/tools/memory_optimizer.py @@ -35,7 +35,7 @@ class MemoryOptimizer(object): def get_mem_size(self, op_type, output_shape): mem_size = [0, 0] - if op_type == 'WinogradTransform' or op_type == 'GEMM': + if op_type == 'WinogradTransform' or op_type == 'MatMul': mem_size[0] = output_shape[2] * output_shape[3] mem_size[1] = output_shape[0] * int((output_shape[1]+3)/4) else: