提交 aa3266b1 编写于 作者: Y yukavio

update for enable static

上级 e66b53ca
...@@ -24,6 +24,7 @@ sys.path.append(os.path.join(__dir__, '..', '..', '..')) ...@@ -24,6 +24,7 @@ sys.path.append(os.path.join(__dir__, '..', '..', '..'))
sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools')) sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
import program import program
import paddle
from paddle import fluid from paddle import fluid
from ppocr.utils.utility import initial_logger from ppocr.utils.utility import initial_logger
logger = initial_logger() logger = initial_logger()
...@@ -32,6 +33,12 @@ from paddleslim.prune import load_model ...@@ -32,6 +33,12 @@ from paddleslim.prune import load_model
def main(): def main():
# Run code with static graph mode.
try:
paddle.enable_static()
except:
pass
startup_prog, eval_program, place, config, _ = program.preprocess() startup_prog, eval_program, place, config, _ = program.preprocess()
feeded_var_names, target_vars, fetches_var_name = program.build_export( feeded_var_names, target_vars, fetches_var_name = program.build_export(
......
...@@ -50,7 +50,12 @@ skip_list = [ ...@@ -50,7 +50,12 @@ skip_list = [
def main(): def main():
paddle.enable_static() # Run code with static graph mode.
try:
paddle.enable_static()
except:
pass
config = program.load_config(FLAGS.config) config = program.load_config(FLAGS.config)
program.merge_config(FLAGS.opt) program.merge_config(FLAGS.opt)
logger.info(config) logger.info(config)
......
...@@ -25,6 +25,7 @@ sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools')) ...@@ -25,6 +25,7 @@ sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
import json import json
import cv2 import cv2
import paddle
from paddle import fluid from paddle import fluid
import paddleslim as slim import paddleslim as slim
from copy import deepcopy from copy import deepcopy
...@@ -60,6 +61,12 @@ def eval_function(eval_args, mode='eval'): ...@@ -60,6 +61,12 @@ def eval_function(eval_args, mode='eval'):
def main(): def main():
# Run code with static graph mode.
try:
paddle.enable_static()
except:
pass
config = program.load_config(FLAGS.config) config = program.load_config(FLAGS.config)
program.merge_config(FLAGS.opt) program.merge_config(FLAGS.opt)
logger.info(config) logger.info(config)
......
...@@ -77,7 +77,12 @@ def main(): ...@@ -77,7 +77,12 @@ def main():
# The decay coefficient of moving average, default is 0.9 # The decay coefficient of moving average, default is 0.9
'moving_rate': 0.9, 'moving_rate': 0.9,
} }
paddle.enable_static() # Run code with static graph mode.
try:
paddle.enable_static()
except:
pass
startup_prog, eval_program, place, config, alg_type = program.preprocess() startup_prog, eval_program, place, config, alg_type = program.preprocess()
feeded_var_names, target_vars, fetches_var_name = program.build_export( feeded_var_names, target_vars, fetches_var_name = program.build_export(
......
...@@ -85,7 +85,12 @@ def get_optimizer(): ...@@ -85,7 +85,12 @@ def get_optimizer():
def main(): def main():
paddle.enable_static() # Run code with static graph mode.
try:
paddle.enable_static()
except:
pass
train_build_outputs = program.build( train_build_outputs = program.build(
config, train_program, startup_program, mode='train') config, train_program, startup_program, mode='train')
train_loader = train_build_outputs[0] train_loader = train_build_outputs[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册