diff --git a/doc/design/fluid-compiler.graffle b/doc/design/fluid-compiler.graffle new file mode 100644 index 0000000000000000000000000000000000000000..c933df2cb855462c52b2d25f7f9a99b95652961d Binary files /dev/null and b/doc/design/fluid-compiler.graffle differ diff --git a/doc/design/fluid-compiler.png b/doc/design/fluid-compiler.png new file mode 100644 index 0000000000000000000000000000000000000000..1b0ffed2039c91a3a00bbb719da08c91c3acf7bb Binary files /dev/null and b/doc/design/fluid-compiler.png differ diff --git a/doc/design/fluid.md b/doc/design/fluid.md new file mode 100644 index 0000000000000000000000000000000000000000..585dc8ef39c0cfb30f470d79f7b27a59ceb5e940 --- /dev/null +++ b/doc/design/fluid.md @@ -0,0 +1,122 @@ +# Design Doc: PaddlePaddle Fluid + +## Why Fluid + +When Baidu developed PaddlePaddle in 2013, the only well-known open source deep learning system at the time was Caffe. However, when PaddlePaddle was open-sourced in 2016, many other choices were available. There was a challenge -- what is the need for open sourcing yet another deep learning framework? + +Fluid is the answer. Fluid is similar to PyTorch and TensorFlow Eager Execution, which describes the "process" of training or inference using the concept of a model. In fact in PyTorch, TensorFlow Eager Execution and Fluid, there is no concept of a model at all. The details are covered in the sections below. Fluid is currently more extreme in the above mentioned idea than PyTorch and Eager Execution, and we are trying to push Fluid towards the directions of a compiler and a new programming language for deep learning. + +## The Evolution of Deep Learning Systems + +Deep learning infrastructure is one of the fastest evolving technologies. Within four years, there have already been three generations of technologies invented. + +| Existed since | model as sequence of layers | model as graph of operators | No model | +|--|--|--|--| +| 2013 | Caffe, Theano, Torch, PaddlePaddle | | | +| 2015 | | TensorFlow, MxNet, Caffe2, ONNX, n-graph | | +| 2016 | | | PyTorch, TensorFlow Eager Execution, PaddlePaddle Fluid | + +From the above table, we see that the deep learning technology is evolving towards getting rid of the concept of a model. To understand the reasons behind this direction, a comparison of the *programming paradigms* or the ways to program deep learning applications using these systems, would be helpful. The following section goes over these. + +## Deep Learning Programming Paradigms + +With the systems listed as the first or second generation, e.g., Caffe or TensorFlow, an AI application training program looks like the following: + +```python +x = layer.data("image") +l = layer.data("label") +f = layer.fc(x, W) +s = layer.softmax(f) +c = layer.mse(l, s) + +for i in xrange(1000): # train for 1000 iterations + m = read_minibatch() + forward({input=x, data=m}, minimize=c) + backward(...) + +print W # print the trained model parameters. +``` + +The above program includes two parts: + +1. The first part describes the model, and +2. The second part describes the training process (or inference process) for the model. + +This paradigm has a well-known problem that limits the productivity of programmers. If the programmer made a mistake in configuring the model, the error messages wouldn't show up until the second part is executed and `forward` and `backward` propagations are performed. This makes it difficult for the programmer to debug and locate a mistake that is located blocks away from the actual error prompt. + +This problem of being hard to debug and re-iterate fast on a program is the primary reason that programmers, in general, prefer PyTorch over the older systems. Using PyTorch, we would write the above program as following: + +```python +W = tensor(...) + +for i in xrange(1000): # train for 1000 iterations + m = read_minibatch() + x = m["image"] + l = m["label"] + f = layer.fc(x, W) + s = layer.softmax(f) + c = layer.mse(l, s) + backward() + +print W # print the trained model parameters. +``` + +We can see that the main difference is the moving the model configuration part (the first step) into the training loop. This change would allow the mistakes in model configuration to be reported where they actually appear in the programming block. This change also represents the model better, or its forward pass, by keeping the configuration process in the training loop. + +## Describe Arbitrary Models for the Future + +Describing the process instead of the model also brings Fluid, the flexibility to define different non-standard models that haven't been invented yet. + +As we write out the program for the process, we can write an RNN as a loop, instead of an RNN as a layer or as an operator. A PyTorch example would look like the following: + +```python +for i in xrange(1000): + m = read_minibatch() + x = m["sentence"] + for t in xrange x.len(): + h[t] = the_step(x[t]) +``` + +With Fluid, the training loop and the RNN in the above program are not really Python loops, but just a "loop structure" provided by Fluid and implemented in C++ as the following: + +```python +train_loop = layers.While(cond) +with train_loop.block(): + m = read_minibatch() + x = m["sentence"] + rnn = layers.While(...) + with rnn.block(): + h[t] = the_step(input[t]) +``` + +An actual Fluid example is described [here](https://github.com/PaddlePaddle/Paddle/blob/a91efdde6910ce92a78e3aa7157412c4c88d9ee8/python/paddle/v2/fluid/tests/test_while_op.py#L36-L44). + +From the example, the Fluid programs look very similar to their PyTorch equivalent programs, except that Fluid's loop structure, wrapped with Python's `with` statement, could run much faster than just a Python loop. + +We have more examples of the [`if-then-else`](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/if_else_op.md) structure of Fluid. + +## Turing Completeness + +In computability theory, a system of data-manipulation rules, such as a programming language, is said to be Turing complete if it can be used to simulate any Turing machine. For a programming language, if it provides if-then-else and loop, it is Turing complete. From the above examples, Fluid seems to be Turing complete; however, it is noteworthy to notice that there is a slight difference between the `if-then-else` of Fluid and that of a programming language. The difference being that the former runs both of its branches and splits the input mini-batch into two -- one for the True condition and another for the False condition. This hasn't been researched in depth if this is equivalent to the `if-then-else` in programming languages that makes them Turing-complete. Based on a conversation with [Yuang Yu](https://research.google.com/pubs/104812.html), it seems to be the case but this needs to be looked into in-depth. + +## The Execution of a Fluid Program + +There are two ways to execute a Fluid program. When a program is executed, it creates a protobuf message [`ProgramDesc`](https://github.com/PaddlePaddle/Paddle/blob/a91efdde6910ce92a78e3aa7157412c4c88d9ee8/paddle/framework/framework.proto#L145) that describes the process and is conceptually like an [abstract syntax tree](https://en.wikipedia.org/wiki/Abstract_syntax_tree). + +There is a C++ class [`Executor`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/executor.h), which runs a `ProgramDesc`, similar to how an interpreter runs a Python program. + +Fluid is moving towards the direction of a compiler, which is explain in more detail later in this article. + +## Backward Compatibility of Fluid + +Given all the advantages from the removal of the concept of a *model*, hardware manufacturers might still prefer the existence of the concept of a model, so it would be easier for them to support multiple frameworks all at once and could run a trained model during inference. For example, Nervana, a startup company acquired by Intel, has been working on an XPU that reads the models in the format known as [n-graph](https://github.com/NervanaSystems/ngraph). Similarly, [Movidius](https://www.movidius.com/) is producing a mobile deep learning chip that reads and runs graphs of operators. The well-known [ONNX](https://github.com/onnx/onnx) is also a file format of graphs of operators. + +For Fluid, we can write a converter that extracts the parts in the `ProgramDesc` protobuf message, converts them into a graph of operators, and exports the graph into the ONNX or n-graph format. + +## Towards a Deep Learning Language and the Compiler + +We can change the `if-then-else` and loop structure a little bit in the above Fluid example programs, to make it into a new programming language, different than Python. + +Even if we do not invent a new language, as long as we get the `ProgramDesc` message filled in, we can write a transpiler, which translates each invocation to an operator, into a C++ call to a kernel function of that operator. For example, a transpiler that weaves the CUDA kernels outputs an NVIDIA-friendly C++ program, which can be built using `nvcc`. Another transpiler could generate MKL-friendly code that should be built using `icc` from Intel. More interestingly, we can translate a Fluid program into its distributed version of two `ProgramDesc` messages, one for running on the trainer process, and the other one for the parameter server. For more details of the last example, the [concurrent programming design](concurrent_programming.md) document would be a good pointer. The following figure explains the proposed two-stage process: + +![](fluid-compiler.png) diff --git a/doc/design/support_new_device.md b/doc/design/support_new_device.md index 92443e43927ddb184ed51199a3a8c548ae607b3f..fd23dc211a35fdc9d87bc9233fcf4e90254da748 100644 --- a/doc/design/support_new_device.md +++ b/doc/design/support_new_device.md @@ -1,33 +1,33 @@ -# Design Doc: Support new Device/Library +# Design Doc: Supporting new Device/Library ## Background -Deep learning has a high demand for computing resources. New high-performance device and computing library are coming constantly. The deep learning framework has to integrate these high-performance device and computing library flexibly. +Deep learning has a high demand for computing resources. New high-performance devices and computing libraries are appearing very frequently. Deep learning frameworks have to integrate these high-performance devices and computing libraries flexibly and efficiently. -On the one hand, hardware and computing library are not usually one-to-one coresponding relations. For example, in Intel CPU, there are Eigen and MKL computing library. And in Nvidia GPU, there are Eigen and cuDNN computing library. We have to implement specific kernels for an operator for each computing library. +On one hand, hardware and computing libraries usually do not have a one-to-one correspondence. For example,Intel CPUs support Eigen and MKL computing libraries while Nvidia GPUs support Eigen and cuDNN computing libraries. We have to implement operator specific kernels for each computing library. -On the other hand, users usually do not want to care about the low-level hardware and computing library when writing a neural network configuration. In Fluid, `Layer` is exposed in `Python`, and `Operator` is exposed in `C++`. Both `Layer` and `Operator` are independent on hardwares. +On the other hand, users usually do not want to care about the low-level hardware and computing libraries when writing a neural network configuration. In Fluid, `Layer` is exposed in `Python`, and `Operator` is exposed in `C++`. Both `Layer` and `Operator` are hardware independent. So, how to support a new Device/Library in Fluid becomes a challenge. ## Basic: Integrate A New Device/Library -For a general overview of fluid, please refer to [overview doc](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/read_source.md). +For a general overview of fluid, please refer to the [overview doc](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/read_source.md). -There are mainly there parts we have to consider in integrating a new device/library: +There are mainly three parts that we have to consider while integrating a new device/library: - Place and DeviceContext: indicates the device id and manages hardware resources - Memory and Tensor: malloc/free data on certain device -- Math Functor and OpKernel: implement computing unit on certain device/library +- Math Functor and OpKernel: implement computing unit on certain devices/libraries ### Place and DeviceContext #### Place -Fluid use class [Place](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/place.h#L55) to represent specific device and computing library. There are inheritance relationships between different kinds of `Place`. +Fluid uses class [Place](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/place.h#L55) to represent different devices and computing libraries. There are inheritance relationships between different kinds of `Place`. ``` | CPUPlace --> MKLDNNPlace @@ -43,7 +43,7 @@ typedef boost::variant Place; #### DeviceContext -Fluid use class [DeviceContext](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/device_context.h#L30) to manage the resources in certain hardware, such as CUDA stream in `CDUADeviceContext`. There are also inheritance relationships between different kinds of `DeviceContext`. +Fluid uses class [DeviceContext](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/device_context.h#L30) to manage the resources in different hardwares, such as CUDA stream in `CDUADeviceContext`. There are also inheritance relationships between different kinds of `DeviceContext`. ``` @@ -52,7 +52,7 @@ DeviceContext ----> CUDADeviceContext --> CUDNNDeviceContext \-> FPGADeviceContext ``` -A example of Nvidia GPU is as follows: +An example of Nvidia GPU is as follows: - DeviceContext @@ -93,7 +93,7 @@ class CUDNNDeviceContext : public CUDADeviceContext { #### memory module -Fluid provide following [memory interfaces](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/memory/memory.h#L36): +Fluid provides the following [memory interfaces](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/memory/memory.h#L36): ``` template @@ -106,12 +106,12 @@ template size_t Used(Place place); ``` -To implementing these interfaces, we have to implement MemoryAllocator for specific Device +To implementing these interfaces, we have to implement MemoryAllocator for different Devices #### Tensor -[Tensor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/tensor.h#L36) holds data with some shape in certain Place. +[Tensor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/tensor.h#L36) holds data with some shape in a specific Place. ```cpp class Tensor { @@ -168,7 +168,7 @@ t.mutable_data(place); ### Math Functor and OpKernel -Fluid implements computing unit based on different DeviceContext. Some computing unit is shared between operators. These common part will be put in operators/math directory as basic Functors. +Fluid implements computing units based on different DeviceContexts. Some computing units are shared between operators. This common part will be put in operators/math directory as basic Functors. Let's take [MaxOutFunctor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/math/maxouting.h#L27) as an example: @@ -183,7 +183,7 @@ class MaxOutFunctor { }; ``` -CPU implement in .cc file +CPU implemention is in .cc file ``` template @@ -197,7 +197,7 @@ class MaxOutFunctor { }; ``` -CUDA implement in .cu file +CUDA implemention is in .cu file ``` template @@ -212,11 +212,11 @@ class MaxOutFunctor { ``` -We get computing handle from concrete DeviceContext, and make compution on tensors. +We get computing handle from a concrete DeviceContext, and make compution on tensors. -The implement of `OpKernel` is similar to math functors, the extra thing we need to do is registering the OpKernel to global map. +The implemention of `OpKernel` is similar to math functors, the extra thing we need to do is to register the OpKernel in a global map. -Fluid provides different register interface in op_registry.h +Fluid provides different register interfaces in op_registry.h Let's take [Crop](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/crop_op.cc#L134) operator as an example: @@ -240,7 +240,7 @@ REGISTER_OP_CUDA_KERNEL( ## Advanced topics: How to switch between different Device/Library -Generally, we will impelement OpKernel for all Device/Library of an Operator. We can easily train a Convolutional Neural Network in GPU. However, some OpKernel is not sutibale in a specific Device. For example, crf operator can be only run at CPU, whereas most other operators can be run at GPU. To achieve high performance in such circumstance, we have to switch between different Device/Library. +Generally, we will impelement OpKernel for all Device/Library of an Operator. We can easily train a Convolutional Neural Network in GPU. However, some OpKernel is not sutibale on a specific Device. For example, crf operator can only run on CPU, whereas most other operators can run at GPU. To achieve high performance in such circumstance, we have to switch between different Device/Library. We will discuss how to implement an efficient OpKernel switch policy. diff --git a/doc/faq/build_and_install/index_cn.rst b/doc/faq/build_and_install/index_cn.rst index f1677e216f31d79b53ac29a0afbf6fbb886a0dcd..a2bdeead7841393fdfe90c78e5b91d9e61678a24 100644 --- a/doc/faq/build_and_install/index_cn.rst +++ b/doc/faq/build_and_install/index_cn.rst @@ -14,7 +14,7 @@ $ export CUDA_SO="$(\ls usr/lib64/libcuda* | xargs -I{} echo '-v {}:{}') $(\ls /usr/lib64/libnvidia* | xargs -I{} echo '-v {}:{}')" $ export DEVICES=$(\ls /dev/nvidia* | xargs -I{} echo '--device {}:{}') - $ docker run ${CUDA_SO} ${DEVICES} -it paddledev/paddlepaddle:latest-gpu + $ docker run ${CUDA_SO} ${DEVICES} -it paddlepaddle/paddle:latest-gpu 更多关于Docker的安装与使用, 请参考 `PaddlePaddle Docker 文档 `_ 。 diff --git a/doc/getstarted/build_and_install/docker_install_cn.rst b/doc/getstarted/build_and_install/docker_install_cn.rst index f78b1fb0e11aa028a4b7abb5270740b97f8039e9..1eb06e4182d40c3be20d71e37b34009905eaf9d6 100644 --- a/doc/getstarted/build_and_install/docker_install_cn.rst +++ b/doc/getstarted/build_and_install/docker_install_cn.rst @@ -114,7 +114,7 @@ PaddlePaddle Book是为用户和开发者制作的一个交互式的Jupyter Note .. code-block:: bash - nvidia-docker run -it -v $PWD:/work paddledev/paddle:latest-gpu /bin/bash + nvidia-docker run -it -v $PWD:/work paddlepaddle/paddle:latest-gpu /bin/bash **注: 如果没有安装nvidia-docker,可以尝试以下的方法,将CUDA库和Linux设备挂载到Docker容器内:** @@ -122,7 +122,7 @@ PaddlePaddle Book是为用户和开发者制作的一个交互式的Jupyter Note export CUDA_SO="$(\ls /usr/lib64/libcuda* | xargs -I{} echo '-v {}:{}') $(\ls /usr/lib64/libnvidia* | xargs -I{} echo '-v {}:{}')" export DEVICES=$(\ls /dev/nvidia* | xargs -I{} echo '--device {}:{}') - docker run ${CUDA_SO} ${DEVICES} -it paddledev/paddle:latest-gpu + docker run ${CUDA_SO} ${DEVICES} -it paddlepaddle/paddle:latest-gpu **关于AVX:** diff --git a/doc/getstarted/build_and_install/docker_install_en.rst b/doc/getstarted/build_and_install/docker_install_en.rst index d7acc7aeb744b19d83acb520d07c8551168dd096..5a46c598f2248c7912169a9e77b16851230c1d2e 100644 --- a/doc/getstarted/build_and_install/docker_install_en.rst +++ b/doc/getstarted/build_and_install/docker_install_en.rst @@ -122,7 +122,7 @@ GPU driver installed before move on. .. code-block:: bash - nvidia-docker run -it -v $PWD:/work paddledev/paddle:latest-gpu /bin/bash + nvidia-docker run -it -v $PWD:/work paddlepaddle/paddle:latest-gpu /bin/bash **NOTE: If you don't have nvidia-docker installed, try the following method to mount CUDA libs and devices into the container.** @@ -130,7 +130,7 @@ GPU driver installed before move on. export CUDA_SO="$(\ls /usr/lib64/libcuda* | xargs -I{} echo '-v {}:{}') $(\ls /usr/lib64/libnvidia* | xargs -I{} echo '-v {}:{}')" export DEVICES=$(\ls /dev/nvidia* | xargs -I{} echo '--device {}:{}') - docker run ${CUDA_SO} ${DEVICES} -it paddledev/paddle:latest-gpu + docker run ${CUDA_SO} ${DEVICES} -it paddlepaddle/paddle:latest-gpu **About AVX:** diff --git a/paddle/operators/conv_op.h b/paddle/operators/conv_op.h index 749258183ba058cf0ed8d91c4406813694314b85..d2de4e80f751d4938ac9cad60871b470fccf225c 100644 --- a/paddle/operators/conv_op.h +++ b/paddle/operators/conv_op.h @@ -261,8 +261,12 @@ class GemmConvGradKernel : public framework::OpKernel { if (input_grad) { input_grad->mutable_data(context.GetPlace()); - set_zero(dev_ctx, input_grad, static_cast(0)); + // if is_expand is false, the operation of set_zero is unnecessary, + // because math::matmul will reset input_grad. + if (is_expand) { + set_zero(dev_ctx, input_grad, static_cast(0)); + } math::Col2VolFunctor col2vol; math::Col2ImFunctor col2im; diff --git a/paddle/operators/conv_transpose_op.h b/paddle/operators/conv_transpose_op.h index 80600b53614994ba0c740aed0d75c9944333fecc..1171b0435fd2b1abe541043e8283a8fc09dc13c7 100644 --- a/paddle/operators/conv_transpose_op.h +++ b/paddle/operators/conv_transpose_op.h @@ -225,7 +225,6 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { if (input_grad) { input_grad->mutable_data(context.GetPlace()); - set_zero(dev_ctx, input_grad, static_cast(0)); } if (filter_grad) { // filter size (m, c, k_h, k_w) filter_grad->mutable_data(context.GetPlace()); diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 1b560a7e2d29c1b63a25d4ec9bbd82d5960a279d..e33070c40fbfa7f2794426247ef77b8fcaee4ec6 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -273,6 +273,13 @@ void set_constant_with_place( TensorSetConstantGPU(context, tensor, value)); } +template <> +void set_constant_with_place( + const platform::DeviceContext& context, framework::Tensor* tensor, + float value) { + set_constant_with_place(context, tensor, value); +} + template struct RowwiseAdd; template struct RowwiseAdd; template struct ColwiseSum; diff --git a/paddle/operators/reduce_op.cc b/paddle/operators/reduce_op.cc index b754637bf29225615f129d7423d60518e053ca18..fedc2a5c37ff84ffdf8ebd2f19296db92e256e5b 100644 --- a/paddle/operators/reduce_op.cc +++ b/paddle/operators/reduce_op.cc @@ -37,18 +37,23 @@ class ReduceOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_LT( dim, x_rank, "The dim should be in the range [-rank(input), rank(input))."); - bool keep_dim = ctx->Attrs().Get("keep_dim"); - auto dims_vector = vectorize(x_dims); - if (keep_dim || x_rank == 1) { - dims_vector[dim] = 1; + bool reduce_all = ctx->Attrs().Get("reduce_all"); + if (reduce_all) { + ctx->SetOutputDim("Out", {1}); } else { - dims_vector.erase(dims_vector.begin() + dim); - } - auto out_dims = framework::make_ddim(dims_vector); - ctx->SetOutputDim("Out", out_dims); - if (dim != 0) { - // Only pass LoD when not reducing on the first dim. - ctx->ShareLoD("X", /*->*/ "Out"); + bool keep_dim = ctx->Attrs().Get("keep_dim"); + auto dims_vector = vectorize(x_dims); + if (keep_dim || x_rank == 1) { + dims_vector[dim] = 1; + } else { + dims_vector.erase(dims_vector.begin() + dim); + } + auto out_dims = framework::make_ddim(dims_vector); + ctx->SetOutputDim("Out", out_dims); + if (dim != 0) { + // Only pass LoD when not reducing on the first dim. + ctx->ShareLoD("X", /*->*/ "Out"); + } } } }; @@ -95,11 +100,16 @@ class ReduceOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, default false) " "If true, retain the reduced dimension with length 1.") .SetDefault(false); + AddAttr("reduce_all", + "(bool, default false) " + "If true, output a scalar reduced along all dimensions.") + .SetDefault(false); comment_ = R"DOC( {ReduceOp} Operator. This operator computes the {reduce} of input tensor along the given dimension. The result tensor has 1 fewer dimension than the input unless keep_dim is true. +If reduce_all is true, just reduce along all dimensions and output a scalar. )DOC"; AddComment(comment_); diff --git a/paddle/operators/reduce_op.h b/paddle/operators/reduce_op.h index 47ce910f2821467c701a7f5e22a8dbe5c8c95c92..7bd99cb1e6d532963ef648202f460f363baad9b5 100644 --- a/paddle/operators/reduce_op.h +++ b/paddle/operators/reduce_op.h @@ -26,10 +26,12 @@ using DDim = framework::DDim; template using EigenTensor = framework::EigenTensor; - template using EigenScalar = framework::EigenScalar; +template +using EigenVector = framework::EigenVector; struct SumFunctor { template @@ -95,26 +97,41 @@ template class ReduceKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - int rank = context.Input("X")->dims().size(); - switch (rank) { - case 1: - ReduceCompute<1>(context); - break; - case 2: - ReduceCompute<2>(context); - break; - case 3: - ReduceCompute<3>(context); - break; - case 4: - ReduceCompute<4>(context); - break; - case 5: - ReduceCompute<5>(context); - break; - case 6: - ReduceCompute<6>(context); - break; + bool reduce_all = context.Attr("reduce_all"); + if (reduce_all) { + // Flatten and reduce 1-D tensor + auto* input = context.Input("X"); + auto* output = context.Output("Out"); + output->mutable_data(context.GetPlace()); + auto x = EigenVector::Flatten(*input); + auto out = EigenScalar::From(*output); + auto& place = + *context.template device_context().eigen_device(); + auto reduce_dim = Eigen::array({{0}}); + Functor functor; + functor(place, x, out, reduce_dim); + } else { + int rank = context.Input("X")->dims().size(); + switch (rank) { + case 1: + ReduceCompute<1>(context); + break; + case 2: + ReduceCompute<2>(context); + break; + case 3: + ReduceCompute<3>(context); + break; + case 4: + ReduceCompute<4>(context); + break; + case 5: + ReduceCompute<5>(context); + break; + case 6: + ReduceCompute<6>(context); + break; + } } } @@ -157,26 +174,46 @@ template class ReduceGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - int rank = context.Input("X")->dims().size(); - switch (rank) { - case 1: - ReduceGradCompute<1>(context); - break; - case 2: - ReduceGradCompute<2>(context); - break; - case 3: - ReduceGradCompute<3>(context); - break; - case 4: - ReduceGradCompute<4>(context); - break; - case 5: - ReduceGradCompute<5>(context); - break; - case 6: - ReduceGradCompute<6>(context); - break; + bool reduce_all = context.Attr("reduce_all"); + if (reduce_all) { + auto* input0 = context.Input("X"); + auto* input1 = context.Input("Out"); + auto* input2 = context.Input(framework::GradVarName("Out")); + auto* output = context.Output(framework::GradVarName("X")); + output->mutable_data(context.GetPlace()); + auto x = EigenVector::Flatten(*input0); + auto x_reduce = EigenVector::From(*input1); + auto x_reduce_grad = EigenVector::From(*input2); + auto x_grad = EigenVector::Flatten(*output); + auto& place = + *context.template device_context().eigen_device(); + auto broadcast_dim = + Eigen::array({{static_cast(input0->numel())}}); + Functor functor; + functor(place, x, x_reduce, x_grad, x_reduce_grad, broadcast_dim, + broadcast_dim[0]); + } else { + int rank = context.Input("X")->dims().size(); + switch (rank) { + case 1: + ReduceGradCompute<1>(context); + break; + case 2: + ReduceGradCompute<2>(context); + break; + case 3: + ReduceGradCompute<3>(context); + break; + case 4: + ReduceGradCompute<4>(context); + break; + case 5: + ReduceGradCompute<5>(context); + break; + case 6: + ReduceGradCompute<6>(context); + break; + } } } diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index 2c7f96421621b9a34d1ec96c13d9c354a0d4012c..1c72b5055971e73c7aa560a61ca9d3c48dc56fbc 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -125,6 +125,22 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; } cudaStream_t CUDADeviceContext::stream() const { return stream_; } +CudnnDeviceContext::CudnnDeviceContext(CudnnPlace place) + : CUDADeviceContext(place), place_(place) { + PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); + PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream())); +} + +CudnnDeviceContext::~CudnnDeviceContext() { + SetDeviceId(place_.device); + Wait(); + PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); +} + +Place CudnnDeviceContext::GetPlace() const { return CudnnPlace(); } + +cudnnHandle_t CudnnDeviceContext::cudnn_handle() const { return cudnn_handle_; } + #endif } // namespace platform diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 596d9d0bba420a47fc10cc9dd96a755daa35dbac..f67194993db1f4160bd6894b2c845a82f4da2354 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -86,6 +86,22 @@ class CUDADeviceContext : public DeviceContext { cublasHandle_t cublas_handle_; }; +class CudnnDeviceContext : public CUDADeviceContext { + public: + explicit CudnnDeviceContext(CudnnPlace place); + virtual ~CudnnDeviceContext(); + + /*! \brief Return place in the device context. */ + Place GetPlace() const final; + + /*! \brief Return cudnn handle in the device context. */ + cudnnHandle_t cudnn_handle() const; + + private: + cudnnHandle_t cudnn_handle_; + CudnnPlace place_; +}; + #endif } // namespace platform diff --git a/paddle/platform/device_context_test.cc b/paddle/platform/device_context_test.cc index 4893cd92f6a74f7992c279ebd51232049f29e853..be3b2af5af09cb18f5156412ff60a7fc15a16487 100644 --- a/paddle/platform/device_context_test.cc +++ b/paddle/platform/device_context_test.cc @@ -46,3 +46,19 @@ TEST(Device, CUDADeviceContext) { delete device_context; } } + +TEST(Device, CudnnDeviceContext) { + using paddle::platform::CudnnDeviceContext; + using paddle::platform::CudnnPlace; + if (paddle::platform::dynload::HasCUDNN()) { + int count = paddle::platform::GetCUDADeviceCount(); + for (int i = 0; i < count; ++i) { + CudnnDeviceContext* device_context = + new CudnnDeviceContext(CudnnPlace(i)); + cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); + ASSERT_NE(nullptr, cudnn_handle); + ASSERT_NE(nullptr, device_context->stream()); + delete device_context; + } + } +} diff --git a/paddle/platform/place.h b/paddle/platform/place.h index 70acc5b2d4454f43ffa34fc205fe7236bcdd709d..5a1ce528006301a4d28c560a63db604076843a01 100644 --- a/paddle/platform/place.h +++ b/paddle/platform/place.h @@ -51,6 +51,11 @@ struct GPUPlace { int device; }; +struct CudnnPlace : public GPUPlace { + CudnnPlace() : GPUPlace() {} + explicit CudnnPlace(int d) : GPUPlace(d) {} +}; + struct IsGPUPlace : public boost::static_visitor { bool operator()(const CPUPlace &) const { return false; } bool operator()(const MKLDNNPlace &) const { return false; } @@ -67,7 +72,7 @@ struct IsMKLDNNPlace : public boost::static_visitor { // should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT) #define NUM_PLACE_TYPE_LIMIT_IN_BIT 4 -typedef boost::variant Place; +typedef boost::variant Place; // static check number of place types is less equal than // 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT) diff --git a/paddle/scripts/cluster_train_v2/openmpi/docker_cluster/Dockerfile b/paddle/scripts/cluster_train_v2/openmpi/docker_cluster/Dockerfile index 1a2d19e823541750830fcaa25f65b2f8e1ea2b49..c2f631bdf4ed52a5dfa3fbcf1157d0abbdeadb9b 100644 --- a/paddle/scripts/cluster_train_v2/openmpi/docker_cluster/Dockerfile +++ b/paddle/scripts/cluster_train_v2/openmpi/docker_cluster/Dockerfile @@ -1,7 +1,7 @@ # Build this image: docker build -t mpi . # -FROM paddledev/paddle:0.10.0rc3 +FROM paddlepaddle/paddle:0.10.0rc3 ENV DEBIAN_FRONTEND noninteractive diff --git a/paddle/scripts/tools/build_docs/build_docs.sh b/paddle/scripts/tools/build_docs/build_docs.sh index c6cbbc4eef94fb2e2fc3c1ce71734fbb23fc22d7..f9bc8bf63ae9afdfca1ff660bc83e62e71f03005 100755 --- a/paddle/scripts/tools/build_docs/build_docs.sh +++ b/paddle/scripts/tools/build_docs/build_docs.sh @@ -5,4 +5,4 @@ docker run --rm \ -e "WITH_AVX=ON" \ -e "WITH_DOC=ON" \ -e "WOBOQ=ON" \ - ${1:-"paddledev/paddle:dev"} + ${1:-"paddlepaddle/paddle:latest-dev"} diff --git a/python/paddle/v2/fluid/tests/test_reduce_op.py b/python/paddle/v2/fluid/tests/test_reduce_op.py index 70359d60cbe656150877673c63e81eae92d8ab9a..a021d4dd91bb9cc1e5d85411b3813b966ef5b296 100644 --- a/python/paddle/v2/fluid/tests/test_reduce_op.py +++ b/python/paddle/v2/fluid/tests/test_reduce_op.py @@ -85,5 +85,19 @@ class Test1DReduce(OpTest): self.check_grad(['X'], 'Out') +class TestReduceAll(OpTest): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float32")} + self.attrs = {'reduce_all': True} + self.outputs = {'Out': self.inputs['X'].sum()} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/parameters.py b/python/paddle/v2/parameters.py index bd97dc1199fedc8ac91c1c6086957e8cce88bdc4..7b7d1a1d1672802e0e91a857100604758683224e 100644 --- a/python/paddle/v2/parameters.py +++ b/python/paddle/v2/parameters.py @@ -383,19 +383,22 @@ class Parameters(object): params.deserialize(param_name, f) return params - def init_from_tar(self, f): + def init_from_tar(self, f, exclude_params=[]): """ Different from `from_tar`, this interface can be used to init partial network parameters from another saved model. :param f: the initialized model file. :type f: tar file + :param exclude_params: the names of parameters that should + not be initialized from the model file. + :type exclude_params: list of strings :return: Nothing. """ tar_param = Parameters.from_tar(f) for pname in tar_param.names(): - if pname in self.names(): + if pname in self.names() and pname not in exclude_params: self.set(pname, tar_param.get(pname))