提交 a26a6bc7 编写于 作者: D dzhwinter

add flag. test=develop

上级 06f24488
......@@ -1725,6 +1725,18 @@ class Program(object):
self._trainers_endpoints = []
# the distributed lookup table names
self._distributed_lookup_table = None
# whether the program is optimized by memory_optimize_transpiler
self.__is_optimized = False
@property
def _is_optimized(self):
# if the program is optimized, operator input/outputs
# maybe same, which conflict with save_inference_model.
return self.__is_optimized
@_is_optimized.setter
def set__is_optimized(self, target):
self.__is_optimized = target
@property
def op_role(self):
......
......@@ -16,6 +16,7 @@ from __future__ import print_function
import os
import errno
import warnings
import time
import shutil
import six
......@@ -930,6 +931,13 @@ def save_inference_model(dirname,
if main_program is None:
main_program = default_main_program()
if main_program.is_optimized:
warnings.warn(
"save_inference_model must put before you call memory_optimize. \
the memory_optimize will modify the original program, \
is not suitable for saving inference model \
we save the original program as inference model.",
RuntimeWarning)
# when a pserver and a trainer running on the same machine, mkdir may conflict
try:
......
......@@ -540,6 +540,7 @@ def memory_optimize(input_program,
if skip_opt_set is not None:
skip_opt_set = set(map(to_name_str, skip_opt_set))
cfgs = _get_cfgs(input_program)
input_program.is_optimized = True
for cfg in cfgs:
cfg.memory_optimize(skip_opt_set=skip_opt_set, level=level)
......@@ -559,5 +560,6 @@ def release_memory(input_program, skip_opt_set=None):
None
"""
cfgs = _get_cfgs(input_program)
input_program.is_optimized = True
for cfg in cfgs:
cfg.release_memory(skip_opt_set=skip_opt_set)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册