未验证 提交 a1446709 编写于 作者: G Guanghua Yu 提交者: GitHub

[Dygraph] update dygraph export model (#1857)

* update dygraph export model

* delete get_export

* adapt faster and cascade
上级 9b279ee3
...@@ -133,7 +133,7 @@ class Detector(object): ...@@ -133,7 +133,7 @@ class Detector(object):
boxes_tensor = self.predictor.get_output_handle(output_names[0]) boxes_tensor = self.predictor.get_output_handle(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu() np_boxes = boxes_tensor.copy_to_cpu()
if self.pred_config.mask_resolution is not None: if self.pred_config.mask_resolution is not None:
masks_tensor = self.predictor.get_output_handle(output_names[1]) masks_tensor = self.predictor.get_output_handle(output_names[2])
np_masks = masks_tensor.copy_to_cpu() np_masks = masks_tensor.copy_to_cpu()
t2 = time.time() t2 = time.time()
ms = (t2 - t1) * 1000.0 / repeats ms = (t2 - t1) * 1000.0 / repeats
......
...@@ -79,7 +79,7 @@ class ResizeOp(object): ...@@ -79,7 +79,7 @@ class ResizeOp(object):
im_info['scale_factor'] = np.array( im_info['scale_factor'] = np.array(
[im_scale_y, im_scale_x]).astype('float32') [im_scale_y, im_scale_x]).astype('float32')
# padding im when image_shape fixed by infer_cfg.yml # padding im when image_shape fixed by infer_cfg.yml
if self.keep_ratio: if self.keep_ratio and im_info['input_shape'][1] is not None:
max_size = im_info['input_shape'][1] max_size = im_info['input_shape'][1]
padding_im = np.zeros( padding_im = np.zeros(
(max_size, max_size, im_channel), dtype=np.float32) (max_size, max_size, im_channel), dtype=np.float32)
......
...@@ -158,17 +158,12 @@ class CascadeRCNN(BaseArch): ...@@ -158,17 +158,12 @@ class CascadeRCNN(BaseArch):
loss.update({'loss': total_loss}) loss.update({'loss': total_loss})
return loss return loss
def get_pred(self, return_numpy=True): def get_pred(self):
bbox, bbox_num = self.bboxes bbox, bbox_num = self.bboxes
output = { output = {
'bbox': bbox.numpy(), 'bbox': bbox,
'bbox_num': bbox_num.numpy(), 'bbox_num': bbox_num,
'im_id': self.inputs['im_id'].numpy(),
} }
if self.with_mask: if self.with_mask:
mask = self.mask_post_process(self.bboxes, self.mask_head_out, output.update(self.mask_head_out)
self.inputs['im_shape'],
self.inputs['scale_factor'])
output.update(mask)
return output return output
...@@ -92,12 +92,10 @@ class FasterRCNN(BaseArch): ...@@ -92,12 +92,10 @@ class FasterRCNN(BaseArch):
loss.update({'loss': total_loss}) loss.update({'loss': total_loss})
return loss return loss
def get_pred(self, return_numpy=True): def get_pred(self):
bbox, bbox_num = self.bboxes bbox, bbox_num = self.bboxes
output = { output = {
'bbox': bbox.numpy(), 'bbox': bbox,
'bbox_num': bbox_num.numpy(), 'bbox_num': bbox_num,
'im_id': self.inputs['im_id'].numpy()
} }
return output return output
...@@ -133,15 +133,11 @@ class MaskRCNN(BaseArch): ...@@ -133,15 +133,11 @@ class MaskRCNN(BaseArch):
loss.update({'loss': total_loss}) loss.update({'loss': total_loss})
return loss return loss
def get_pred(self, return_numpy=True): def get_pred(self):
mask = self.mask_post_process(self.bboxes, self.mask_head_out,
self.inputs['im_shape'],
self.inputs['scale_factor'])
bbox, bbox_num = self.bboxes bbox, bbox_num = self.bboxes
output = { output = {
'bbox': bbox.numpy(), 'bbox': bbox,
'bbox_num': bbox_num.numpy(), 'bbox_num': bbox_num,
'im_id': self.inputs['im_id'].numpy() 'mask': self.mask_head_out
} }
output.update(mask)
return output return output
...@@ -16,18 +16,24 @@ class BaseArch(nn.Layer): ...@@ -16,18 +16,24 @@ class BaseArch(nn.Layer):
def __init__(self): def __init__(self):
super(BaseArch, self).__init__() super(BaseArch, self).__init__()
def forward(self, data, input_def, mode, input_tensor=None): def forward(self,
input_tensor=None,
data=None,
input_def=None,
mode='infer'):
if input_tensor is None: if input_tensor is None:
assert data is not None and input_def is not None
self.inputs = self.build_inputs(data, input_def) self.inputs = self.build_inputs(data, input_def)
else: else:
self.inputs = input_tensor self.inputs = input_tensor
self.inputs['mode'] = mode self.inputs['mode'] = mode
self.model_arch() self.model_arch()
if mode == 'train': if mode == 'train':
out = self.get_loss() out = self.get_loss()
elif mode == 'infer': elif mode == 'infer':
out = self.get_pred(input_tensor is None) out = self.get_pred()
else: else:
out = None out = None
raise "Now, only support train and infer mode!" raise "Now, only support train and infer mode!"
...@@ -47,6 +53,3 @@ class BaseArch(nn.Layer): ...@@ -47,6 +53,3 @@ class BaseArch(nn.Layer):
def get_pred(self, ): def get_pred(self, ):
raise NotImplementedError("Should implement get_pred method!") raise NotImplementedError("Should implement get_pred method!")
def get_export_model(self, input_tensor):
return self.forward(None, None, 'infer', input_tensor)
...@@ -43,16 +43,12 @@ class YOLOv3(BaseArch): ...@@ -43,16 +43,12 @@ class YOLOv3(BaseArch):
loss = self.yolo_head.get_loss(self.yolo_head_outs, self.inputs) loss = self.yolo_head.get_loss(self.yolo_head_outs, self.inputs)
return loss return loss
def get_pred(self, return_numpy=True): def get_pred(self):
bbox, bbox_num = self.post_process( bbox, bbox_num = self.post_process(
self.yolo_head_outs, self.yolo_head.mask_anchors, self.yolo_head_outs, self.yolo_head.mask_anchors,
self.inputs['im_shape'], self.inputs['scale_factor']) self.inputs['im_shape'], self.inputs['scale_factor'])
if return_numpy: outs = {
outs = { "bbox": bbox,
"bbox": bbox.numpy(), "bbox_num": bbox_num,
"bbox_num": bbox_num.numpy(), }
'im_id': self.inputs['im_id'].numpy()
}
else:
outs = [bbox, bbox_num]
return outs return outs
...@@ -160,12 +160,19 @@ class MaskHead(Layer): ...@@ -160,12 +160,19 @@ class MaskHead(Layer):
bbox, bbox_num = bboxes bbox, bbox_num = bboxes
if bbox.shape[0] == 0: if bbox.shape[0] == 0:
mask_head_out = bbox mask_head_out = paddle.full([1, 6], -1)
return mask_head_out
else: else:
scale_factor_list = [] # TODO(guanghua): Remove fluid dependency
scale_factor_list = paddle.fluid.layers.create_array('float32')
num_count = 0
for idx, num in enumerate(bbox_num): for idx, num in enumerate(bbox_num):
for n in range(num): for n in range(num):
scale_factor_list.append(scale_factor[idx, 0]) paddle.fluid.layers.array_write(
x=scale_factor[idx, 0],
i=paddle.to_tensor(num_count),
array=scale_factor_list)
num_count += 1
scale_factor_list = paddle.cast( scale_factor_list = paddle.cast(
paddle.concat(scale_factor_list), 'float32') paddle.concat(scale_factor_list), 'float32')
scale_factor_list = paddle.reshape(scale_factor_list, shape=[-1, 1]) scale_factor_list = paddle.reshape(scale_factor_list, shape=[-1, 1])
...@@ -182,7 +189,7 @@ class MaskHead(Layer): ...@@ -182,7 +189,7 @@ class MaskHead(Layer):
mode='infer') mode='infer')
mask_logit = self.mask_fcn_logits[stage](mask_feat) mask_logit = self.mask_fcn_logits[stage](mask_feat)
mask_head_out = F.sigmoid(mask_logit) mask_head_out = F.sigmoid(mask_logit)
return mask_head_out return mask_head_out
def forward(self, def forward(self,
inputs, inputs,
......
...@@ -36,7 +36,6 @@ class RoIAlign(object): ...@@ -36,7 +36,6 @@ class RoIAlign(object):
def __call__(self, feats, rois, spatial_scale): def __call__(self, feats, rois, spatial_scale):
roi, rois_num = rois roi, rois_num = rois
if self.start_level == self.end_level: if self.start_level == self.end_level:
rois_feat = ops.roi_align( rois_feat = ops.roi_align(
feats[self.start_level], feats[self.start_level],
...@@ -44,28 +43,28 @@ class RoIAlign(object): ...@@ -44,28 +43,28 @@ class RoIAlign(object):
self.resolution, self.resolution,
spatial_scale, spatial_scale,
rois_num=rois_num) rois_num=rois_num)
return rois_feat else:
offset = 2 offset = 2
k_min = self.start_level + offset k_min = self.start_level + offset
k_max = self.end_level + offset k_max = self.end_level + offset
rois_dist, restore_index, rois_num_dist = ops.distribute_fpn_proposals( rois_dist, restore_index, rois_num_dist = ops.distribute_fpn_proposals(
roi, roi,
k_min, k_min,
k_max, k_max,
self.canconical_level, self.canconical_level,
self.canonical_size, self.canonical_size,
rois_num=rois_num) rois_num=rois_num)
rois_feat_list = [] rois_feat_list = []
for lvl in range(self.start_level, self.end_level + 1): for lvl in range(self.start_level, self.end_level + 1):
roi_feat = ops.roi_align( roi_feat = ops.roi_align(
feats[lvl], feats[lvl],
rois_dist[lvl], rois_dist[lvl],
self.resolution, self.resolution,
spatial_scale[lvl], spatial_scale[lvl],
sampling_ratio=self.sampling_ratio, sampling_ratio=self.sampling_ratio,
rois_num=rois_num_dist[lvl]) rois_num=rois_num_dist[lvl])
rois_feat_list.append(roi_feat) rois_feat_list.append(roi_feat)
rois_feat_shuffle = paddle.concat(rois_feat_list) rois_feat_shuffle = paddle.concat(rois_feat_list)
rois_feat = paddle.gather(rois_feat_shuffle, restore_index) rois_feat = paddle.gather(rois_feat_shuffle, restore_index)
return rois_feat return rois_feat
...@@ -80,7 +80,8 @@ class FPN(Layer): ...@@ -80,7 +80,8 @@ class FPN(Layer):
for lvl in range(self.min_level, self.max_level): for lvl in range(self.min_level, self.max_level):
laterals.append(self.lateral_convs[lvl](body_feats[lvl])) laterals.append(self.lateral_convs[lvl](body_feats[lvl]))
for lvl in range(self.max_level - 1, self.min_level, -1): for i in range(self.min_level + 1, self.max_level):
lvl = self.max_level + self.min_level - i
upsample = F.interpolate( upsample = F.interpolate(
laterals[lvl], laterals[lvl],
scale_factor=2., scale_factor=2.,
......
此差异已折叠。
...@@ -73,7 +73,8 @@ def bbox_post_process(bboxes, ...@@ -73,7 +73,8 @@ def bbox_post_process(bboxes,
@jit @jit
def mask_post_process(bboxes, def mask_post_process(bbox,
bbox_nums,
masks, masks,
im_shape, im_shape,
scale_factor, scale_factor,
...@@ -81,7 +82,6 @@ def mask_post_process(bboxes, ...@@ -81,7 +82,6 @@ def mask_post_process(bboxes,
binary_thresh=0.5): binary_thresh=0.5):
if masks.shape[0] == 0: if masks.shape[0] == 0:
return masks return masks
bbox, bbox_nums = bboxes
M = resolution M = resolution
scale = (M + 2.0) / M scale = (M + 2.0) / M
boxes = bbox[:, 2:] boxes = bbox[:, 2:]
...@@ -98,7 +98,6 @@ def mask_post_process(bboxes, ...@@ -98,7 +98,6 @@ def mask_post_process(bboxes,
boxes_n = boxes[st_num:end_num] boxes_n = boxes[st_num:end_num]
labels_n = labels[st_num:end_num] labels_n = labels[st_num:end_num]
masks_n = masks[st_num:end_num] masks_n = masks[st_num:end_num]
im_h = int(round(im_shape[i][0] / scale_factor[i])) im_h = int(round(im_shape[i][0] / scale_factor[i]))
im_w = int(round(im_shape[i][1] / scale_factor[i])) im_w = int(round(im_shape[i][1] / scale_factor[i]))
boxes_n = expand_bbox(boxes_n, scale) boxes_n = expand_bbox(boxes_n, scale)
......
...@@ -5,7 +5,7 @@ from __future__ import print_function ...@@ -5,7 +5,7 @@ from __future__ import print_function
import os import os
import sys import sys
import json import json
from ppdet.py_op.post_process import get_det_res, get_seg_res from ppdet.py_op.post_process import get_det_res, get_seg_res, mask_post_process
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -33,7 +33,8 @@ def json_eval_results(metric, json_directory=None, dataset=None): ...@@ -33,7 +33,8 @@ def json_eval_results(metric, json_directory=None, dataset=None):
logger.info("{} not exists!".format(v_json)) logger.info("{} not exists!".format(v_json))
def get_infer_results(outs_res, eval_type, catid): def get_infer_results(outs_res, eval_type, catid, im_info,
mask_resolution=None):
""" """
Get result at the stage of inference. Get result at the stage of inference.
The output format is dictionary containing bbox or mask result. The output format is dictionary containing bbox or mask result.
...@@ -49,16 +50,25 @@ def get_infer_results(outs_res, eval_type, catid): ...@@ -49,16 +50,25 @@ def get_infer_results(outs_res, eval_type, catid):
if 'bbox' in eval_type: if 'bbox' in eval_type:
box_res = [] box_res = []
for outs in outs_res: for i, outs in enumerate(outs_res):
box_res += get_det_res(outs['bbox'], outs['bbox_num'], im_ids = im_info[i][2]
outs['im_id'], catid) box_res += get_det_res(outs['bbox'].numpy(),
outs['bbox_num'].numpy(), im_ids, catid)
infer_res['bbox'] = box_res infer_res['bbox'] = box_res
if 'mask' in eval_type: if 'mask' in eval_type:
seg_res = [] seg_res = []
for outs in outs_res: # mask post process
seg_res += get_seg_res(outs['mask'], outs['bbox_num'], for i, outs in enumerate(outs_res):
outs['im_id'], catid) im_shape = im_info[i][0]
scale_factor = im_info[i][1]
im_ids = im_info[i][2]
mask = mask_post_process(outs['bbox'].numpy(),
outs['bbox_num'].numpy(),
outs['mask'].numpy(), im_shape,
scale_factor[0], mask_resolution)
seg_res += get_seg_res(mask, outs['bbox_num'].numpy(), im_ids,
catid)
infer_res['mask'] = seg_res infer_res['mask'] = seg_res
return infer_res return infer_res
......
...@@ -75,12 +75,18 @@ def run(FLAGS, cfg, place): ...@@ -75,12 +75,18 @@ def run(FLAGS, cfg, place):
outs_res = [] outs_res = []
start_time = time.time() start_time = time.time()
sample_num = 0 sample_num = 0
im_info = []
for iter_id, data in enumerate(eval_loader): for iter_id, data in enumerate(eval_loader):
# forward # forward
fields = cfg['EvalReader']['inputs_def']['fields']
model.eval() model.eval()
outs = model(data, cfg['EvalReader']['inputs_def']['fields'], 'infer') outs = model(data=data, input_def=fields, mode='infer')
outs_res.append(outs) outs_res.append(outs)
im_info.append([
data[fields.index('im_shape')].numpy(),
data[fields.index('scale_factor')].numpy(),
data[fields.index('im_id')].numpy()
])
# log # log
sample_num += len(data) sample_num += len(data)
if iter_id % 100 == 0: if iter_id % 100 == 0:
...@@ -102,7 +108,15 @@ def run(FLAGS, cfg, place): ...@@ -102,7 +108,15 @@ def run(FLAGS, cfg, place):
clsid2catid, catid2name = get_category_info(anno_file, with_background, clsid2catid, catid2name = get_category_info(anno_file, with_background,
use_default_label) use_default_label)
infer_res = get_infer_results(outs_res, eval_type, clsid2catid) mask_resolution = None
if cfg['MaskPostProcess']['mask_resolution'] is not None:
mask_resolution = int(cfg['MaskPostProcess']['mask_resolution'])
infer_res = get_infer_results(
outs_res,
eval_type,
clsid2catid,
im_info,
mask_resolution=mask_resolution)
eval_results(infer_res, cfg.metric, anno_file) eval_results(infer_res, cfg.metric, anno_file)
......
...@@ -53,63 +53,43 @@ def parse_args(): ...@@ -53,63 +53,43 @@ def parse_args():
return args return args
def dygraph_to_static(model, save_dir, cfg):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
inputs_def = cfg['TestReader']['inputs_def']
image_shape = inputs_def.get('image_shape')
if image_shape is None:
image_shape = [3, None, None]
# Save infer cfg
dump_infer_config(cfg, os.path.join(save_dir, 'infer_cfg.yml'), image_shape)
input_spec = [{
"image": InputSpec(
shape=[None] + image_shape, name='image'),
"im_shape": InputSpec(
shape=[None, 2], name='im_shape'),
"scale_factor": InputSpec(
shape=[None, 2], name='scale_factor')
}]
export_model = to_static(model, input_spec=input_spec)
# save Model
paddle.jit.save(export_model, os.path.join(save_dir, 'model'))
def run(FLAGS, cfg): def run(FLAGS, cfg):
# Model # Model
main_arch = cfg.architecture main_arch = cfg.architecture
model = create(cfg.architecture) model = create(cfg.architecture)
inputs_def = cfg['TestReader']['inputs_def']
assert 'image_shape' in inputs_def, 'image_shape must be specified.'
image_shape = inputs_def.get('image_shape')
assert not None in image_shape, 'image_shape should not contain None'
cfg_name = os.path.basename(FLAGS.config).split('.')[0] cfg_name = os.path.basename(FLAGS.config).split('.')[0]
save_dir = os.path.join(FLAGS.output_dir, cfg_name) save_dir = os.path.join(FLAGS.output_dir, cfg_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
image_shape = dump_infer_config(cfg,
os.path.join(save_dir, 'infer_cfg.yml'),
image_shape)
class ExportModel(nn.Layer):
def __init__(self, model):
super(ExportModel, self).__init__()
self.model = model
@to_static(input_spec=[
{
'image': InputSpec(
shape=[None] + image_shape, name='image')
},
{
'im_shape': InputSpec(
shape=[None, 2], name='im_shape')
},
{
'scale_factor': InputSpec(
shape=[None, 2], name='scale_factor')
},
])
def forward(self, image, im_shape, scale_factor):
inputs = {}
inputs_tensor = [image, im_shape, scale_factor]
for t in inputs_tensor:
inputs.update(t)
outs = self.model.get_export_model(inputs)
return outs
export_model = ExportModel(model)
# debug for dy2static, remove later
#paddle.jit.set_code_level()
# Init Model # Init Model
load_weight(export_model.model, cfg.weights) load_weight(model, cfg.weights)
export_model.eval()
# export config and model # export config and model
paddle.jit.save(export_model, os.path.join(save_dir, 'model')) dygraph_to_static(model, save_dir, cfg)
logger.info('Export model to {}'.format(save_dir)) logger.info('Export model to {}'.format(save_dir))
......
...@@ -109,7 +109,8 @@ def dump_infer_config(config, path, image_shape): ...@@ -109,7 +109,8 @@ def dump_infer_config(config, path, image_shape):
os._exit(0) os._exit(0)
if 'Mask' in config['architecture']: if 'Mask' in config['architecture']:
infer_cfg['mask_resolution'] = config['Mask']['mask_resolution'] infer_cfg['mask_resolution'] = config['MaskPostProcess'][
'mask_resolution']
infer_cfg['with_background'], infer_cfg['Preprocess'], infer_cfg[ infer_cfg['with_background'], infer_cfg['Preprocess'], infer_cfg[
'label_list'], image_shape = parse_reader( 'label_list'], image_shape = parse_reader(
config['TestReader'], config['TestDataset'], config['metric'], config['TestReader'], config['TestDataset'], config['metric'],
......
...@@ -147,15 +147,32 @@ def run(FLAGS, cfg, place): ...@@ -147,15 +147,32 @@ def run(FLAGS, cfg, place):
# Run Infer # Run Infer
for iter_id, data in enumerate(test_loader): for iter_id, data in enumerate(test_loader):
# forward # forward
fields = cfg.TestReader['inputs_def']['fields']
model.eval() model.eval()
outs = model(data, cfg.TestReader['inputs_def']['fields'], 'infer') outs = model(
data=data,
batch_res = get_infer_results([outs], outs.keys(), clsid2catid) input_def=cfg.TestReader['inputs_def']['fields'],
mode='infer')
im_info = [[
data[fields.index('im_shape')].numpy(),
data[fields.index('scale_factor')].numpy(),
data[fields.index('im_id')].numpy()
]]
im_ids = data[fields.index('im_id')].numpy()
mask_resolution = None
if cfg['MaskPostProcess']['mask_resolution'] is not None:
mask_resolution = int(cfg['MaskPostProcess']['mask_resolution'])
batch_res = get_infer_results(
[outs],
outs.keys(),
clsid2catid,
im_info,
mask_resolution=mask_resolution)
logger.info('Infer iter {}'.format(iter_id)) logger.info('Infer iter {}'.format(iter_id))
bbox_res = None bbox_res = None
mask_res = None mask_res = None
im_ids = outs['im_id']
bbox_num = outs['bbox_num'] bbox_num = outs['bbox_num']
start = 0 start = 0
for i, im_id in enumerate(im_ids): for i, im_id in enumerate(im_ids):
......
...@@ -35,6 +35,7 @@ from ppdet.utils.stats import TrainingStats ...@@ -35,6 +35,7 @@ from ppdet.utils.stats import TrainingStats
from ppdet.utils.check import check_gpu, check_version, check_config from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.cli import ArgsParser from ppdet.utils.cli import ArgsParser
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight, save_model from ppdet.utils.checkpoint import load_weight, load_pretrain_weight, save_model
from export_model import dygraph_to_static
from paddle.distributed import ParallelEnv from paddle.distributed import ParallelEnv
import logging import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s' FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
...@@ -149,6 +150,8 @@ def run(FLAGS, cfg, place): ...@@ -149,6 +150,8 @@ def run(FLAGS, cfg, place):
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
fields = train_loader.collate_fn.output_fields fields = train_loader.collate_fn.output_fields
cfg_name = os.path.basename(FLAGS.config).split('.')[0]
save_dir = os.path.join(cfg.save_dir, cfg_name)
# Run Train # Run Train
time_stat = deque(maxlen=cfg.log_iter) time_stat = deque(maxlen=cfg.log_iter)
start_time = time.time() start_time = time.time()
...@@ -167,7 +170,7 @@ def run(FLAGS, cfg, place): ...@@ -167,7 +170,7 @@ def run(FLAGS, cfg, place):
# Model Forward # Model Forward
model.train() model.train()
outputs = model(data, fields, 'train') outputs = model(data=data, input_def=fields, mode='train')
# Model Backward # Model Backward
loss = outputs['loss'] loss = outputs['loss']
...@@ -193,11 +196,12 @@ def run(FLAGS, cfg, place): ...@@ -193,11 +196,12 @@ def run(FLAGS, cfg, place):
if ParallelEnv().local_rank == 0 and ( if ParallelEnv().local_rank == 0 and (
cur_eid % cfg.snapshot_epoch == 0 or cur_eid % cfg.snapshot_epoch == 0 or
(cur_eid + 1) == int(cfg.epoch)): (cur_eid + 1) == int(cfg.epoch)):
cfg_name = os.path.basename(FLAGS.config).split('.')[0]
save_name = str(cur_eid) if cur_eid + 1 != int( save_name = str(cur_eid) if cur_eid + 1 != int(
cfg.epoch) else "model_final" cfg.epoch) else "model_final"
save_dir = os.path.join(cfg.save_dir, cfg_name)
save_model(model, optimizer, save_dir, save_name, cur_eid + 1) save_model(model, optimizer, save_dir, save_name, cur_eid + 1)
# TODO(guanghua): dygraph model to static model
# if ParallelEnv().local_rank == 0 and (cur_eid + 1) == int(cfg.epoch)):
# dygraph_to_static(model, os.path.join(save_dir, 'static_model_final'), cfg)
def main(): def main():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册