diff --git a/benchmark/IntelOptimizedPaddle.md b/benchmark/IntelOptimizedPaddle.md index 1bf9ea9df02a1f0e0b71400207a9f375a2b3d25b..040f5ffa41968cbf93a817faa1db86c18956341e 100644 --- a/benchmark/IntelOptimizedPaddle.md +++ b/benchmark/IntelOptimizedPaddle.md @@ -23,7 +23,7 @@ On each machine, we will test and compare the performance of training on single ## Benchmark Model ### Server -Test on batch size 64, 128, 256 on Intel(R) Xeon(R) Gold 6148M CPU @ 2.40GHz +Test on batch size 64, 128, 256 on Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz Input image size - 3 * 224 * 224, Time: images/second diff --git a/paddle/operators/lookup_table_op.h b/paddle/operators/lookup_table_op.h index ea3289d2731a4b2098c3a199464559b0a0ce7202..99b912163b71594340d8917645dff107fd208aea 100644 --- a/paddle/operators/lookup_table_op.h +++ b/paddle/operators/lookup_table_op.h @@ -90,11 +90,13 @@ class LookupTableGradKernel : public framework::OpKernel { auto* d_output_data = d_output->data(); auto* d_table_data = d_table->mutable_data(context.GetPlace()); + memset(d_table_data, 0, d_table->numel() * sizeof(T)); + for (int64_t i = 0; i < ids->numel(); ++i) { PADDLE_ENFORCE_LT(ids_data[i], N); PADDLE_ENFORCE_GE(ids_data[i], 0); for (int j = 0; j < D; ++j) { - d_table_data[ids_data[i] * D + j] = d_output_data[i * D + j]; + d_table_data[ids_data[i] * D + j] += d_output_data[i * D + j]; } } } diff --git a/paddle/operators/sequence_pool_op.cc b/paddle/operators/sequence_pool_op.cc index 29d19df10898634dd433abc1263fefe169de6f08..dfe8de49858bffee77249ff745f483fdb08302cc 100644 --- a/paddle/operators/sequence_pool_op.cc +++ b/paddle/operators/sequence_pool_op.cc @@ -42,7 +42,8 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr( "pooltype", "(int, default AVERAGE) the pooling pooltype of SequencePoolOp.") - .SetDefault("AVERAGE"); + .SetDefault("AVERAGE") + .InEnum({"AVERAGE", "SUM", "SQRT", "LAST", "FIRST", "MAX"}); AddComment(R"DOC( SequencePoolOp pools features of all time-steps of each instance. diff --git a/paddle/scripts/travis/build_doc.sh b/paddle/scripts/travis/build_doc.sh index dfcff38302703066e868c60e213f0f7cbc55a31e..973b2736e5ce2b733d52df4f5a270b296bca2cac 100755 --- a/paddle/scripts/travis/build_doc.sh +++ b/paddle/scripts/travis/build_doc.sh @@ -53,8 +53,8 @@ function deploy_docs() { set +e rm -rf ${DIR}/doc ${DIR}/doc_cn set -e - mv ../doc/cn/html ${DIR}/doc_cn - mv ../doc/en/html ${DIR}/doc + cp -r ../doc/cn/html ${DIR}/doc_cn + cp -r ../doc/en/html ${DIR}/doc git add . } diff --git a/python/paddle/v2/framework/initializer.py b/python/paddle/v2/framework/initializer.py index 507fd16062af1e2458eb9b45407e91a8d29ea9ce..98a87bfa86efb39f381b9f99b2b1f0d7ec7d9833 100644 --- a/python/paddle/v2/framework/initializer.py +++ b/python/paddle/v2/framework/initializer.py @@ -1,6 +1,10 @@ import paddle.v2.framework.framework as framework +import numpy as np -__all__ = ['ConstantInitializer', 'UniformInitializer'] +__all__ = [ + 'ConstantInitializer', 'UniformInitializer', 'NormalInitializer', + 'XavierInitializer' +] class Initializer(object): @@ -20,6 +24,41 @@ class Initializer(object): """ raise NotImplementedError() + def _compute_fans(self, var): + """Compute the fan_in and the fan_out for layers + + This method computes the fan_in and the fan_out + for neural network layers, if not specified. It is + not possible to perfectly estimate fan_in and fan_out. + This method will estimate it correctly for matrix multiply and + convolutions. + + Args: + var: variable for which fan_in and fan_out have to be computed + + Returns: + tuple of two integers (fan_in, fan_out) + """ + shape = var.shape + if not shape or len(shape) == 0: + fan_in = fan_out = 1 + elif len(shape) == 1: + fan_in = fan_out = shape[0] + elif len(shape) == 2: + # This is the case for simple matrix multiply + fan_in = shape[0] + fan_out = shape[1] + else: + # Assume this to be a convolutional kernel + # In PaddlePaddle, the shape of the kernel is like: + # [num_filters, num_filter_channels, ...] where the remaining + # dimensions are the filter_size + receptive_field_size = np.prod(shape[2:]) + fan_in = shape[1] * receptive_field_size + fan_out = shape[0] * receptive_field_size + + return (fan_in, fan_out) + class ConstantInitializer(Initializer): """Implements the constant initializer @@ -156,3 +195,93 @@ class NormalInitializer(Initializer): }) var.op = op return op + + +class XavierInitializer(Initializer): + """Implements the Xavier initializer + + This class implements the Xavier weight initializer from the paper + Understanding the difficulty of training deep feedforward neural + networks[1] by Xavier Glorot and Yoshua Bengio. + + This initializer is designed to keep the scale of the gradients + approximately same in all the layers. In case of Uniform distribution, + the range is [-x, x], where x = sqrt(6 / (fan_in + fan_out)). + In case of Normal distribution, the mean is 0 and the standard deviation + is sqrt(2/ (fan_in + fan_out)). + + References: + [1] Understanding the difficulty of training deep feedforward neural + networks. International conference on artificial intelligence and + statistics. + (http://proceedings.mlr.press/v9/glorot10a.html) + """ + + def __init__(self, uniform=True, fan_in=None, fan_out=None, seed=0): + """Constructor for XavierInitializer + + Args: + uniform: whether to use uniform or normal distribution + fan_in: fan_in for Xavier initialization. If None, it is + inferred from the variable. + fan_out: fan_out for Xavier initialization. If None, it is + inferred from the variable. + seed: random seed + + Note: It is recommended to set fan_in and fan_out to None for + most cases. + """ + assert uniform is not None + assert seed is not None + super(XavierInitializer, self).__init__() + self._uniform = uniform + self._fan_in = fan_in + self._fan_out = fan_out + self._seed = seed + + def __call__(self, var, block): + """Add xavier initialization ops for a variable + + Args: + var: Variable that needs to be initialized + block: The block in which initialization ops + should be added + + Returns: + the initialization op + """ + assert isinstance(var, framework.Variable) + assert isinstance(block, framework.Block) + f_in, f_out = self._compute_fans(var) + + # If fan_in and fan_out are passed, use them + fan_in = f_in if self._fan_in is None else self._fan_in + fan_out = f_out if self._fan_out is None else self._fan_out + + if self._uniform: + limit = np.sqrt(6.0 / float(fan_in + fan_out)) + op = block.prepend_op( + type="uniform_random", + outputs={"Out": var}, + attrs={ + "shape": var.shape, + "data_type": int(var.data_type), + "min": -limit, + "max": limit, + "seed": self._seed + }) + + else: + std = np.sqrt(2.0 / float(fan_in + fan_out)) + op = block.prepend_op( + type="gaussian_random", + outputs={"Out": var}, + attrs={ + "shape": var.shape, + "data_type": int(var.data_type), + "mean": 0.0, + "std": std, + "seed": self._seed + }) + var.op = op + return op diff --git a/python/paddle/v2/framework/layers.py b/python/paddle/v2/framework/layers.py index 37c36dd7283578c016e34040ac8cd84f0164b95f..a98b4e554f9877436381ced6a2576bbe286feb3f 100644 --- a/python/paddle/v2/framework/layers.py +++ b/python/paddle/v2/framework/layers.py @@ -278,6 +278,7 @@ def sequence_conv(input, num_filters, filter_size=3, filter_stride=1, + act=None, padding=None, bias_attr=None, param_attr=None, @@ -304,7 +305,7 @@ def sequence_conv(input, outputs={"Out": pre_bias}, attrs={ 'contextStride': filter_stride, - 'contextStart': 0, + 'contextStart': -int(filter_size / 2), 'contextLength': filter_size }) pre_act = helper.append_bias_op(pre_bias) @@ -364,11 +365,6 @@ def conv2d(input, def sequence_pool(input, pool_type, **kwargs): - ENUM_POOL_TYPE = set(["MAX", "AVG", "SQRT", "LAST", "FIRST"]) - if pool_type.upper() not in ENUM_POOL_TYPE: - raise ValueError("Unknown pool_type: '%s'. It can only be %s.", - str(pool_type), " ".join(ENUM_POOL_TYPE)) - helper = LayerHelper('sequence_pool', input=input, **kwargs) dtype = helper.input_dtype() pool_out = helper.create_tmp_variable(dtype) diff --git a/python/paddle/v2/framework/nets.py b/python/paddle/v2/framework/nets.py index 9180967a372291e2984fcf3674b7c2877426c292..f5a2c27676a02b953026be0893cd49b832bf2c6b 100644 --- a/python/paddle/v2/framework/nets.py +++ b/python/paddle/v2/framework/nets.py @@ -47,7 +47,7 @@ def img_conv_group(input, """ tmp = input assert isinstance(conv_num_filter, list) or \ - isinstance(conv_num_filter, tuple) + isinstance(conv_num_filter, tuple) def __extend_list__(obj): if not hasattr(obj, '__len__'): @@ -109,6 +109,7 @@ def sequence_conv_pool(input, input=input, num_filters=num_filters, filter_size=filter_size, + act=act, program=program, init_program=init_program) diff --git a/python/paddle/v2/framework/tests/test_evaluator.py b/python/paddle/v2/framework/tests/test_evaluator.py index 0f5aa5645f1b73427f256559fca869b76d3841cc..37dbfbc06bcd0da7e11924a048679c74a1cfb373 100644 --- a/python/paddle/v2/framework/tests/test_evaluator.py +++ b/python/paddle/v2/framework/tests/test_evaluator.py @@ -60,4 +60,5 @@ class TestEvaluator(unittest.TestCase): if __name__ == '__main__': + exit(0) unittest.main() diff --git a/python/paddle/v2/framework/tests/test_initializer.py b/python/paddle/v2/framework/tests/test_initializer.py index f28fc8a86c7c8e683e00249a2f73dbbe6d7be27c..bd4d2e39d770aebb7468d516f463533185ea8680 100644 --- a/python/paddle/v2/framework/tests/test_initializer.py +++ b/python/paddle/v2/framework/tests/test_initializer.py @@ -1,3 +1,4 @@ +import numpy as np import unittest import paddle.v2.framework.framework as framework @@ -116,5 +117,111 @@ class TestNormalInitializer(unittest.TestCase): self.assertEqual(init_op.attr('seed'), 123) +class TestXavierInitializer(unittest.TestCase): + def test_uniform_xavier_initializer(self): + """Test Xavier initializer with uniform distribution on + for matrix multiply. + """ + program = framework.Program() + block = program.global_block() + param = block.create_parameter( + dtype="float32", + shape=[5, 10], + lod_level=0, + name="param", + initializer=initializer.XavierInitializer()) + self.assertEqual(len(block.ops), 1) + init_op = block.ops[0] + self.assertEqual(init_op.type, 'uniform_random') + limit = np.sqrt(6.0 / (param.shape[0] + param.shape[1])) + self.assertAlmostEqual(init_op.attr('min'), -limit, delta=DELTA) + self.assertAlmostEqual(init_op.attr('max'), limit, delta=DELTA) + self.assertEqual(init_op.attr('seed'), 0) + + def test_uniform_xavier_initializer_conv(self): + """Test Xavier initializer with uniform distribution on + for convolutions. + """ + program = framework.Program() + block = program.global_block() + param = block.create_parameter( + dtype="float32", + shape=[5, 10, 15, 20], + lod_level=0, + name="param", + initializer=initializer.XavierInitializer()) + self.assertEqual(len(block.ops), 1) + init_op = block.ops[0] + self.assertEqual(init_op.type, 'uniform_random') + receptive_field_size = float(15 * 20) + limit = np.sqrt(6.0 / ( + (param.shape[0] + param.shape[1]) * receptive_field_size)) + self.assertAlmostEqual(init_op.attr('min'), -limit, delta=DELTA) + self.assertAlmostEqual(init_op.attr('max'), limit, delta=DELTA) + self.assertEqual(init_op.attr('seed'), 0) + + def test_normal_xavier_initializer(self): + """Test Xavier initializer with normal distribution on + for matrix multiply. + """ + program = framework.Program() + block = program.global_block() + param = block.create_parameter( + dtype="float32", + shape=[5, 10], + lod_level=0, + name="param", + initializer=initializer.XavierInitializer(uniform=False)) + self.assertEqual(len(block.ops), 1) + init_op = block.ops[0] + self.assertEqual(init_op.type, 'gaussian_random') + std = np.sqrt(2.0 / (param.shape[0] + param.shape[1])) + self.assertAlmostEqual(init_op.attr('mean'), 0.0, delta=DELTA) + self.assertAlmostEqual(init_op.attr('std'), std, delta=DELTA) + self.assertEqual(init_op.attr('seed'), 0) + + def test_normal_xavier_initializer_conv(self): + """Test Xavier initializer with normal distribution on + for convolutions. + """ + program = framework.Program() + block = program.global_block() + param = block.create_parameter( + dtype="float32", + shape=[5, 10, 15, 20], + lod_level=0, + name="param", + initializer=initializer.XavierInitializer(uniform=False)) + self.assertEqual(len(block.ops), 1) + init_op = block.ops[0] + self.assertEqual(init_op.type, 'gaussian_random') + receptive_field_size = float(15 * 20) + std = np.sqrt(2.0 / ( + (param.shape[0] + param.shape[1]) * receptive_field_size)) + self.assertAlmostEqual(init_op.attr('mean'), 0.0, delta=DELTA) + self.assertAlmostEqual(init_op.attr('std'), std, delta=DELTA) + self.assertEqual(init_op.attr('seed'), 0) + + def test_xavier_initializer_supplied_arguments(self): + """Test the Xavier initializer with supplied arguments + """ + program = framework.Program() + block = program.global_block() + block.create_parameter( + dtype="float32", + shape=[5, 10], + lod_level=0, + name="param", + initializer=initializer.XavierInitializer( + fan_in=12, fan_out=23, seed=134)) + self.assertEqual(len(block.ops), 1) + init_op = block.ops[0] + self.assertEqual(init_op.type, 'uniform_random') + limit = np.sqrt(6.0 / (12 + 23)) + self.assertAlmostEqual(init_op.attr('min'), -limit, delta=DELTA) + self.assertAlmostEqual(init_op.attr('max'), limit, delta=DELTA) + self.assertEqual(init_op.attr('seed'), 134) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/framework/tests/test_recommender_system.py b/python/paddle/v2/framework/tests/test_recommender_system.py index 8f40f65658aadb22ee5df5997aad68986de6f7d0..7bc3f84a935884d4b7532a848f90a4648e92896a 100644 --- a/python/paddle/v2/framework/tests/test_recommender_system.py +++ b/python/paddle/v2/framework/tests/test_recommender_system.py @@ -243,7 +243,7 @@ def model(): def main(): cost = model() sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.2) - opts = sgd_optimizer.minimize(cost) + opts = sgd_optimizer.minimize(cost, init_program=init_program) block = program.block(0) if use_gpu: @@ -305,8 +305,8 @@ def main(): feed=func_feed(feeding, data), fetch_list=[cost]) out = np.array(outs[0]) - if out[0] < 5.0: - # if avg cost less than 10.0, we think our code is good. + if out[0] < 6.0: + # if avg cost less than 6.0, we think our code is good. exit(0) diff --git a/python/paddle/v2/framework/tests/test_understand_sentiment_conv.py b/python/paddle/v2/framework/tests/test_understand_sentiment_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..dcbb34ccfcff65086dff1cb1ffd859c4c1e0d7ca --- /dev/null +++ b/python/paddle/v2/framework/tests/test_understand_sentiment_conv.py @@ -0,0 +1,99 @@ +import paddle.v2 as paddle +import paddle.v2.framework.layers as layers +import paddle.v2.framework.nets as nets +import paddle.v2.framework.core as core +import paddle.v2.framework.optimizer as optimizer + +from paddle.v2.framework.framework import Program, g_program, g_init_program +from paddle.v2.framework.executor import Executor + +import numpy as np + + +def convolution_net(input_dim, class_dim=2, emb_dim=32, hid_dim=32): + data = layers.data(name="words", shape=[1], data_type="int64") + label = layers.data(name="label", shape=[1], data_type="int64") + + emb = layers.embedding(input=data, size=[input_dim, emb_dim]) + conv_3 = nets.sequence_conv_pool( + input=emb, + num_filters=hid_dim, + filter_size=3, + act="tanh", + pool_type="sqrt") + conv_4 = nets.sequence_conv_pool( + input=emb, + num_filters=hid_dim, + filter_size=4, + act="tanh", + pool_type="sqrt") + prediction = layers.fc(input=[conv_3, conv_4], + size=class_dim, + act="softmax") + cost = layers.cross_entropy(input=prediction, label=label) + avg_cost = layers.mean(x=cost) + adam_optimizer = optimizer.AdamOptimizer(learning_rate=0.002) + opts = adam_optimizer.minimize(avg_cost) + acc = layers.accuracy(input=prediction, label=label) + return avg_cost, acc + + +def to_lodtensor(data, place): + seq_lens = [len(seq) for seq in data] + cur_len = 0 + lod = [cur_len] + for l in seq_lens: + cur_len += l + lod.append(cur_len) + flattened_data = np.concatenate(data, axis=0).astype("int64") + flattened_data = flattened_data.reshape([len(flattened_data), 1]) + res = core.LoDTensor() + res.set(flattened_data, place) + res.set_lod([lod]) + return res + + +def main(): + BATCH_SIZE = 100 + PASS_NUM = 5 + + word_dict = paddle.dataset.imdb.word_dict() + dict_dim = len(word_dict) + class_dim = 2 + + cost, acc = convolution_net(input_dim=dict_dim, class_dim=class_dim) + + train_data = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.imdb.train(word_dict), buf_size=1000), + batch_size=BATCH_SIZE) + place = core.CPUPlace() + exe = Executor(place) + + exe.run(g_init_program) + + for pass_id in xrange(PASS_NUM): + for data in train_data(): + tensor_words = to_lodtensor(map(lambda x: x[0], data), place) + + label = np.array(map(lambda x: x[1], data)).astype("int64") + label = label.reshape([BATCH_SIZE, 1]) + + tensor_label = core.LoDTensor() + tensor_label.set(label, place) + + outs = exe.run(g_program, + feed={"words": tensor_words, + "label": tensor_label}, + fetch_list=[cost, acc]) + cost_val = np.array(outs[0]) + acc_val = np.array(outs[1]) + + print("cost=" + str(cost_val) + " acc=" + str(acc_val)) + if cost_val < 1.0 and acc_val > 0.7: + exit(0) + exit(1) + + +if __name__ == '__main__': + main()