提交 8cad30f0 编写于 作者: u010070587's avatar u010070587 提交者: kolinwei

update ce (#4106)

* modify ce for deeplabv3+

modify ce for auto_dialogue_evaluation

add ce for human_pose_estimation

* modify gru4rec ce
上级 4d3bec4c
...@@ -28,6 +28,10 @@ default_config = { ...@@ -28,6 +28,10 @@ default_config = {
"crop_size": 769, "crop_size": 769,
} }
# used for ce
if 'ce_mode' in os.environ:
np.random.seed(0)
def slice_with_pad(a, s, value=0): def slice_with_pad(a, s, value=0):
pads = [] pads = []
......
...@@ -145,12 +145,6 @@ deeplabv3p = models.deeplabv3p ...@@ -145,12 +145,6 @@ deeplabv3p = models.deeplabv3p
sp = fluid.Program() sp = fluid.Program()
tp = fluid.Program() tp = fluid.Program()
# only for ce
if args.enable_ce:
SEED = 102
sp.random_seed = SEED
tp.random_seed = SEED
crop_size = args.train_crop_size crop_size = args.train_crop_size
batch_size = args.batch_size batch_size = args.batch_size
image_shape = [crop_size, crop_size] image_shape = [crop_size, crop_size]
...@@ -162,6 +156,13 @@ weight_decay = 0.00004 ...@@ -162,6 +156,13 @@ weight_decay = 0.00004
base_lr = args.base_lr base_lr = args.base_lr
total_step = args.total_step total_step = args.total_step
# only for ce
if args.enable_ce:
SEED = 102
sp.random_seed = SEED
tp.random_seed = SEED
reader.default_config['shuffle'] = False
with fluid.program_guard(tp, sp): with fluid.program_guard(tp, sp):
if args.use_py_reader: if args.use_py_reader:
batch_size_each = batch_size // utility.get_device_count() batch_size_each = batch_size // utility.get_device_count()
...@@ -255,7 +256,7 @@ with profile_context(args.profile): ...@@ -255,7 +256,7 @@ with profile_context(args.profile):
train_loss = np.mean(train_loss) train_loss = np.mean(train_loss)
end_time = time.time() end_time = time.time()
total_time += end_time - begin_time total_time += end_time - begin_time
if i % 100 == 0: if i % 100 == 0:
print("Model is saved to", args.save_weights_path) print("Model is saved to", args.save_weights_path)
save_model() save_model()
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
############################################################################## ##############################################################################
"""Data reader for COCO dataset.""" """Data reader for COCO dataset."""
from __future__ import absolute_import from __future__ import absolute_import
...@@ -60,6 +59,7 @@ from pycocotools.coco import COCO ...@@ -60,6 +59,7 @@ from pycocotools.coco import COCO
# [7,9],[8,10],[9,11],[2,3],[1,2],[1,3],[2,4],[3,5],[4,6],[5,7] # [7,9],[8,10],[9,11],[2,3],[1,2],[1,3],[2,4],[3,5],[4,6],[5,7]
# ] # ]
class Config: class Config:
"""Configurations for COCO dataset. """Configurations for COCO dataset.
""" """
...@@ -68,13 +68,14 @@ class Config: ...@@ -68,13 +68,14 @@ class Config:
# For reader # For reader
BUF_SIZE = 102400 BUF_SIZE = 102400
THREAD = 1 if DEBUG else 8 # have to be larger than 0 THREAD = 1 if DEBUG else 8 # have to be larger than 0
# Fixed infos of dataset # Fixed infos of dataset
DATAROOT = 'data/coco' DATAROOT = 'data/coco'
IMAGEDIR = 'images' IMAGEDIR = 'images'
NUM_JOINTS = 17 NUM_JOINTS = 17
FLIP_PAIRS = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]] FLIP_PAIRS = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14],
[15, 16]]
PARENT_IDS = None PARENT_IDS = None
# CFGS # CFGS
...@@ -90,12 +91,15 @@ class Config: ...@@ -90,12 +91,15 @@ class Config:
STD = [0.229, 0.224, 0.225] STD = [0.229, 0.224, 0.225]
PIXEL_STD = 200 PIXEL_STD = 200
cfg = Config() cfg = Config()
def _box2cs(box): def _box2cs(box):
x, y, w, h = box[:4] x, y, w, h = box[:4]
return _xywh2cs(x, y, w, h) return _xywh2cs(x, y, w, h)
def _xywh2cs(x, y, w, h): def _xywh2cs(x, y, w, h):
center = np.zeros((2), dtype=np.float32) center = np.zeros((2), dtype=np.float32)
center[0] = x + w * 0.5 center[0] = x + w * 0.5
...@@ -106,21 +110,20 @@ def _xywh2cs(x, y, w, h): ...@@ -106,21 +110,20 @@ def _xywh2cs(x, y, w, h):
elif w < cfg.ASPECT_RATIO * h: elif w < cfg.ASPECT_RATIO * h:
w = h * cfg.ASPECT_RATIO w = h * cfg.ASPECT_RATIO
scale = np.array( scale = np.array(
[w * 1.0 / cfg.PIXEL_STD, h * 1.0 / cfg.PIXEL_STD], [w * 1.0 / cfg.PIXEL_STD, h * 1.0 / cfg.PIXEL_STD], dtype=np.float32)
dtype=np.float32)
if center[0] != -1: if center[0] != -1:
scale = scale * 1.25 scale = scale * 1.25
return center, scale return center, scale
def _select_data(db): def _select_data(db):
db_selected = [] db_selected = []
for rec in db: for rec in db:
num_vis = 0 num_vis = 0
joints_x = 0.0 joints_x = 0.0
joints_y = 0.0 joints_y = 0.0
for joint, joint_vis in zip( for joint, joint_vis in zip(rec['joints_3d'], rec['joints_3d_vis']):
rec['joints_3d'], rec['joints_3d_vis']):
if joint_vis[0] <= 0: if joint_vis[0] <= 0:
continue continue
num_vis += 1 num_vis += 1
...@@ -135,8 +138,8 @@ def _select_data(db): ...@@ -135,8 +138,8 @@ def _select_data(db):
area = rec['scale'][0] * rec['scale'][1] * (cfg.PIXEL_STD**2) area = rec['scale'][0] * rec['scale'][1] * (cfg.PIXEL_STD**2)
joints_center = np.array([joints_x, joints_y]) joints_center = np.array([joints_x, joints_y])
bbox_center = np.array(rec['center']) bbox_center = np.array(rec['center'])
diff_norm2 = np.linalg.norm((joints_center-bbox_center), 2) diff_norm2 = np.linalg.norm((joints_center - bbox_center), 2)
ks = np.exp(-1.0*(diff_norm2**2) / ((0.2)**2*2.0*area)) ks = np.exp(-1.0 * (diff_norm2**2) / ((0.2)**2 * 2.0 * area))
metric = (0.2 / 16) * num_vis + 0.45 - 0.2 / 16 metric = (0.2 / 16) * num_vis + 0.45 - 0.2 / 16
if ks > metric: if ks > metric:
...@@ -146,7 +149,9 @@ def _select_data(db): ...@@ -146,7 +149,9 @@ def _select_data(db):
print('=> num selected db: {}'.format(len(db_selected))) print('=> num selected db: {}'.format(len(db_selected)))
return db_selected return db_selected
def _load_coco_keypoint_annotation(image_set_index, coco, _coco_ind_to_class_ind, image_set):
def _load_coco_keypoint_annotation(image_set_index, coco,
_coco_ind_to_class_ind, image_set):
"""Ground truth bbox and keypoints. """Ground truth bbox and keypoints.
""" """
print('generating coco gt_db...') print('generating coco gt_db...')
...@@ -168,7 +173,7 @@ def _load_coco_keypoint_annotation(image_set_index, coco, _coco_ind_to_class_ind ...@@ -168,7 +173,7 @@ def _load_coco_keypoint_annotation(image_set_index, coco, _coco_ind_to_class_ind
x2 = np.min((width - 1, x1 + np.max((0, w - 1)))) x2 = np.min((width - 1, x1 + np.max((0, w - 1))))
y2 = np.min((height - 1, y1 + np.max((0, h - 1)))) y2 = np.min((height - 1, y1 + np.max((0, h - 1))))
if obj['area'] > 0 and x2 >= x1 and y2 >= y1: if obj['area'] > 0 and x2 >= x1 and y2 >= y1:
obj['clean_bbox'] = [x1, y1, x2-x1, y2-y1] obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]
valid_objs.append(obj) valid_objs.append(obj)
objs = valid_objs objs = valid_objs
...@@ -197,7 +202,8 @@ def _load_coco_keypoint_annotation(image_set_index, coco, _coco_ind_to_class_ind ...@@ -197,7 +202,8 @@ def _load_coco_keypoint_annotation(image_set_index, coco, _coco_ind_to_class_ind
center, scale = _box2cs(obj['clean_bbox'][:4]) center, scale = _box2cs(obj['clean_bbox'][:4])
rec.append({ rec.append({
'image': os.path.join(cfg.DATAROOT, cfg.IMAGEDIR, image_set+'2017', '%012d.jpg' % index), 'image': os.path.join(cfg.DATAROOT, cfg.IMAGEDIR,
image_set + '2017', '%012d.jpg' % index),
'center': center, 'center': center,
'scale': scale, 'scale': scale,
'joints_3d': joints_3d, 'joints_3d': joints_3d,
...@@ -209,6 +215,7 @@ def _load_coco_keypoint_annotation(image_set_index, coco, _coco_ind_to_class_ind ...@@ -209,6 +215,7 @@ def _load_coco_keypoint_annotation(image_set_index, coco, _coco_ind_to_class_ind
gt_db.extend(rec) gt_db.extend(rec)
return gt_db return gt_db
def data_augmentation(sample, is_train): def data_augmentation(sample, is_train):
image_file = sample['image'] image_file = sample['image']
filename = sample['filename'] if 'filename' in sample else '' filename = sample['filename'] if 'filename' in sample else ''
...@@ -220,28 +227,32 @@ def data_augmentation(sample, is_train): ...@@ -220,28 +227,32 @@ def data_augmentation(sample, is_train):
# imgnum = sample['imgnum'] if 'imgnum' in sample else '' # imgnum = sample['imgnum'] if 'imgnum' in sample else ''
r = 0 r = 0
data_numpy = cv2.imread( # used for ce
image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION) if 'ce_mode' in os.environ:
random.seed(0)
np.random.seed(0)
data_numpy = cv2.imread(image_file, cv2.IMREAD_COLOR |
cv2.IMREAD_IGNORE_ORIENTATION)
if is_train: if is_train:
sf = cfg.SCALE_FACTOR sf = cfg.SCALE_FACTOR
rf = cfg.ROT_FACTOR rf = cfg.ROT_FACTOR
s = s * np.clip(np.random.randn()*sf + 1, 1 - sf, 1 + sf) s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf)
r = np.clip(np.random.randn()*rf, -rf*2, rf*2) \ r = np.clip(np.random.randn()*rf, -rf*2, rf*2) \
if random.random() <= 0.6 else 0 if random.random() <= 0.6 else 0
if cfg.FLIP and random.random() <= 0.5: if cfg.FLIP and random.random() <= 0.5:
data_numpy = data_numpy[:, ::-1, :] data_numpy = data_numpy[:, ::-1, :]
joints, joints_vis = fliplr_joints( joints, joints_vis = fliplr_joints(
joints, joints_vis, data_numpy.shape[1], cfg.FLIP_PAIRS) joints, joints_vis, data_numpy.shape[1], cfg.FLIP_PAIRS)
c[0] = data_numpy.shape[1] - c[0] - 1 c[0] = data_numpy.shape[1] - c[0] - 1
trans = get_affine_transform(c, s, r, cfg.IMAGE_SIZE) trans = get_affine_transform(c, s, r, cfg.IMAGE_SIZE)
input = cv2.warpAffine( input = cv2.warpAffine(
data_numpy, data_numpy,
trans, trans, (int(cfg.IMAGE_SIZE[0]), int(cfg.IMAGE_SIZE[1])),
(int(cfg.IMAGE_SIZE[0]), int(cfg.IMAGE_SIZE[1])), flags=cv2.INTER_LINEAR)
flags=cv2.INTER_LINEAR)
for i in range(cfg.NUM_JOINTS): for i in range(cfg.NUM_JOINTS):
if joints_vis[i, 0] > 0.0: if joints_vis[i, 0] > 0.0:
...@@ -263,23 +274,30 @@ def data_augmentation(sample, is_train): ...@@ -263,23 +274,30 @@ def data_augmentation(sample, is_train):
else: else:
return input, target, target_weight, c, s, score, image_file return input, target, target_weight, c, s, score, image_file
# Create a reader
def _reader_creator(root, image_set, shuffle=False, is_train=False, use_gt_bbox=False):
# Create a reader
def _reader_creator(root,
image_set,
shuffle=False,
is_train=False,
use_gt_bbox=False):
def reader(): def reader():
if image_set in ['train', 'val']: if image_set in ['train', 'val']:
file_name = os.path.join(root, 'annotations', 'person_keypoints_'+image_set+'2017.json') file_name = os.path.join(
root, 'annotations',
'person_keypoints_' + image_set + '2017.json')
elif image_set in ['test', 'test-dev']: elif image_set in ['test', 'test-dev']:
file_name = os.path.join(root, 'annotations', 'image_info_'+image_set+'2017.json') file_name = os.path.join(root, 'annotations',
'image_info_' + image_set + '2017.json')
else: else:
raise ValueError("The dataset '{}' is not supported".format(image_set)) raise ValueError("The dataset '{}' is not supported".format(
image_set))
# Load annotations # Load annotations
coco = COCO(file_name) coco = COCO(file_name)
# Deal with class names # Deal with class names
cats = [cat['name'] cats = [cat['name'] for cat in coco.loadCats(coco.getCatIds())]
for cat in coco.loadCats(coco.getCatIds())]
classes = ['__background__'] + cats classes = ['__background__'] + cats
print('=> classes: {}'.format(classes)) print('=> classes: {}'.format(classes))
num_classes = len(classes) num_classes = len(classes)
...@@ -287,7 +305,7 @@ def _reader_creator(root, image_set, shuffle=False, is_train=False, use_gt_bbox= ...@@ -287,7 +305,7 @@ def _reader_creator(root, image_set, shuffle=False, is_train=False, use_gt_bbox=
_class_to_coco_ind = dict(zip(cats, coco.getCatIds())) _class_to_coco_ind = dict(zip(cats, coco.getCatIds()))
_coco_ind_to_class_ind = dict([(_class_to_coco_ind[cls], _coco_ind_to_class_ind = dict([(_class_to_coco_ind[cls],
_class_to_ind[cls]) _class_to_ind[cls])
for cls in classes[1:]]) for cls in classes[1:]])
# Load image file names # Load image file names
image_set_index = coco.getImgIds() image_set_index = coco.getImgIds()
...@@ -296,7 +314,7 @@ def _reader_creator(root, image_set, shuffle=False, is_train=False, use_gt_bbox= ...@@ -296,7 +314,7 @@ def _reader_creator(root, image_set, shuffle=False, is_train=False, use_gt_bbox=
if is_train or use_gt_bbox: if is_train or use_gt_bbox:
gt_db = _load_coco_keypoint_annotation( gt_db = _load_coco_keypoint_annotation(
image_set_index, coco, _coco_ind_to_class_ind, image_set) image_set_index, coco, _coco_ind_to_class_ind, image_set)
gt_db = _select_data(gt_db) gt_db = _select_data(gt_db)
if shuffle: if shuffle:
...@@ -308,23 +326,40 @@ def _reader_creator(root, image_set, shuffle=False, is_train=False, use_gt_bbox= ...@@ -308,23 +326,40 @@ def _reader_creator(root, image_set, shuffle=False, is_train=False, use_gt_bbox=
mapper = functools.partial(data_augmentation, is_train=is_train) mapper = functools.partial(data_augmentation, is_train=is_train)
return reader, mapper return reader, mapper
def train(): def train():
reader, mapper = _reader_creator(cfg.DATAROOT, 'train', shuffle=True, is_train=True) reader, mapper = _reader_creator(
cfg.DATAROOT, 'train', shuffle=True, is_train=True)
# used for ce
if 'ce_mode' in os.environ:
reader, mapper = _reader_creator(
cfg.DATAROOT, 'train', shuffle=False, is_train=True)
def pop(): def pop():
for i, x in enumerate(reader()): for i, x in enumerate(reader()):
yield mapper(x) yield mapper(x)
return pop return pop
def valid(): def valid():
reader, mapper = _reader_creator(cfg.DATAROOT, 'val', shuffle=False, is_train=False, use_gt_bbox=True) reader, mapper = _reader_creator(
cfg.DATAROOT, 'val', shuffle=False, is_train=False, use_gt_bbox=True)
def pop(): def pop():
for i, x in enumerate(reader()): for i, x in enumerate(reader()):
yield mapper(x) yield mapper(x)
return pop return pop
def test(): def test():
reader, mapper = _reader_creator(cfg.DATAROOT, 'test', shuffle=False, is_train=False, use_gt_bbox=True) reader, mapper = _reader_creator(
cfg.DATAROOT, 'test', shuffle=False, is_train=False, use_gt_bbox=True)
def pop(): def pop():
for i, x in enumerate(reader()): for i, x in enumerate(reader()):
yield mapper(x) yield mapper(x)
return pop return pop
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
############################################################################## ##############################################################################
"""Functions for training.""" """Functions for training."""
import os import os
...@@ -42,8 +41,10 @@ add_arg('pretrained_model', str, "pretrained/resnet_50/115", "Whether to use ...@@ -42,8 +41,10 @@ add_arg('pretrained_model', str, "pretrained/resnet_50/115", "Whether to use
add_arg('checkpoint', str, None, "Whether to resume checkpoint.") add_arg('checkpoint', str, None, "Whether to resume checkpoint.")
add_arg('lr', float, 0.001, "Set learning rate.") add_arg('lr', float, 0.001, "Set learning rate.")
add_arg('lr_strategy', str, "piecewise_decay", "Set the learning rate decay strategy.") add_arg('lr_strategy', str, "piecewise_decay", "Set the learning rate decay strategy.")
add_arg('enable_ce', bool, False, "If set True, enable continuous evaluation job.")
# yapf: enable # yapf: enable
def optimizer_setting(args, params): def optimizer_setting(args, params):
lr_drop_ratio = 0.1 lr_drop_ratio = 0.1
...@@ -64,8 +65,8 @@ def optimizer_setting(args, params): ...@@ -64,8 +65,8 @@ def optimizer_setting(args, params):
# AdamOptimizer # AdamOptimizer
optimizer = paddle.fluid.optimizer.AdamOptimizer( optimizer = paddle.fluid.optimizer.AdamOptimizer(
learning_rate=fluid.layers.piecewise_decay( learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr)) boundaries=bd, values=lr))
else: else:
lr = params["lr"] lr = params["lr"]
optimizer = fluid.optimizer.Momentum( optimizer = fluid.optimizer.Momentum(
...@@ -87,28 +88,41 @@ def train(args): ...@@ -87,28 +88,41 @@ def train(args):
IMAGE_SIZE = [288, 384] IMAGE_SIZE = [288, 384]
HEATMAP_SIZE = [72, 96] HEATMAP_SIZE = [72, 96]
args.kp_dim = 17 args.kp_dim = 17
args.total_images = 144406 # 149813 args.total_images = 144406 # 149813
elif args.dataset == 'mpii': elif args.dataset == 'mpii':
import lib.mpii_reader as reader import lib.mpii_reader as reader
IMAGE_SIZE = [384, 384] IMAGE_SIZE = [384, 384]
HEATMAP_SIZE = [96, 96] HEATMAP_SIZE = [96, 96]
args.kp_dim = 16 args.kp_dim = 16
args.total_images = 22246 args.total_images = 22246
else: else:
raise ValueError('The dataset {} is not supported yet.'.format(args.dataset)) raise ValueError('The dataset {} is not supported yet.'.format(
args.dataset))
print_arguments(args) print_arguments(args)
# Image and target # Image and target
image = layers.data(name='image', shape=[3, IMAGE_SIZE[1], IMAGE_SIZE[0]], dtype='float32') image = layers.data(
target = layers.data(name='target', shape=[args.kp_dim, HEATMAP_SIZE[1], HEATMAP_SIZE[0]], dtype='float32') name='image', shape=[3, IMAGE_SIZE[1], IMAGE_SIZE[0]], dtype='float32')
target_weight = layers.data(name='target_weight', shape=[args.kp_dim, 1], dtype='float32') target = layers.data(
name='target',
shape=[args.kp_dim, HEATMAP_SIZE[1], HEATMAP_SIZE[0]],
dtype='float32')
target_weight = layers.data(
name='target_weight', shape=[args.kp_dim, 1], dtype='float32')
# used for ce
if args.enable_ce:
fluid.default_startup_program().random_seed = 90
fluid.default_main_program().random_seed = 90
# Build model # Build model
model = pose_resnet.ResNet(layers=50, kps_num=args.kp_dim) model = pose_resnet.ResNet(layers=50, kps_num=args.kp_dim)
# Output # Output
loss, output = model.net(input=image, target=target, target_weight=target_weight) loss, output = model.net(input=image,
target=target,
target_weight=target_weight)
# Parameters from model and arguments # Parameters from model and arguments
params = {} params = {}
...@@ -127,11 +141,13 @@ def train(args): ...@@ -127,11 +141,13 @@ def train(args):
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
if args.pretrained_model: if args.pretrained_model:
def if_exist(var): def if_exist(var):
exist_flag = os.path.exists(os.path.join(args.pretrained_model, var.name)) exist_flag = os.path.exists(
os.path.join(args.pretrained_model, var.name))
return exist_flag return exist_flag
fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist) fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist)
if args.checkpoint is not None: if args.checkpoint is not None:
...@@ -139,7 +155,8 @@ def train(args): ...@@ -139,7 +155,8 @@ def train(args):
# Dataloader # Dataloader
train_reader = paddle.batch(reader.train(), batch_size=args.batch_size) train_reader = paddle.batch(reader.train(), batch_size=args.batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image, target, target_weight]) feeder = fluid.DataFeeder(
place=place, feed_list=[image, target, target_weight])
train_exe = fluid.ParallelExecutor( train_exe = fluid.ParallelExecutor(
use_cuda=True if args.use_gpu else False, loss_name=loss.name) use_cuda=True if args.use_gpu else False, loss_name=loss.name)
...@@ -147,29 +164,40 @@ def train(args): ...@@ -147,29 +164,40 @@ def train(args):
for pass_id in range(params["num_epochs"]): for pass_id in range(params["num_epochs"]):
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
current_lr = np.array(paddle.fluid.global_scope().find_var('learning_rate').get_tensor()) current_lr = np.array(paddle.fluid.global_scope().find_var(
'learning_rate').get_tensor())
input_image, loss, out_heatmaps = train_exe.run( input_image, loss, out_heatmaps = train_exe.run(
fetch_list, feed=feeder.feed(data)) fetch_list, feed=feeder.feed(data))
loss = np.mean(np.array(loss)) loss = np.mean(np.array(loss))
print_immediately('Epoch [{:4d}/{:3d}] LR: {:.10f} ' print_immediately('Epoch [{:4d}/{:3d}] LR: {:.10f} '
'Loss = {:.5f}'.format( 'Loss = {:.5f}'.format(batch_id, pass_id,
batch_id, pass_id, current_lr[0], loss)) current_lr[0], loss))
if batch_id % 10 == 0: if batch_id % 10 == 0:
save_batch_heatmaps(input_image, out_heatmaps, file_name='visualization@train.jpg', normalize=True) save_batch_heatmaps(
input_image,
model_path = os.path.join(args.model_save_dir + '/' + 'simplebase-{}'.format(args.dataset), out_heatmaps,
str(pass_id)) file_name='visualization@train.jpg',
normalize=True)
model_path = os.path.join(
args.model_save_dir + '/' + 'simplebase-{}'.format(args.dataset),
str(pass_id))
if not os.path.isdir(model_path): if not os.path.isdir(model_path):
os.makedirs(model_path) os.makedirs(model_path)
fluid.io.save_persistables(exe, model_path) fluid.io.save_persistables(exe, model_path)
# used for ce
if args.enable_ce:
device_num = fluid.core.get_cuda_device_count() if args.use_gpu else 1
print("kpis\t{}_train_cost_card{}\t{:.5f}".format(args.dataset,
device_num, loss))
if __name__ == '__main__': if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
check_cuda(args.use_gpu) check_cuda(args.use_gpu)
train(args) train(args)
...@@ -19,45 +19,42 @@ import sys ...@@ -19,45 +19,42 @@ import sys
import time import time
import random import random
import numpy as np import numpy as np
import os
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
class DataProcessor(object): class DataProcessor(object):
def __init__(self, data_path, max_seq_length, batch_size): def __init__(self, data_path, max_seq_length, batch_size):
"""init""" """init"""
self.data_file = data_path self.data_file = data_path
self.max_seq_len = max_seq_length self.max_seq_len = max_seq_length
self.batch_size = batch_size self.batch_size = batch_size
self.num_examples = {'train': -1, 'dev': -1, 'test': -1} self.num_examples = {'train': -1, 'dev': -1, 'test': -1}
def get_examples(self): def get_examples(self):
"""load examples""" """load examples"""
examples = [] examples = []
index = 0 index = 0
fr = io.open(self.data_file, 'r', encoding="utf8") fr = io.open(self.data_file, 'r', encoding="utf8")
for line in fr: for line in fr:
if index !=0 and index % 100 == 0: if index != 0 and index % 100 == 0:
print("processing data: %d" % index) print("processing data: %d" % index)
index += 1 index += 1
examples.append(line.strip()) examples.append(line.strip())
return examples return examples
def get_num_examples(self, phase): def get_num_examples(self, phase):
"""Get number of examples for train, dev or test.""" """Get number of examples for train, dev or test."""
if phase not in ['train', 'dev', 'test']: if phase not in ['train', 'dev', 'test']:
raise ValueError( raise ValueError(
"Unknown phase, which should be in ['train', 'dev', 'test'].") "Unknown phase, which should be in ['train', 'dev', 'test'].")
count = len(io.open(self.data_file, 'r', encoding="utf8").readlines()) count = len(io.open(self.data_file, 'r', encoding="utf8").readlines())
self.num_examples[phase] = count self.num_examples[phase] = count
return self.num_examples[phase] return self.num_examples[phase]
def data_generator(self, def data_generator(self, place, phase="train", shuffle=True, sample_pro=1):
place,
phase="train",
shuffle=True,
sample_pro=1):
""" """
Generate data for train, dev or test. Generate data for train, dev or test.
...@@ -67,25 +64,34 @@ class DataProcessor(object): ...@@ -67,25 +64,34 @@ class DataProcessor(object):
sample_pro: sample data ratio sample_pro: sample data ratio
""" """
examples = self.get_examples() examples = self.get_examples()
if shuffle:
# used for ce
if 'ce_mode' in os.environ:
np.random.seed(0)
random.seed(0)
shuffle = False
if shuffle:
np.random.shuffle(examples) np.random.shuffle(examples)
def batch_reader(): def batch_reader():
"""read batch data""" """read batch data"""
batch = [] batch = []
for example in examples: for example in examples:
if sample_pro < 1: if sample_pro < 1:
if random.random() > sample_pro: if random.random() > sample_pro:
continue continue
tokens = example.strip().split('\t') tokens = example.strip().split('\t')
if len(tokens) != 3: if len(tokens) != 3:
print("data format error: %s" % example.strip()) print("data format error: %s" % example.strip())
print("please input data: context \t response \t label") print("please input data: context \t response \t label")
continue continue
context = [int(x) for x in tokens[0].split()[: self.max_seq_len]] context = [int(x) for x in tokens[0].split()[:self.max_seq_len]]
response = [int(x) for x in tokens[1].split()[: self.max_seq_len]] response = [
int(x) for x in tokens[1].split()[:self.max_seq_len]
]
label = [int(tokens[2])] label = [int(tokens[2])]
instance = (context, response, label) instance = (context, response, label)
...@@ -96,15 +102,15 @@ class DataProcessor(object): ...@@ -96,15 +102,15 @@ class DataProcessor(object):
yield batch yield batch
batch = [instance] batch = [instance]
if len(batch) > 0: if len(batch) > 0:
yield batch yield batch
def create_lodtensor(data_ids, place): def create_lodtensor(data_ids, place):
"""create LodTensor for input ids""" """create LodTensor for input ids"""
cur_len = 0 cur_len = 0
lod = [cur_len] lod = [cur_len]
seq_lens = [len(ids) for ids in data_ids] seq_lens = [len(ids) for ids in data_ids]
for l in seq_lens: for l in seq_lens:
cur_len += l cur_len += l
lod.append(cur_len) lod.append(cur_len)
flattened_data = np.concatenate(data_ids, axis=0).astype("int64") flattened_data = np.concatenate(data_ids, axis=0).astype("int64")
...@@ -114,9 +120,9 @@ class DataProcessor(object): ...@@ -114,9 +120,9 @@ class DataProcessor(object):
res.set_lod([lod]) res.set_lod([lod])
return res return res
def wrapper(): def wrapper():
"""yield batch data to network""" """yield batch data to network"""
for batch_data in batch_reader(): for batch_data in batch_reader():
context_ids = [batch[0] for batch in batch_data] context_ids = [batch[0] for batch in batch_data]
response_ids = [batch[1] for batch in batch_data] response_ids = [batch[1] for batch in batch_data]
label_ids = [batch[2] for batch in batch_data] label_ids = [batch[2] for batch in batch_data]
...@@ -125,6 +131,5 @@ class DataProcessor(object): ...@@ -125,6 +131,5 @@ class DataProcessor(object):
label_ids = np.array(label_ids).astype("int64").reshape([-1, 1]) label_ids = np.array(label_ids).astype("int64").reshape([-1, 1])
input_batch = [context_res, response_res, label_ids] input_batch = [context_res, response_res, label_ids]
yield input_batch yield input_batch
return wrapper
return wrapper
...@@ -100,7 +100,7 @@ def prepare_data(file_dir, ...@@ -100,7 +100,7 @@ def prepare_data(file_dir,
is_train=True): is_train=True):
""" prepare the English Pann Treebank (PTB) data """ """ prepare the English Pann Treebank (PTB) data """
print("start constuct word dict") print("start constuct word dict")
if is_train: if is_train and 'ce_mode' not in os.environ:
vocab_size = get_vocab_size(vocab_path) vocab_size = get_vocab_size(vocab_path)
reader = sort_batch( reader = sort_batch(
paddle.reader.shuffle( paddle.reader.shuffle(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册