diff --git a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py index 95aafec05361a8b66b849268c7a738bb2ee5da86..5a7d04ed19437757bda6e0657ab45a320712ce86 100755 --- a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py +++ b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py @@ -484,8 +484,11 @@ def memory_optimize(input_program, if level != 0 and level != 1: raise ValueError("only support opt_level 0 or 1.") - if skip_opt_set is not None and not isinstance(skip_opt_set, set): - raise ValueError("only support skip_opt_set as set.") + if skip_opt_set is not None: + if isinstance(skip_opt_set, set) or isinstance(skip_opt_set, list): + skip_opt_set = set(skip_opt_set) + else: + raise ValueError("only support skip_opt_set as set.") global PRINT_LOG PRINT_LOG = print_log if skip_grads: