diff --git a/fluid/deep_attention_matching_net/model.py b/fluid/deep_attention_matching_net/model.py index e0c3f478237bf8b9a467c812fa7bb7589e864c92..537722a419038b2832e9f7234b3fe3c08baa2bf8 100644 --- a/fluid/deep_attention_matching_net/model.py +++ b/fluid/deep_attention_matching_net/model.py @@ -15,45 +15,85 @@ class Net(object): self._stack_num = stack_num self._channel1_num = channel1_num self._channel2_num = channel2_num + self._feed_names = [] self.word_emb_name = "shared_word_emb" self.use_stack_op = True self.use_mask_cache = True self.use_sparse_embedding = True - def set_word_embedding(self, word_emb, place): - word_emb_param = fluid.global_scope().find_var( - self.word_emb_name).get_tensor() - word_emb_param.set(word_emb, place) - - def create_network(self): - mask_cache = dict() if self.use_mask_cache else None - - turns_data = [] + def create_py_reader(self, capacity, name): + # turns ids + shapes = [[-1, self._max_turn_len, 1] + for i in six.moves.xrange(self._max_turn_num)] + dtypes = ["int32" for i in six.moves.xrange(self._max_turn_num)] + # turns mask + shapes += [[-1, self._max_turn_len, 1] + for i in six.moves.xrange(self._max_turn_num)] + dtypes += ["float32" for i in six.moves.xrange(self._max_turn_num)] + + # response ids, response mask, label + shapes += [[-1, self._max_turn_len, 1], [-1, self._max_turn_len, 1], + [-1, 1]] + dtypes += ["int32", "float32", "float32"] + + py_reader = fluid.layers.py_reader( + capacity=capacity, + shapes=shapes, + lod_levels=[0] * (2 * self._max_turn_num + 3), + dtypes=dtypes, + name=name, + use_double_buffer=True) + + data_vars = fluid.layers.read_file(py_reader) + + self.turns_data = data_vars[0:self._max_turn_num] + self.turns_mask = data_vars[self._max_turn_num:2 * self._max_turn_num] + self.response = data_vars[-3] + self.response_mask = data_vars[-2] + self.label = data_vars[-1] + return py_reader + + def create_data_layers(self): + self._feed_names = [] + + self.turns_data = [] for i in six.moves.xrange(self._max_turn_num): + name = "turn_%d" % i turn = fluid.layers.data( - name="turn_%d" % i, - shape=[self._max_turn_len, 1], - dtype="int32") - turns_data.append(turn) + name=name, shape=[self._max_turn_len, 1], dtype="int32") + self.turns_data.append(turn) + self._feed_names.append(name) - turns_mask = [] + self.turns_mask = [] for i in six.moves.xrange(self._max_turn_num): + name = "turn_mask_%d" % i turn_mask = fluid.layers.data( - name="turn_mask_%d" % i, - shape=[self._max_turn_len, 1], - dtype="float32") - turns_mask.append(turn_mask) + name=name, shape=[self._max_turn_len, 1], dtype="float32") + self.turns_mask.append(turn_mask) + self._feed_names.append(name) - response = fluid.layers.data( + self.response = fluid.layers.data( name="response", shape=[self._max_turn_len, 1], dtype="int32") - response_mask = fluid.layers.data( + self.response_mask = fluid.layers.data( name="response_mask", shape=[self._max_turn_len, 1], dtype="float32") - label = fluid.layers.data(name="label", shape=[1], dtype="float32") + self.label = fluid.layers.data(name="label", shape=[1], dtype="float32") + self._feed_names += ["response", "response_mask", "label"] + + def get_feed_names(self): + return self._feed_names + + def set_word_embedding(self, word_emb, place): + word_emb_param = fluid.global_scope().find_var( + self.word_emb_name).get_tensor() + word_emb_param.set(word_emb, place) + + def create_network(self): + mask_cache = dict() if self.use_mask_cache else None response_emb = fluid.layers.embedding( - input=response, + input=self.response, size=[self._vocab_size + 1, self._emb_size], is_sparse=self.use_sparse_embedding, param_attr=fluid.ParamAttr( @@ -71,8 +111,8 @@ class Net(object): key=Hr, value=Hr, d_key=self._emb_size, - q_mask=response_mask, - k_mask=response_mask, + q_mask=self.response_mask, + k_mask=self.response_mask, mask_cache=mask_cache) Hr_stack.append(Hr) @@ -80,7 +120,7 @@ class Net(object): sim_turns = [] for t in six.moves.xrange(self._max_turn_num): Hu = fluid.layers.embedding( - input=turns_data[t], + input=self.turns_data[t], size=[self._vocab_size + 1, self._emb_size], is_sparse=self.use_sparse_embedding, param_attr=fluid.ParamAttr( @@ -96,8 +136,8 @@ class Net(object): key=Hu, value=Hu, d_key=self._emb_size, - q_mask=turns_mask[t], - k_mask=turns_mask[t], + q_mask=self.turns_mask[t], + k_mask=self.turns_mask[t], mask_cache=mask_cache) Hu_stack.append(Hu) @@ -111,8 +151,8 @@ class Net(object): key=Hr_stack[index], value=Hr_stack[index], d_key=self._emb_size, - q_mask=turns_mask[t], - k_mask=response_mask, + q_mask=self.turns_mask[t], + k_mask=self.response_mask, mask_cache=mask_cache) r_a_t = layers.block( name="r_attend_t_" + str(index), @@ -120,8 +160,8 @@ class Net(object): key=Hu_stack[index], value=Hu_stack[index], d_key=self._emb_size, - q_mask=response_mask, - k_mask=turns_mask[t], + q_mask=self.response_mask, + k_mask=self.turns_mask[t], mask_cache=mask_cache) t_a_r_stack.append(t_a_r) @@ -158,5 +198,5 @@ class Net(object): sim = fluid.layers.concat(input=sim_turns, axis=2) final_info = layers.cnn_3d(sim, self._channel1_num, self._channel2_num) - loss, logits = layers.loss(final_info, label) + loss, logits = layers.loss(final_info, self.label) return loss, logits diff --git a/fluid/deep_attention_matching_net/train_and_evaluate.py b/fluid/deep_attention_matching_net/train_and_evaluate.py index d4f8374c66938985a946f5cf2da6d1b6e391f37c..bd701904d6445be22d8bcb0634b78ad59713cc21 100644 --- a/fluid/deep_attention_matching_net/train_and_evaluate.py +++ b/fluid/deep_attention_matching_net/train_and_evaluate.py @@ -7,7 +7,7 @@ import multiprocessing import paddle import paddle.fluid as fluid import utils.reader as reader -from utils.util import print_arguments +from utils.util import print_arguments, mkdir try: import cPickle as pickle #python 2 @@ -49,6 +49,10 @@ def parse_args(): '--use_cuda', action='store_true', help='If set, use cuda for training.') + parser.add_argument( + '--use_pyreader', + action='store_true', + help='If set, use pyreader for reading data.') parser.add_argument( '--ext_eval', action='store_true', @@ -105,7 +109,75 @@ def parse_args(): #yapf: enable +def evaluate(score_path, result_file_path): + if args.ext_eval: + import utils.douban_evaluation as eva + else: + import utils.evaluation as eva + #write evaluation result + result = eva.evaluate(score_path) + with open(result_file_path, 'w') as out_file: + for p_at in result: + out_file.write(str(p_at) + '\n') + print('finish evaluation') + print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))) + + +def test_with_feed(exe, program, feed_names, fetch_list, score_path, batches, + batch_num, dev_count): + score_file = open(score_path, 'w') + for it in six.moves.xrange(batch_num // dev_count): + feed_list = [] + for dev in six.moves.xrange(dev_count): + val_index = it * dev_count + dev + batch_data = reader.make_one_batch_input(batches, val_index) + feed_dict = dict(zip(feed_names, batch_data)) + feed_list.append(feed_dict) + + predicts = exe.run(feed=feed_list, fetch_list=fetch_list) + + scores = np.array(predicts[0]) + for dev in six.moves.xrange(dev_count): + val_index = it * dev_count + dev + for i in six.moves.xrange(args.batch_size): + score_file.write( + str(scores[args.batch_size * dev + i][0]) + '\t' + str( + batches["label"][val_index][i]) + '\n') + score_file.close() + + +def test_with_pyreader(exe, program, pyreader, fetch_list, score_path, batches, + batch_num, dev_count): + def data_provider(): + for index in six.moves.xrange(batch_num): + yield reader.make_one_batch_input(batches, index) + + score_file = open(score_path, 'w') + pyreader.decorate_tensor_provider(data_provider) + it = 0 + pyreader.start() + while True: + try: + predicts = exe.run(fetch_list=fetch_list) + + scores = np.array(predicts[0]) + for dev in six.moves.xrange(dev_count): + val_index = it * dev_count + dev + for i in six.moves.xrange(args.batch_size): + score_file.write( + str(scores[args.batch_size * dev + i][0]) + '\t' + str( + batches["label"][val_index][i]) + '\n') + it += 1 + except fluid.core.EOFException: + pyreader.reset() + break + score_file.close() + + def train(args): + if not os.path.exists(args.save_path): + os.makedirs(args.save_path) + # data data_config data_conf = { "batch_size": args.batch_size, @@ -117,27 +189,47 @@ def train(args): dam = Net(args.max_turn_num, args.max_turn_len, args.vocab_size, args.emb_size, args.stack_num, args.channel1_num, args.channel2_num) - loss, logits = dam.create_network() - loss.persistable = True - logits.persistable = True - - train_program = fluid.default_main_program() - test_program = fluid.default_main_program().clone(for_test=True) - - # gradient clipping - fluid.clip.set_gradient_clip(clip=fluid.clip.GradientClipByValue( - max=1.0, min=-1.0)) - - optimizer = fluid.optimizer.Adam( - learning_rate=fluid.layers.exponential_decay( - learning_rate=args.learning_rate, - decay_steps=400, - decay_rate=0.9, - staircase=True)) - optimizer.minimize(loss) - - fluid.memory_optimize(train_program) + train_program = fluid.Program() + train_startup = fluid.Program() + with fluid.program_guard(train_program, train_startup): + with fluid.unique_name.guard(): + if args.use_pyreader: + train_pyreader = dam.create_py_reader( + capacity=10, name='train_reader') + else: + dam.create_data_layers() + loss, logits = dam.create_network() + loss.persistable = True + logits.persistable = True + # gradient clipping + fluid.clip.set_gradient_clip(clip=fluid.clip.GradientClipByValue( + max=1.0, min=-1.0)) + + optimizer = fluid.optimizer.Adam( + learning_rate=fluid.layers.exponential_decay( + learning_rate=args.learning_rate, + decay_steps=400, + decay_rate=0.9, + staircase=True)) + optimizer.minimize(loss) + fluid.memory_optimize(train_program) + + test_program = fluid.Program() + test_startup = fluid.Program() + with fluid.program_guard(test_program, test_startup): + with fluid.unique_name.guard(): + if args.use_pyreader: + test_pyreader = dam.create_py_reader( + capacity=10, name='test_reader') + else: + dam.create_data_layers() + + loss, logits = dam.create_network() + loss.persistable = True + logits.persistable = True + + test_program = test_program.clone(for_test=True) if args.use_cuda: place = fluid.CUDAPlace(0) @@ -152,7 +244,8 @@ def train(args): program=train_program, batch_size=args.batch_size)) exe = fluid.Executor(place) - exe.run(fluid.default_startup_program()) + exe.run(train_startup) + exe.run(test_startup) train_exe = fluid.ParallelExecutor( use_cuda=args.use_cuda, loss_name=loss.name, main_program=train_program) @@ -162,11 +255,6 @@ def train(args): main_program=test_program, share_vars_from=train_exe) - if args.ext_eval: - import utils.douban_evaluation as eva - else: - import utils.evaluation as eva - if args.word_emb_init is not None: print("start loading word embedding init ...") if six.PY2: @@ -199,17 +287,15 @@ def train(args): print("begin model training ...") print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))) - step = 0 - for epoch in six.moves.xrange(args.num_scan_data): - shuffle_train = reader.unison_shuffle(train_data) - train_batches = reader.build_batches(shuffle_train, data_conf) - + # train on one epoch data by feeding + def train_with_feed(step): ave_cost = 0.0 for it in six.moves.xrange(batch_num // dev_count): feed_list = [] for dev in six.moves.xrange(dev_count): index = it * dev_count + dev - feed_dict = reader.make_one_batch_input(train_batches, index) + batch_data = reader.make_one_batch_input(train_batches, index) + feed_dict = dict(zip(dam.get_feed_names(), batch_data)) feed_list.append(feed_dict) cost = train_exe.run(feed=feed_list, fetch_list=[loss.name]) @@ -226,41 +312,73 @@ def train(args): print("Save model at step %d ... " % step) print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))) - fluid.io.save_persistables(exe, save_path) + fluid.io.save_persistables(exe, save_path, train_program) score_path = os.path.join(args.save_path, 'score.' + str(step)) - score_file = open(score_path, 'w') - for it in six.moves.xrange(val_batch_num // dev_count): - feed_list = [] - for dev in six.moves.xrange(dev_count): - val_index = it * dev_count + dev - feed_dict = reader.make_one_batch_input(val_batches, - val_index) - feed_list.append(feed_dict) - - predicts = test_exe.run(feed=feed_list, - fetch_list=[logits.name]) - - scores = np.array(predicts[0]) - for dev in six.moves.xrange(dev_count): - val_index = it * dev_count + dev - for i in six.moves.xrange(args.batch_size): - score_file.write( - str(scores[args.batch_size * dev + i][0]) + '\t' - + str(val_batches["label"][val_index][ - i]) + '\n') - score_file.close() - - #write evaluation result - result = eva.evaluate(score_path) + test_with_feed(test_exe, test_program, + dam.get_feed_names(), [logits.name], score_path, + val_batches, val_batch_num, dev_count) + result_file_path = os.path.join(args.save_path, 'result.' + str(step)) - with open(result_file_path, 'w') as out_file: - for p_at in result: - out_file.write(str(p_at) + '\n') - print('finish evaluation') - print(time.strftime('%Y-%m-%d %H:%M:%S', - time.localtime(time.time()))) + evaluate(score_path, result_file_path) + return step + + # train on one epoch with pyreader + def train_with_pyreader(step): + def data_provider(): + for index in six.moves.xrange(batch_num): + yield reader.make_one_batch_input(train_batches, index) + + train_pyreader.decorate_tensor_provider(data_provider) + + ave_cost = 0.0 + train_pyreader.start() + while True: + try: + cost = train_exe.run(fetch_list=[loss.name]) + + ave_cost += np.array(cost[0]).mean() + step = step + 1 + if step % print_step == 0: + print("processed: [" + str(step * dev_count * 1.0 / + batch_num) + "] ave loss: [" + + str(ave_cost / print_step) + "]") + ave_cost = 0.0 + + if (args.save_path is not None) and (step % save_step == 0): + save_path = os.path.join(args.save_path, + "step_" + str(step)) + print("Save model at step %d ... " % step) + print(time.strftime('%Y-%m-%d %H:%M:%S', + time.localtime(time.time()))) + fluid.io.save_persistables(exe, save_path, train_program) + + score_path = os.path.join(args.save_path, + 'score.' + str(step)) + test_with_pyreader(test_exe, test_program, test_pyreader, + [logits.name], score_path, val_batches, + val_batch_num, dev_count) + + result_file_path = os.path.join(args.save_path, + 'result.' + str(step)) + evaluate(score_path, result_file_path) + + except fluid.core.EOFException: + train_pyreader.reset() + break + return step + + # train over different epoches + global_step = 0 + for epoch in six.moves.xrange(args.num_scan_data): + shuffle_train = reader.unison_shuffle(train_data) + train_batches = reader.build_batches(shuffle_train, data_conf) + + if args.use_pyreader: + global_step = train_with_pyreader(global_step) + else: + global_step = train_with_feed(global_step) if __name__ == '__main__': diff --git a/fluid/deep_attention_matching_net/utils/reader.py b/fluid/deep_attention_matching_net/utils/reader.py index af687a6eb35b2ca4d88eff064190025a284c27c9..96b4bfd71658c037076c4c64a3cccccb467c5e9e 100644 --- a/fluid/deep_attention_matching_net/utils/reader.py +++ b/fluid/deep_attention_matching_net/utils/reader.py @@ -202,30 +202,30 @@ def make_one_batch_input(data_batches, index): every_turn_len[:, i] for i in six.moves.xrange(max_turn_num) ] - feed_dict = {} + feed_list = [] for i, turn in enumerate(turns_list): - feed_dict["turn_%d" % i] = turn - feed_dict["turn_%d" % i] = np.expand_dims( - feed_dict["turn_%d" % i], axis=-1) + turn = np.expand_dims(turn, axis=-1) + feed_list.append(turn) for i, turn_len in enumerate(every_turn_len_list): - feed_dict["turn_mask_%d" % i] = np.ones( - (batch_size, max_turn_len, 1)).astype("float32") + turn_mask = np.ones((batch_size, max_turn_len, 1)).astype("float32") for row in six.moves.xrange(batch_size): - feed_dict["turn_mask_%d" % i][row, turn_len[row]:, 0] = 0 + turn_mask[row, turn_len[row]:, 0] = 0 + feed_list.append(turn_mask) - feed_dict["response"] = response - feed_dict["response"] = np.expand_dims(feed_dict["response"], axis=-1) + response = np.expand_dims(response, axis=-1) + feed_list.append(response) - feed_dict["response_mask"] = np.ones( - (batch_size, max_turn_len, 1)).astype("float32") + response_mask = np.ones((batch_size, max_turn_len, 1)).astype("float32") for row in six.moves.xrange(batch_size): - feed_dict["response_mask"][row, response_len[row]:, 0] = 0 + response_mask[row, response_len[row]:, 0] = 0 + feed_list.append(response_mask) - feed_dict["label"] = np.array([data_batches["label"][index]]).reshape( + label = np.array([data_batches["label"][index]]).reshape( [-1, 1]).astype("float32") + feed_list.append(label) - return feed_dict + return feed_list if __name__ == '__main__':