diff --git a/README.md b/README.md index ceeb6d9e5193763293d3fce76e464340fbce533f..577528e7aaf45ce002467590ec66b19afb145920 100644 --- a/README.md +++ b/README.md @@ -61,32 +61,32 @@ Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddl ## Installation It is recommended to check out the -[Docker installation guide](http://doc.paddlepaddle.org/develop/doc/getstarted/build_and_install/docker_install_en.html) +[Docker installation guide](http://www.paddlepaddle.org/docs/develop/documentation/en/getstarted/build_and_install/docker_install_en.html) before looking into the -[build from source guide](http://doc.paddlepaddle.org/develop/doc/getstarted/build_and_install/build_from_source_en.html). +[build from source guide](http://www.paddlepaddle.org/docs/develop/documentation/en/getstarted/build_and_install/build_from_source_en.html). ## Documentation -We provide [English](http://doc.paddlepaddle.org/develop/doc/) and -[Chinese](http://doc.paddlepaddle.org/doc_cn/) documentation. +We provide [English](http://www.paddlepaddle.org/docs/develop/documentation/en/getstarted/index_en.html) and +[Chinese](http://www.paddlepaddle.org/docs/develop/documentation/zh/getstarted/index_cn.html) documentation. -- [Deep Learning 101](http://book.paddlepaddle.org/index.html) +- [Deep Learning 101](http://www.paddlepaddle.org/docs/develop/book/01.fit_a_line/index.html) You might want to start from this online interactive book that can run in a Jupyter Notebook. -- [Distributed Training](http://doc.paddlepaddle.org/develop/doc/howto/usage/cluster/cluster_train_en.html) +- [Distributed Training](http://www.paddlepaddle.org/docs/develop/documentation/en/howto/usage/cluster/cluster_train_en.html) You can run distributed training jobs on MPI clusters. -- [Distributed Training on Kubernetes](http://doc.paddlepaddle.org/develop/doc/howto/usage/k8s/k8s_en.html) +- [Distributed Training on Kubernetes](http://www.paddlepaddle.org/docs/develop/documentation/en/howto/usage/cluster/k8s_en.html) You can also run distributed training jobs on Kubernetes clusters. -- [Python API](http://doc.paddlepaddle.org/develop/doc/api/index_en.html) +- [Python API](http://www.paddlepaddle.org/docs/develop/documentation/en/api/index_en.html) Our new API enables much shorter programs. -- [How to Contribute](http://doc.paddlepaddle.org/develop/doc/howto/dev/contribute_to_paddle_en.html) +- [How to Contribute](http://www.paddlepaddle.org/docs/develop/documentation/en/howto/dev/contribute_to_paddle_en.html) We appreciate your contributions! diff --git a/doc/api/v2/fluid/layers.rst b/doc/api/v2/fluid/layers.rst index 9f3669e11583a4ed6467f1a1bb509481fdf0b9d1..92ca1cf0f836a376387f3e6f2b5a24c78109323d 100644 --- a/doc/api/v2/fluid/layers.rst +++ b/doc/api/v2/fluid/layers.rst @@ -188,12 +188,6 @@ beam_search_decode :noindex: -lstm ---------- -.. autofunction:: paddle.v2.fluid.layers.lstm - :noindex: - - lod_rank_table --------- .. autofunction:: paddle.v2.fluid.layers.lod_rank_table @@ -300,7 +294,21 @@ conv2d_transpose .. autofunction:: paddle.v2.fluid.layers.conv2d_transpose :noindex: + sequence_expand --------- .. autofunction:: paddle.v2.fluid.layers.sequence_expand :noindex: + + +lstm_unit +--------- +.. autofunction:: paddle.v2.fluid.layers.lstm_unit + :noindex: + + +sequence_softmax +--------- +.. autofunction:: paddle.v2.fluid.layers.sequence_softmax + :noindex: + diff --git a/doc/howto/usage/cluster/cluster_train_cn.md b/doc/howto/usage/cluster/cluster_train_cn.md index c9f90538a669d4705d18c3cd9b6dbf4a535c35b8..659bae9c0ceaf2fb2df8446b9d406a822a9df0ea 100644 --- a/doc/howto/usage/cluster/cluster_train_cn.md +++ b/doc/howto/usage/cluster/cluster_train_cn.md @@ -1,4 +1,4 @@ -# PaddlePaddle分布式训练 +# 分布式训练 ## 概述 @@ -181,8 +181,8 @@ PaddlePaddle可以使用多种分布式计算平台构建分布式计算任务 ## 在不同集群中运行 - - [fabric](fabric_cn.md) - - [openmpi](openmpi_cn.md) - - [kubernetes](k8s_cn.md) - - [kubernetes distributed](k8s_distributed_cn.md) - - [kubernetes on AWS](k8s_aws_cn.md) + - [fabric集群](fabric_cn.md) + - [openmpi集群](openmpi_cn.md) + - [kubernetes单机](k8s_cn.md) + - [kubernetes distributed分布式](k8s_distributed_cn.md) + - [AWS上运行kubernetes集群训练](k8s_aws_cn.md) diff --git a/doc/howto/usage/cluster/cluster_train_en.md b/doc/howto/usage/cluster/cluster_train_en.md index f9819470c0c622b4bc0ea064303d742385603230..915405ca5b446981515e301ca4b7ee065a82a9ff 100644 --- a/doc/howto/usage/cluster/cluster_train_en.md +++ b/doc/howto/usage/cluster/cluster_train_en.md @@ -1,4 +1,4 @@ -# PaddlePaddle Distributed Training +# Distributed Training ## Introduction @@ -188,5 +188,4 @@ These cluster platforms provide API or environment variables for training proces - [fabric](fabric_en.md) - [openmpi](openmpi_en.md) - [kubernetes](k8s_en.md) - - kubernetes distributed - [kubernetes on AWS](k8s_aws_en.md) diff --git a/doc/howto/usage/cluster/k8s_cn.md b/doc/howto/usage/cluster/k8s_cn.md index be43acd826e402af49033ae6370283e5a7bef5e1..c1a11f7165a2f9da9dd044641274447e7943a597 100644 --- a/doc/howto/usage/cluster/k8s_cn.md +++ b/doc/howto/usage/cluster/k8s_cn.md @@ -1,6 +1,6 @@ # Kubernetes单机训练 -在这篇文档里,我们介绍如何在 Kubernetes 集群上启动一个单机使用CPU的Paddle训练作业。在下一篇中,我们将介绍如何启动分布式训练作业。 +在这篇文档里,我们介绍如何在 Kubernetes 集群上启动一个单机使用CPU的PaddlePaddle训练作业。在下一篇中,我们将介绍如何启动分布式训练作业。 ## 制作Docker镜像 @@ -104,7 +104,7 @@ spec: restartPolicy: Never ``` -### 创建Paddle Job +### 创建PaddlePaddle Job 使用上文创建的yaml文件创建Kubernetes Job,命令为: diff --git a/doc/howto/usage/cluster/k8s_en.md b/doc/howto/usage/cluster/k8s_en.md index 1a16046800f151500cf70136336faf7c1f3e0e0c..c374f00a495d705ceddf8d3d930768ceeb93282b 100644 --- a/doc/howto/usage/cluster/k8s_en.md +++ b/doc/howto/usage/cluster/k8s_en.md @@ -1,6 +1,6 @@ -# Paddle On Kubernetes +# PaddlePaddle On Kubernetes ->In this article, we will introduce how to run Paddle training job on single CPU machine using Kubernetes. In next article, we will introduce how to run Paddle training job on distributed cluster. +In this article, we will introduce how to run PaddlePaddle training job on single CPU machine using Kubernetes. In next article, we will introduce how to run PaddlePaddle training job on distributed cluster. ## Build Docker Image @@ -76,7 +76,7 @@ $ docker commit quick_start_data mypaddle/paddle:quickstart ## Use Kubernetes For Training ->We will use Kubernetes job for training process, following steps shows how to do the training with Kubernetes. +We will use Kubernetes job for training process, following steps shows how to do the training with Kubernetes. ### Create Yaml Files @@ -108,7 +108,7 @@ spec: restartPolicy: Never ``` -### Start Paddle Job +### Start PaddlePaddle Job Using the above yaml file to start the Kubernetes job. diff --git a/paddle/operators/lstm_unit_op.cc b/paddle/operators/lstm_unit_op.cc index 18b9cdf2a39e8226c634194ff2cc56d169979774..b6eb33bafe50548502a0478d37842fd2dfdebda4 100644 --- a/paddle/operators/lstm_unit_op.cc +++ b/paddle/operators/lstm_unit_op.cc @@ -51,7 +51,10 @@ class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker { LstmUnitOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "FC input before the non-linear activation."); + AddInput("X", + "Lstm unit only applies non-linear activations, please make sure" + "that linear tranformation has already been applied to `X`. " + "Linear tranformation can be applied by adding a `fc` layer"); AddInput( "C_prev", "The cell state tensor of last time-step in the Lstm Unit operator."); diff --git a/python/paddle/v2/fluid/__init__.py b/python/paddle/v2/fluid/__init__.py index 59986c9f0ca8e4b793463db0e8c5da0489654ee9..9b3792ee9e3e4c6f319b3e2b13c4aa3a05cce8be 100644 --- a/python/paddle/v2/fluid/__init__.py +++ b/python/paddle/v2/fluid/__init__.py @@ -16,12 +16,13 @@ import regularizer from param_attr import ParamAttr from data_feeder import DataFeeder from core import LoDTensor, CPUPlace, GPUPlace +import clip Tensor = LoDTensor __all__ = framework.__all__ + executor.__all__ + [ 'io', 'initializer', 'layers', 'nets', 'optimizer', 'backward', 'regularizer', 'LoDTensor', 'CPUPlace', 'GPUPlace', 'Tensor', 'ParamAttr' - 'DataFeeder' + 'DataFeeder', 'clip' ] diff --git a/python/paddle/v2/fluid/clip.py b/python/paddle/v2/fluid/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..d7ec2fbe13fe6d9158345099b8668afc5c7d4571 --- /dev/null +++ b/python/paddle/v2/fluid/clip.py @@ -0,0 +1,61 @@ +import functools +import layers + +__all__ = ['GradientClipByValue', 'append_gradient_clip_ops'] + + +class BaseGradientClipAttr(object): + def process_context(self, context, p_g): + raise NotImplementedError() + + def create_operators(self, param, grad): + raise NotImplementedError() + + +class NullGradientClipAttr(BaseGradientClipAttr): + def process_context(self, context, p_g): + pass + + def create_operators(self, param, grad): + return param, grad + + +class GradientClipByValue(BaseGradientClipAttr): + def __init__(self, max, min=None): + max = float(max) + if min is None: + min = -max + else: + min = float(min) + self.max = max + self.min = min + + def process_context(self, context, p_g): + pass + + def create_operators(self, param, grad): + new_grad = layers.clip(x=grad, min=self.min, max=self.max) + return param, new_grad + + +def append_gradient_clip_ops(param_grad): + context = dict() + create_op_callbacks = [] + for p, g in param_grad: + clip_attr = getattr(p, 'clip_attr', NullGradientClipAttr()) + if clip_attr is None: + clip_attr = NullGradientClipAttr() + if not isinstance(clip_attr, BaseGradientClipAttr): + raise TypeError( + "clip attribute should be an instance of BaseGradientClippingAttr" + ) + + clip_attr.process_context(context=context, p_g=param_grad) + create_op_callbacks.append( + functools.partial( + clip_attr.create_operators, param=p, grad=g)) + + return [each_callback() for each_callback in create_op_callbacks] + + +ClipByValue = GradientClipByValue diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index bf0cd275b62ae2c4d7312592b8a730291c59a071..973672e6e469c7619ea5b4a166cce120655e4c6e 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -704,6 +704,7 @@ class Block(object): trainable=p.trainable, optimize_attr=p.optimize_attr, regularizer=p.regularizer, + clip_attr=p.clip_attr, name=v.name) self.vars[new_p.name] = new_p @@ -866,6 +867,8 @@ class Parameter(Variable): self.regularizer = kwargs.get('regularizer', None) + self.clip_attr = kwargs.get('clip_attr', None) + # program is a global instance. _main_program_ = Program() diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 5863957c5fb6f65ae299e2203bae324283c850e7..2c38c232240fbe3541ca5e0efc51d8f47c6e4190 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -5,12 +5,15 @@ All layers just related to the neural network. from ..layer_helper import LayerHelper from ..initializer import Normal, Constant from ..framework import Variable +from ..param_attr import ParamAttr +from tensor import concat __all__ = [ 'fc', 'embedding', 'dynamic_lstm', 'gru_unit', 'linear_chain_crf', 'crf_decoding', 'cos_sim', 'cross_entropy', 'square_error_cost', 'accuracy', 'chunk_eval', 'sequence_conv', 'conv2d', 'sequence_pool', 'pool2d', - 'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'sequence_expand' + 'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'sequence_expand', + 'lstm_unit' ] @@ -761,7 +764,7 @@ def conv2d_transpose(input, return out -def sequence_expand(x, y, main_program=None, startup_program=None): +def sequence_expand(x, y): """Sequence Expand Layer. This layer will expand the input variable **x** according to LoD information of **y**. And the following examples will explain how sequence_expand works: @@ -805,8 +808,6 @@ def sequence_expand(x, y, main_program=None, startup_program=None): Args: x (Variable): The input variable which is a Tensor or LoDTensor. y (Variable): The input variable which is a LoDTensor. - main_program (Program): The main program. - startup_program (Program): The startup program. Returns: Variable: The expanded variable which is a LoDTensor. @@ -826,3 +827,111 @@ def sequence_expand(x, y, main_program=None, startup_program=None): type='sequence_expand', inputs={'X': x, 'Y': y}, outputs={'Out': tmp}) return tmp + + +def lstm_unit(x_t, + hidden_t_prev, + cell_t_prev, + forget_bias=0.0, + param_attr=None, + bias_attr=None): + """Lstm unit layer. The equation of a lstm step is: + + .. math:: + + i_t & = \sigma(W_{x_i}x_{t} + W_{h_i}h_{t-1} + W_{c_i}c_{t-1} + b_i) + + f_t & = \sigma(W_{x_f}x_{t} + W_{h_f}h_{t-1} + W_{c_f}c_{t-1} + b_f) + + c_t & = f_tc_{t-1} + i_t tanh (W_{x_c}x_t+W_{h_c}h_{t-1} + b_c) + + o_t & = \sigma(W_{x_o}x_{t} + W_{h_o}h_{t-1} + W_{c_o}c_t + b_o) + + h_t & = o_t tanh(c_t) + + The inputs of lstm unit includes :math:`x_t`, :math:`h_{t-1}` and + :math:`c_{t-1}`. The implementation separates the linear transformation + and non-linear transformation apart. Here, we take :math:`i_t` as an + example. The linear transformation is applied by calling a `fc` layer and + the equation is: + + .. math:: + + L_{i_t} = W_{x_i}x_{t} + W_{h_i}h_{t-1} + W_{c_i}c_{t-1} + b_i + + The non-linear transformation is applied by calling `lstm_unit_op` and the + equation is: + + .. math:: + + i_t = \sigma(L_{i_t}) + + This layer has two outputs including :math:`h_t` and :math:`o_t`. + + Args: + x_t (Variable): The input value of current step. + hidden_t_prev (Variable): The hidden value of lstm unit. + cell_t_prev (Variable): The cell value of lstm unit. + forget_bias (float): The forget bias of lstm unit. + param_attr (ParamAttr): The attributes of parameter weights, used to set + initializer, name etc. + bias_attr (ParamAttr): The attributes of bias weights, if not False, + bias weights will be created and be set to default value. + + Returns: + tuple: The hidden value and cell value of lstm unit. + + Raises: + ValueError: The ranks of **x_t**, **hidden_t_prev** and **cell_t_prev**\ + not be 2 or the 1st dimensions of **x_t**, **hidden_t_prev** \ + and **cell_t_prev** not be the same. + + Examples: + + .. code-block:: python + + x_t = fluid.layers.fc(input=x_t_data, size=10) + prev_hidden = fluid.layers.fc(input=prev_hidden_data, size=20) + prev_cell = fluid.layers.fc(input=prev_cell_data, size=30) + hidden_value, cell_value = fluid.layers.lstm_unit(x_t=x_t, + hidden_t_prev=prev_hidden, + cell_t_prev=prev_cell) + """ + helper = LayerHelper('lstm_unit', **locals()) + + if len(x_t.shape) != 2: + raise ValueError("Rank of x_t must be 2.") + + if len(hidden_t_prev.shape) != 2: + raise ValueError("Rank of hidden_t_prev must be 2.") + + if len(cell_t_prev.shape) != 2: + raise ValueError("Rank of cell_t_prev must be 2.") + + if x_t.shape[0] != hidden_t_prev.shape[0] or x_t.shape[ + 0] != cell_t_prev.shape[0]: + raise ValueError("The 1s dimension of x_t, hidden_t_prev and " + "cell_t_prev must be the same.") + + if bias_attr is None: + bias_attr = ParamAttr() + + size = cell_t_prev.shape[1] + concat_out = concat(input=[x_t, hidden_t_prev], axis=1) + fc_out = fc(input=concat_out, + size=4 * size, + param_attr=param_attr, + bias_attr=bias_attr) + dtype = x_t.dtype + c = helper.create_tmp_variable(dtype) + h = helper.create_tmp_variable(dtype) + + helper.append_op( + type='lstm_unit', + inputs={"X": fc_out, + "C_prev": cell_t_prev}, + outputs={"C": c, + "H": h}, + attrs={"forget_bias": forget_bias}) + + return h, c diff --git a/python/paddle/v2/fluid/layers/ops.py b/python/paddle/v2/fluid/layers/ops.py index fa312ace60390e5fdd9637dccc71ccf8b247ca47..d2ff6841a317aaf6903edadc9213f69ef6c41216 100644 --- a/python/paddle/v2/fluid/layers/ops.py +++ b/python/paddle/v2/fluid/layers/ops.py @@ -2,7 +2,7 @@ from ..registry import register_layer __all__ = [ 'mean', 'mul', 'dropout', 'reshape', 'sigmoid', 'scale', 'transpose', 'sigmoid_cross_entropy_with_logits', 'elementwise_add', 'elementwise_div', - 'elementwise_sub', 'elementwise_mul', 'clip', 'abs' + 'elementwise_sub', 'elementwise_mul', 'clip', 'abs', 'sequence_softmax' ] for _OP in set(__all__): diff --git a/python/paddle/v2/fluid/optimizer.py b/python/paddle/v2/fluid/optimizer.py index 9f03eeea83e6d212da5fbe3d090d82028fa378ac..84fcbcdc2f2868a84bad5e145a934e33485b1fef 100644 --- a/python/paddle/v2/fluid/optimizer.py +++ b/python/paddle/v2/fluid/optimizer.py @@ -6,6 +6,7 @@ from framework import unique_name, program_guard from initializer import Constant from layer_helper import LayerHelper from regularizer import append_regularization_ops +from clip import append_gradient_clip_ops __all__ = ['SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad'] @@ -197,9 +198,13 @@ class Optimizer(object): `create_optimization_pass()` into one. """ params_grads = append_backward_ops(loss, parameter_list, no_grad_set) + + params_grads = append_gradient_clip_ops(params_grads) + # Add regularization if any params_grads = append_regularization_ops(params_grads, self.regularization) + optimize_ops = self.create_optimization_pass(params_grads, loss, startup_program) return optimize_ops diff --git a/python/paddle/v2/fluid/param_attr.py b/python/paddle/v2/fluid/param_attr.py index 7952a5ea51c00f72664443fb26faa455e89da7be..f6f320c788e7e08d44df8aff5ad3792b237e103a 100644 --- a/python/paddle/v2/fluid/param_attr.py +++ b/python/paddle/v2/fluid/param_attr.py @@ -1,6 +1,8 @@ from initializer import Initializer, Xavier, Constant from regularizer import WeightDecayRegularizer +__all__ = ['ParamAttr'] + class ParamAttr(object): def __init__(self, @@ -8,12 +10,14 @@ class ParamAttr(object): initializer=None, learning_rate=1.0, regularizer=None, - trainable=True): + trainable=True, + clip=None): self.name = name self.initializer = initializer self.learning_rate = learning_rate self.regularizer = regularizer self.trainable = trainable + self.clip = clip def set_default_initializer(self, initializer): if initializer is None: @@ -56,7 +60,8 @@ class ParamAttr(object): 'name': self.name, 'learning_rate': self.learning_rate, 'regularizer': self.regularizer, - 'trainable': self.trainable + 'trainable': self.trainable, + 'clip_attr': self.clip } if with_initializer: kwargs['initializer'] = self.initializer diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py index d77f19660ebcd470837e8b4e63509683de4a7a82..fc073f6be8563a363c0f98b9235ae267fa68562d 100644 --- a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py +++ b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py @@ -11,7 +11,9 @@ regularizer = fluid.regularizer.L2Decay(0.0005 * BATCH_SIZE) hidden1 = fluid.layers.fc(input=image, size=128, act='relu', - param_attr=regularizer) + param_attr=fluid.ParamAttr( + regularizer=regularizer, + clip=fluid.clip.ClipByValue(10))) hidden2 = fluid.layers.fc(input=hidden1, size=64, act='relu', diff --git a/python/paddle/v2/fluid/tests/test_layers.py b/python/paddle/v2/fluid/tests/test_layers.py index 2286e94a90a4810dfb170ba6e929a7c4f3edaba1..9d2dcca56dd1361b9e2448be9f1d5403f8ee17e3 100644 --- a/python/paddle/v2/fluid/tests/test_layers.py +++ b/python/paddle/v2/fluid/tests/test_layers.py @@ -161,7 +161,7 @@ class TestBook(unittest.TestCase): x=dat, label=lbl)) print(str(program)) - def test_seq_expand(self): + def test_sequence_expand(self): program = Program() with program_guard(program): x = layers.data(name='x', shape=[10], dtype='float32') @@ -170,6 +170,32 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(layers.sequence_expand(x=x, y=y)) print(str(program)) + def test_lstm_unit(self): + program = Program() + with program_guard(program): + x_t_data = layers.data( + name='x_t_data', shape=[10, 10], dtype='float32') + x_t = layers.fc(input=x_t_data, size=10) + prev_hidden_data = layers.data( + name='prev_hidden_data', shape=[10, 20], dtype='float32') + prev_hidden = layers.fc(input=prev_hidden_data, size=20) + prev_cell_data = layers.data( + name='prev_cell', shape=[10, 30], dtype='float32') + prev_cell = layers.fc(input=prev_cell_data, size=30) + self.assertIsNotNone( + layers.lstm_unit( + x_t=x_t, hidden_t_prev=prev_hidden, cell_t_prev=prev_cell)) + print(str(program)) + + def test_sequence_softmax(self): + program = Program() + with program_guard(program): + seq_data = layers.data( + name='seq_data', shape=[10, 10], dtype='float32', lod_level=1) + seq = layers.fc(input=seq_data, size=20) + self.assertIsNotNone(layers.sequence_softmax(x=seq)) + print(str(program)) + if __name__ == '__main__': unittest.main()