diff --git a/doc/fluid/design/dynamic_rnn/rnn_design_en.md b/doc/fluid/design/dynamic_rnn/rnn_design_en.md new file mode 100644 index 0000000000000000000000000000000000000000..9493908f4f73b3e7d91f5f6364a2a3660257d508 --- /dev/null +++ b/doc/fluid/design/dynamic_rnn/rnn_design_en.md @@ -0,0 +1,175 @@ +# Varient Length supported RNN Design +For the learning of variable length sequences, the existing mainstream frameworks such as tensorflow, pytorch, caffe2, mxnet and so on all use padding. + +Different-length sequences in a mini-batch will be padded with zeros and transformed to same length. + +The existing RNN implementations of the PaddlePaddle is `RecurrentLayerGroup`, +which supports the variable length sequences without padding. +This doc will design fluid's RNN based on this idea. + +## Multi-layer sequence data format `LODTensor` +At present, Paddle stores data in one mini-batch in one-dimensional array. + +`Argument.sequenceStartPositions` is used to store information for each sentence. + +In Paddle, `Argument.subSequenceStartPositions` is used to store 2 levels of sequence information, while higher dimensional sequences can not be supported. + +In order to support the storage of `N-level` sequences, we define sequence information as the following data structure. + + +```c++ +std::shared_ptr>> lod_start_pos_; +``` + +Or more clearly defined here + +```c++ +typedef std::vector level_t; +std::vector lod_start_pos; +``` +Each `level_t` here stores a level of offset information consistent with paddle's current practice. + +In order to transmit sequence information more transparently, we have introduced a new tensor called `LODTensor`[1]. +Its tensor-related interfaces all inherit directly from `Tensor`, but it also adds serial-related interfaces. +Thus, when working with a `LODTensor`, ordinary `Op` is used directly as `Tensor`. +The `Op` of the operation sequence will additionally operate the relevant interface of the `LODTensor` variable-length sequence operation. + +The definition of `LODTensor` is as follows: + + +```c++ +class LODTensor : public Tensor { +public: + size_t Levels() const { return seq_start_positions_.size(); } + size_t Elements(int level = 0) const { + return seq_start_positions_[level].size(); + } + // slice of level[elem_begin: elem_end] + // NOTE low performance in slice seq_start_positions_. + // TODO should call Tensor's Slice. + LODTensor LODSlice(int level, int elem_begin, int elem_end) const; + + // slice with tensor's data shared with this. + LODTensor LODSliceShared(int level, int elem_begin, int elem_end) const; + + // copy other's lod_start_pos_, to share LOD info. + // NOTE the LOD info sould not be changed. + void ShareConstLODFrom(const LODTensor &other) { + lod_start_pos_ = other.lod_start_pos_; + } + // copy other's lod_start_pos_'s content, free to mutate. + void ShareMutableLODFrom(const LODTensor &other) { + lod_start_pos_ = std::make_shared < + std::vector>(other.lod_start_pos_.begin(), + other.lod_start_pos_.end()); + } + +private: + std::shared_ptr>> lod_start_pos_; +}; +``` +Among them, `lod_start_pos_` uses `shared_ptr` to reduce the cost of storage and replication. +`LODTensor` can be thought as an extension of `Tensor`, which is almost completely compatible with the original `Tensor`. + +## How to support the framework +### Replace `Tensor` with `LoDTensor` +To implement the passing of `LODTensor`, most `Tensor` in the framework need to be replaced with `LODTensor`. +Simple implementation, directly **replace all previous `Tensor` with `LODTensor`** , where you can directly modify the `Tensor` interface created in `pybind.cc`. + +In addition, the user may need to perceive the existence of a sequence (such as the sequence of the visualization needs to parse the output sequence in the model), so some of the serial operation APIs also need to be exposed to the python layer. + +### Transmit `lod_start_pos` along with the Op call chain +`lod_start_pos` is passed along with the Op call chain +The framework needs to support the following features to implement the transmit of `lod_start_pos`: + +1. Implement the transfer as `shared_ptr` + - Do not modify the contents of `lod_start_pos` as a consumer + - Modify producer of `lod_start_pos` as producer + - Conventions consumer only needs to copy `shared_ptr` passed over + - producer needs to create its own independent memory to store its own independent modifications and expose `shared_ptr` to subsequent consumer + - Since the transfer process is implemented by copying `shared_ptr`, the framework only needs to pass `lod_start_pos` once. + +2. Op is transparent enough not to sense `lod_start_pos` +3. Producer Op that needs to modify `lod_start_pos` can update its `lod_start_pos` data when `Run` + +## sorted by length +After sorting by length, the batch size from the forward time step will naturally decrement, and you can directly plug it into Net to do the batch calculation. + +For example, the original input: + +``` +origin: +xxxx +xx +xxx + +-> sorted: +xxxx +xxx +xx +``` + +After `SegmentInputs`, there will be 4 time steps, the input of each time step is as follows (vertical arrangement) + +``` +0 1 2 3 +x x x x +x x x +x x +``` + +In order to track the changes before and after sorting, use here + +```c++ +struct SortedSeqItem { + void *start{nullptr}; + void *end{nullptr}; +}; + +std::vector sorted_seqs; +``` +To track the position of the sequence after sorting, and add a new interface + +```c++ +std::vector SortBySeqLen(const LODTensor& tensor); +``` +Due to the sequence of input sequences, the following existing interfaces need to be modified: + +- InitMemories, memory needs to be rearranged according to `sorted_seqs` +- SetmentInputs +- ConcatOutputs + +In addition, because `sorted_seqs` needs to be multiplexed with `RecurrentGradientOp`, it will become a new output of `RecurrentOp`. +It is passed in as an input to `RecurrentGradientOp`. + +## InitMemories +Due to the sequence change, the order of the elements on the `boot_memories` batch also needs to be rearranged accordingly. + +## SegmentInputs + +`SegmentInputs` relies on the information of `sorted_seqs` to cut the original sequence from the horizontal to the input of each step in the sorted sequence order. + +the transition is as follows: +``` +origin: +xxxx +xx +xxx + + | + | + \ / + ! +0 1 2 3 +x x x x +x x x +x x +``` +## ConcatOutputs +`ConcatOutputs` needs + +- Restore the output of each time step back to the original input sequence order (to prevent the order of Infer phase from being upset) +- Concat each sequence as a regular mini-batch representation + +## references +1. [Level of details](https://en.wikipedia.org/wiki/Level_of_detail) diff --git a/doc/fluid/getstarted/quickstart_cn.rst b/doc/fluid/getstarted/quickstart_cn.rst deleted file mode 120000 index 93a9e4e37a8495c553cec257c27363ca8d062d39..0000000000000000000000000000000000000000 --- a/doc/fluid/getstarted/quickstart_cn.rst +++ /dev/null @@ -1 +0,0 @@ -../../v2/getstarted/quickstart_cn.rst \ No newline at end of file diff --git a/doc/fluid/getstarted/quickstart_cn.rst b/doc/fluid/getstarted/quickstart_cn.rst new file mode 100644 index 0000000000000000000000000000000000000000..135beb75d0330f39d062753aa2aa83a077f36bb1 --- /dev/null +++ b/doc/fluid/getstarted/quickstart_cn.rst @@ -0,0 +1,45 @@ +快速开始 +======== + +快速安装 +-------- + +PaddlePaddle支持使用pip快速安装,目前支持CentOS 6以上, Ubuntu 14.04以及MacOS 10.12,并安装有Python2.7。 +执行下面的命令完成快速安装,版本为cpu_avx_openblas: + + .. code-block:: bash + + pip install paddlepaddle + +如果需要安装支持GPU的版本(cuda7.5_cudnn5_avx_openblas),需要执行: + + .. code-block:: bash + + pip install paddlepaddle-gpu + +更详细的安装和编译方法参考: :ref:`install_steps` 。 + +快速使用 +-------- + +创建一个 housing.py 并粘贴此Python代码: + + .. code-block:: python + + import paddle.dataset.uci_housing as uci_housing + import paddle.fluid as fluid + + with fluid.scope_guard(fluid.core.Scope()): + # initialize executor with cpu + exe = fluid.Executor(place=fluid.CPUPlace()) + # load inference model + [inference_program, feed_target_names,fetch_targets] = \ + fluid.io.load_inference_model(uci_housing.fluid_model(), exe) + # run inference + result = exe.run(inference_program, + feed={feed_target_names[0]: uci_housing.predict_reader()}, + fetch_list=fetch_targets) + # print predicted price is $12,273.97 + print 'Predicted price: ${:,.2f}'.format(result[0][0][0] * 1000) + +执行 :code:`python housing.py` 瞧! 它应该打印出预测住房数据的清单。 diff --git a/doc/fluid/getstarted/quickstart_en.rst b/doc/fluid/getstarted/quickstart_en.rst deleted file mode 120000 index 6e1894faa1176bb9e77f616e07df36191e54b782..0000000000000000000000000000000000000000 --- a/doc/fluid/getstarted/quickstart_en.rst +++ /dev/null @@ -1 +0,0 @@ -../../v2/getstarted/quickstart_en.rst \ No newline at end of file diff --git a/doc/fluid/getstarted/quickstart_en.rst b/doc/fluid/getstarted/quickstart_en.rst new file mode 100644 index 0000000000000000000000000000000000000000..df6619cfd039fc1fdca8cde57db9cc6aebf8f029 --- /dev/null +++ b/doc/fluid/getstarted/quickstart_en.rst @@ -0,0 +1,49 @@ +Quick Start +============ + +Quick Install +------------- + +You can use pip to install PaddlePaddle with a single command, supports +CentOS 6 above, Ubuntu 14.04 above or MacOS 10.12, with Python 2.7 installed. +Simply run the following command to install, the version is cpu_avx_openblas: + + .. code-block:: bash + + pip install paddlepaddle + +If you need to install GPU version (cuda7.5_cudnn5_avx_openblas), run: + + .. code-block:: bash + + pip install paddlepaddle-gpu + +For more details about installation and build: :ref:`install_steps` . + +Quick Use +--------- + +Create a new file called housing.py, and paste this Python +code: + + + .. code-block:: python + + import paddle.dataset.uci_housing as uci_housing + import paddle.fluid as fluid + + with fluid.scope_guard(fluid.core.Scope()): + # initialize executor with cpu + exe = fluid.Executor(place=fluid.CPUPlace()) + # load inference model + [inference_program, feed_target_names,fetch_targets] = \ + fluid.io.load_inference_model(uci_housing.fluid_model(), exe) + # run inference + result = exe.run(inference_program, + feed={feed_target_names[0]: uci_housing.predict_reader()}, + fetch_list=fetch_targets) + # print predicted price is $12,273.97 + print 'Predicted price: ${:,.2f}'.format(result[0][0][0] * 1000) + +Run :code:`python housing.py` and voila! It should print out a list of predictions +for the test housing data. diff --git a/paddle/fluid/framework/details/fetch_op_handle.cc b/paddle/fluid/framework/details/fetch_op_handle.cc index 946ee91a667496e2427304df4228334bb1061890..423449abff97dbf70d81314f852d9135e25f243f 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.cc +++ b/paddle/fluid/framework/details/fetch_op_handle.cc @@ -66,7 +66,7 @@ void FetchOpHandle::RunImpl() { auto &t = var->Get(); if (platform::is_gpu_place(t.place())) { #ifdef PADDLE_WITH_CUDA - TensorCopy(t, cpu, *dev_ctxes_[t.place()], &tensors_[i]); + TensorCopy(t, cpu, *dev_ctxes_[t.place()], &tensors_[i], true); dev_ctxes_.at(t.place())->Wait(); #endif } else { diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 10d39e779336e2001d66a55ac6d01ee768ddd4ff..3413467b149539bcff42d78a9a6fe315d6558bb4 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -78,6 +78,33 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, } } +bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, + OpDesc *send_op) const { + if (send_op == nullptr) { + return false; + } + + auto checker = [&](const std::vector opvars, + const std::vector sendvars) -> bool { + bool is_dist_train_op = false; + for (auto &var : opvars) { + if (var.find(".block") != std::string::npos && + std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) { + is_dist_train_op = true; + break; + } + } + return is_dist_train_op; + }; + + if (op.Type() == "split") { + return checker(op.OutputArgumentNames(), send_op->InputArgumentNames()); + } else if (op.Type() == "concat") { + return checker(op.InputArgumentNames(), send_op->OutputArgumentNames()); + } + return false; +} + std::unique_ptr MultiDevSSAGraphBuilder::Build( const ProgramDesc &program) const { auto graph = new SSAGraph(); @@ -89,19 +116,30 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( std::unordered_map>>>( places_.size()); + // Find "send" op first for split is in front of send. + OpDesc *send_op = nullptr; + for (auto *op : program.Block(0).AllOps()) { + if (op->Type() == "send") { + send_op = op; + break; + } + } + bool is_forwarding = true; for (auto *op : program.Block(0).AllOps()) { if (op->Type() == "send") { // append send op if program is distributed trainer main program. // always use the first device CreateSendOp(&result, *op); + } else if (IsDistTrainOp(*op, send_op)) { + CreateComputationalOps(&result, *op, 1); } else if (IsScaleLossOp(*op)) { if (!skip_scale_loss_) { CreateScaleLossGradOp(&result); } is_forwarding = false; } else { - CreateComputationalOps(&result, *op); + CreateComputationalOps(&result, *op, places_.size()); if (!is_forwarding) { // Currently, we assume that once gradient is generated, it can be // broadcast, and each gradient is only broadcast once. But there are no @@ -199,8 +237,9 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { } void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result, - const OpDesc &op) const { - for (size_t scope_idx = 0; scope_idx < places_.size(); ++scope_idx) { + const OpDesc &op, + size_t num_places) const { + for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { auto p = places_[scope_idx]; auto s = local_scopes_[scope_idx]; result->ops_.emplace_back(new ComputationOpHandle(op, s, p)); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 009c31b40c279ae3e924d2d7c67933e7444ed85c..dc3da70eda2abaa1a312c25aedf94fa7e427c78a 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -65,7 +65,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { void CreateSendOp(SSAGraph *result, const OpDesc &op) const; - void CreateComputationalOps(SSAGraph *result, const OpDesc &op) const; + bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const; + + void CreateComputationalOps(SSAGraph *result, const OpDesc &op, + size_t num_places) const; void CreateScaleLossGradOp(SSAGraph *result) const; diff --git a/paddle/fluid/framework/init.cc b/paddle/fluid/framework/init.cc index 642b892105a83fa19f0a1a501ecaa7d00f6ac5d9..721364e4bd0bace3d11702d7a905c2b15ebc800c 100644 --- a/paddle/fluid/framework/init.cc +++ b/paddle/fluid/framework/init.cc @@ -34,7 +34,7 @@ std::once_flag p2p_init_flag; using paddle::platform::DeviceContextPool; -void Init(std::vector &argv) { +void Init(std::vector argv) { InitGflags(argv); // init devices std::vector devices; @@ -46,7 +46,7 @@ void Init(std::vector &argv) { InitDevices(FLAGS_init_p2p, devices); } -void InitGflags(std::vector &argv) { +void InitGflags(std::vector argv) { std::call_once(gflags_init_flag, [&]() { argv.push_back("dummy"); int argc = argv.size(); @@ -108,7 +108,7 @@ void InitP2P(std::vector devices) { } void InitDevices(bool init_p2p) { - /*Init all avaiable devices by default */ + /*Init all available devices by default */ std::vector places; places.emplace_back(platform::CPUPlace()); diff --git a/paddle/fluid/framework/init.h b/paddle/fluid/framework/init.h index cf792f18b7ac16808d09e86a982e741fd0dca968..05ce3376feaddcb926aa8ec0255d8d3f7a3b782d 100644 --- a/paddle/fluid/framework/init.h +++ b/paddle/fluid/framework/init.h @@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include +#include // NOLINT +#include +#include #include "gflags/gflags.h" #include "glog/logging.h" @@ -20,9 +22,9 @@ limitations under the License. */ namespace paddle { namespace framework { -void Init(std::vector &argv); +void Init(std::vector argv); -void InitGflags(std::vector &argv); +void InitGflags(std::vector argv); void InitGLOG(const std::string &prog_name); diff --git a/paddle/fluid/framework/library_type.h b/paddle/fluid/framework/library_type.h index ea538731b469901a3357d624c5bb0fddc4058488..904cc013012b9c3ea8054816446844f6d2cda26b 100644 --- a/paddle/fluid/framework/library_type.h +++ b/paddle/fluid/framework/library_type.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include +#include namespace paddle { namespace framework { @@ -67,5 +68,5 @@ inline std::ostream& operator<<(std::ostream& out, LibraryType l) { return out; } -} // namespace -} // framework +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index d1b01ae05b808b229309e9689165483a11530c84..d2e60ab1dd16758a91d22ef6872edc5053ef88b3 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -20,7 +20,7 @@ namespace paddle { namespace framework { void TensorCopy(const Tensor& src, const platform::Place& dst_place, - const platform::DeviceContext& ctx, Tensor* dst) { + const platform::DeviceContext& ctx, Tensor* dst, bool sync) { VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to " << dst_place; src.check_memory_size(); @@ -47,9 +47,11 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, PADDLE_ENFORCE(platform::is_gpu_place(ctx_place)); auto ctx_gpu_place = boost::get(ctx_place); PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place); - memory::Copy( - dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size, - reinterpret_cast(ctx).stream()); + auto stream = + sync ? nullptr + : reinterpret_cast(ctx) + .stream(); + memory::Copy(dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream); } else if (platform::is_cpu_place(src_place) && platform::is_gpu_place(dst_place)) { auto src_cpu_place = boost::get(src_place); @@ -58,18 +60,22 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, PADDLE_ENFORCE(platform::is_gpu_place(ctx_place)); auto ctx_gpu_place = boost::get(ctx_place); PADDLE_ENFORCE_EQ(dst_gpu_place, ctx_gpu_place); - memory::Copy( - dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, - reinterpret_cast(ctx).stream()); + auto stream = + sync ? nullptr + : reinterpret_cast(ctx) + .stream(); + memory::Copy(dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, stream); } else if (platform::is_gpu_place(src_place) && platform::is_gpu_place(dst_place)) { auto src_gpu_place = boost::get(src_place); auto dst_gpu_place = boost::get(dst_place); auto ctx_place = ctx.GetPlace(); PADDLE_ENFORCE(platform::is_gpu_place(ctx_place)); - memory::Copy( - dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, - reinterpret_cast(ctx).stream()); + auto stream = + sync ? nullptr + : reinterpret_cast(ctx) + .stream(); + memory::Copy(dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream); } #endif } diff --git a/paddle/fluid/framework/tensor_util.h b/paddle/fluid/framework/tensor_util.h index 78b165ebed13cbae791b922e8820cd9551dfd198..3af68402dc56230171e858bf8f8f8c89c2bfe760 100644 --- a/paddle/fluid/framework/tensor_util.h +++ b/paddle/fluid/framework/tensor_util.h @@ -24,7 +24,8 @@ namespace paddle { namespace framework { void TensorCopy(const Tensor& src, const platform::Place& dst_place, - const platform::DeviceContext& ctx, Tensor* dst); + const platform::DeviceContext& ctx, Tensor* dst, + bool sync = false); void TensorCopy(const Tensor& src, const platform::Place& dst_place, Tensor* dst); diff --git a/paddle/fluid/memory/memcpy.cc b/paddle/fluid/memory/memcpy.cc index eddcaab8befda84dd14ed46c31ac025dfbcc7ca9..a177d4985fd0e2cca983b6873af89c60f526b811 100644 --- a/paddle/fluid/memory/memcpy.cc +++ b/paddle/fluid/memory/memcpy.cc @@ -32,7 +32,11 @@ void Copy( platform::CPUPlace dst_place, void* dst, platform::CUDAPlace src_place, const void* src, size_t num, cudaStream_t stream) { platform::SetDeviceId(src_place.device); - platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream); + if (stream) { + platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream); + } else { + platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost); + } } template <> @@ -40,7 +44,11 @@ void Copy( platform::CUDAPlace dst_place, void* dst, platform::CPUPlace src_place, const void* src, size_t num, cudaStream_t stream) { platform::SetDeviceId(dst_place.device); - platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream); + if (stream) { + platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream); + } else { + platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice); + } } template <> @@ -49,10 +57,19 @@ void Copy( const void* src, size_t num, cudaStream_t stream) { if (dst_place == src_place) { platform::SetDeviceId(src_place.device); - platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream); + if (stream) { + platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream); + } else { + platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToDevice); + } } else { - platform::GpuMemcpyPeer(dst, dst_place.device, src, src_place.device, num, - stream); + if (stream) { + platform::GpuMemcpyPeerAsync(dst, dst_place.device, src, src_place.device, + num, stream); + } else { + platform::GpuMemcpyPeerSync(dst, dst_place.device, src, src_place.device, + num); + } } } @@ -83,7 +100,11 @@ void Copy( platform::CUDAPlace src_place, const void* src, size_t num, cudaStream_t stream) { platform::SetDeviceId(src_place.device); - platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream); + if (stream) { + platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream); + } else { + platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost); + } } template <> @@ -92,7 +113,11 @@ void Copy( platform::CUDAPinnedPlace src_place, const void* src, size_t num, cudaStream_t stream) { platform::SetDeviceId(dst_place.device); - platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream); + if (stream) { + platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream); + } else { + platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice); + } } #endif diff --git a/paddle/fluid/operators/elementwise_op_function.h b/paddle/fluid/operators/elementwise_op_function.h index 415182201a7a9e11d8ea8c62b92849b5ea3bac3e..f0362ec606c994d69f31c7a2e1e9ad0d0108b621 100644 --- a/paddle/fluid/operators/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise_op_function.h @@ -356,8 +356,8 @@ __device__ T reduceSum(T val, int tid, int len) { // I use Warp-Level Parallelism and assume the Warp size // is 32 which may be different for different GPU, // but most card's warp size is 32. - __shared__ T shm[32]; const int warpSize = 32; + __shared__ T shm[warpSize]; unsigned mask = 0u; CREATE_SHFL_MASK(mask, tid < len); @@ -371,6 +371,7 @@ __device__ T reduceSum(T val, int tid, int len) { if (tid % warpSize == 0) { shm[tid / warpSize] = val; } + __syncthreads(); CREATE_SHFL_MASK(mask, tid < warpSize); diff --git a/paddle/fluid/operators/mul_mkldnn_op.cc b/paddle/fluid/operators/mul_mkldnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a5f3a98f678a870d30eebfc4cf329de7c93266ee --- /dev/null +++ b/paddle/fluid/operators/mul_mkldnn_op.cc @@ -0,0 +1,197 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "mkldnn.hpp" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/mul_op.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/mkldnn_helper.h" + +namespace paddle { +namespace operators { + +using paddle::framework::Tensor; +using paddle::platform::MKLDNNDeviceContext; + +template +mkldnn::memory::desc type(const std::vector& dims, Format&& f) { + return platform::MKLDNNMemDesc(dims, mkldnn::memory::data_type::f32, f); +} + +template +class MulMKLDNNOpKernel : public paddle::framework::OpKernel { + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + + auto& dev_ctx = ctx.template device_context(); + auto mkldnn_engine = dev_ctx.GetEngine(); + + auto input = ctx.Input("X"); + auto weight = ctx.Input("Y"); + + PADDLE_ENFORCE(input->dims().size() & (2 | 4), + "Input must be with 2 or 4 dimensions, i.e. NC or NCHW"); + PADDLE_ENFORCE(weight->dims().size() & (2 | 4), + "Weights must be with 2 or 4 dimensions, i.e. OI or OIHW"); + + std::vector w_tz = paddle::framework::vectorize2int(weight->dims()); + std::vector src_tz = paddle::framework::vectorize2int(input->dims()); + + auto src_md = + src_tz.size() != 2 + ? type(src_tz, mkldnn::memory::format::nchw) + : type({src_tz[0], src_tz[1]}, mkldnn::memory::format::nc); + + auto dst_md = type({src_tz[0], w_tz[1]}, mkldnn::memory::format::nc); + + auto weights_md = + src_tz.size() != 2 + ? type({w_tz[1], src_tz[1], src_tz[2], src_tz[3]}, + mkldnn::memory::format::oihw) + : type({w_tz[1], src_tz[1]}, mkldnn::memory::format::oi); + + auto output = ctx.Output("Out"); + T* output_data = output->mutable_data(ctx.GetPlace()); + + const std::string key = ctx.op().Output("Out"); + const std::string key_fc_pd = key + "@mul_pd"; + + const T* input_data = input->data(); + const T* w_data = weight->data(); + + auto dst_memory = mkldnn::memory({dst_md, mkldnn_engine}, output_data); + + auto src_memory = mkldnn::memory({src_md, mkldnn_engine}, + platform::to_void_cast(input_data)); + + auto weights_memory = mkldnn::memory({weights_md, mkldnn_engine}, + platform::to_void_cast(w_data)); + + auto pd = platform::MKLDNNFwdPrimitiveDesc( + mkldnn_engine, src_md, weights_md, dst_md); + + dev_ctx.SetBlob(key_fc_pd, pd); + + auto forward = mkldnn::inner_product_forward(*pd, src_memory, + weights_memory, dst_memory); + + std::vector pipeline = {forward}; + mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); + } +}; + +template +class MulMKLDNNGradOpKernel : public paddle::framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + + auto& dev_ctx = ctx.template device_context(); + auto mkldnn_engine = dev_ctx.GetEngine(); + + const Tensor* input = ctx.Input("X"); + const Tensor* w = ctx.Input("Y"); + + const Tensor* out_grad = ctx.Input(framework::GradVarName("Out")); + Tensor* input_grad = ctx.Output(framework::GradVarName("X")); + Tensor* w_grad = ctx.Output(framework::GradVarName("Y")); + + const std::string key = ctx.op().Input("Out"); + const std::string key_fc_pd = key + "@mul_pd"; + + const T* input_data = input->data(); + const T* w_data = w->data(); + const T* out_grad_data = out_grad->data(); + T* input_grad_data = nullptr; + T* w_grad_data = nullptr; + + if (input_grad) { + input_grad_data = input_grad->mutable_data(ctx.GetPlace()); + } + if (w_grad) { + w_grad_data = w_grad->mutable_data(ctx.GetPlace()); + } + + std::vector src_tz = paddle::framework::vectorize2int(input->dims()); + std::vector w_tz = paddle::framework::vectorize2int(w->dims()); + + auto src_md = + src_tz.size() != 2 + ? type(src_tz, mkldnn::memory::format::nchw) + : type({src_tz[0], src_tz[1]}, mkldnn::memory::format::nc); + + auto dst_md = type({src_tz[0], w_tz[1]}, mkldnn::memory::format::nc); + + auto weights_md = + src_tz.size() != 2 + ? type({w_tz[1], src_tz[1], src_tz[2], src_tz[3]}, + mkldnn::memory::format::oihw) + : type({w_tz[1], src_tz[1]}, mkldnn::memory::format::oi); + + auto src_memory = mkldnn::memory({src_md, mkldnn_engine}, + platform::to_void_cast(input_data)); + + auto dst_memory = mkldnn::memory({dst_md, mkldnn_engine}, + platform::to_void_cast(out_grad_data)); + + auto weight_memory = mkldnn::memory({weights_md, mkldnn_engine}, + platform::to_void_cast(w_data)); + + auto pd = + std::static_pointer_cast( + dev_ctx.GetBlob(key_fc_pd)); + + PADDLE_ENFORCE(pd != nullptr, "Fail to find pd in device context"); + + if (w_grad) { + auto weights_grad_memory = mkldnn::memory( + {weights_md, mkldnn_engine}, platform::to_void_cast(w_grad_data)); + + auto bwd_weight_pd = platform::MKLDNNBwdPrimitiveDesc< + mkldnn::inner_product_backward_weights>(mkldnn_engine, *pd, src_md, + weights_md, dst_md); + + auto bwd_weights_prim = mkldnn::inner_product_backward_weights( + bwd_weight_pd, src_memory, dst_memory, weights_grad_memory); + + std::vector pipeline{bwd_weights_prim}; + mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); + } + + if (input_grad) { + auto src_grad_memory = mkldnn::memory( + {src_md, mkldnn_engine}, platform::to_void_cast(input_grad_data)); + + auto bwd_data_pd = + platform::MKLDNNBwdPrimitiveDesc( + mkldnn_engine, *pd, src_md, weights_md, dst_md); + + auto bwd_data_prim = mkldnn::inner_product_backward_data( + bwd_data_pd, dst_memory, weight_memory, src_grad_memory); + + std::vector pipeline{bwd_data_prim}; + mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); + } + } +}; +} // namespace operators +} // namespace paddle + +REGISTER_OP_KERNEL(mul, MKLDNN, ::paddle::platform::CPUPlace, + paddle::operators::MulMKLDNNOpKernel); + +REGISTER_OP_KERNEL(mul_grad, MKLDNN, ::paddle::platform::CPUPlace, + paddle::operators::MulMKLDNNGradOpKernel); diff --git a/paddle/fluid/operators/mul_op.cc b/paddle/fluid/operators/mul_op.cc index bfb20fefba2b8d6e95750c6dc2bc44d606d2ddd1..c9fabc8d485b3bba2c8ae14b3616d0bdcae058a7 100644 --- a/paddle/fluid/operators/mul_op.cc +++ b/paddle/fluid/operators/mul_op.cc @@ -13,8 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/mul_op.h" +#include #include +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif + namespace paddle { namespace operators { @@ -71,6 +76,22 @@ class MulOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Out", framework::make_ddim(output_dims)); ctx->ShareLoD("X", /*->*/ "Out"); } + + private: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + framework::LibraryType library{framework::LibraryType::kPlain}; +#ifdef PADDLE_WITH_MKLDNN + if (library == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library = framework::LibraryType::kMKLDNN; + } +#endif + framework::DataLayout layout{framework::DataLayout::kAnyLayout}; + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), + layout, library); + } }; class MulOpMaker : public framework::OpProtoAndCheckerMaker { @@ -100,6 +121,9 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker { )DOC") .SetDefault(1) .EqualGreaterThan(1); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); AddAttr( "y_num_col_dims", R"DOC((int, default 1), The mul_op can take tensors with more than two, @@ -154,6 +178,22 @@ class MulGradOp : public framework::OperatorWithKernel { ctx->SetOutputDim(y_grad_name, y_dims); } } + + private: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + framework::LibraryType library{framework::LibraryType::kPlain}; +#ifdef PADDLE_WITH_MKLDNN + if (library == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library = framework::LibraryType::kMKLDNN; + } +#endif + framework::DataLayout layout{framework::DataLayout::kAnyLayout}; + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), + layout, library); + } }; } // namespace operators diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index 0b7c1d6af714558d35918dac62d92d9e0f86c970..4372f23fc1dbd85e43b04a9d644977392316c2e9 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -180,7 +180,8 @@ void DoubleBufferReader::PrefetchThreadFunc() { auto* gpu_ctx = ctxs_[cached_tensor_id].get(); gpu_batch.resize(cpu_batch.size()); for (size_t i = 0; i < cpu_batch.size(); ++i) { - framework::TensorCopy(cpu_batch[i], place_, *gpu_ctx, &gpu_batch[i]); + framework::TensorCopy(cpu_batch[i], place_, *gpu_ctx, &gpu_batch[i], + true); gpu_batch[i].set_lod(cpu_batch[i].lod()); } } diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index aaebeb1353a13ab16fcf98f10da59d41fd2f5b48..4cee93f3a4224cb97327254cd1679021d197a1b1 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -127,11 +127,24 @@ void GpuMemcpyAsync(void *dst, const void *src, size_t count, "cudaMemcpyAsync failed in paddle::platform::GpuMemcpyAsync"); } -void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device, - size_t count, cudaStream_t stream) { +void GpuMemcpySync(void *dst, const void *src, size_t count, + enum cudaMemcpyKind kind) { + PADDLE_ENFORCE(cudaMemcpy(dst, src, count, kind), + "cudaMemcpy failed in paddle::platform::GpuMemcpySync"); +} + +void GpuMemcpyPeerAsync(void *dst, int dst_device, const void *src, + int src_device, size_t count, cudaStream_t stream) { PADDLE_ENFORCE( cudaMemcpyPeerAsync(dst, dst_device, src, src_device, count, stream), - "cudaMemcpyPeerAsync failed in paddle::platform::GpuMemcpyPeer"); + "cudaMemcpyPeerAsync failed in paddle::platform::GpuMemcpyPeerAsync"); +} + +void GpuMemcpyPeerSync(void *dst, int dst_device, const void *src, + int src_device, size_t count) { + PADDLE_ENFORCE( + cudaMemcpyPeer(dst, dst_device, src, src_device, count), + "cudaMemcpyPeer failed in paddle::platform::GpuMemcpyPeerSync"); } void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream) { diff --git a/paddle/fluid/platform/gpu_info.h b/paddle/fluid/platform/gpu_info.h index 36345e17406e22970806fa274d5a73a703517c43..f4640d3eaa2165c35e8e14690d83e9e7e7168c0b 100644 --- a/paddle/fluid/platform/gpu_info.h +++ b/paddle/fluid/platform/gpu_info.h @@ -57,9 +57,17 @@ size_t GpuMaxChunkSize(); void GpuMemcpyAsync(void *dst, const void *src, size_t count, enum cudaMemcpyKind kind, cudaStream_t stream); -//! Copy memory from one device to another device. -void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device, - size_t count, cudaStream_t stream); +//! Copy memory from address src to dst synchronously. +void GpuMemcpySync(void *dst, const void *src, size_t count, + enum cudaMemcpyKind kind); + +//! Copy memory from one device to another device asynchronously. +void GpuMemcpyPeerAsync(void *dst, int dst_device, const void *src, + int src_device, size_t count, cudaStream_t stream); + +//! Copy memory from one device to another device synchronously. +void GpuMemcpyPeerSync(void *dst, int dst_device, const void *src, + int src_device, size_t count); //! Set memory dst with value count size asynchronously void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream); diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index de8056237fb022f62488e0fedf9a4f67e4601072..23f1d615daab91f0e4b353bc7d9a3ca7f5cec5ae 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include - -#include "mkldnn/include/mkldnn.hpp" #include "paddle/fluid/framework/operator.h" namespace paddle { @@ -34,6 +33,32 @@ typedef std::unique_ptr MKLDNNMemoryPtr; typedef std::unique_ptr MKLDNNPrimitivePtr; typedef std::unique_ptr MKLDNNPrimitiveDescPtr; +template +void* to_void_cast(const Type* t) { + return static_cast(const_cast(t)); +} + +template +using tf_desc = typename Type::desc; + +template +using tf_pd = typename Type::primitive_desc; + +template +std::shared_ptr> MKLDNNFwdPrimitiveDesc(const Engine& e, + Args&&... args) { + auto desc = tf_desc(mkldnn::prop_kind::forward, (args)...); + auto pd = new tf_pd(desc, e); + return std::shared_ptr>(pd); +} + +template +tf_pd MKLDNNBwdPrimitiveDesc(const Engine& e, const Primitive& p, + Args&&... args) { + auto desc = tf_desc(args...); + return tf_pd(desc, e, p); +} + inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector& dims, mkldnn::memory::data_type data_type, mkldnn::memory::format format) { diff --git a/paddle/gserver/dataproviders/PyDataProvider2.cpp b/paddle/gserver/dataproviders/PyDataProvider2.cpp index e3e4457f9b72c5edb8082fdf378ae662b4aee42f..b4215bb307cc31ce64bb724986b88fdc20bbbf45 100644 --- a/paddle/gserver/dataproviders/PyDataProvider2.cpp +++ b/paddle/gserver/dataproviders/PyDataProvider2.cpp @@ -390,9 +390,7 @@ private: if (this->loadThread_) { // wait poolActualSize < poolSize; std::unique_lock l(mtx_); - pushCV_.wait(l, [this, additionalBatchSize] { - return this->poolActualSize_ < poolSize_; - }); + pushCV_.wait(l, [this] { return this->poolActualSize_ < poolSize_; }); } { diff --git a/paddle/gserver/gradientmachines/MultiGradientMachine.cpp b/paddle/gserver/gradientmachines/MultiGradientMachine.cpp index 3f46cc98cdef17d14c253c732814bcba005fd667..b8d4d28f0f309a5f7348605e8d35e160e7fd5552 100644 --- a/paddle/gserver/gradientmachines/MultiGradientMachine.cpp +++ b/paddle/gserver/gradientmachines/MultiGradientMachine.cpp @@ -52,7 +52,7 @@ MultiGradientMachine::MultiGradientMachine(const ModelConfig& config, } else { numDevices_ = 0; } - ParamInitCallback mainParamInitCb = [this](int paramId, Parameter* para) { + ParamInitCallback mainParamInitCb = [](int paramId, Parameter* para) { // only create buf for CPU parameters // GPU parameters will be created in each thread if (para->useGpu()) return; diff --git a/paddle/gserver/layers/RecurrentLayerGroup.cpp b/paddle/gserver/layers/RecurrentLayerGroup.cpp index 27e8b5868e6d85cf004945d7cb086d6d57487f9f..44b57185c5a5fa7703ca477b990a73cdad2c2aa1 100644 --- a/paddle/gserver/layers/RecurrentLayerGroup.cpp +++ b/paddle/gserver/layers/RecurrentLayerGroup.cpp @@ -72,7 +72,7 @@ void RecurrentLayerGroup::initSubNetwork( setNeedGradient(true); network_.reset(new RecurrentGradientMachine(config_.name(), rootNetwork)); - ParamInitCallback cb = [this, rootNetwork](int paramId, Parameter* para) { + ParamInitCallback cb = [rootNetwork](int paramId, Parameter* para) { para->enableSharedType( PARAMETER_VALUE, rootNetwork->getParameters()[paramId]->getBuf(PARAMETER_VALUE), diff --git a/paddle/parameter/Argument.cpp b/paddle/parameter/Argument.cpp index cfdaf8998b04e0307bc442dec0df734452634c67..94522f718a0c19bfc704ca92eddef5c5a9cb6919 100644 --- a/paddle/parameter/Argument.cpp +++ b/paddle/parameter/Argument.cpp @@ -325,12 +325,12 @@ void Argument::concat(const std::vector& args, ->copyFrom(*src->subVec(srcStartRow, size), stream); }; - auto copyStrs = [batchSize, stream](SVectorPtr& dst, - const SVectorPtr& src, - int desStartRow, - int srcStartRow, - int size, - bool useGpu) { + auto copyStrs = [batchSize](SVectorPtr& dst, + const SVectorPtr& src, + int desStartRow, + int srcStartRow, + int size, + bool useGpu) { if (!src) { dst.reset(); return; @@ -413,7 +413,7 @@ void Argument::concat(const std::vector& args, dst->subVec(startRow, src->getSize())->copyFrom(*src, stream); }; - auto copyStrs = [batchSize, stream]( + auto copyStrs = [batchSize]( SVectorPtr& dst, const SVectorPtr& src, int startRow, bool useGpu) { if (!src) { dst.reset(); diff --git a/paddle/parameter/AverageOptimizer.cpp b/paddle/parameter/AverageOptimizer.cpp index 75998d81dd9c8be35fe45e903dc1cd69068f83c6..82a7fed6c6451b8908851f2d039f17b9dc513818 100644 --- a/paddle/parameter/AverageOptimizer.cpp +++ b/paddle/parameter/AverageOptimizer.cpp @@ -81,9 +81,9 @@ ParameterOptimizer::TraverseCallback AverageOptimizer::needSpecialTraversal( if (numUpdates_ % kMaxNumAccumulates == 0) { // Move the sum to a different buffer to avoid loss of precision // due to too many sums. - callbacks.emplace_back([this](const VectorPtr vecs[], - const ParameterConfig& config, - size_t sparseId) { + callbacks.emplace_back([](const VectorPtr vecs[], + const ParameterConfig& config, + size_t sparseId) { vecs[PARAMETER_SUM2]->add(*vecs[PARAMETER_SUM1]); vecs[PARAMETER_SUM1]->zeroMem(); }); @@ -94,9 +94,9 @@ ParameterOptimizer::TraverseCallback AverageOptimizer::needSpecialTraversal( if (auto callback = this->startCatchUpWith()) { callbacks.emplace_back(callback); } - callbacks.emplace_back([this](const VectorPtr vecs[], - const ParameterConfig& config, - size_t sparseId) { + callbacks.emplace_back([](const VectorPtr vecs[], + const ParameterConfig& config, + size_t sparseId) { vecs[PARAMETER_SUM3]->add(*vecs[PARAMETER_SUM1], *vecs[PARAMETER_SUM2]); vecs[PARAMETER_SUM1]->zeroMem(); vecs[PARAMETER_SUM2]->zeroMem(); diff --git a/paddle/parameter/FirstOrderOptimizer.cpp b/paddle/parameter/FirstOrderOptimizer.cpp index 5e280bcac3389179181d2eda58c08e579e867ecc..182e833405e8f8bc3a4c9ffddbf628040f9cceaa 100644 --- a/paddle/parameter/FirstOrderOptimizer.cpp +++ b/paddle/parameter/FirstOrderOptimizer.cpp @@ -145,9 +145,9 @@ AdagradParameterOptimizer::needSpecialTraversal( if (numUpdates_ % kMaxNumAccumulates == 0) { // Move the sum to a different buffer to avoid loss of precision // due to too many sums. - return [this](const VectorPtr vecs[], - const ParameterConfig& config, - size_t sparseId) { + return [](const VectorPtr vecs[], + const ParameterConfig& config, + size_t sparseId) { vecs[PARAMETER_GRADIENT_SQURESUM]->add( *vecs[PARAMETER_GRADIENT_SQURESUM1]); vecs[PARAMETER_GRADIENT_SQURESUM1]->zeroMem(); diff --git a/python/paddle/dataset/uci_housing.py b/python/paddle/dataset/uci_housing.py index 6a56e9d5563c76ab6f524ccea9191693dc227010..fbfa477d055eb5f484989eacce38cee8d617d729 100644 --- a/python/paddle/dataset/uci_housing.py +++ b/python/paddle/dataset/uci_housing.py @@ -19,7 +19,11 @@ https://archive.ics.uci.edu/ml/machine-learning-databases/housing/ and parse training set and test set into paddle reader creators. """ +import os + import numpy as np +import tempfile +import tarfile import os import paddle.dataset.common @@ -34,8 +38,9 @@ feature_names = [ UCI_TRAIN_DATA = None UCI_TEST_DATA = None -URL_MODEL = 'https://github.com/PaddlePaddle/book/raw/develop/01.fit_a_line/fit_a_line.tar' -MD5_MODEL = '52fc3da8ef3937822fcdd87ee05c0c9b' + +FLUID_URL_MODEL = 'https://github.com/PaddlePaddle/book/raw/develop/01.fit_a_line/fluid/fit_a_line.fluid.tar' +FLUID_MD5_MODEL = '6e6dd637ccd5993961f68bfbde46090b' def feature_range(maximums, minimums): @@ -113,6 +118,29 @@ def test(): return reader +def fluid_model(): + parameter_tar = paddle.dataset.common.download( + FLUID_URL_MODEL, 'uci_housing', FLUID_MD5_MODEL, 'fit_a_line.fluid.tar') + + tar = tarfile.TarFile(parameter_tar, mode='r') + dirpath = tempfile.mkdtemp() + tar.extractall(path=dirpath) + + return dirpath + + +def predict_reader(): + """ + It returns just one tuple data to do inference. + + :return: one tuple data + :rtype: tuple + """ + global UCI_TEST_DATA + load_data(paddle.dataset.common.download(URL, 'uci_housing', MD5)) + return (UCI_TEST_DATA[0][:-1], ) + + def fetch(): paddle.dataset.common.download(URL, 'uci_housing', MD5) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 5e6abceb0a8c2a97a804d6563b5390a245208e3f..9a0c328033cdfdae39da050fc482abba17032dd9 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -159,67 +159,37 @@ def fc(input, dtype = helper.input_dtype() mul_results = [] - if use_mkldnn: - tmp = helper.create_tmp_variable(dtype) - input_shape = input.shape + for input_var, param_attr in helper.iter_inputs_and_params(): + input_shape = input_var.shape param_shape = [ reduce(lambda a, b: a * b, input_shape[num_flatten_dims:], 1) ] + [size] w = helper.create_parameter( - attr=helper.param_attr, - shape=param_shape, - dtype=dtype, - is_bias=False) - if bias_attr is None or bias_attr is False: - bias_attr = False - else: - bias_attr = True + attr=param_attr, shape=param_shape, dtype=dtype, is_bias=False) + tmp = helper.create_tmp_variable(dtype) helper.append_op( - type="fc", - inputs={"Input": input, - "W": w}, + type="mul", + inputs={"X": input_var, + "Y": w}, outputs={"Out": tmp}, attrs={ - "use_mkldnn": use_mkldnn, - "is_test": is_test, - "bias_attr": bias_attr + "x_num_col_dims": num_flatten_dims, + "y_num_col_dims": 1, + "use_mkldnn": use_mkldnn }) - return helper.append_activation(tmp) + mul_results.append(tmp) + + if len(mul_results) == 1: + pre_bias = mul_results[0] else: - for input_var, param_attr in helper.iter_inputs_and_params(): - input_shape = input_var.shape - param_shape = [ - reduce(lambda a, b: a * b, input_shape[num_flatten_dims:], 1) - ] + [size] - - w = helper.create_parameter( - attr=param_attr, shape=param_shape, dtype=dtype, is_bias=False) - tmp = helper.create_tmp_variable(dtype) - helper.append_op( - type="mul", - inputs={"X": input_var, - "Y": w}, - outputs={"Out": tmp}, - attrs={ - "x_num_col_dims": num_flatten_dims, - "y_num_col_dims": 1, - }) - mul_results.append(tmp) - - if len(mul_results) == 1: - pre_bias = mul_results[0] - else: - pre_bias = helper.create_tmp_variable(dtype) - helper.append_op( - type="sum", - inputs={"X": mul_results}, - outputs={"Out": pre_bias}) - # add bias - pre_activation = helper.append_bias_op( - pre_bias, dim_start=num_flatten_dims) - # add activation - return helper.append_activation(pre_activation) + pre_bias = helper.create_tmp_variable(dtype) + helper.append_op( + type="sum", inputs={"X": mul_results}, outputs={"Out": pre_bias}) + # add bias + pre_activation = helper.append_bias_op(pre_bias, dim_start=num_flatten_dims) + # add activation + return helper.append_activation(pre_activation) def embedding(input, @@ -3733,8 +3703,8 @@ def label_smooth(label, name=None): """ Label smoothing is a mechanism to regularize the classifier layer and is - called label-smoothing regularization (LSR). - + called label-smoothing regularization (LSR). + Label smoothing is proposed to encourage the model to be less confident, since optimizing the log-likelihood of the correct label directly may cause overfitting and reduce the ability of the model to adapt. Label @@ -3758,10 +3728,10 @@ def label_smooth(label, prior_dist(Variable): The prior distribution to be used to smooth labels. If not provided, an uniform distribution is used. The shape of :attr:`prior_dist` should - be :math:`(1, class\_num)`. + be :math:`(1, class\_num)`. epsilon(float): The weight used to mix up the original ground-truth distribution and the fixed distribution. - dtype(np.dtype|core.VarDesc.VarType|str): The type of data : float32, + dtype(np.dtype|core.VarDesc.VarType|str): The type of data : float32, float_64, int etc. name(str|None): A name for this layer(optional). If set None, the layer will be named automatically. diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_gradient_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_gradient_op.py new file mode 100644 index 0000000000000000000000000000000000000000..c6f45381af8ac64d117eb27325f25763fbf6cae7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_elementwise_gradient_op.py @@ -0,0 +1,103 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +import numpy as np + +import paddle.fluid.core as core +import paddle.fluid as fluid + + +class TestElementWiseAddOp(unittest.TestCase): + def __assert_close(self, tensor, np_array, msg, atol=1e-4): + self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) + + def check_forward_backward(self): + def test_with_place(place): + out_grad = np.random.random_sample(self.x.shape).astype(np.float32) + x_grad = out_grad + sum_axis = range(0, len(self.x.shape)) + del sum_axis[self.axis] + y_grad = np.sum(out_grad, axis=tuple(sum_axis)) + + var_dict = locals() + var_dict['y'] = self.y + var_dict['x'] = self.x + var_dict['out'] = self.out + var_dict['y@GRAD'] = y_grad + var_dict['x@GRAD'] = x_grad + var_dict['out@GRAD'] = out_grad + + var_names = ['x', 'y', 'out', 'y@GRAD', 'x@GRAD', 'out@GRAD'] + ground_truth = {name: var_dict[name] for name in var_names} + + program = fluid.Program() + with fluid.program_guard(program): + block = program.global_block() + for name in ground_truth: + block.create_var( + name=name, + dtype='float32', + shape=ground_truth[name].shape) + elementwise_add_op = block.append_op( + type="elementwise_add", + inputs={ + "X": block.var('x'), + "Y": block.var('y'), + }, + outputs={"Out": block.var('out'), }, + attrs={"axis": self.axis, }) + + # generate backward op_desc + grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc( + elementwise_add_op.desc, set(), []) + grad_op_desc = grad_op_desc_list[0] + new_op_desc = block.desc.append_op() + new_op_desc.copy_from(grad_op_desc) + for var_name in grad_op_desc.output_arg_names(): + block.desc.var(var_name.encode("ascii")) + grad_op_desc.infer_var_type(block.desc) + grad_op_desc.infer_shape(block.desc) + for arg in grad_op_desc.output_arg_names(): + grad_var = block.desc.find_var(arg.encode("ascii")) + grad_var.set_dtype(core.VarDesc.VarType.FP32) + + exe = fluid.Executor(place) + out = exe.run(program, + feed={ + name: var_dict[name] + for name in ['x', 'y', 'out@GRAD'] + }, + fetch_list=['x@GRAD', 'y@GRAD']) + self.__assert_close(x_grad, out[0], "x@GRAD") + self.__assert_close(y_grad, out[1], "y@GRAD", atol=1.4) + + places = [core.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu( + "elementwise_add"): + places.append(core.CUDAPlace(0)) + + for place in places: + test_with_place(place) + + def test_check_forward_backward_with_scale_and_bias(self): + np.random.seed(123) + self.x = np.random.random((4, 32, 220, 220)).astype(np.float32) + self.y = np.random.random((32)).astype(np.float32) + self.out = self.x + self.y.reshape(1, 32, 1, 1) + self.axis = 1 + self.check_forward_backward() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_mul_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_mul_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..42d68ef376dc4a664a96ff5a24545c1997ee924a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_mul_mkldnn_op.py @@ -0,0 +1,44 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from test_mul_op import TestMulOp, TestMulOp2, TestFP16MulOp1, TestFP16MulOp2 + + +class TestMKLDNNMulOp(TestMulOp): + def init_op_test(self): + super(TestMKLDNNMulOp, self).setUp() + self.attrs = {"use_mkldnn": True} + + +class TestMKLDNNMulOp2(TestMulOp2): + def init_op_test(self): + super(TestMKLDNNMulOp2, self).setUp() + self.attrs = {"use_mkldnn": True} + + +class TestMKLDNNFP16MulOp1(TestFP16MulOp1): + def init_op_test(self): + super(TestMKLDNNFP16MulOp1, self).setUp() + self.attrs = {"use_mkldnn": True} + + +class TestMKLDNNFP16MulOp2(TestFP16MulOp2): + def init_op_test(self): + super(TestMKLDNNFP16MulOp2, self).setUp() + self.attrs = {"use_mkldnn": True} + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_mul_op.py b/python/paddle/fluid/tests/unittests/test_mul_op.py index 40440bea1267112b84b66002a0bf921be3029265..d984393c89f44f5b9679a22bf7bb6182599233e3 100644 --- a/python/paddle/fluid/tests/unittests/test_mul_op.py +++ b/python/paddle/fluid/tests/unittests/test_mul_op.py @@ -21,10 +21,12 @@ from op_test import OpTest class TestMulOp(OpTest): def setUp(self): self.op_type = "mul" + self.use_mkldnn = False self.inputs = { 'X': np.random.random((32, 84)).astype("float32"), 'Y': np.random.random((84, 100)).astype("float32") } + self.attrs = {'use_mkldnn': self.use_mkldnn} self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])} def test_check_output(self): @@ -45,11 +47,16 @@ class TestMulOp(OpTest): class TestMulOp2(OpTest): def setUp(self): self.op_type = "mul" + self.use_mkldnn = False self.inputs = { 'X': np.random.random((15, 4, 12, 10)).astype("float32"), 'Y': np.random.random((4, 30, 8, 2, 9)).astype("float32") } - self.attrs = {'x_num_col_dims': 2, 'y_num_col_dims': 2} + self.attrs = { + 'x_num_col_dims': 2, + 'y_num_col_dims': 2, + 'use_mkldnn': self.use_mkldnn + } result = np.dot(self.inputs['X'].reshape(15 * 4, 12 * 10), self.inputs['Y'].reshape(4 * 30, 8 * 2 * 9)) result = result.reshape(15, 4, 8, 2, 9) @@ -73,9 +80,11 @@ class TestMulOp2(OpTest): class TestFP16MulOp1(OpTest): def setUp(self): self.op_type = "mul" + self.use_mkldnn = False x = np.random.random((32, 84)).astype("float16") y = np.random.random((84, 100)).astype("float16") self.inputs = {'X': x.view(np.uint16), 'Y': y.view(np.uint16)} + self.attrs = {'use_mkldnn': self.use_mkldnn} self.outputs = {'Out': np.dot(x, y)} def test_check_output(self): @@ -88,12 +97,14 @@ class TestFP16MulOp1(OpTest): class TestFP16MulOp2(OpTest): def setUp(self): self.op_type = "mul" + self.use_mkldnn = False x = np.random.random((15, 4, 12, 10)).astype("float16") y = np.random.random((4, 30, 8, 2, 9)).astype("float16") self.inputs = {'X': x.view(np.uint16), 'Y': y.view(np.uint16)} self.attrs = { 'x_num_col_dims': 2, 'y_num_col_dims': 2, + 'use_mkldnn': self.use_mkldnn } result = np.dot( x.reshape(15 * 4, 12 * 10), y.reshape(4 * 30, 8 * 2 * 9)) diff --git a/python/paddle/fluid/tests/unittests/test_operator_desc.py b/python/paddle/fluid/tests/unittests/test_operator_desc.py index 649fabe4a0cdef4c665f8a6d3ebee1bb8232185f..779ae388f04496a7be9a6d5aa4e39b8245022925 100644 --- a/python/paddle/fluid/tests/unittests/test_operator_desc.py +++ b/python/paddle/fluid/tests/unittests/test_operator_desc.py @@ -62,7 +62,8 @@ class TestOperator(unittest.TestCase): self.assertEqual(mul_op.output_names, ["Out"]) self.assertEqual(mul_op.output("Out"), ["mul.out"]) self.assertEqual( - set(mul_op.attr_names), set(["x_num_col_dims", "y_num_col_dims"])) + set(mul_op.attr_names), + set(["x_num_col_dims", "y_num_col_dims", "use_mkldnn"])) self.assertEqual(mul_op.has_attr("x_num_col_dims"), True) self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT) self.assertEqual(mul_op.attr("x_num_col_dims"), 1)