提交 0dcbeda2 编写于 作者: N nhzlx

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into...

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add_tensorrt_pooling_converter
...@@ -790,101 +790,3 @@ def get_parameter_value_by_name(name, executor, program=None): ...@@ -790,101 +790,3 @@ def get_parameter_value_by_name(name, executor, program=None):
program = default_main_program() program = default_main_program()
var = program.global_block().var(name) var = program.global_block().var(name)
return get_parameter_value(var, executor) return get_parameter_value(var, executor)
def get_test_program(filelist, program=None, startup_program=None):
"""
Transpile current train program to a program to read test dataset
if the program is using reader ops like "open_files_op".
"""
def _copy_reader_var_(block, var, new_name=None):
if new_name == None:
new_name = var.name
new_var = block.create_var(
name=str(new_name), type=core.VarDesc.VarType.READER)
new_var.desc.set_shapes(var.desc.shapes())
new_var.desc.set_dtypes(var.desc.dtypes())
new_var.persistable = True
return new_var
def _get_test_reader_name(train_reader_name):
return train_reader_name + "_test"
def _is_reader_op(op):
block = op.block
if "Out" in op.output_names:
reader_out = block.vars[op.output("Out")[0]]
if reader_out.type == core.VarDesc.VarType.READER:
return True
return False
if program == None:
program = default_main_program()
if startup_program == None:
startup_program = default_startup_program()
startup_block = startup_program.global_block()
# 1. find out the orignal reader var name
startup_reader_op_list = []
for op in startup_block.ops:
if _is_reader_op(op):
startup_reader_op_list.append(op)
if len(startup_reader_op_list) == 0:
return program
root_reader_op = startup_reader_op_list[0]
train_test_reader_map = {}
# 2. add operators to startup to read open and read test data files
for op in startup_reader_op_list:
assert (len(op.output("Out")) == 1)
train_reader_name = op.output("Out")[0]
train_reader = startup_block.vars[train_reader_name]
test_reader = _copy_reader_var_(
startup_block,
train_reader,
new_name=_get_test_reader_name(train_reader_name))
train_test_reader_map[train_reader.name] = test_reader
test_op_inputs = {}
for name in op.input_names:
train_arg_names = op.input(name)
test_arg_vars = []
for arg_name in train_arg_names:
arg_var = train_test_reader_map[
arg_name] if name == "UnderlyingReader" else startup_block.vars[
arg_name]
test_arg_vars.append(arg_var)
test_op_inputs[name] = test_arg_vars
test_op = startup_block.append_op(
type=op.type,
inputs=test_op_inputs,
outputs={'Out': [test_reader]},
attrs=op.attrs)
# root reader op's filelist attr for read test files
if op.type == root_reader_op.type:
test_op.set_attr("file_names", filelist)
if op.type == "create_multi_pass_reader":
test_op.set_attr("pass_num", 1)
# 3. rename reader vars in inference program to different name
# to avoid read from train data.
main_block = program.global_block()
for var in main_block.vars.values():
if var.type == core.VarDesc.VarType.READER:
main_block._rename_var(
str(var.name), str(_get_test_reader_name(var.name)))
for op in main_block.ops:
if op.type == root_reader_op.type:
test_op.set_attr("file_names", filelist)
if op.type == "create_multi_pass_reader":
test_op.set_attr("pass_num", 1)
startup_program._sync_with_cpp()
program._sync_with_cpp()
return program
...@@ -35,7 +35,7 @@ if len(sys.argv) == 1: ...@@ -35,7 +35,7 @@ if len(sys.argv) == 1:
word_dict = paddle.dataset.imdb.word_dict() word_dict = paddle.dataset.imdb.word_dict()
else: else:
word_dict = load_vocab(sys.argv[1]) word_dict = load_vocab(sys.argv[1])
word_dict["<unk>"] = len(word_dict) word_dict["<unk>"] = len(word_dict)
print "Dict dim = ", len(word_dict) print "Dict dim = ", len(word_dict)
# input text data # input text data
...@@ -50,7 +50,7 @@ feeder = fluid.DataFeeder(feed_list=[data, label], place=fluid.CPUPlace()) ...@@ -50,7 +50,7 @@ feeder = fluid.DataFeeder(feed_list=[data, label], place=fluid.CPUPlace())
BATCH_SIZE = 128 BATCH_SIZE = 128
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
paddle.dataset.imdb.train(word_dict), buf_size=10000), paddle.dataset.imdb.train(word_dict), buf_size=25000),
batch_size=BATCH_SIZE) batch_size=BATCH_SIZE)
test_reader = paddle.batch( test_reader = paddle.batch(
......
...@@ -19,7 +19,7 @@ import sys ...@@ -19,7 +19,7 @@ import sys
TRAIN_FILES = ['train.recordio'] TRAIN_FILES = ['train.recordio']
TEST_FILES = ['test.recordio'] TEST_FILES = ['test.recordio']
DICT_DIM = 89528 DICT_DIM = 5147
# embedding dim # embedding dim
emb_dim = 128 emb_dim = 128
...@@ -27,58 +27,46 @@ emb_dim = 128 ...@@ -27,58 +27,46 @@ emb_dim = 128
# hidden dim # hidden dim
hid_dim = 128 hid_dim = 128
# hidden dim2
hid_dim2 = 96
# class num # class num
class_dim = 2 class_dim = 2
# epoch num
epoch_num = 10
def network_cfg(is_train, pass_num=100):
with fluid.unique_name.guard():
train_file_obj = fluid.layers.open_files(
filenames=TRAIN_FILES,
pass_num=pass_num,
shapes=[[-1, 1], [-1, 1]],
lod_levels=[1, 0],
dtypes=['int64', 'int64'])
test_file_obj = fluid.layers.open_files( def build_program(is_train):
filenames=TEST_FILES, file_obj_handle = fluid.layers.io.open_files(
pass_num=1, filenames=TRAIN_FILES if is_train else TEST_FILES,
shapes=[[-1, 1], [-1, 1]], shapes=[[-1, 1], [-1, 1]],
lod_levels=[1, 0], lod_levels=[1, 0],
dtypes=['int64', 'int64']) dtypes=['int64', 'int64'])
if is_train: file_obj = fluid.layers.io.double_buffer(file_obj_handle)
file_obj = fluid.layers.shuffle(train_file_obj, buffer_size=1000)
else:
file_obj = test_file_obj
file_obj = fluid.layers.double_buffer( with fluid.unique_name.guard():
file_obj,
name="train_double_buffer" if is_train else 'test_double_buffer')
data, label = fluid.layers.read_file(file_obj) data, label = fluid.layers.read_file(file_obj)
emb = fluid.layers.embedding(input=data, size=[DICT_DIM, emb_dim]) emb = fluid.layers.embedding(input=data, size=[DICT_DIM, emb_dim])
# sequence conv with window size = 3
win_size = 3
conv_3 = fluid.nets.sequence_conv_pool( conv_3 = fluid.nets.sequence_conv_pool(
input=emb, input=emb,
num_filters=hid_dim, num_filters=hid_dim,
filter_size=win_size, filter_size=3,
act="tanh", act="tanh",
pool_type="max") pool_type="sqrt")
# fc layer after conv conv_4 = fluid.nets.sequence_conv_pool(
fc_1 = fluid.layers.fc(input=[conv_3], size=hid_dim2) input=emb,
num_filters=hid_dim,
filter_size=4,
act="tanh",
pool_type="sqrt")
# probability of each class prediction = fluid.layers.fc(input=[conv_3, conv_4],
prediction = fluid.layers.fc(input=[fc_1],
size=class_dim, size=class_dim,
act="softmax") act="softmax")
# cross entropy loss # cross entropy loss
cost = fluid.layers.cross_entropy(input=prediction, label=label) cost = fluid.layers.cross_entropy(input=prediction, label=label)
...@@ -88,58 +76,62 @@ def network_cfg(is_train, pass_num=100): ...@@ -88,58 +76,62 @@ def network_cfg(is_train, pass_num=100):
if is_train: if is_train:
# SGD optimizer # SGD optimizer
sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=0.01) sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=0.001)
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
return { return {'loss': avg_cost, 'log': [avg_cost, acc], 'file': file_obj_handle}
'loss': avg_cost,
'log': [avg_cost, acc],
'file': train_file_obj if is_train else test_file_obj
}
def main(): def main():
train = fluid.Program() train = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
test = fluid.Program()
with fluid.program_guard(train, startup): with fluid.program_guard(train, startup):
train_args = network_cfg(is_train=True) train_args = build_program(is_train=True)
test = fluid.Program()
with fluid.program_guard(test, fluid.Program()): with fluid.program_guard(test, startup):
test_args = network_cfg(is_train=False) test_args = build_program(is_train=False)
use_cuda = fluid.core.is_compiled_with_cuda()
# startup # startup
place = fluid.CUDAPlace(0) place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place=place) exe = fluid.Executor(place=place)
exe.run(startup) exe.run(startup)
train_exe = fluid.ParallelExecutor( train_exe = fluid.ParallelExecutor(
use_cuda=True, loss_name=train_args['loss'].name, main_program=train) use_cuda=use_cuda,
loss_name=train_args['loss'].name,
main_program=train)
test_exe = fluid.ParallelExecutor(
use_cuda=use_cuda, main_program=test, share_vars_from=train_exe)
fetch_var_list = [var.name for var in train_args['log']] fetch_var_list = [var.name for var in train_args['log']]
for i in xrange(sys.maxint): for epoch_id in range(epoch_num):
result = map(numpy.array, # train
train_exe.run(fetch_list=fetch_var_list try:
if i % 1000 == 0 else [])) batch_id = 0
if len(result) != 0: while True:
print 'Train: ', result loss, acc = map(numpy.array,
train_exe.run(fetch_list=fetch_var_list))
if i % 1000 == 0: print 'Train epoch', epoch_id, 'batch', batch_id, 'loss:', loss, 'acc:', acc
test_exe = fluid.ParallelExecutor( batch_id += 1
use_cuda=True, main_program=test, share_vars_from=train_exe) except fluid.core.EOFException:
print 'End of epoch', epoch_id
train_args['file'].reset()
# test
loss = [] loss = []
acc = [] acc = []
try: try:
while True: while True:
loss_np, acc_np = map( loss_np, acc_np = map(numpy.array,
numpy.array, test_exe.run(fetch_list=fetch_var_list)) test_exe.run(fetch_list=fetch_var_list))
loss.append(loss_np[0]) loss.append(loss_np[0])
acc.append(acc_np[0]) acc.append(acc_np[0])
except: except:
test_args['file'].reset() test_args['file'].reset()
print 'TEST: ', numpy.mean(loss), numpy.mean(acc) print 'Test loss:', numpy.mean(loss), 'acc:', numpy.mean(acc)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -107,44 +107,24 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -107,44 +107,24 @@ class TestMNIST(TestParallelExecutorBase):
label = np.ones(shape=[32, 1], dtype='int64') label = np.ones(shape=[32, 1], dtype='int64')
return img, label return img, label
# simple_fc def _compare_reduce_and_allreduce(self, model, use_cuda, random_data=True):
def check_simple_fc_convergence(self, use_cuda, use_reduce=False):
if use_cuda and not core.is_compiled_with_cuda(): if use_cuda and not core.is_compiled_with_cuda():
return return
self.check_network_convergence(simple_fc_net, use_cuda=use_cuda)
self.check_network_convergence( self.check_network_convergence(
simple_fc_net, use_cuda=use_cuda, allow_op_delay=True) model, use_cuda=use_cuda, use_reduce=True)
img, label = self._init_data()
self.check_network_convergence( self.check_network_convergence(
simple_fc_net, model, use_cuda=use_cuda, allow_op_delay=True, use_reduce=True)
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
use_reduce=use_reduce)
def check_simple_fc_convergence_with_Reduce(self, use_cuda): img, label = self._init_data(random_data)
if use_cuda and not core.is_compiled_with_cuda():
return
self.check_network_convergence(
simple_fc_net, use_cuda=use_cuda, use_reduce=True)
self.check_network_convergence(
simple_fc_net,
use_cuda=use_cuda,
allow_op_delay=True,
use_reduce=True)
img, label = self._init_data()
all_reduce_first_loss, all_reduce_last_loss = self.check_network_convergence( all_reduce_first_loss, all_reduce_last_loss = self.check_network_convergence(
simple_fc_net, model,
feed_dict={"image": img, feed_dict={"image": img,
"label": label}, "label": label},
use_cuda=use_cuda, use_cuda=use_cuda,
use_reduce=False) use_reduce=False)
reduce_first_loss, reduce_last_loss = self.check_network_convergence( reduce_first_loss, reduce_last_loss = self.check_network_convergence(
simple_fc_net, model,
feed_dict={"image": img, feed_dict={"image": img,
"label": label}, "label": label},
use_cuda=use_cuda, use_cuda=use_cuda,
...@@ -153,7 +133,24 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -153,7 +133,24 @@ class TestMNIST(TestParallelExecutorBase):
for loss in zip(all_reduce_first_loss, reduce_first_loss): for loss in zip(all_reduce_first_loss, reduce_first_loss):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6) self.assertAlmostEquals(loss[0], loss[1], delta=1e-6)
for loss in zip(all_reduce_last_loss, reduce_last_loss): for loss in zip(all_reduce_last_loss, reduce_last_loss):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6) self.assertAlmostEquals(loss[0], loss[1], delta=1e-4)
# simple_fc
def check_simple_fc_convergence(self, use_cuda, use_reduce=False):
if use_cuda and not core.is_compiled_with_cuda():
return
self.check_network_convergence(simple_fc_net, use_cuda=use_cuda)
self.check_network_convergence(
simple_fc_net, use_cuda=use_cuda, allow_op_delay=True)
img, label = self._init_data()
self.check_network_convergence(
simple_fc_net,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
use_reduce=use_reduce)
def test_simple_fc(self): def test_simple_fc(self):
# use_cuda # use_cuda
...@@ -162,8 +159,8 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -162,8 +159,8 @@ class TestMNIST(TestParallelExecutorBase):
def test_simple_fc_with_new_strategy(self): def test_simple_fc_with_new_strategy(self):
# use_cuda, use_reduce # use_cuda, use_reduce
self.check_simple_fc_convergence_with_Reduce(True) self._compare_reduce_and_allreduce(simple_fc_net, True)
self.check_simple_fc_convergence_with_Reduce(False) self._compare_reduce_and_allreduce(simple_fc_net, False)
def check_simple_fc_parallel_accuracy(self, use_cuda): def check_simple_fc_parallel_accuracy(self, use_cuda):
if use_cuda and not core.is_compiled_with_cuda(): if use_cuda and not core.is_compiled_with_cuda():
...@@ -209,39 +206,13 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -209,39 +206,13 @@ class TestMNIST(TestParallelExecutorBase):
"label": label}, "label": label},
use_cuda=use_cuda) use_cuda=use_cuda)
def check_batchnorm_fc_convergence_use_reduce(self, use_cuda):
if use_cuda and not core.is_compiled_with_cuda():
return
self.check_network_convergence(
fc_with_batchnorm, use_cuda=use_cuda, use_reduce=True)
img, label = self._init_data()
all_reduce_first_loss, all_reduce_last_loss = self.check_network_convergence(
fc_with_batchnorm,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
use_reduce=False)
reduce_first_loss, reduce_last_loss = self.check_network_convergence(
fc_with_batchnorm,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
use_reduce=True)
for loss in zip(all_reduce_first_loss, reduce_first_loss):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6)
for loss in zip(all_reduce_last_loss, reduce_last_loss):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-4)
def test_batchnorm_fc(self): def test_batchnorm_fc(self):
self.check_batchnorm_fc_convergence(True) self.check_batchnorm_fc_convergence(True)
self.check_batchnorm_fc_convergence(False) self.check_batchnorm_fc_convergence(False)
def test_batchnorm_fc_with_new_strategy(self): def test_batchnorm_fc_with_new_strategy(self):
self.check_batchnorm_fc_convergence_use_reduce(True) self._compare_reduce_and_allreduce(fc_with_batchnorm, True)
self.check_batchnorm_fc_convergence_use_reduce(False) self._compare_reduce_and_allreduce(fc_with_batchnorm, False)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -887,7 +887,8 @@ class DistributeTranspiler(object): ...@@ -887,7 +887,8 @@ class DistributeTranspiler(object):
# create table optimize block in pserver program # create table optimize block in pserver program
table_opt_op = [ table_opt_op = [
op for op in self.optimize_ops op for op in self.optimize_ops
if op.input("Param")[0] == self.table_name if 'Param' in op.input_names and op.input("Param")[0] ==
self.table_name
][0] ][0]
table_opt_block = pserver_program.create_block(pre_block_idx) table_opt_block = pserver_program.create_block(pre_block_idx)
# only support sgd now # only support sgd now
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册