提交 88407fab 编写于 作者: C chengyuz

Merge branch 'develop' of https://github.com/PaddlePaddle/models into yucheng

......@@ -211,13 +211,12 @@ def main(train_data_file, test_data_file, model_save_dir, num_passes):
avg_cost, feature_out, word, mention, target = ner_net(word_dict_len,
label_dict_len)
crf_decode = fluid.layers.crf_decoding(
input=feature_out, param_attr=fluid.ParamAttr(name='crfw'))
sgd_optimizer = fluid.optimizer.SGD(learning_rate=1e-3)
sgd_optimizer.minimize(avg_cost)
crf_decode = fluid.layers.crf_decoding(
input=feature_out, param_attr=fluid.ParamAttr(
name='crfw', ))
(precision, recall, f1_score, num_infer_chunks, num_label_chunks,
num_correct_chunks) = fluid.layers.chunk_eval(
input=crf_decode,
......@@ -289,8 +288,8 @@ def main(train_data_file, test_data_file, model_save_dir, num_passes):
+ str(f1))
save_dirname = os.path.join(model_save_dir,
"params_pass_%d" % pass_id)
fluid.io.save_inference_model(
save_dirname, ['word', 'mention', 'target'], [crf_decode], exe)
fluid.io.save_inference_model(save_dirname, ['word', 'mention'],
[crf_decode], exe)
if __name__ == "__main__":
......
model/
data/
label/
pretrained/
*.swp
......@@ -17,6 +17,8 @@ class sampler():
max_aspect_ratio,
min_jaccard_overlap,
max_jaccard_overlap,
min_object_coverage,
max_object_coverage,
use_square=False):
self.max_sample = max_sample
self.max_trial = max_trial
......@@ -26,6 +28,8 @@ class sampler():
self.max_aspect_ratio = max_aspect_ratio
self.min_jaccard_overlap = min_jaccard_overlap
self.max_jaccard_overlap = max_jaccard_overlap
self.min_object_coverage = min_object_coverage
self.max_object_coverage = max_object_coverage
self.use_square = use_square
......@@ -37,10 +41,36 @@ class bbox():
self.ymax = ymax
def intersect_bbox(bbox1, bbox2):
if bbox2.xmin > bbox1.xmax or bbox2.xmax < bbox1.xmin or \
bbox2.ymin > bbox1.ymax or bbox2.ymax < bbox1.ymin:
intersection_box = bbox(0.0, 0.0, 0.0, 0.0)
else:
intersection_box = bbox(
max(bbox1.xmin, bbox2.xmin),
max(bbox1.ymin, bbox2.ymin),
min(bbox1.xmax, bbox2.xmax), min(bbox1.ymax, bbox2.ymax))
return intersection_box
def bbox_coverage(bbox1, bbox2):
inter_box = intersect_bbox(bbox1, bbox2)
intersect_size = bbox_area(inter_box)
if intersect_size > 0:
bbox1_size = bbox_area(bbox1)
return intersect_size / bbox1_size
else:
return 0.
def bbox_area(src_bbox):
width = src_bbox.xmax - src_bbox.xmin
height = src_bbox.ymax - src_bbox.ymin
return width * height
if src_bbox.xmax < src_bbox.xmin or src_bbox.ymax < src_bbox.ymin:
return 0.
else:
width = src_bbox.xmax - src_bbox.xmin
height = src_bbox.ymax - src_bbox.ymin
return width * height
def generate_sample(sampler, image_width, image_height):
......@@ -91,20 +121,41 @@ def jaccard_overlap(sample_bbox, object_bbox):
def satisfy_sample_constraint(sampler, sample_bbox, bbox_labels):
if sampler.min_jaccard_overlap == 0 and sampler.max_jaccard_overlap == 0:
has_jaccard_overlap = False
else:
has_jaccard_overlap = True
if sampler.min_object_coverage == 0 and sampler.max_object_coverage == 0:
has_object_coverage = False
else:
has_object_coverage = True
if not has_jaccard_overlap and not has_object_coverage:
return True
found = False
for i in range(len(bbox_labels)):
object_bbox = bbox(bbox_labels[i][0], bbox_labels[i][1],
bbox_labels[i][2], bbox_labels[i][3])
# now only support constraint by jaccard overlap
overlap = jaccard_overlap(sample_bbox, object_bbox)
if sampler.min_jaccard_overlap != 0 and \
overlap < sampler.min_jaccard_overlap:
continue
if sampler.max_jaccard_overlap != 0 and \
overlap > sampler.max_jaccard_overlap:
continue
return True
return False
object_bbox = bbox(bbox_labels[i][1], bbox_labels[i][2],
bbox_labels[i][3], bbox_labels[i][4])
if has_jaccard_overlap:
overlap = jaccard_overlap(sample_bbox, object_bbox)
if sampler.min_jaccard_overlap != 0 and \
overlap < sampler.min_jaccard_overlap:
continue
if sampler.max_jaccard_overlap != 0 and \
overlap > sampler.max_jaccard_overlap:
continue
found = True
if has_object_coverage:
object_coverage = bbox_coverage(object_bbox, sample_bbox)
if sampler.min_object_coverage != 0 and \
object_coverage < sampler.min_object_coverage:
continue
if sampler.max_object_coverage != 0 and \
object_coverage > sampler.max_object_coverage:
continue
found = True
if found:
return True
return found
def generate_batch_samples(batch_sampler, bbox_labels, image_width,
......@@ -170,8 +221,8 @@ def transform_labels(bbox_labels, sample_bbox):
sample_labels = []
for i in range(len(bbox_labels)):
sample_label = []
object_bbox = bbox(bbox_labels[i][0], bbox_labels[i][1],
bbox_labels[i][2], bbox_labels[i][3])
object_bbox = bbox(bbox_labels[i][1], bbox_labels[i][2],
bbox_labels[i][3], bbox_labels[i][4])
if not meet_emit_constraint(object_bbox, sample_bbox):
continue
proj_bbox = project_bbox(object_bbox, sample_bbox)
......
import os
import time
import numpy as np
import argparse
import functools
from PIL import Image
from PIL import ImageDraw
import paddle
import paddle.fluid as fluid
import reader
from pyramidbox import PyramidBox
from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('use_pyramidbox', bool, False, "Whether use PyramidBox model.")
add_arg('confs_threshold', float, 0.15, "Confidence threshold to draw bbox.")
add_arg('image_path', str, '', "The data root path.")
add_arg('model_dir', str, '', "The model path.")
add_arg('resize_h', int, 0, "The resized image height.")
add_arg('resize_w', int, 0, "The resized image height.")
# yapf: enable
def draw_bounding_box_on_image(image_path, nms_out, confs_threshold):
image = Image.open(image_path)
draw = ImageDraw.Draw(image)
im_width, im_height = image.size
for dt in nms_out:
category_id, score, xmin, ymin, xmax, ymax = dt.tolist()
if score < confs_threshold:
continue
bbox = dt[2:]
xmin, ymin, xmax, ymax = bbox
(left, right, top, bottom) = (xmin * im_width, xmax * im_width,
ymin * im_height, ymax * im_height)
draw.line(
[(left, top), (left, bottom), (right, bottom), (right, top),
(left, top)],
width=4,
fill='red')
image_name = image_path.split('/')[-1]
print("image with bbox drawed saved as {}".format(image_name))
image.save(image_name)
def infer(args, data_args):
num_classes = 2
infer_reader = reader.infer(data_args, args.image_path)
data = infer_reader()
if args.resize_h and args.resize_w:
image_shape = [3, args.resize_h, args.resize_w]
else:
image_shape = data.shape[1:]
fetches = []
network = PyramidBox(
image_shape,
num_classes,
sub_network=args.use_pyramidbox,
is_infer=True)
infer_program, nmsed_out = network.infer()
fetches = [nmsed_out]
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
model_dir = args.model_dir
if not os.path.exists(model_dir):
raise ValueError("The model path [%s] does not exist." % (model_dir))
def if_exist(var):
return os.path.exists(os.path.join(model_dir, var.name))
fluid.io.load_vars(exe, model_dir, predicate=if_exist)
feed = {'image': fluid.create_lod_tensor(data, [], place)}
predict, = exe.run(infer_program,
feed=feed,
fetch_list=fetches,
return_numpy=False)
predict = np.array(predict)
draw_bounding_box_on_image(args.image_path, predict, args.confs_threshold)
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
data_dir = 'data/WIDERFACE/WIDER_val/images/'
file_list = 'label/val_gt_widerface.res'
data_args = reader.Settings(
data_dir=data_dir,
resize_h=args.resize_h,
resize_w=args.resize_w,
mean_value=[104., 117., 123],
apply_distort=False,
apply_expand=False,
ap_version='11point')
infer(args, data_args=data_args)
......@@ -45,26 +45,45 @@ def conv_block(input, groups, filters, ksizes, strides=None, with_pool=True):
class PyramidBox(object):
def __init__(self, data_shape, is_infer=False, sub_network=False):
def __init__(self,
data_shape,
num_classes,
is_infer=False,
sub_network=False):
self.data_shape = data_shape
self.min_sizes = [16., 32., 64., 128., 256., 512.]
self.steps = [4., 8., 16., 32., 64., 128.]
self.is_infer = is_infer
self.sub_network = sub_network
self.num_classes = num_classes
# the base network is VGG with atrus layers
# the base network is VGG with atrous layers
self._input()
self._vgg()
if sub_network:
self._low_level_fpn()
self._cpm_module()
self._pyramidbox()
else:
self._vgg_ssd()
def feeds(self):
if self.is_infer:
return [self.image]
else:
return [
self.image, self.face_box, self.head_box, self.gt_label,
self.difficult
]
def _input(self):
self.image = fluid.layers.data(
name='image', shape=self.data_shape, dtype='float32')
if not self.is_infer:
self.gt_box = fluid.layers.data(
name='gt_box', shape=[4], dtype='float32', lod_level=1)
self.face_box = fluid.layers.data(
name='face_box', shape=[4], dtype='float32', lod_level=1)
self.head_box = fluid.layers.data(
name='head_box', shape=[4], dtype='float32', lod_level=1)
self.gt_label = fluid.layers.data(
name='gt_label', shape=[1], dtype='int32', lod_level=1)
self.difficult = fluid.layers.data(
......@@ -176,9 +195,10 @@ class PyramidBox(object):
"""
Get prior-boxes and pyramid-box
"""
self.ssh_conv3_norm = self._l2_norm_scale(self.ssh_conv3)
self.ssh_conv4_norm = self._l2_norm_scale(self.ssh_conv4)
self.ssh_conv5_norm = self._l2_norm_scale(self.ssh_conv5)
self.ssh_conv3_norm = self._l2_norm_scale(
self.ssh_conv3, init_scale=10.)
self.ssh_conv4_norm = self._l2_norm_scale(self.ssh_conv4, init_scale=8.)
self.ssh_conv5_norm = self._l2_norm_scale(self.ssh_conv5, init_scale=5.)
def permute_and_reshape(input, last_dim):
trans = fluid.layers.transpose(input, perm=[0, 2, 3, 1])
......@@ -241,10 +261,10 @@ class PyramidBox(object):
self.prior_boxes = fluid.layers.concat(boxes)
self.box_vars = fluid.layers.concat(vars)
def vgg_ssd(self, num_classes, image_shape):
self.conv3_norm = self._l2_norm_scale(self.conv3)
self.conv4_norm = self._l2_norm_scale(self.conv4)
self.conv5_norm = self._l2_norm_scale(self.conv5)
def _vgg_ssd(self):
self.conv3_norm = self._l2_norm_scale(self.conv3, init_scale=10.)
self.conv4_norm = self._l2_norm_scale(self.conv4, init_scale=8.)
self.conv5_norm = self._l2_norm_scale(self.conv5, init_scale=5.)
mbox_locs, mbox_confs, box, box_var = fluid.layers.multi_box_head(
inputs=[
......@@ -252,40 +272,47 @@ class PyramidBox(object):
self.conv7, self.conv8
],
image=self.image,
num_classes=num_classes,
# min_ratio=20,
# max_ratio=90,
num_classes=self.num_classes,
min_sizes=[16.0, 32.0, 64.0, 128.0, 256.0, 512.0],
max_sizes=[[], [], [], [], [], []],
# max_sizes=[[], 150.0, 195.0, 240.0, 285.0, 300.0],
aspect_ratios=[[1.], [1.], [1.], [1.], [1.], [1.]],
steps=[4.0, 8.0, 16.0, 32.0, 64.0, 128.0],
base_size=image_shape[2],
base_size=self.data_shape[2],
offset=0.5,
flip=False)
# locs, confs, box, box_var = vgg_extra_net(num_classes, image, image_shape)
# nmsed_out = fluid.layers.detection_output(
# locs, confs, box, box_var, nms_threshold=args.nms_threshold)
loss = fluid.layers.ssd_loss(mbox_locs, mbox_confs, self.gt_box,
self.gt_label, box, box_var)
self.face_mbox_loc = mbox_locs
self.face_mbox_conf = mbox_confs
self.prior_boxes = box
self.box_vars = box_var
def vgg_ssd_loss(self):
loss = fluid.layers.ssd_loss(
self.face_mbox_loc,
self.face_mbox_conf,
self.face_box,
self.gt_label,
self.prior_boxes,
self.box_vars,
overlap_threshold=0.35,
neg_overlap=0.35)
loss = fluid.layers.reduce_sum(loss)
return loss
def train(self):
face_loss = fluid.layers.ssd_loss(
self.face_mbox_loc, self.face_mbox_conf, self.gt_box, self.gt_label,
self.prior_boxes, self.box_vars)
self.face_mbox_loc, self.face_mbox_conf, self.face_box,
self.gt_label, self.prior_boxes, self.box_vars)
head_loss = fluid.layers.ssd_loss(
self.head_mbox_loc, self.head_mbox_conf, self.gt_box, self.gt_label,
self.prior_boxes, self.box_vars)
self.head_mbox_loc, self.head_mbox_conf, self.head_box,
self.gt_label, self.prior_boxes, self.box_vars)
face_loss = fluid.layers.reduce_sum(face_loss)
head_loss = fluid.layers.reduce_sum(head_loss)
total_loss = face_loss + head_loss
return face_loss, head_loss, total_loss
def test(self):
def infer(self):
test_program = fluid.default_main_program().clone(for_test=True)
with fluid.program_guard(test_program):
face_nmsed_out = fluid.layers.detection_output(
......@@ -294,24 +321,4 @@ class PyramidBox(object):
self.prior_boxes,
self.box_vars,
nms_threshold=0.45)
head_nmsed_out = fluid.layers.detection_output(
self.head_mbox_loc,
self.head_mbox_conf,
self.prior_boxes,
self.box_vars,
nms_threshold=0.45)
face_map_eval = fluid.evaluator.DetectionMAP(
face_nmsed_out,
self.gt_label,
self.gt_box,
class_num=2,
overlap_threshold=0.5,
ap_version='11point')
head_map_eval = fluid.evaluator.DetectionMAP(
head_nmsed_out,
self.gt_label,
self.gt_box,
class_num=2,
overlap_threshold=0.5,
ap_version='11point')
return test_program, face_map_eval, head_map_eval
return test_program, face_nmsed_out
......@@ -72,7 +72,7 @@ class Settings(object):
return self._toy
@property
def apply_distort(self):
def apply_expand(self):
return self._apply_expand
@property
......@@ -117,15 +117,20 @@ def preprocess(img, bbox_labels, mode, settings):
batch_sampler = []
# hard-code here
batch_sampler.append(
image_util.sampler(1, 50, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, True))
image_util.sampler(1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0,
True))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 1.0, 0.0, True))
image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0,
True))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 1.0, 0.0, True))
image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0,
True))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 1.0, 0.0, True))
image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0,
True))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 1.0, 0.0, True))
image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0,
True))
sampled_bbox = image_util.generate_batch_samples(
batch_sampler, bbox_labels, img_width, img_height)
......@@ -190,6 +195,30 @@ def put_txt_in_dict(input_txt):
return dict_input_txt
def expand_bboxes(bboxes,
expand_left=2.,
expand_up=2.,
expand_right=2.,
expand_down=2.):
"""
Expand bboxes, expand 2 times by defalut.
"""
expand_boxes = []
for bbox in bboxes:
xmin = bbox[0]
ymin = bbox[1]
xmax = bbox[2]
ymax = bbox[3]
w = xmax - xmin
h = ymax - ymin
ex_xmin = max(xmin - w / expand_left, 0.)
ex_ymin = max(ymin - h / expand_up, 0.)
ex_xmax = min(xmax + w / expand_right, 1.)
ex_ymax = min(ymax + h / expand_down, 1.)
expand_boxes.append([ex_xmin, ex_ymin, ex_xmax, ex_ymax])
return expand_boxes
def pyramidbox(settings, file_list, mode, shuffle):
dict_input_txt = {}
......@@ -208,12 +237,11 @@ def pyramidbox(settings, file_list, mode, shuffle):
im = im.convert('RGB')
im_width, im_height = im.size
# layout: category_id | xmin | ymin | xmax | ymax | iscrowd
# layout: label | xmin | ymin | xmax | ymax
bbox_labels = []
for index_box in range(len(dict_input_txt[index_image])):
if index_box >= 2:
bbox_sample = []
temp_info_box = dict_input_txt[index_image][
index_box].split(' ')
xmin = float(temp_info_box[0])
......@@ -223,6 +251,7 @@ def pyramidbox(settings, file_list, mode, shuffle):
xmax = xmin + w
ymax = ymin + h
bbox_sample.append(1)
bbox_sample.append(float(xmin) / im_width)
bbox_sample.append(float(ymin) / im_height)
bbox_sample.append(float(xmax) / im_width)
......@@ -233,11 +262,10 @@ def pyramidbox(settings, file_list, mode, shuffle):
sample_labels = np.array(sample_labels)
if len(sample_labels) == 0: continue
im = im.astype('float32')
boxes = sample_labels[:, 0:4]
boxes = sample_labels[:, 1:5]
lbls = [1] * len(boxes)
difficults = [1] * len(boxes)
yield im, boxes, lbls, difficults
yield im, boxes, expand_bboxes(boxes), lbls, difficults
return reader
......@@ -246,5 +274,27 @@ def train(settings, file_list, shuffle=True):
return pyramidbox(settings, file_list, 'train', shuffle)
def test(settings, file_list):
return pyramidbox(settings, file_list, 'test', False)
def infer(settings, image_path):
def batch_reader():
img = Image.open(image_path)
if img.mode == 'L':
img = im.convert('RGB')
im_width, im_height = img.size
if settings.resize_w and settings.resize_h:
img = img.resize((settings.resize_w, settings.resize_h),
Image.ANTIALIAS)
img = np.array(img)
# HWC to CHW
if len(img.shape) == 3:
img = np.swapaxes(img, 1, 2)
img = np.swapaxes(img, 1, 0)
# RBG to BGR
img = img[[2, 1, 0], :, :]
img = img.astype('float32')
img -= settings.img_mean
img = img * 0.007843
img = [img]
img = np.array(img)
return img
return batch_reader
import os
import shutil
import numpy as np
import time
import argparse
......@@ -22,7 +23,7 @@ add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('use_pyramidbox', bool, False, "Whether use PyramidBox model.")
add_arg('dataset', str, 'WIDERFACE', "coco2014, coco2017, and pascalvoc.")
add_arg('model_save_dir', str, 'model', "The path to save model.")
add_arg('pretrained_model', str, './vgg_model/', "The init model path.")
add_arg('pretrained_model', str, './pretrained/', "The init model path.")
add_arg('resize_h', int, 640, "The resized image height.")
add_arg('resize_w', int, 640, "The resized image height.")
#yapf: enable
......@@ -38,12 +39,15 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model,
image_shape = [3, data_args.resize_h, data_args.resize_w]
fetches = []
network = PyramidBox(image_shape, num_classes,
sub_network=args.use_pyramidbox)
if args.use_pyramidbox:
network = PyramidBox(image_shape, sub_network=args.use_pyramidbox)
face_loss, head_loss, loss = network.train()
fetches = [face_loss, head_loss]
else:
network = PyramidBox(image_shape, sub_network=args.use_pyramidbox)
loss = network.vgg_ssd(num_classes, image_shape)
loss = network.vgg_ssd_loss()
fetches = [loss]
epocs = 12880 / batch_size
boundaries = [epocs * 100, epocs * 125, epocs * 150]
......@@ -71,11 +75,12 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model,
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
# fluid.io.save_inference_model('./vgg_model/', ['image'], [loss], exe)
if pretrained_model:
if not os.path.exists(pretrained_model):
raise ValueError("The pre-trained model path [%s] does not exist." %
(pretrained_model))
def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name))
print('Load pre-trained model.')
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
if args.parallel:
......@@ -84,11 +89,7 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model,
train_reader = paddle.batch(
reader.train(data_args, train_file_list), batch_size=batch_size)
feeder = fluid.DataFeeder(
place=place,
feed_list=[
network.image, network.gt_box, network.gt_label, network.difficult
])
feeder = fluid.DataFeeder(place=place, feed_list=network.feeds())
def save_model(postfix):
model_path = os.path.join(model_save_dir, postfix)
......@@ -97,8 +98,6 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model,
print 'save models to %s' % (model_path)
fluid.io.save_persistables(exe, model_path)
best_map = 0.
for pass_id in range(num_passes):
start_time = time.time()
prev_start_time = start_time
......@@ -108,20 +107,27 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model,
start_time = time.time()
if len(data) < devices_num: continue
if args.parallel:
loss_v, = train_exe.run(fetch_list=[loss.name],
feed=feeder.feed(data))
fetch_vars = train_exe.run(fetch_list=[v.name for v in fetches],
feed=feeder.feed(data))
else:
loss_v, = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[loss])
fetch_vars = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=fetches)
end_time = time.time()
loss_v = np.mean(np.array(loss_v))
fetch_vars = [np.mean(np.array(v)) for v in fetch_vars]
if batch_id % 1 == 0:
print("Pass {0}, batch {1}, loss {2}, time {3}".format(
pass_id, batch_id, loss_v, start_time - prev_start_time))
if pass_id % 10 == 0 or pass_id == num_passes - 1:
if not args.use_pyramidbox:
print("Pass {0}, batch {1}, loss {2}, time {3}".format(
pass_id, batch_id, fetch_vars[0],
start_time - prev_start_time))
else:
print("Pass {0}, batch {1}, face loss {2}, head loss {3}, " \
"time {4}".format(pass_id,
batch_id, fetch_vars[0], fetch_vars[1],
start_time - prev_start_time))
if pass_id % 1 == 0 or pass_id == num_passes - 1:
save_model(str(pass_id))
print("Best test map {0}".format(best_map))
if __name__ == '__main__':
......
......@@ -20,8 +20,8 @@ def calc_diff(f1, f2):
d1 = np.load(f1)
d2 = np.load(f2)
print d1.shape
print d2.shape
#print d1.shape
#print d2.shape
#print d1[0, 0, 0:10, 0:10]
#print d2[0, 0, 0:10, 0:10]
#d1 = d1[:, :, 1:-2, 1:-2]
......
......@@ -78,6 +78,54 @@ def dump_results(results, names, root):
np.save(filename + '.npy', res)
def normalize_name(name_map):
return {
k.replace('/', '_'): v.replace('/', '_')
for k, v in name_map.items()
}
def rename_layer_name(names, net):
""" because the names of output layers from caffe maybe changed for 'INPLACE' operation,
and paddle's layers maybe fused, so we need to re-mapping their relationship for comparing
"""
#build a mapping from paddle's name to caffe's name
trace = getattr(net, 'name_trace', None)
cf_trace = trace['caffe']
real2cf = normalize_name(cf_trace['real2chg'])
pd_trace = trace['paddle']
pd2real = normalize_name(pd_trace['chg2real'])
pd_deleted = normalize_name(pd_trace['deleted'])
pd2cf_name = {}
for pd_name, real_name in pd2real.items():
if real_name in real2cf:
pd2cf_name[pd_name] = '%s.%s.%s.both_changed' \
% (real2cf[real_name], real_name, pd_name)
else:
pd2cf_name[pd_name] = '%s.%s.pd_changed' % (real_name, pd_name)
for pd_name, trace in pd_deleted.items():
assert pd_name not in pd2cf_name, "this name[%s] has already exist" % (
pd_name)
pd2cf_name[pd_name] = '%s.pd_deleted' % (pd_name)
for real_name, cf_name in real2cf.items():
if cf_name not in pd2cf_name:
pd2cf_name[cf_name] = '%s.cf_deleted' % (cf_name)
if real_name not in pd2cf_name:
pd2cf_name[real_name] = '%s.%s.cf_changed' % (cf_name, real_name)
ret = []
for name in names:
new_name = pd2cf_name[name] if name in pd2cf_name else name
print('remap paddle name[%s] to output name[%s]' % (name, new_name))
ret.append(new_name)
return ret
def load_model(exe, place, net_file, net_name, net_weight, debug):
""" load model using xxxnet.py and xxxnet.npy
"""
......@@ -117,7 +165,8 @@ def load_model(exe, place, net_file, net_name, net_weight, debug):
'feed_names': feed_names,
'fetch_vars': fetch_list_var,
'fetch_names': fetch_list_name,
'feed_shapes': feed_shapes
'feed_shapes': feed_shapes,
'net': net
}
......@@ -171,6 +220,7 @@ def infer(model_path, imgfile, net_file=None, net_name=None, debug=True):
fetch_targets = ret['fetch_vars']
fetch_list_name = ret['fetch_names']
feed_shapes = ret['feed_shapes']
net = ret['net']
input_name = feed_names[0]
input_shape = feed_shapes[0]
......@@ -182,7 +232,8 @@ def infer(model_path, imgfile, net_file=None, net_name=None, debug=True):
if debug is True:
dump_path = 'results.paddle'
dump_results(results, fetch_list_name, dump_path)
dump_names = rename_layer_name(fetch_list_name, net)
dump_results(results, dump_names, dump_path)
print('all result of layers dumped to [%s]' % (dump_path))
else:
result = results[0]
......
......@@ -19,4 +19,6 @@ if [[ $# -eq 3 ]];then
else
caffe_file="./results/${model_name}.caffe/${2}.npy"
fi
python ./compare.py $paddle_file $caffe_file
cmd="python ./compare.py $paddle_file $caffe_file"
echo $cmd
eval $cmd
......@@ -3,7 +3,7 @@
#function:
# a tool used to compare all layers' results
#
#set -x
if [[ $# -ne 1 ]];then
echo "usage:"
echo " bash $0 [model_name]"
......@@ -13,11 +13,20 @@ fi
model_name=$1
prototxt="models.caffe/$model_name/${model_name}.prototxt"
layers=$(cat $prototxt | perl -ne 'if(/^\s+name\s*:\s*\"([^\"]+)/){print $1."\n";}')
cat $prototxt | grep name | perl -ne 'if(/^\s*name\s*:\s+\"([^\"]+)/){ print $1."\n";}' >.layer_names
final_layer=$(cat $prototxt | perl -ne 'if(/^\s*top\s*:\s+\"([^\"]+)/){ print $1."\n";}' | tail -n1)
ret=$(grep "^$final_layer$" .layer_names | wc -l)
if [[ $ret -eq 0 ]];then
echo $final_layer >>.layer_names
fi
for i in $layers;do
for i in $(cat .layer_names);do
i=${i//\//_}
cf_npy="results/${model_name}.caffe/${i}.npy"
pd_npy="results/${model_name}.paddle/${i}.npy"
#pd_npy="results/${model_name}.paddle/${i}.npy"
#pd_npy=$(find results/${model_name}.paddle -iname "${i}*.npy" | head -n1)
pd_npy=$(find results/${model_name}.paddle -iname "${i}.*npy" | grep deleted -v | head -n1)
if [[ ! -e $cf_npy ]];then
echo "caffe's result not exist[$cf_npy]"
......
......@@ -29,8 +29,8 @@ fi
mkdir -p $results_root
model_prototxt="models.caffe/$model_name/${model_name}.prototxt"
model_caffemodel="models.caffe/${model_name}/${model_name}.caffemodel"
prototxt="models.caffe/$model_name/${model_name}.prototxt"
caffemodel="models.caffe/${model_name}/${model_name}.caffemodel"
#1, dump layers' results from paddle
paddle_results="$results_root/${model_name}.paddle"
......@@ -51,7 +51,7 @@ PYTHON=`which cfpython`
if [[ -z $PYTHON ]];then
PYTHON=`which python`
fi
$PYTHON ./infer.py caffe $model_prototxt $model_caffemodel $paddle_results/data.npy
$PYTHON ./infer.py caffe $prototxt $caffemodel $paddle_results/data.npy
if [[ $? -ne 0 ]] || [[ ! -e "results.caffe" ]];then
echo "not found caffe's results, maybe failed to do inference with caffe"
exit 1
......@@ -59,10 +59,25 @@ fi
mv results.caffe $caffe_results
#3, extract layer names
cat $model_prototxt | grep name | perl -ne 'if(/^\s*name:\s+\"([^\"]+)/){ print $1."\n";}' >.layer_names
cat $prototxt | grep name | perl -ne 'if(/^\s*name\s*:\s+\"([^\"]+)/){ print $1."\n";}' >.layer_names
final_layer=$(cat $prototxt | perl -ne 'if(/^\s*top\s*:\s+\"([^\"]+)/){ print $1."\n";}' | tail -n1)
ret=$(grep "^$final_layer$" .layer_names | wc -l)
if [[ $ret -eq 0 ]];then
echo $final_layer >>.layer_names
fi
#4, compare one by one
for i in $(cat ".layer_names" | tail -n1);do
#for i in $(cat .layer_names);do
for i in $(cat .layer_names | tail -n1);do
i=${i//\//_}
echo "process $i"
$PYTHON compare.py $caffe_results/${i}.npy $paddle_results/${i}.npy
pd_npy=$(find $paddle_results/ -iname "${i}.*npy" | grep deleted -v | head -n1)
#pd_npy="$paddle_results/${i}.npy"
if [[ -f $pd_npy ]];then
$PYTHON compare.py $caffe_results/${i}.npy $pd_npy
else
echo "not found npy file[${i}.*npy] for layer[$i]"
exit 1
fi
done
......@@ -71,7 +71,9 @@ if [[ -z $only_convert ]];then
if [[ -z $net_name ]];then
net_name="MyNet"
fi
$PYTHON ./infer.py dump $net_file $weight_file $imgfile $net_name
cmd="$PYTHON ./infer.py dump $net_file $weight_file $imgfile $net_name"
echo $cmd
eval $cmd
ret=$?
fi
exit $ret
#!/bin/bash
#
#script to test all models
#
models="alexnet vgg16 googlenet resnet152 resnet101 resnet50"
for i in $models;do
echo "begin to process $i"
bash ./tools/diff.sh $i 2>&1
echo "finished to process $i with ret[$?]"
done
......@@ -58,11 +58,13 @@ def argmax_layer(input, name, out_max_val=False, top_k=1, axis=-1):
if axis < 0:
axis += len(input.shape)
topk_var, index_var = fluid.layers.topk(input=input, k=top_k)
if out_max_val is True:
topk_var, index_var = fluid.layers.topk(input=input, k=top_k)
index_var = fluid.layers.cast(index_var, dtype=topk_var.dtype)
output = fluid.layers.concat([index_var, topk_var], axis=axis)
output = fluid.layers.concat(
[index_var, topk_var], axis=axis, name=name)
else:
topk_var, index_var = fluid.layers.topk(input=input, k=top_k, name=name)
output = index_var
return output
......
......@@ -43,7 +43,7 @@ def axpy_layer(inputs, name):
x = inputs[1]
y = inputs[2]
output = fluid.layers.elementwise_mul(x, alpha, axis=0)
output = fluid.layers.elementwise_add(output, y)
output = fluid.layers.elementwise_add(output, y, name=name)
return output
......
......@@ -63,9 +63,10 @@ class Node(object):
class Graph(object):
def __init__(self, nodes=None, name=None):
def __init__(self, nodes=None, name=None, trace={}):
self.nodes = nodes or []
self.node_lut = {node.name: node for node in self.nodes}
self.output_trace = trace
if name is None or name == '':
self.name = 'MyNet'
else:
......@@ -81,6 +82,15 @@ class Graph(object):
except KeyError:
raise KaffeError('Layer not found: %s' % name)
def add_name_trace(self, trace, which='caffe'):
self.output_trace[which] = trace
def get_name_trace(self, which=None):
if which is not None:
return self.output_trace[which]
else:
return self.output_trace
def get_input_nodes(self):
return [node for node in self.nodes if len(node.parents) == 0]
......@@ -116,7 +126,7 @@ class Graph(object):
*NodeKind.compute_output_shape(node))
def replaced(self, new_nodes):
return Graph(nodes=new_nodes, name=self.name)
return Graph(nodes=new_nodes, name=self.name, trace=self.output_trace)
def transformed(self, transformers):
graph = self
......@@ -262,6 +272,7 @@ class GraphBuilder(object):
# The current implementation only supports single-output nodes (note that a node can still
# have multiple children, since multiple child nodes can refer to the single top's name).
node_outputs = {}
output_trace = {}
for layer in layers:
node = graph.get_node(layer.name)
for input_name in layer.bottom:
......@@ -291,7 +302,26 @@ class GraphBuilder(object):
#
# For both cases, future references to this top re-routes to this node.
node_outputs[output_name] = node
if output_name in output_trace:
output_trace[output_name].append(node.name)
else:
output_trace[output_name] = [output_name, node.name]
#build a mapping from real-name to changed-name(for caffe's INPLACE inference)
real2chg = {}
deleted = {}
for k, v in output_trace.items():
real2chg[v[-1]] = k
for n in v:
if n in real2chg:
continue
if n not in deleted:
deleted[n] = '%s.%s' % (k, v[-1])
graph.add_name_trace({
'real2chg': real2chg,
'deleted': deleted
}, 'caffe')
graph.compute_output_shapes()
return graph
......
......@@ -216,7 +216,7 @@ class LayerAdapter(object):
s_w = self.get_kernel_value(
params.stride_w, params.stride, 1, default=1)
p_h = self.get_kernel_value(params.pad_h, params.pad, 0, default=0)
p_w = self.get_kernel_value(params.pad_h, params.pad, 1, default=0)
p_w = self.get_kernel_value(params.pad_w, params.pad, 1, default=0)
return KernelParameters(k_h, k_w, s_h, s_w, p_h, p_w)
......
......@@ -47,6 +47,8 @@ class Network(object):
self.trainable = trainable
# Switch variable for dropout
self.paddle_env = None
self.output_names = []
self.name_trace = None
self.setup()
def setup(self):
......@@ -79,6 +81,10 @@ class Network(object):
data_dict = np.load(data_path).item()
for op_name in data_dict:
if op_name == 'caffe2fluid_name_trace':
self.name_trace = data_dict[op_name]
continue
layer = self.layers[op_name]
for param_name, data in data_dict[op_name].iteritems():
try:
......@@ -117,6 +123,15 @@ class Network(object):
ident = sum(t.startswith(prefix) for t, _ in self.layers.items()) + 1
return '%s_%d' % (prefix, ident)
def get_unique_output_name(self, prefix, layertype):
'''Returns an index-suffixed unique name for the given prefix.
This is used for auto-generating layer names based on the type-prefix.
'''
ident = sum(t.startswith(prefix) for t in self.output_names) + 1
unique_name = '%s.%s.output.%d' % (prefix, layertype, ident)
self.output_names.append(unique_name)
return unique_name
@layer
def conv(self,
input,
......@@ -152,6 +167,7 @@ class Network(object):
act = None
output = fluid.layers.conv2d(
name=self.get_unique_output_name(name, 'conv2d'),
input=input,
filter_size=[k_h, k_w],
num_filters=c_o,
......@@ -170,7 +186,8 @@ class Network(object):
@layer
def relu(self, input, name):
fluid = import_fluid()
output = fluid.layers.relu(x=input)
output = fluid.layers.relu(
name=self.get_unique_output_name(name, 'relu'), x=input)
return output
def pool(self, pool_type, input, k_h, k_w, s_h, s_w, ceil_mode, padding,
......@@ -182,6 +199,7 @@ class Network(object):
fluid = import_fluid()
output = fluid.layers.pool2d(
name=name,
input=input,
pool_size=k_hw,
pool_stride=s_hw,
......@@ -200,8 +218,16 @@ class Network(object):
ceil_mode,
padding=[0, 0],
name=None):
return self.pool('max', input, k_h, k_w, s_h, s_w, ceil_mode, padding,
name)
return self.pool(
'max',
input,
k_h,
k_w,
s_h,
s_w,
ceil_mode,
padding,
name=self.get_unique_output_name(name, 'max_pool'))
@layer
def avg_pool(self,
......@@ -213,25 +239,41 @@ class Network(object):
ceil_mode,
padding=[0, 0],
name=None):
return self.pool('avg', input, k_h, k_w, s_h, s_w, ceil_mode, padding,
name)
return self.pool(
'avg',
input,
k_h,
k_w,
s_h,
s_w,
ceil_mode,
padding,
name=self.get_unique_output_name(name, 'avg_pool'))
@layer
def sigmoid(self, input, name):
fluid = import_fluid()
return fluid.layers.sigmoid(input)
return fluid.layers.sigmoid(
input, name=self.get_unique_output_name(name, 'sigmoid'))
@layer
def lrn(self, input, radius, alpha, beta, name, bias=1.0):
fluid = import_fluid()
output = fluid.layers.lrn(input=input, \
n=radius, k=bias, alpha=alpha, beta=beta, name=name)
output = fluid.layers.lrn(input=input,
n=radius,
k=bias,
alpha=alpha,
beta=beta,
name=self.get_unique_output_name(name, 'lrn'))
return output
@layer
def concat(self, inputs, axis, name):
fluid = import_fluid()
output = fluid.layers.concat(input=inputs, axis=axis)
output = fluid.layers.concat(
input=inputs,
axis=axis,
name=self.get_unique_output_name(name, 'concat'))
return output
@layer
......@@ -239,7 +281,8 @@ class Network(object):
fluid = import_fluid()
output = inputs[0]
for i in inputs[1:]:
output = fluid.layers.elementwise_add(x=output, y=i)
output = fluid.layers.elementwise_add(
x=output, y=i, name=self.get_unique_output_name(name, 'add'))
return output
@layer
......@@ -251,7 +294,7 @@ class Network(object):
prefix = name + '_'
output = fluid.layers.fc(
name=name,
name=self.get_unique_output_name(name, 'fc'),
input=input,
size=num_out,
act=act,
......@@ -269,7 +312,8 @@ class Network(object):
str(shape))
input = fluid.layers.reshape(input, shape[0:2])
output = fluid.layers.softmax(input)
output = fluid.layers.softmax(
input, name=self.get_unique_output_name(name, 'softmax'))
return output
@layer
......@@ -289,7 +333,7 @@ class Network(object):
mean_name = prefix + 'mean'
variance_name = prefix + 'variance'
output = fluid.layers.batch_norm(
name=name,
name=self.get_unique_output_name(name, 'batch_norm'),
input=input,
is_test=True,
param_attr=param_attr,
......@@ -308,7 +352,10 @@ class Network(object):
output = input
else:
output = fluid.layers.dropout(
input, dropout_prob=drop_prob, is_test=is_test)
input,
dropout_prob=drop_prob,
is_test=is_test,
name=self.get_unique_output_name(name, 'dropout'))
return output
@layer
......@@ -328,8 +375,16 @@ class Network(object):
offset_param = fluid.layers.create_parameter(
shape=scale_shape, dtype=input.dtype, name=name, attr=offset_attr)
output = fluid.layers.elementwise_mul(input, scale_param, axis=axis)
output = fluid.layers.elementwise_add(output, offset_param, axis=axis)
output = fluid.layers.elementwise_mul(
input,
scale_param,
axis=axis,
name=self.get_unique_output_name(name, 'scale_mul'))
output = fluid.layers.elementwise_add(
output,
offset_param,
axis=axis,
name=self.get_unique_output_name(name, 'scale_add'))
return output
def custom_layer_factory(self):
......@@ -342,5 +397,6 @@ class Network(object):
def custom_layer(self, inputs, kind, name, *args, **kwargs):
""" make custom layer
"""
name = self.get_unique_output_name(name, kind)
layer_factory = self.custom_layer_factory()
return layer_factory(kind, inputs, name, *args, **kwargs)
......@@ -3,9 +3,9 @@ import numpy as np
from ..errors import KaffeError, print_stderr
from ..graph import GraphBuilder, NodeMapper
from ..layers import NodeKind
from ..transformers import (DataInjector, DataReshaper, NodeRenamer, ReLUFuser,
BatchNormScaleBiasFuser, BatchNormPreprocessor,
ParameterNamer)
from ..transformers import (DataInjector, DataReshaper, NodeRenamer,
SubNodeFuser, ReLUFuser, BatchNormScaleBiasFuser,
BatchNormPreprocessor, ParameterNamer)
from . import network
......@@ -18,7 +18,7 @@ def get_padding_type(kernel_params, input_shape, output_shape):
https://github.com/Yangqing/caffe2/blob/master/caffe2/proto/caffe2_legacy.proto
'''
k_h, k_w, s_h, s_w, p_h, p_w = kernel_params
if p_h * p_w > 0:
if p_h > 0 or p_w > 0:
return [p_h, p_w]
else:
return None
......@@ -315,6 +315,23 @@ class Transformer(object):
self.graph = graph.transformed(transformers)
#for the purpose of recording name mapping because of fused nodes
trace = SubNodeFuser.traced_names()
chg2real = {}
deleted = {}
for k, v in trace.items():
chg2real[k] = v[-1] #mapping from changed-name to real-name
for n in v:
if n in chg2real:
continue
if n not in deleted:
deleted[n] = '%s.%s' % (k, v[-1])
self.graph.add_name_trace({
'chg2real': chg2real,
'deleted': deleted
}, 'paddle')
# Display the graph
if self.verbose:
print_stderr(self.graph)
......@@ -339,6 +356,8 @@ class Transformer(object):
node.name: node.data
for node in self.graph.nodes if node.data
}
self.params['caffe2fluid_name_trace'] = self.graph.get_name_trace()
return self.params
def transform_source(self):
......
......@@ -181,6 +181,20 @@ class SubNodeFuser(object):
'''
An abstract helper for merging a single-child with its single-parent.
'''
_traced_names = {}
@classmethod
def traced_names(cls):
return cls._traced_names
@classmethod
def trace(cls, fname, tname):
""" recording the names mapping,
the value of 'fname' will be replaced by value of 'tname'
"""
if fname not in cls._traced_names:
cls._traced_names[fname] = []
cls._traced_names[fname].append(tname)
def __call__(self, graph):
nodes = graph.nodes
......@@ -234,6 +248,7 @@ class ReLUFuser(SubNodeFuser):
child.kind == NodeKind.ReLU)
def merge(self, parent, child):
SubNodeFuser.trace(parent.name, child.name)
parent.metadata['relu'] = True
parent.metadata['relu_negative_slope'] = child.parameters.negative_slope
......@@ -255,6 +270,7 @@ class BatchNormScaleBiasFuser(SubNodeFuser):
child.parameters.bias_term == True)
def merge(self, parent, child):
SubNodeFuser.trace(parent.name, child.name)
parent.scale_bias_node = child
......
......@@ -238,7 +238,7 @@ def lstm_net(data,
size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(learning_rate=emb_lr))
fc0 = fluid.layers.fc(input=emb, size=hid_dim * 4, act='tanh')
fc0 = fluid.layers.fc(input=emb, size=hid_dim * 4)
lstm_h, c = fluid.layers.dynamic_lstm(
input=fc0, size=hid_dim * 4, is_reverse=False)
......@@ -273,9 +273,9 @@ def bilstm_net(data,
size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(learning_rate=emb_lr))
fc0 = fluid.layers.fc(input=emb, size=hid_dim * 4, act='tanh')
fc0 = fluid.layers.fc(input=emb, size=hid_dim * 4)
rfc0 = fluid.layers.fc(input=emb, size=hid_dim * 4, act='tanh')
rfc0 = fluid.layers.fc(input=emb, size=hid_dim * 4)
lstm_h, c = fluid.layers.dynamic_lstm(
input=fc0, size=hid_dim * 4, is_reverse=False)
......
......@@ -238,7 +238,7 @@ def lstm_net(data,
size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(learning_rate=emb_lr))
fc0 = fluid.layers.fc(input=emb, size=hid_dim * 4, act='tanh')
fc0 = fluid.layers.fc(input=emb, size=hid_dim * 4)
lstm_h, c = fluid.layers.dynamic_lstm(
input=fc0, size=hid_dim * 4, is_reverse=False)
......@@ -273,9 +273,9 @@ def bilstm_net(data,
size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(learning_rate=emb_lr))
fc0 = fluid.layers.fc(input=emb, size=hid_dim * 4, act='tanh')
fc0 = fluid.layers.fc(input=emb, size=hid_dim * 4)
rfc0 = fluid.layers.fc(input=emb, size=hid_dim * 4, act='tanh')
rfc0 = fluid.layers.fc(input=emb, size=hid_dim * 4)
lstm_h, c = fluid.layers.dynamic_lstm(
input=fc0, size=hid_dim * 4, is_reverse=False)
......
......@@ -75,7 +75,7 @@ def lstm_net(data,
size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(learning_rate=emb_lr))
fc0 = fluid.layers.fc(input=emb, size=hid_dim * 4, act='tanh')
fc0 = fluid.layers.fc(input=emb, size=hid_dim * 4)
lstm_h, c = fluid.layers.dynamic_lstm(
input=fc0, size=hid_dim * 4, is_reverse=False)
......
## Introduction
Sequence is an input data type faced by many machine learning and data mining tasks. Taking Natural Language Processing task as an example, sentence is composed of words, and paragraph is composed of sentences. As a result, a paragraph can be seen as a nested sequence (or called: double sequence), and each element of the sequence is a sequence.
Double sequence is a very flexible data organization method supported by PaddlePaddle, which can help us better describe more complex data such as paragraphs, multiple rounds of dialogues. With a double-layer sequence as input, we can design a hierarchical network to better accomplish some complex tasks.
This unit will introduce how to use a double sequence in PaddlePaddle.
- [Text Classification Based on Double Sequence](https://github.com/PaddlePaddle/models/tree/develop/nested_sequence/text_classification)
......@@ -22,7 +22,7 @@ PaddlePaddle 实现该网络结构的代码见 `network_conf.py`。
对双层时间序列的处理,需要先将双层时间序列数据变换成单层时间序列数据,再对每一个单层时间序列进行处理。 在 PaddlePaddle 中 ,`recurrent_group` 是帮助我们构建处理双层序列的层次化模型的主要工具。这里,我们使用两个嵌套的 `recurrent_group` 。外层的 `recurrent_group` 将段落拆解为句子,`step` 函数中拿到的输入是句子序列;内层的 `recurrent_group` 将句子拆解为词语,`step` 函数中拿到的输入是非序列的词语。
在词语级别,我们通过 CNN 网络以词向量为输入输出学习到的句子表示;在段落级别,将每个句子的表示通过池化作用得到段落表示。
在词语级别,我们运用 CNN 网络,以词向量为输入,输出学习到的句子表示;在段落级别,我们通过池化作用,从若干句子的表示中得到段落的表示。
``` python
nest_group = paddle.layer.recurrent_group(input=[paddle.layer.SubsequenceInput(emb),
......@@ -112,18 +112,18 @@ python train.py
```
将以 PaddlePaddle 内置的情感分类数据集: `imdb` 运行本例。
### 预测
训练结束后模型将存储在指定目录当中(默认models目录),在终端执行:
训练结束后,模型将被存储在指定目录当中(默认models目录),在终端执行:
```bash
python infer.py --model_path 'models/params_pass_00000.tar.gz'
```
默认情况下,预测脚本将加载训练一个pass的模型对 `imdb的测试集` 进行测试。
预测脚本将加载、训练一个pass的模型,并用这个模型对 `imdb的测试集` 进行测试。
## 使用自定义数据训练和预测
### 训练
1.数据组织
输入数据格式如下:每一行为一条样本,以 `\t` 分隔,第一列是类别标签,第二列是输入文本的内容。以下是两条示例数据:
每一行为一条样本,以 `\t` 分隔,第一列是类别标签,第二列是输入文本的内容。
```
positive This movie is very good. The actor is so handsome.
......@@ -132,7 +132,7 @@ negative What a terrible movie. I waste so much time.
2.编写数据读取接口
自定义数据读取接口只需编写一个 Python 生成器实现**从原始输入文本中解析一条训练样本**的逻辑。以下代码片段实现了读取原始数据返回类型为: `paddle.data_type.integer_value_sub_sequence``paddle.data_type.integer_value`
自定义数据读取接口只需编写一个 Python 生成器,实现**解析输入文本**的逻辑。以下代码片段实现了读取原始数据返回类型为: `paddle.data_type.integer_value_sub_sequence``paddle.data_type.integer_value`
```python
def train_reader(data_dir, word_dict, label_dict):
"""
......
Running sample code in this directory requires PaddelPaddle v0.11.0 and later. If the PaddlePaddle on your device is lower than� this version, please follow the instructions in [installation document](http://www.paddlepaddle.org/docs/develop/documentation/en/build_and_install/pip_install_en.html) and make an update.
---
# Text Classification Based on Double Sequence
## Introduction
This example will demonstrate how to organize long text(usually paragraphs or chapters) input into a double sequence in PaddlePaddle to complete the task of classifying long text.
## Model introduction
We treat a text as a sequence of sentences, and each sentence is a sequence of words.
We first use the convolutional neural network to encode each sentence in the paragraph; then, let the expression vector of each sentence goes through the pooled layer to obtain the encoded vector of the paragraph; finally, the encoded vector of the paragraph is used as the classifier(the full connection of softmax layer) input to obtain the final classification result.
**The model structure is shown in the figure below**
<p align="center">
<img src="images/model.jpg" width = "60%" align="center"/><br/>
Figure1. Text classification model based on double layer sequence
</p>
PaddlePaddle implementation of the network structure is in `network_conf.py`.
To process double-level time series, we need to transform the double layer time series data into single time series data, and then process each single time series. In PaddlePaddle, recurrent_group is the main tool to help us build a hierarchical model for processing double decker sequences. Here, we use two nested recurrent_group. The outer recurrent_group dissolves the paragraph into a sentence, and the input from the step function is the sentence sequence. The recurrent_group in the inner layer dismantles the sentence into word. The input in the step function is a group of non-sequential words.
At the level of words, we obtain the expression of a sentence from word vectors using CNN. At the level of paragraphs, we obtain the expression of a paragraph from the expressions of the sentences in the paragraph through pooling.
``` python
nest_group = paddle.layer.recurrent_group(input=[paddle.layer.SubsequenceInput(emb),
hidden_size],
step=cnn_cov_group)
```
The single layer sequence data after disassembly is represented by a CNN network to learn the corresponding vector, and the network structure of the CNN contains the following parts:
- **Convolution layer**: convolution in text classification is done on time series. The width of convolution kernel is consistent with the matrix generated by word vector level. After convolution, the result is a "feature map". Multiple feature maps can be obtained by using multiple convolutions of different heights. This code uses the convolution kernel of 3 (the red box of Figure 1) and 4 (the blue box of Figure 1) by default.
- **Maximum pool layer**: the maximum pool operation is performed on each feature graph obtained by convolution. Since the feature graph itself is already a vector, the maximum pooling is actually the largest element in the selection of each vector. All the largest elements are spliced together to form a new vector.
- **Linear projection layer**: splices the results from the maximum pool operations into a long vector. Linear projection is used to get the representation vectors of corresponding single layer sequences.
Implementation of CNN network:
```python
def cnn_cov_group(group_input, hidden_size):
"""
Convolution group definition.
:param group_input: The input of this layer.
:type group_input: LayerOutput
:params hidden_size: The size of the fully connected layer.
:type hidden_size: int
"""
conv3 = paddle.networks.sequence_conv_pool(
input=group_input, context_len=3, hidden_size=hidden_size)
conv4 = paddle.networks.sequence_conv_pool(
input=group_input, context_len=4, hidden_size=hidden_size)
linear_proj = paddle.layer.fc(input=[conv3, conv4],
size=hidden_size,
param_attr=paddle.attr.ParamAttr(name='_cov_value_weight'),
bias_attr=paddle.attr.ParamAttr(name='_cov_value_bias'),
act=paddle.activation.Linear())
return linear_proj
```
PaddlePaddle has been encapsulated with a pooled text sequence convolution module: `paddle.networks.sequence_conv_pool`, which can be called directly.
After getting the expression vectors of each sentence, all the sentence vectors are passed through an average pool level, and a vector representation of a sample is obtained. The vector outputs the final prediction result through a fully connected layer. The code:
```python
avg_pool = paddle.layer.pooling(input=nest_group,
pooling_type=paddle.pooling.Avg(),
agg_level=paddle.layer.AggregateLevel.TO_NO_SEQUENCE)
prob = paddle.layer.mixed(size=class_num,
input=[paddle.layer.full_matrix_projection(input=avg_pool)],
act=paddle.activation.Softmax())
```
## Install dependency package
```bash
pip install -r requirements.txt
```
## Specify training configuration parameters
The training and model configuration parameters are modified through the `config.py` script. There are detailed explanations for configurable parameters in the script. The examples are as follows:
```python
class TrainerConfig(object):
# whether to use GPU for training
use_gpu = False
# the number of threads used in one machine
trainer_count = 1
# train batch size
batch_size = 32
...
class ModelConfig(object):
# embedding vector dimension
emb_size = 28
...
```
Modify the `config.py` to adjust the parameters. For example, we can specify whether or not to use GPU for training by modifying `use_gpu`.
## Implement with PaddlePaddle Built-in data
### Train
Execute at the terminal:
```bash
python train.py
```
You will run this example with the PaddlePaddle's built-in emotional categorization dataset, `imdb` .
### Prediction
After training, the model will be stored in the specified directory (the default models directory), execute the following command:
```bash
python infer.py --model_path 'models/params_pass_00000.tar.gz'
```
The prediction script will load and train a pass model to test `test set of the IMDB`.
## Use custom data train and predict
### Train
1.Data structure
Each line is a sample with class label and text. Class label and text content are seperated by `\t`. The following are two samples::
```
positive This movie is very good. The actor is so handsome.
negative What a terrible movie. I waste so much time.
```
2.Write the Data Reading Interface
To define a custom data reading interface, we only need to write a Python generator to **parse the input text**. The following code fragment is implemented to read the return type of the original data: `paddle.data_type.integer_value_sub_sequence` and `paddle.data_type.integer_value`
```python
def train_reader(data_dir, word_dict, label_dict):
"""
Reader interface for training data
:param data_dir: data directory
:type data_dir: str
:param word_dict: path of word dictionary,
the dictionary must has a "UNK" in it.
:type word_dict: Python dict
:param label_dict: path of label dictionary.
:type label_dict: Python dict
"""
def reader():
UNK_ID = word_dict['<unk>']
word_col = 1
lbl_col = 0
for file_name in os.listdir(data_dir):
file_path = os.path.join(data_dir, file_name)
if not os.path.isfile(file_path):
continue
with open(file_path, "r") as f:
for line in f:
line_split = line.strip().split("\t")
doc = line_split[word_col]
doc_ids = []
for sent in doc.strip().split("."):
sent_ids = [
word_dict.get(w, UNK_ID)
for w in sent.split()]
if sent_ids:
doc_ids.append(sent_ids)
yield doc_ids, label_dict[line_split[lbl_col]]
return reader
```
Note that, in this case to English period `'.'` as a separator, the text is divided into a certain number of sentences, and each sentence is expressed as the corresponding index array Thesaurus (`sent_ids`). Since the representation of the current sample (`doc_ids`) contains all the sentences of the text, it is type: `paddle.data_type.integer_value_sub_sequence`.
3.Specify command line parameters for training
`train.py` contains the following parameters:
```
Options:
--train_data_dir TEXT The path of training dataset (default: None). If
this parameter is not set, imdb dataset will be
used.
--test_data_dir TEXT The path of testing dataset (default: None). If this
parameter is not set, imdb dataset will be used.
--word_dict_path TEXT The path of word dictionary (default: None). If this
parameter is not set, imdb dataset will be used. If
this parameter is set, but the file does not exist,
word dictionay will be built from the training data
automatically.
--label_dict_path TEXT The path of label dictionary (default: None).If this
parameter is not set, imdb dataset will be used. If
this parameter is set, but the file does not exist,
label dictionay will be built from the training data
automatically.
--model_save_dir TEXT The path to save the trained models (default:
'models').
--help Show this message and exit.
```
Modify the startup parameters in the `train.py` script to run this example directly. Take the sample data in the data directory for example, execute at the terminal:
```bash
python train.py \
--train_data_dir 'data/train_data' \
--test_data_dir 'data/test_data' \
--word_dict_path 'word_dict.txt' \
--label_dict_path 'label_dict.txt'
```
So you can train with sample data.
### Prediction
1.Specify command line parameters
`infer.py` contains the following parameters:
```
Options:
--data_path TEXT The path of data for inference (default: None). If
this parameter is not set, imdb test dataset will be
used.
--model_path TEXT The path of saved model. [required]
--word_dict_path TEXT The path of word dictionary (default: None). If this
parameter is not set, imdb dataset will be used.
--label_dict_path TEXT The path of label dictionary (default: None).If this
parameter is not set, imdb dataset will be used.
--batch_size INTEGER The number of examples in one batch (default: 32).
--help Show this message and exit.
```
2.take the sample data in the `data` directory as an example, execute at the terminal:
```bash
python infer.py \
--data_path 'data/infer.txt' \
--word_dict_path 'word_dict.txt' \
--label_dict_path 'label_dict.txt' \
--model_path 'models/params_pass_00000.tar.gz'
```
So the sample data can be predicted.
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册