提交 48d0b27d 编写于 作者: W wuzewu

Do not create checkpoint dir during the predict phase

上级 d4dbc11c
......@@ -249,16 +249,11 @@ class BaseTask(object):
self.exe = fluid.Executor(place=self.place)
self.build_strategy = fluid.BuildStrategy()
# log item
if not os.path.exists(self.config.checkpoint_dir):
mkdir(self.config.checkpoint_dir)
tb_log_dir = os.path.join(self.config.checkpoint_dir, "visualization")
self.tb_writer = SummaryWriter(tb_log_dir)
# run environment
self._phases = []
self._envs = {}
self._predict_data = None
self._tb_writer = None
# event hooks
self._hooks = TaskHooks()
......@@ -559,6 +554,15 @@ class BaseTask(object):
return [metric.name for metric in self.metrics] + [self.loss.name]
return [output.name for output in self.outputs]
@property
def tb_writer(self):
if not os.path.exists(self.config.checkpoint_dir):
mkdir(self.config.checkpoint_dir)
tb_log_dir = os.path.join(self.config.checkpoint_dir, "visualization")
if not self._tb_writer:
self._tb_writer = SummaryWriter(tb_log_dir)
return self._tb_writer
def create_event_function(self, hook_type):
def hook_function(self, *args):
for name, func in self._hooks[hook_type].items():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册