提交 a26a6bc7 编写于 作者: D dzhwinter

add flag. test=develop

上级 06f24488
...@@ -1725,6 +1725,18 @@ class Program(object): ...@@ -1725,6 +1725,18 @@ class Program(object):
self._trainers_endpoints = [] self._trainers_endpoints = []
# the distributed lookup table names # the distributed lookup table names
self._distributed_lookup_table = None self._distributed_lookup_table = None
# whether the program is optimized by memory_optimize_transpiler
self.__is_optimized = False
@property
def _is_optimized(self):
# if the program is optimized, operator input/outputs
# maybe same, which conflict with save_inference_model.
return self.__is_optimized
@_is_optimized.setter
def set__is_optimized(self, target):
self.__is_optimized = target
@property @property
def op_role(self): def op_role(self):
......
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
import os import os
import errno import errno
import warnings
import time import time
import shutil import shutil
import six import six
...@@ -930,6 +931,13 @@ def save_inference_model(dirname, ...@@ -930,6 +931,13 @@ def save_inference_model(dirname,
if main_program is None: if main_program is None:
main_program = default_main_program() main_program = default_main_program()
if main_program.is_optimized:
warnings.warn(
"save_inference_model must put before you call memory_optimize. \
the memory_optimize will modify the original program, \
is not suitable for saving inference model \
we save the original program as inference model.",
RuntimeWarning)
# when a pserver and a trainer running on the same machine, mkdir may conflict # when a pserver and a trainer running on the same machine, mkdir may conflict
try: try:
......
...@@ -540,6 +540,7 @@ def memory_optimize(input_program, ...@@ -540,6 +540,7 @@ def memory_optimize(input_program,
if skip_opt_set is not None: if skip_opt_set is not None:
skip_opt_set = set(map(to_name_str, skip_opt_set)) skip_opt_set = set(map(to_name_str, skip_opt_set))
cfgs = _get_cfgs(input_program) cfgs = _get_cfgs(input_program)
input_program.is_optimized = True
for cfg in cfgs: for cfg in cfgs:
cfg.memory_optimize(skip_opt_set=skip_opt_set, level=level) cfg.memory_optimize(skip_opt_set=skip_opt_set, level=level)
...@@ -559,5 +560,6 @@ def release_memory(input_program, skip_opt_set=None): ...@@ -559,5 +560,6 @@ def release_memory(input_program, skip_opt_set=None):
None None
""" """
cfgs = _get_cfgs(input_program) cfgs = _get_cfgs(input_program)
input_program.is_optimized = True
for cfg in cfgs: for cfg in cfgs:
cfg.release_memory(skip_opt_set=skip_opt_set) cfg.release_memory(skip_opt_set=skip_opt_set)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册