diff --git a/CMakeLists.txt b/CMakeLists.txt index 48e52961a95d50264b201eec50ccb3a462f39c54..317f7f9eb46a96e9f6ea393abf82d608af50fc4b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -138,12 +138,6 @@ else() set(THIRD_PARTY_BUILD_TYPE Release) endif() -if(WITH_MKL) - option(MKL_SPLIT_GEMM "PaddlePaddle MKL gemm would split to small ones" OFF) - if (MKL_SPLIT_GEMM) - add_definitions(-DPADDLE_MKL_SPLIT_GEMM) - endif() -endif() set(WITH_MKLML ${WITH_MKL}) if (NOT DEFINED WITH_MKLDNN) if (WITH_MKL AND AVX2_FOUND) diff --git a/cmake/external/mkldnn.cmake b/cmake/external/mkldnn.cmake index 260985cc8aa4ad0f231798666c048703b64c6d15..baf253df2755657b01b67c410f63b7d8422d4df3 100644 --- a/cmake/external/mkldnn.cmake +++ b/cmake/external/mkldnn.cmake @@ -54,7 +54,7 @@ ExternalProject_Add( ${EXTERNAL_PROJECT_LOG_ARGS} DEPENDS ${MKLDNN_DEPENDS} GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git" - GIT_TAG "a29d8487a63afca3d5b8c5bbdbb473cf8ccc6e51" + GIT_TAG "64e03a1939e0d526aa8e9f2e3f7dc0ad8d372944" PREFIX ${MKLDNN_SOURCES_DIR} UPDATE_COMMAND "" CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} diff --git a/doc/fluid/design/dist_train/dist_train_nccl2.md b/doc/fluid/design/dist_train/dist_train_nccl2.md index aa7455ec5de0d46d7c2b0cef3b7ebf4754af3cb1..b8b8427811cddcddf872db5badfd37c96a76c3e3 100644 --- a/doc/fluid/design/dist_train/dist_train_nccl2.md +++ b/doc/fluid/design/dist_train/dist_train_nccl2.md @@ -1,7 +1,7 @@ # Distributed Training with NCCL2 We design a pattern that can enable training with `ParallelExecutor` and -using [NCCL2](https://developer.nvidia.com/nccl) as it's collective +use [NCCL2](https://developer.nvidia.com/nccl) as it's collective communication library. In `ParallelExecutor` we can use `AllReduce` or `Reduce` and `Broadcast` @@ -9,14 +9,14 @@ to do multi GPU training. And if we initialize NCCL2 communicators as ranks in a distributed environment, we can simply run the `ParallelExecutor` as a distributed program! The only thing that may be different than in the single node version is that we need to broadcast the NCCL unique ID -to all the nodes, and initialize communicators using that ID, so NCCL2 -will know each other as ranks. +to all the nodes and initialize communicators using that ID, so NCCL2 +can know each other as ranks. To achieve this feature, we introduce a new operator: `gen_nccl_id` op, so we are ***not*** "bind to" running NCCL2 with MPI, we can run it in -what ever platform you like. +whatever platform you like. -It have two running modes: +It has two running modes: 1. Generate and broadcast mode, which should be used on trainer 0; 1. Listen and fetch mode, which should be used on trainers other than 0. @@ -29,7 +29,7 @@ initialize NCCL communicator objects. The above figure indicates the general process when training with NCCL2 -distributed. Each trainer have the number of communicators equal to the +distributed. Each trainer has the number of communicators equal to the number of GPUs, but the ranks should match the global ranks number: here we have total 8 GPUs, so `nranks==8`, for each trainer, the ranks should be from 0 ~ 3 on trainer 0 and 4 ~ 7 on trainer 1. diff --git a/doc/fluid/dev/new_op_cn.md b/doc/fluid/dev/new_op_cn.md index c00f73be955e0fb54bb01ffa9a61b3f27c112f75..ff7408111fa20a7a6a3a2fe9f9ba20835918f399 100644 --- a/doc/fluid/dev/new_op_cn.md +++ b/doc/fluid/dev/new_op_cn.md @@ -36,19 +36,19 @@ OpProtoMake定义 -`.cc`文件,Backward Op不需要定义OpProtoMake +.cc 文件,Backward Op不需要定义OpProtoMake Op定义 - `.cc`文件 + .cc 文件 Kernel实现 - CPU、CUDA共享Kernel实现在`.h`文件中,否则,CPU 实现在`.cc`文件中,CUDA 实现在`.cu`文件中。 + CPU、CUDA共享Kernel实现在.h 文件中,否则,CPU 实现在.cc 文件中,CUDA 实现在.cu 文件中。 注册Op - Op注册实现在`.cc`文件;Kernel注册CPU实现在`.cc`文件中,CUDA实现在`.cu`文件中 + Op注册实现在.cc 文件;Kernel注册CPU实现在.cc 文件中,CUDA实现在.cu 文件中 @@ -391,7 +391,7 @@ PADDLE_ENFORCE(ctx->HasInput("X"), ""); ``` 问题示例2 :提示信息过于简单 ``` -PADDLE_ENFORCE(i != nullptr, "I must be set"); // I是什么? +PADDLE_ENFORCE(i != nullptr, "i must be set"); // i是什么? ``` 2. 在报错信息中使用开发人员定义的变量缩写,不易理解! diff --git a/doc/fluid/howto/cluster/nccl2_rdma_training.md b/doc/fluid/howto/cluster/nccl2_rdma_training.md index cecd5c3a7a7339e3be6772543a534728ec132105..8adaf324fccb4cda7af16b9bace559c0642ae444 100644 --- a/doc/fluid/howto/cluster/nccl2_rdma_training.md +++ b/doc/fluid/howto/cluster/nccl2_rdma_training.md @@ -1,12 +1,12 @@ # Distributed Training with NCCL2 and RDMA -When doing distributed multi-GPU training, network bandwith often becomes the -bottle neck. We introduce a way to use NCCL2 to do such training job to -achieve best performace. +When doing distributed multi-GPU training, network bandwidth often becomes the +bottleneck. We introduce a way to use NCCL2 to do such training job to +achieve best performance. -## Prepare Hardwares with RDMA and Multiple GPUs +## Prepare Hardware with RDMA and Multiple GPUs -I'm using two Linux servers each of them is installed with 8 GPUs and +I'm using two Linux servers each of them installed with 8 GPUs and one 100Gb RDMA card. Base environment is: @@ -25,7 +25,7 @@ In general, the steps including: 1. Use docker to run tests and make sure GPUs and RDMA can work inside the container. -I'll ommit section "Install GPU drivers" because we can find it easily +I'll omit the section "Install GPU drivers" because we can find it easily somewhere else. ### Install RDMA drivers @@ -33,7 +33,7 @@ somewhere else. For my case, I've got two machines with device "Mellanox Technologies MT27700 Family [ConnectX-4]" installed. The OS was "CentOS 7.4" and I updated the kernel to version 4.4 so that docker can -work with latest overlay2 filesystem. +work with the latest overlay2 filesystem. ***NOTE: before you start, make sure you have a way to get a console of the server other than ssh because we may need to re-configure the @@ -45,14 +45,14 @@ network device.*** 1. Run `./mlnxofedinstall --add-kernel-support` in the software package. 1. Run `/etc/init.d/openibd restart` to make everything work, note that this operation may cause the network goes down if you are using this - RDMA device as default network device and use ssh to login the server. + RDMA device as default network device and use ssh to log in the server. 1. Re-configure the network interface, for example: `ifconfig eth2 192.168.16.30/20 up`, then add routes if needed: `ip route add default via 192.168.16.1 dev eth2`. 1. Do the same thing on the other node. 1. Use `ping` to test if the two nodes have typical ICMP connection. 1. Use either `udaddy` or `ib_write_bw` to test the network connection is - ready and have the desired bandwith. + ready and have the desired bandwidth. ### Prepare Docker Image to Run RDMA Programs @@ -60,7 +60,7 @@ network device.*** package in it. 1. Start a docker container and mount GPU driver libs into it (you can skip this step if you are using nvidia-docker). -1. Mount RDMA dirvers and libs into the docker image (see below section), +1. Mount RDMA drivers and libs into the docker image (see below section), also `udaddy` and `ib_write_bw` if needed. 1. Mount GPU devices and RDMA devices into the container using `--device` or just use privileged mode `--privileged`. diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index bbf1623c39c8010f7d48fc8d7f653c34e92fb99a..8c5cc44528a754f7612a23b1de09c247ca3f0c8e 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -162,6 +162,7 @@ paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None)) +paddle.fluid.layers.stack ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) diff --git a/paddle/fluid/framework/array.h b/paddle/fluid/framework/array.h new file mode 100644 index 0000000000000000000000000000000000000000..be9efcd74924a2050a2fd9ab83059590a1a2a2fd --- /dev/null +++ b/paddle/fluid/framework/array.h @@ -0,0 +1,48 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/platform/hostdevice.h" + +namespace paddle { +namespace framework { +template +class Array { + static_assert(N > 0, "The size of array must be larger than 0"); + + public: + HOSTDEVICE Array() {} + + HOSTDEVICE explicit Array(const T &val) { + for (size_t i = 0; i < N; ++i) data_[i] = val; + } + + HOSTDEVICE const T *Get() const { return data_; } + + HOSTDEVICE T *GetMutable() { return data_; } + + HOSTDEVICE T &operator[](size_t index) { return data_[index]; } + + HOSTDEVICE const T &operator[](size_t index) const { return data_[index]; } + + HOSTDEVICE constexpr size_t size() const { return N; } + + private: + T data_[N]; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/op_proto_maker.cc b/paddle/fluid/framework/op_proto_maker.cc index 9c289243c5a27839f628f3e143ce0363bf75a0b1..2288c7fe6609a765612b468d69ad35101b92b384 100644 --- a/paddle/fluid/framework/op_proto_maker.cc +++ b/paddle/fluid/framework/op_proto_maker.cc @@ -129,10 +129,6 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, "Optimized for variable") .SetDefault({}); - AddAttr>(OpCreationCallstackAttrName(), - "Callstack for Op Creatation.") - .SetDefault({}); - Validate(); } diff --git a/paddle/fluid/framework/op_proto_maker.h b/paddle/fluid/framework/op_proto_maker.h index cb9c8ab1704ab867182079db31a34125669c645b..80970291c9c234f1306162f4ffa3c2528f88c35f 100644 --- a/paddle/fluid/framework/op_proto_maker.h +++ b/paddle/fluid/framework/op_proto_maker.h @@ -39,7 +39,6 @@ class OpProtoAndCheckerMaker { public: static const char *OpRoleAttrName() { return "op_role"; } static const char *OpRoleVarAttrName() { return "op_role_var"; } - static const char *OpCreationCallstackAttrName() { return "op_callstack"; } void operator()(proto::OpProto *proto, OpAttrChecker *attr_checker); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 9f8cdf1aeba43d30676cb2adf80a77cab86547a8..d04f7744961b2561977f4d36d0073a97557043da 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -11,17 +11,15 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/framework/operator.h" +#include +#include + #include -#include -#include -#include -#include "gflags/gflags.h" -#include "glog/logging.h" + #include "paddle/fluid/framework/data_transform.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/op_proto_maker.h" +#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/platform/profiler.h" @@ -129,48 +127,19 @@ static LoD GetLoD(const Scope& scope, const std::string& name) { } void OperatorBase::Run(const Scope& scope, const platform::Place& place) { - try { - if (VLOG_IS_ON(4)) { - VLOG(4) << place << " " << DebugStringEx(&scope); - } - if (platform::is_gpu_place(place)) { + VLOG(4) << place << " " << DebugStringEx(&scope); + if (platform::is_gpu_place(place)) { #ifndef PADDLE_WITH_CUDA - PADDLE_THROW("Cannot run operator on place %s", place); + PADDLE_THROW("Cannot run operator on place %s", place); #else - auto dev_id = boost::get(place).device; - platform::SetDeviceId(dev_id); + auto dev_id = boost::get(place).device; + platform::SetDeviceId(dev_id); #endif - } - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - platform::RecordEvent record_event(Type(), pool.Get(place)); - RunImpl(scope, place); - if (VLOG_IS_ON(3)) { - VLOG(3) << place << " " << DebugStringEx(&scope); - } - } catch (platform::EnforceNotMet exception) { - if (Attrs().count("sub_block") != 0) { - throw exception; - } - - auto& callstack = Attr>( - OpProtoAndCheckerMaker::OpCreationCallstackAttrName()); - - if (callstack.empty()) { - throw exception; - } - std::ostringstream sout; - sout << "Invoke operator " << Type() << " error.\n"; - sout << "Python Callstacks: \n"; - for (auto& line : callstack) { - sout << line; - } - sout << "C++ Callstacks: \n"; - sout << exception.err_str_; - exception.err_str_ = sout.str(); - throw exception; - } catch (...) { - std::rethrow_exception(std::current_exception()); } + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + platform::RecordEvent record_event(Type(), pool.Get(place)); + RunImpl(scope, place); + VLOG(3) << place << " " << DebugStringEx(&scope); } bool OperatorBase::HasInputs(const std::string& name) const { @@ -198,7 +167,7 @@ const std::vector& OperatorBase::Inputs( } bool OperatorBase::HasOutputs(const std::string& name) const { - if (outputs_.end() != outputs_.find(name)) { + if (outputs_.find(name) != outputs_.end()) { return true; } else { return false; diff --git a/paddle/fluid/framework/tensor.cc b/paddle/fluid/framework/tensor.cc index 56bb9142dabe0d5546e321e675a5acba7bf4d306..d61dbb98a235ca9af089d35318b7f4c68cb125cc 100644 --- a/paddle/fluid/framework/tensor.cc +++ b/paddle/fluid/framework/tensor.cc @@ -31,7 +31,8 @@ size_t Tensor::memory_size() const { return holder_ == nullptr ? 0UL : holder_->size() - offset_; } -void* Tensor::mutable_data(platform::Place place, std::type_index type) { +void* Tensor::mutable_data(platform::Place place, std::type_index type, + size_t requested_size) { if (holder_ != nullptr) { holder_->set_type(type); } @@ -39,7 +40,7 @@ void* Tensor::mutable_data(platform::Place place, std::type_index type) { "When calling this method, the Tensor's numel must be " "equal or larger than zero. " "Please check Tensor::Resize has been called first."); - int64_t size = numel() * SizeOfType(type); + size_t size = requested_size ? requested_size : numel() * SizeOfType(type); /* some versions of boost::variant don't have operator!= */ if (holder_ == nullptr || !(holder_->place() == place) || holder_->size() < size + offset_) { @@ -68,10 +69,10 @@ void* Tensor::mutable_data(platform::Place place, std::type_index type) { offset_); } -void* Tensor::mutable_data(platform::Place place) { +void* Tensor::mutable_data(platform::Place place, size_t requested_size) { PADDLE_ENFORCE(this->holder_ != nullptr, "Cannot invoke mutable data if current hold nothing."); - return mutable_data(place, holder_->type()); + return mutable_data(place, holder_->type(), requested_size); } Tensor& Tensor::ShareDataWith(const Tensor& src) { diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index 0bbfd66148e9bc9080654bf1b0b34477115a0e6b..4cf95fa0ae07823289fbf337062190f05e6c6bcf 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -89,22 +89,24 @@ class Tensor { * @note If not exist, then allocation. */ template - T* mutable_data(platform::Place place); + T* mutable_data(platform::Place place, size_t requested_size = 0); - void* mutable_data(platform::Place place, std::type_index type); + void* mutable_data(platform::Place place, std::type_index type, + size_t requested_size = 0); - void* mutable_data(platform::Place place); + void* mutable_data(platform::Place place, size_t requested_size = 0); /** * @brief Return a pointer to mutable memory block. * - * @param[in] dims The dimensions of the memory block. - * @param[in] place The place of the memory block. + * @param[in] dims The dimensions of the memory block. + * @param[in] place The place of the memory block. + * @param[in] requested_size The size of the block in bytes. * * @note If not exist, then allocation. */ template - T* mutable_data(DDim dims, platform::Place place); + T* mutable_data(DDim dims, platform::Place place, size_t requested_size = 0); /*! Return the dimensions of the memory block. */ const DDim& dims() const; diff --git a/paddle/fluid/framework/tensor_impl.h b/paddle/fluid/framework/tensor_impl.h index b7b62eef23ec351686378c913d18fc72308fd7b2..6d3047c95d6cf30c2a5308d4f69ded367066d78c 100644 --- a/paddle/fluid/framework/tensor_impl.h +++ b/paddle/fluid/framework/tensor_impl.h @@ -46,16 +46,17 @@ inline T* Tensor::data() { } template -inline T* Tensor::mutable_data(DDim dims, platform::Place place) { +inline T* Tensor::mutable_data(DDim dims, platform::Place place, + size_t requested_size) { static_assert(std::is_pod::value, "T must be POD"); Resize(dims); - return mutable_data(place); + return mutable_data(place, requested_size); } template -inline T* Tensor::mutable_data(platform::Place place) { +inline T* Tensor::mutable_data(platform::Place place, size_t requested_size) { static_assert(std::is_pod::value, "T must be POD"); - return reinterpret_cast(mutable_data(place, typeid(T))); + return reinterpret_cast(mutable_data(place, typeid(T), requested_size)); } inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) { diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..1cb65346ee2b755b48f8dd8f1456a32861c3a0b6 --- /dev/null +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -0,0 +1,422 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/attention_lstm_op.h" +#include +#include +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/cpu_vec.h" +#include "paddle/fluid/operators/math/fc_compute.h" +#include "paddle/fluid/platform/cpu_info.h" + +namespace paddle { +namespace operators { + +void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("C0"), + "Input(C0) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"), + "Input(LSTMWeight) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LSTMBias"), + "Input(LSTMBias) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"), + "Input(AttentionWeight) of AttentionLSTM should not be null."); + + PADDLE_ENFORCE(ctx->HasOutput("Hidden"), + "Output(Hidden) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Cell"), + "Output(Cell) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"), + "Output(AttentionedX) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"), + "Output(AttentionFCOut) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("LSTMX"), + "Output(LSTMX) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"), + "Output(LSTMOUT) of AttentionLSTM should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + const int M = x_dims[1]; + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); + + auto w_dims = ctx->GetInputDim("LSTMWeight"); + const int D = w_dims[1] / 4; + PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2."); + PADDLE_ENFORCE_EQ(w_dims[0], D + M, + "LSTMWeight dims should be (%d + %d) * %d.", D + M, 4 * D); + + auto b_dims = ctx->GetInputDim("LSTMBias"); + PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2."); + PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x %d.", 4 * D); + PADDLE_ENFORCE_EQ(b_dims[1], 4 * D, "LSTMBias dims should be 1 x %d.", 4 * D); + + auto c_dims = ctx->GetInputDim("C0"); + PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2."); + PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D); + if (ctx->HasInput("H0")) { + auto h_dims = ctx->GetInputDim("H0"); + PADDLE_ENFORCE(h_dims == c_dims, + "The dimension of Input(H0) and Input(C0) " + "should be the same."); + } + + auto atten_w_dims = ctx->GetInputDim("AttentionWeight"); + PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2, + "Input(AttentionWeight)'s rank must be 2."); + PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D, + "AttentionWeight shapes must be (%d + %d) * 1.", M, D); + PADDLE_ENFORCE_EQ(atten_w_dims[1], 1, + "AttentionWeight shapes must be (%d + %d) * 1.", M, D); + if (ctx->HasInput("AttentionBias")) { + auto atten_b_dims = ctx->GetInputDim("AttentionBias"); + PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2, + "Input(AttentionBias)'s rank must be 2."); + PADDLE_ENFORCE_EQ(atten_b_dims[0], 1, + "AttentionBias shapes must be 1 * 1."); + PADDLE_ENFORCE_EQ(atten_b_dims[1], 1, + "AttentionBias shapes must be 1 * 1."); + } + + if (ctx->HasInput("AttentionScalar")) { + auto dims = ctx->GetInputDim("AttentionScalar"); + PADDLE_ENFORCE_EQ(dims.size(), 2, + "Input(AttentionScalar)'s rank must be 2."); + PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1."); + PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1."); + } + + if (ctx->HasInput("AttentionScalarBias")) { + auto dims = ctx->GetInputDim("AttentionScalarBias"); + PADDLE_ENFORCE( + ctx->HasInput("AttentionScalar"), + "AttentionScalar should not be null when have AttentionScalarBias."); + PADDLE_ENFORCE_EQ(dims.size(), 2, + "Input(AttentionScalarBias)'s rank must be 2."); + PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalarBias shapes must be 1 * 1."); + PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalarBias shapes must be 1 * 1."); + } + + framework::DDim out_dims({x_dims[0], D}); + ctx->SetOutputDim("Hidden", out_dims); + ctx->SetOutputDim("Cell", out_dims); + ctx->SetOutputDim("AttentionedX", {x_dims[0], 1}); + ctx->SetOutputDim("LSTMX", {1, M}); + ctx->SetOutputDim("LSTMOUT", {1, 4 * D}); + // AttentionFCOut should be reshape as (maxseqlen,1) in runtime + ctx->ShareLoD("X", "Hidden"); + ctx->ShareLoD("X", "Cell"); +} + +framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); +} + +void AttentionLSTMOpMaker::Make() { + AddInput("X", + "(LoDTensor) the input is a LodTensor, which support " + "variable-time length input sequence. The underlying tensor in " + "this LoDTensor is a matrix with shape (T X M), where T is the " + "total time steps in this mini-batch, M is the dim size of x."); + AddInput("C0", + "(Tensor) LSTM C0" + "This is a tensor with shape (N x D), where N is the batch size, D " + "is the gate size." + "C0 is necessary because of attention."); + AddInput("H0", + "(Tensor, optional) LSTM H0" + "This is a tensor with shape (N x D), where N is the " + "batch size and D is the gate size.") + .AsDispensable(); + AddInput("AttentionWeight", + "(Tensor) the weights of attention fc. Always relu the fc result." + "The shape is ((M+D) x 1), where M is the dim size of x, D is the " + "gate size of LSTM."); + AddInput("AttentionBias", + "(Tensor, optional) the bias of attention fc." + "The shape is (1 x 1)") + .AsDispensable(); + AddInput("AttentionScalar", + "(Tensor, optional) the scalar on the result of attentioned fc. " + "Always relu the Scalar." + "The shape is (1 x 1)") + .AsDispensable(); + AddInput("AttentionScalarBias", + "(Tensor, optional) the scalar bias of attention fc." + "The shape is (1 x 1)") + .AsDispensable(); + AddInput("LSTMWeight", + "(Tensor) the combined weight of LSTM" + " - The shape is ((D+M) x 4D), where D is the hidden gate size, M " + "is the dim size of x" + " - Weight = {W_forget, W_input, W_output, W_cell}"); + AddInput("LSTMBias", + "(Tensor) the combined bias of LSTM, shape (1x4D)." + "Note: we should add the bias of hidden and context accorindg to " + "the same gate: " + "{B_forget, B_input, B_output, B_cell}"); + AddOutput("Hidden", + "(LoDTensor) (same as LSTMOp) the hidden state of LSTM operator. " + "The shape is (T x D), and lod is the same with the `Input`."); + AddOutput("Cell", + "(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. " + "The shape is (T x D), and lod is the same with the `Input`."); + AddOutput("AttentionedX", + "(Tensor) shape is (T x 1), the result after X * AttentionWeight," + " where T is the total time steps in this mini-batch," + " D is the hidden size.") + .AsIntermediate(); + AddOutput("AttentionFCOut", + "(Tensor) (max_seq_len, 1), compute at each step.") + .AsIntermediate(); + AddOutput("LSTMX", + "(Tensor) the input X of LSTM for each step." + "Shape is (1 x M), where M is the x frame size") + .AsIntermediate(); + AddOutput( + "LSTMOUT", + "(Tensor) the output of LSTM X(1*(D+M))* weight((D+M)*4D) for each step." + "Shape is (1 x 4D), where M is the x frame size") + .AsIntermediate(); + AddAttr("gate_activation", + "(string, default: sigmoid)" + "The activation for input gate, forget gate and output " + "gate, `sigmoid` by default.") + .SetDefault("sigmoid") + .InEnum({"sigmoid", "tanh", "relu", "identity"}); + AddAttr("cell_activation", + "(string, default: tanh)" + "The activation for cell output, `tanh` by defalut.") + .SetDefault("tanh") + .InEnum({"sigmoid", "tanh", "relu", "identity"}); + AddAttr("candidate_activation", + "(string, default: tanh)" + "The activation for candidate hidden state, " + "`tanh` by default.") + .SetDefault("tanh") + .InEnum({"sigmoid", "tanh", "relu", "identity"}); + AddComment(R"DOC( +Attention Long-Short Term Memory (LSTM) Operator. + +Attention part: +concat( x(seqlen * M), expand( cell_t-1(1,D) ) ) => tmp(seqlen*(M+D)) + +tmp(seqlen*(M+D)) * fc((M+D)*1) => fcout(seqlen*1) with bias, relu + +fcout(seqlen*1) * scalar => fcout(seqlen*1) with bias, relu + +dotmul and sum pool ( fcout(seqlen*1), x(seqlen * M) ) => lstm_x_t(1, M) + +LSTM part: +use lstm_x_t as input and compute as standard LSTM. + +)DOC"); +} + +// y[i] = (x[i] + bias[0]) > 0 ? (x[i] + bias[0]) : 0; +template +inline void bias_relu(const int n, const T* x, const T* bias, T* y) { + if (bias) { + for (int i = 0; i < n; ++i) { + y[i] = x[i] + bias[0]; + } + math::vec_relu(n, y, y); + } else { + math::vec_relu(n, x, y); + } +} + +template +inline void vec_softmax(const math::BlasT& blas, const int n, + const T* x, T* y) { + T scalar = x[0]; + // max + for (int i = 1; i < n; ++i) { + scalar = scalar < x[i] ? x[i] : scalar; + } + + // sub + for (int i = 0; i < n; ++i) { + y[i] = x[i] - scalar; + } + + // exp + blas.VEXP(n, y, y); + + // sum + scalar = T(0); + for (int i = 0; i < n; ++i) { + scalar += y[i]; + } + + // scale + blas.SCAL(n, static_cast(1) / scalar, y); +} + +template +class AttentionLSTMKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + using DeviceContext = paddle::platform::CPUDeviceContext; + + auto* x = ctx.Input("X"); + auto* h0 = ctx.Input("H0"); + auto* c0 = ctx.Input("C0"); + auto* atten_w = ctx.Input("AttentionWeight"); + auto* atten_b = ctx.Input("AttentionBias"); + auto* atten_scalar = ctx.Input("AttentionScalar"); + auto* atten_scalar_bias = ctx.Input("AttentionScalarBias"); + auto* lstm_w = ctx.Input("LSTMWeight"); + auto* lstm_b = ctx.Input("LSTMBias"); + + auto* hidden_out = ctx.Output("Hidden"); + auto* cell_out = ctx.Output("Cell"); + auto* atted_x = ctx.Output("AttentionedX"); + auto* fc_out = ctx.Output("AttentionFCOut"); + auto* lstm_x = ctx.Output("LSTMX"); + auto* lstm_out = ctx.Output("LSTMOUT"); + + // some shape should be reshape here since infershape can not get lod info + auto x_lod = x->lod(); + const int N = x_lod[0].size() - 1; // batch size + auto x_dims = x->dims(); // T x M + auto w_dims = lstm_w->dims(); // (D+M) x 4D + const int total_T = x_dims[0]; + const int M = x_dims[1]; // x frame size + const int D = w_dims[1] / 4; // gate frame size + const int D2 = D * 2; + const int D3 = D * 3; + const int D4 = w_dims[1]; + int max_seq_len = x_lod[0][1]; + for (int i = 1; i < N; ++i) { + int len = x_lod[0][i + 1] - x_lod[0][i]; + max_seq_len = max_seq_len < len ? len : max_seq_len; + } + PADDLE_ENFORCE_EQ(x_lod.size(), 1, "Input(X)'s lod size must be 1."); + PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D); + fc_out->Resize({max_seq_len, 1}); + + math::VecActivations act_functor; + std::function act_gate, act_cell, act_cand; + act_gate = act_functor(ctx.Attr("gate_activation")); + act_cell = act_functor(ctx.Attr("cell_activation")); + act_cand = act_functor(ctx.Attr("candidate_activation")); + + const T* x_data = x->data(); + const T* h0_data = h0 ? h0->data() : NULL; + const T* c0_data = c0->data(); + const T* lstm_w_data = lstm_w->data(); + const T* lstm_b_data = lstm_b->data(); + const T* atten_w_data = atten_w->data(); + const T* atten_b_data = atten_b ? atten_b->data() : NULL; + const T* atten_scalar_data = atten_scalar ? atten_scalar->data() : NULL; + const T* atten_scalar_bias_data = + atten_scalar_bias ? atten_scalar_bias->data() : NULL; + + T* hidden_out_data = hidden_out->mutable_data(ctx.GetPlace()); + T* cell_out_data = cell_out->mutable_data(ctx.GetPlace()); + T* atted_x_data = atted_x->mutable_data(ctx.GetPlace()); + T* fc_out_data = fc_out->mutable_data(ctx.GetPlace()); + T* lstm_x_data = lstm_x->mutable_data(ctx.GetPlace()); + T* lstm_out_data = lstm_out->mutable_data(ctx.GetPlace()); + + // x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1 + auto blas = math::GetBlas(ctx); + math::FCCompute(blas, total_T, 1, M, x_data, atten_w_data, + atted_x_data, atten_b_data); + + const T* cur_atten_x_data = atted_x_data; + const T* cur_x_data = x_data; + const T* prev_cell_data = NULL; + const T* prev_hidden_data = NULL; + T* cur_cell_out_data = cell_out_data; + T* cur_hidden_out_data = hidden_out_data; + for (int i = 0; i < N; ++i) { + int seq_len = x_lod[0][i + 1] - x_lod[0][i]; + prev_cell_data = c0_data + i * D; + prev_hidden_data = h0_data ? h0_data + i * D : NULL; + for (int step = 0; step < seq_len; ++step) { + /// 1. compute attention vector + // 1a. prev_cell(1xD) * fc(D) rest part of atten_wgt + T prev_cell_bias = blas.DOT(D, prev_cell_data, atten_w_data + M); + // 1b. add cell bias and relu + bias_relu(seq_len, cur_atten_x_data, &prev_cell_bias, fc_out_data); + // 1c. fc scalar + if (atten_scalar_data) { + blas.SCAL(seq_len, *atten_scalar_data, fc_out_data); + bias_relu(seq_len, fc_out_data, atten_scalar_bias_data, + fc_out_data); + } + // 1d. softmax + vec_softmax(blas, seq_len, fc_out_data, fc_out_data); + // mul x(seq_len*M) and sum pool + math::FCCompute(blas, 1, M, seq_len, fc_out_data, + cur_x_data, lstm_x_data); + + /// 2. compute LSTM step + // lstm weight : concat[forget , input , output , tilde] + // shape : (D + M) x (4 * D) + // fc inputX(1xM) * weightX(M*(4D)) => 1 x 4D + blas.MatMul(1, D4, M, lstm_x_data, lstm_w_data + D * D4, lstm_out_data); + if (prev_hidden_data) { + blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast(1), + prev_hidden_data, D, lstm_w_data, D4, static_cast(1), + lstm_out_data, D4); + } + // since input is 1xM, so can use add bias + blas.VADD(D4, lstm_b_data, lstm_out_data, lstm_out_data); + + // gate act: sigmoid + act_gate(D3, lstm_out_data, lstm_out_data); + // candicate act: tanh + act_cand(D, lstm_out_data + D3, lstm_out_data + D3); + + // a = forget * prev_cell + blas.VMUL(D, lstm_out_data, prev_cell_data, lstm_out_data); + + // b = input * tilde + blas.VMUL(D, lstm_out_data + D, lstm_out_data + D3, lstm_out_data + D); + + // cell_out = a + b + blas.VADD(D, lstm_out_data, lstm_out_data + D, cur_cell_out_data); + + // state act tanh(cell_out) * output_gate + act_cell(D, cur_cell_out_data, lstm_out_data); + blas.VMUL(D, lstm_out_data, lstm_out_data + D2, cur_hidden_out_data); + + prev_hidden_data = cur_hidden_out_data; + prev_cell_data = cur_cell_out_data; + cur_cell_out_data = cur_cell_out_data + D; + cur_hidden_out_data = cur_hidden_out_data + D; + } + cur_x_data = cur_x_data + seq_len * M; + cur_atten_x_data = cur_atten_x_data + seq_len; + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(attention_lstm, ops::AttentionLSTMOp, + ops::AttentionLSTMOpMaker, + paddle::framework::DefaultGradOpDescMaker); + +REGISTER_OP_CPU_KERNEL(attention_lstm, ops::AttentionLSTMKernel, + ops::AttentionLSTMKernel); diff --git a/paddle/fluid/operators/attention_lstm_op.h b/paddle/fluid/operators/attention_lstm_op.h new file mode 100644 index 0000000000000000000000000000000000000000..6ede3a7f3c96dd2d13d7c5c19816647e16a3c8d0 --- /dev/null +++ b/paddle/fluid/operators/attention_lstm_op.h @@ -0,0 +1,41 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; + +class AttentionLSTMOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override; +}; + +class AttentionLSTMOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 969f75544fa42b948e982569c3d6105d3ce282d6..5912a1a17cbd29c3ebd83f37133c044f0905c8bd 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -135,7 +135,7 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Variance", "The global variance (for training) " "or estimated Variance (for testing)"); - AddOutput("Y", "result after normalization"); + AddOutput("Y", "result after normalization").Reuse("X"); AddOutput("MeanOut", "Share memory with Mean. " "Store the global mean when training") diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index 527a87db533ac25c3170fbb3ae6a9b9aff589b3d..c5cbadc892904dc064b49ebc461944c4671a69da 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -53,6 +53,18 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { key_ += "-BWD"; } + size_t GetDstMemorySize() const { + return conv_pd_->dst_primitive_desc().get_size(); + } + + size_t GetDiffWeightsMemorySize() const { + return conv_bwd_weights_pd_->diff_weights_primitive_desc().get_size(); + } + + size_t GetDiffSourceMemorySize() const { + return conv_bwd_data_pd_->diff_src_primitive_desc().get_size(); + } + std::shared_ptr AcquireSrcMemoryFromWeightsPrimitive( const std::shared_ptr user_memory_p, std::vector& pipeline) { // NOLINT @@ -294,7 +306,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { const T* input_data = input->data(); const T* filter_data = filter->data(); - T* output_data = output->mutable_data(ctx.GetPlace()); std::vector src_tz = paddle::framework::vectorize2int(input->dims()); std::vector weights_tz = @@ -354,6 +365,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { auto user_weights_memory_p = handler.AcquireWeightsMemory( user_weights_md, to_void_cast(filter_data)); + T* output_data = + output->mutable_data(ctx.GetPlace(), handler.GetDstMemorySize()); // create reorder primitive if the input format is not the preferred one auto src_memory_p = handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); @@ -476,13 +489,6 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { T* input_grad_data = nullptr; T* filter_grad_data = nullptr; - if (input_grad) { - input_grad_data = input_grad->mutable_data(ctx.GetPlace()); - } - if (filter_grad) { - filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); - } - std::vector src_tz = paddle::framework::vectorize2int(input->dims()); std::vector weights_tz = paddle::framework::vectorize2int(filter->dims()); @@ -568,6 +574,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { handler.AcquireDiffDstMemoryFromWeightsPrimitive( user_diff_dst_memory_p, pipeline); + const size_t size = handler.GetDiffWeightsMemorySize(); + filter_grad_data = filter_grad->mutable_data(ctx.GetPlace(), size); + auto diff_weights_memory_p = handler.AcquireDiffWeightsMemoryFromWeightsPrimitive( reinterpret_cast(filter_grad_data)); @@ -590,6 +599,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { handler.AcquireDiffDstMemoryFromDataPrimitive(user_diff_dst_memory_p, pipeline); + const size_t size = handler.GetDiffSourceMemorySize(); + input_grad_data = input_grad->mutable_data(ctx.GetPlace(), size); + auto diff_src_memory_p = handler.AcquireDiffSrcMemoryFromDataPrimitive( reinterpret_cast(input_grad_data)); diff --git a/paddle/fluid/operators/fusion_lstm_op.h b/paddle/fluid/operators/fusion_lstm_op.h index 39dc09b4d116193399d8ac9a51e88dbc3e239918..7f79601602348ac454fc6c0cefcba0643ad8e6e2 100644 --- a/paddle/fluid/operators/fusion_lstm_op.h +++ b/paddle/fluid/operators/fusion_lstm_op.h @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -// #include #include "paddle/fluid/framework/op_registry.h" namespace paddle { diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 8dcf7c99f3860789dee834787eeb8b7ad4cc3530..da185d93c09f9b06bd5968b9c8e93176f9ef014b 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -90,6 +90,11 @@ class Blas { void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A, int lda, const T* B, int ldb, T beta, T* C, int ldc) const; + template + void GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, + T alpha, const T* A, int lda, const T* B, int ldb, T beta, T* C, + int ldc) const; + #ifdef PADDLE_WITH_MKLML template T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N, @@ -109,6 +114,10 @@ class Blas { void GEMM_FREE(T* data) const; #endif + template + void MatMul(const int M, const int N, const int K, const T* A, const T* B, + T* C) const; + template void MatMul(const framework::Tensor& mat_a, bool trans_a, const framework::Tensor& mat_b, bool trans_b, T alpha, @@ -140,10 +149,19 @@ class Blas { template void VCOPY(int n, const T* x, T* y) const; + template + void VEXP(int n, const T* x, T* y) const; + template void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta, T* C) const; + template + T DOT(int n, const T* x, const T* y) const; + + template + void SCAL(int n, const T a, T* x) const; + template void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha, const T* A, const T* B, T beta, T* C, @@ -215,11 +233,26 @@ class BlasT : private Blas { Base()->template VCOPY(args...); } + template + void VEXP(ARGS... args) const { + Base()->template VEXP(args...); + } + template void GEMV(ARGS... args) const { Base()->template GEMV(args...); } + template + T DOT(ARGS... args) const { + return Base()->template DOT(args...); + } + + template + void SCAL(ARGS... args) const { + Base()->template SCAL(args...); + } + template void BatchedGEMM(ARGS... args) const { Base()->template BatchedGEMM(args...); diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index dc77b6d793702458a22a2f59b68e9d9f2c23b4ff..e1df78d11e41c5f74e244643f40c6d0581fa6a4a 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -73,6 +73,16 @@ struct CBlas { platform::dynload::cblas_sgemv(args...); } + template + static float DOT(ARGS... args) { + return platform::dynload::cblas_sdot(args...); + } + + template + static void SCAL(ARGS... args) { + platform::dynload::cblas_sscal(args...); + } + template static void GEMM_BATCH(ARGS... args) { platform::dynload::cblas_sgemm_batch(args...); @@ -87,6 +97,11 @@ struct CBlas { static void VMUL(ARGS... args) { platform::dynload::vsMul(args...); } + + template + static void VEXP(ARGS... args) { + platform::dynload::vsExp(args...); + } }; template <> @@ -138,6 +153,16 @@ struct CBlas { platform::dynload::cblas_dgemv(args...); } + template + static double DOT(ARGS... args) { + return platform::dynload::cblas_ddot(args...); + } + + template + static void SCAL(ARGS... args) { + platform::dynload::cblas_dscal(args...); + } + template static void GEMM_BATCH(ARGS... args) { platform::dynload::cblas_dgemm_batch(args...); @@ -152,6 +177,11 @@ struct CBlas { static void VMUL(ARGS... args) { platform::dynload::vdMul(args...); } + + template + static void VEXP(ARGS... args) { + platform::dynload::vdExp(args...); + } }; #else @@ -210,6 +240,9 @@ struct CBlas { PADDLE_THROW("float16 SMM_GEMM not supported on CPU"); } static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); } + static void VEXP(...) { PADDLE_THROW("float16 VEXP not supported on CPU"); } + static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); }; + static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); }; #ifdef PADDLE_WITH_MKLML static void GEMM_BATCH(...) { PADDLE_THROW("float16 GEMM_BATCH not supported on CPU"); @@ -217,64 +250,6 @@ struct CBlas { #endif }; -template -inline bool UseXSMM(const int &m, const int &n, const int &k, bool transa, - bool transb, const T &alpha, const T &beta) { -#ifdef PADDLE_WITH_LIBXSMM - // Refer to https://github.com/hfp/libxsmm/blob/master/README.md - // But the threshold is custom - constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20; - if (m * n * k > LIBXSMM_THRESHOLD || transa || transb || - std::abs(alpha - static_cast(1) > - std::numeric_limits::epsilon()) || - std::abs(beta) > std::numeric_limits::epsilon()) { - return false; - } else { - return true; - } -#endif - return false; -} - -template <> -inline bool UseXSMM(const int &m, const int &n, const int &k, - bool transa, bool transb, - const platform::float16 &alpha, - const platform::float16 &beta) { - return false; -} - -template -inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha, - const T *A, int lda, const T *B, int ldb, T beta, T *C, - int ldc) { -#ifdef PADDLE_WITH_LIBXSMM - if (UseXSMM(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha, - beta)) { - // Note: SMM use ColMajor - const char transa = 'N'; - const char transb = 'N'; - CBlas::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &ldb, A, &lda, - &beta, C, &ldc); - return; - } -#endif - -#ifdef PADDLE_MKL_SPLIT_GEMM - constexpr int bs = 2; - if (M % bs == 0 && transA == CblasNoTrans && transB == CblasNoTrans) { - for (int off = 0; off < M; off += bs) { - CBlas::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, bs, N, K, alpha, - A + off * lda, lda, B, ldb, beta, C + off * ldb, ldc); - } - return; - } -#endif - CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, - beta, C, ldc); -} - #ifdef PADDLE_WITH_MKLML template <> template @@ -319,8 +294,8 @@ void Blas::GEMM(CBLAS_TRANSPOSE transA, int lda = (transA == CblasNoTrans) ? K : M; int ldb = (transB == CblasNoTrans) ? N : K; int ldc = N; - GEMM_WARP(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, - beta, C, ldc); + CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, + beta, C, ldc); } template <> @@ -329,9 +304,20 @@ void Blas::GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T *A, int lda, const T *B, int ldb, T beta, T *C, int ldc) const { - GEMM_WARP(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, - transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, - lda, B, ldb, beta, C, ldc); + CBlas::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, + transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, + lda, B, ldb, beta, C, ldc); +} + +template <> +template +void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, int M, + int N, int K, T alpha, const T *A, + int lda, const T *B, int ldb, + T beta, T *C, int ldc) const { + CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, + beta, C, ldc); } template @@ -399,6 +385,47 @@ void Blas::VMUL(int n, const T *x, const T *y, #endif } +template <> +template +void Blas::VEXP(int n, const T *x, T *y) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VEXP(n, x, y); +#else + // try to find if openblas support vexp + for (int i = 0; i < n; ++i) { + y[i] = std::exp(x[i]); + } +#endif +} + +template <> +template +T Blas::DOT(int n, const T *x, const T *y) const { +#ifdef PADDLE_WITH_MKLML + return CBlas::DOT(n, x, 1, y, 1); +#else + // try to find if openblas support cblas_dot + T sum = 0; + for (int i = 0; i < n; ++i) { + sum += x[i] * y[i]; + } + return sum; +#endif +} + +template <> +template +void Blas::SCAL(int n, const T a, T *x) const { +#ifdef PADDLE_WITH_MKLML + CBlas::SCAL(n, a, x, 1); +#else + // try to find if openblas support cblas_scal + for (int i = 0; i < n; ++i) { + x[i] = a * x[i]; + } +#endif +} + template <> template void Blas::GEMV(bool trans_a, int M, int N, T alpha, @@ -440,6 +467,42 @@ void Blas::BatchedGEMM( #endif } +template +template +void Blas::MatMul(const int M, const int N, const int K, + const T *A, const T *B, T *C) const { + this->template GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, + static_cast(1), A, K, B, N, static_cast(0), C, + N); +} + +template <> +template +void Blas::MatMul(const int M, const int N, + const int K, const T *A, + const T *B, T *C) const { +#ifdef PADDLE_WITH_LIBXSMM + // Refer to https://github.com/hfp/libxsmm/blob/master/README.md + // But the threshold is custom constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20; + + // Since the matrix is very small, + // so the unit of calculation is already very fast, + // and the if( M*N*K < LIBXSMM_THRESHOLD) would be overhead, + // use xsmm directly. + // Note: SMM use ColMajor + const char transa = 'N'; + const char transb = 'N'; + const T alpha = static_cast(1); + const T beta = static_cast(0); + CBlas::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &N, A, &K, &beta, + C, &N); + return; +#endif + + CBlas::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, + static_cast(1), A, K, B, N, static_cast(0), C, N); +} + template template void Blas::MatMul(const framework::Tensor &mat_a, diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h new file mode 100644 index 0000000000000000000000000000000000000000..48c0da0e368a0fe6efcd758536e5659eeee26f7e --- /dev/null +++ b/paddle/fluid/operators/math/cpu_vec.h @@ -0,0 +1,105 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/platform/cpu_info.h" + +namespace paddle { +namespace operators { +namespace math { + +#define SIGMOID_THRESHOLD_MIN -40.0 +#define SIGMOID_THRESHOLD_MAX 13.0 +#define EXP_MAX_INPUT 40.0 + +template +inline T sigmoid(T x) { + return 1. / (1. + exp(-x)); +} + +template +inline T tanh(T x) { + return 2. * sigmoid(2. * x) - 1.; +} + +template +inline void vec_identity(const int n, const T* x, T* y) { + // do nothing + return; +} + +template +inline void vec_sigmoid(const int n, const T* x, T* y) { + const T min = SIGMOID_THRESHOLD_MIN; + const T max = SIGMOID_THRESHOLD_MAX; + for (int i = 0; i < n; ++i) { + T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); + y[i] = 1.0 / (1.0 + std::exp(-tmp)); + } +} + +template +inline void vec_tanh(const int n, const T* x, T* y) { + for (int i = 0; i < n; ++i) { + y[i] = tanh(x[i]); + } +} + +template +inline void vec_relu(const int n, const T* x, T* y) { + for (int i = 0; i < n; ++i) { + y[i] = x[i] > 0 ? x[i] : 0; + } +} + +template <> +inline void vec_relu(const int n, const float* x, + float* y) { + // TODO(TJ): complete me + for (int i = 0; i < n; ++i) { + y[i] = x[i] > 0 ? x[i] : 0; + } +} + +template <> +inline void vec_relu(const int n, const float* x, + float* y) { + // TODO(TJ): complete me + for (int i = 0; i < n; ++i) { + y[i] = x[i] > 0 ? x[i] : 0; + } +} + +template +class VecActivations { + public: + std::function operator()( + const std::string& type) { + if (type == "sigmoid") { + return vec_sigmoid; + } else if (type == "relu") { + return vec_relu; + } else if (type == "tanh") { + return vec_tanh; + } else if (type == "identity" || type == "") { + return vec_identity; + } + PADDLE_THROW("Not support type %s.", type); + } +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/fc_compute.h b/paddle/fluid/operators/math/fc_compute.h index 8600fa9e2c4db9d54cbe0ffb68f82d52c086d4f7..1f5a49c0ab5a10b0d7dc1febd258ce76c467cb1c 100644 --- a/paddle/fluid/operators/math/fc_compute.h +++ b/paddle/fluid/operators/math/fc_compute.h @@ -25,17 +25,25 @@ namespace math { template inline void FCCompute(const BlasT& blas, const int M, const int N, const int K, const T* X, const T* W, T* Y, - const T* B = NULL) { - blas.GEMM(CblasNoTrans, CblasNoTrans, M, N, K, static_cast(1), X, W, - static_cast(0), Y); - if (B) { + const T* B = NULL, bool relu = false) { + blas.MatMul(M, N, K, X, W, Y); + if (B == NULL) { + return; + } + #ifdef PADDLE_WITH_MKLML #pragma omp parallel for if (FLAGS_paddle_num_threads > 1) #endif - for (int i = 0; i < M; i++) { - blas.AXPY(N, static_cast(1), B, Y + i * N); - } + for (int i = 0; i < M; i++) { + blas.AXPY(N, static_cast(1), B, Y + i * N); } + + if (!relu) { + return; + } + + // TODO(TJ): fuse relu + LOG(FATAL) << "Not implemented!"; } } // namespace math diff --git a/paddle/fluid/operators/stack_op.cc b/paddle/fluid/operators/stack_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..3f4b48bc7391def082c82ed451fc5a752009a2f1 --- /dev/null +++ b/paddle/fluid/operators/stack_op.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/stack_op.h" + +namespace plat = paddle::platform; +namespace ops = paddle::operators; +REGISTER_OPERATOR(stack, ops::StackOp, ops::StackOpMaker, + ops::StackGradOpDescMaker); +REGISTER_OPERATOR(stack_grad, ops::StackOpGrad); + +REGISTER_OP_CPU_KERNEL(stack, ops::StackKernel, + ops::StackKernel); + +REGISTER_OP_CPU_KERNEL(stack_grad, + ops::StackGradKernel, + ops::StackGradKernel); diff --git a/paddle/fluid/operators/stack_op.cu b/paddle/fluid/operators/stack_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..92c1bde2bcf089e5c715e90e564408e6ad37ba17 --- /dev/null +++ b/paddle/fluid/operators/stack_op.cu @@ -0,0 +1,25 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/stack_op.h" + +namespace plat = paddle::platform; +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL(stack, ops::StackKernel, + ops::StackKernel); + +REGISTER_OP_CUDA_KERNEL(stack_grad, + ops::StackGradKernel, + ops::StackGradKernel); diff --git a/paddle/fluid/operators/stack_op.h b/paddle/fluid/operators/stack_op.h new file mode 100644 index 0000000000000000000000000000000000000000..c777d5feaec1c3a6216b01359a250072a674b700 --- /dev/null +++ b/paddle/fluid/operators/stack_op.h @@ -0,0 +1,278 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/for_range.h" + +#ifdef __NVCC__ +#include +#include "paddle/fluid/framework/array.h" +#endif + +namespace paddle { +namespace operators { + +class StackOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 0, + "Number of Inputs(X) must be larger than 0"); + PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist."); + + auto input_dims = ctx->GetInputsDim("X"); + for (size_t i = 1; i < input_dims.size(); ++i) { + PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0], + "Dims of all Inputs(X) must be the same"); + } + + // Only lod of X[0] would be shared with Y + ctx->ShareLoD("X", /*->*/ "Y"); + + int axis = ctx->Attrs().Get("axis"); + int rank = input_dims[0].size(); + PADDLE_ENFORCE( + axis >= -(rank + 1) && axis < rank + 1, + "Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank); + if (axis < 0) axis += (rank + 1); + + auto vec = framework::vectorize2int(input_dims[0]); + vec.insert(vec.begin() + axis, input_dims.size()); + ctx->SetOutputDim("Y", framework::make_ddim(vec)); + } +}; + +class StackOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input of stack op.").AsDuplicable(); + AddOutput("Y", "The output of stack op."); + AddAttr("axis", + "The axis along which all of the Inputs(X) should be stacked.") + .SetDefault(0); + AddComment(R"DOC( + Stack Operator. + + Stack all of the Inputs(X) into one tensor along Attr(axis). The dims of all Inputs(X) must be the same. + )DOC"); + } +}; + +template +struct StackFunctor { + HOSTDEVICE StackFunctor(const VecXType &x, T *y, int n, int post) + : x_(x), y_(y), n_(n), post_(post) {} + + HOSTDEVICE void operator()(int idx) { + int i = idx / (n_ * post_); + int which_x = idx / post_ - i * n_; + int x_index = i * post_ + idx % post_; + y_[idx] = x_[which_x][x_index]; + } + + private: + VecXType x_; + T *y_; + int n_; + int post_; +}; + +template +struct StackGradFunctor { + HOSTDEVICE StackGradFunctor(const VecDxType &dx, const T *dy, int n, int post) + : dx_(dx), dy_(dy), n_(n), post_(post) {} + + HOSTDEVICE void operator()(int idx) { + int i = idx / (n_ * post_); + int which_x = idx / post_ - i * n_; + int x_index = i * post_ + idx % post_; + dx_[which_x][x_index] = dy_[idx]; + } + + private: + VecDxType dx_; + const T *dy_; + int n_; + int post_; +}; + +template +static inline void StackFunctorForRange(const DeviceContext &ctx, + const VecXType &x, T *y, int total_num, + int n, int post) { + platform::ForRange for_range(ctx, total_num); + for_range(StackFunctor(x, y, n, post)); +} + +template +static inline void StackGradFunctorForRange(const DeviceContext &ctx, + const VecDxType &dx, const T *dy, + int total_num, int n, int post) { + platform::ForRange for_range(ctx, total_num); + for_range(StackGradFunctor(dx, dy, n, post)); +} + +template +class StackKernel : public framework::OpKernel { + using Tensor = framework::LoDTensor; + + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto x = ctx.MultiInput("X"); + auto *y = ctx.Output("Y"); + + int axis = ctx.Attr("axis"); + if (axis < 0) axis += (x[0]->dims().size() + 1); + + int n = static_cast(x.size()); + auto *y_data = y->mutable_data(ctx.GetPlace()); + std::vector x_datas(n); + for (int i = 0; i < n; i++) x_datas[i] = x[i]->data(); + + int pre = 1, post = 1; + auto &dim = x[0]->dims(); + for (auto i = 0; i < axis; ++i) pre *= dim[i]; + for (auto i = axis; i < dim.size(); ++i) post *= dim[i]; + int total_num = pre * n * post; + + auto &dev_ctx = ctx.template device_context(); + constexpr auto kMaxThreshold = 16; + if (std::is_same::value || + n > kMaxThreshold) { +#ifdef __NVCC__ + VLOG(10) << "Stack more than " << kMaxThreshold + << " tensors on GPU may be slow."; + thrust::device_vector device_x_vec(x_datas); + auto x_data_arr = device_x_vec.data().get(); +#else + auto x_data_arr = x_datas.data(); +#endif + StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post); +#ifdef __NVCC__ + // Wait() must be called because device_x_vec may be destructed before + // kernel ends + dev_ctx.Wait(); +#endif + } +#ifdef __NVCC__ + else { // NOLINT + framework::Array x_data_arr; + for (int i = 0; i < n; ++i) x_data_arr[i] = x_datas[i]; + StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post); + } +#endif + } +}; + +class StackOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), + "Input(Y@Grad) must exist."); + + int axis = ctx->Attrs().Get("axis"); + auto dy_dim = ctx->GetInputDim(framework::GradVarName("Y")); + int rank = dy_dim.size(); + PADDLE_ENFORCE(axis >= -rank && axis < rank, + "Attr(axis) must be inside [-rank, rank), where rank = %d", + rank); + if (axis < 0) axis += rank; + + PADDLE_ENFORCE_EQ(ctx->Outputs(framework::GradVarName("X")).size(), + static_cast(dy_dim[axis]), + "Number of Outputs(X@Grad) is wrong"); + auto vec = framework::vectorize2int(dy_dim); + vec.erase(vec.begin() + axis); + ctx->SetOutputsDim( + framework::GradVarName("X"), + std::vector(dy_dim[axis], framework::make_ddim(vec))); + } +}; + +class StackGradOpDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("stack_grad"); + op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X", false)); + op->SetAttrMap(Attrs()); + return op; + } +}; + +template +class StackGradKernel : public framework::OpKernel { + using Tensor = framework::LoDTensor; + + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *dy = ctx.Input(framework::GradVarName("Y")); + auto dx = ctx.MultiOutput(framework::GradVarName("X")); + int axis = ctx.Attr("axis"); + if (axis < 0) axis += dy->dims().size(); + + int n = dy->dims()[axis]; + std::vector dx_datas(n); // NOLINT + for (int i = 0; i < n; i++) { + dx_datas[i] = dx[i]->mutable_data(ctx.GetPlace()); + } + auto dy_data = dy->data(); + + int pre = 1; + for (int i = 0; i < axis; ++i) pre *= dy->dims()[i]; + int total_num = dy->numel(); + int post = total_num / (n * pre); + + auto &dev_ctx = ctx.template device_context(); + constexpr auto kMaxThreshold = 16; + if (std::is_same::value || + n > kMaxThreshold) { +#ifdef __NVCC__ + VLOG(10) << "Stack more than " << kMaxThreshold + << " tensors on GPU may be slow."; + thrust::device_vector device_dx_vec(dx_datas); + auto dx_data_arr = device_dx_vec.data().get(); +#else + auto dx_data_arr = dx_datas.data(); +#endif + StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, + post); +#ifdef __NVCC__ + // Wait() must be called because device_dx_vec may be destructed before + // kernel ends + dev_ctx.Wait(); +#endif + } +#ifdef __NVCC__ + else { // NOLINT + framework::Array dx_data_arr; + for (int i = 0; i < n; ++i) dx_data_arr[i] = dx_datas[i]; + StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, + post); + } +#endif + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/top_k_op.cc b/paddle/fluid/operators/top_k_op.cc index 92a0697e27ba0da66fa3b0f5380e7bd52575640d..4a8ac441cfaf642fde58ee30865a22e83c065498 100644 --- a/paddle/fluid/operators/top_k_op.cc +++ b/paddle/fluid/operators/top_k_op.cc @@ -30,8 +30,6 @@ class TopkOp : public framework::OperatorWithKernel { "Output(Indices) of TopkOp should not be null."); auto input_dims = ctx->GetInputDim("X"); - PADDLE_ENFORCE_EQ(input_dims.size(), 2, - "Rank of TopK op's input must be 2."); const int k = static_cast(ctx->Attrs().Get("k")); PADDLE_ENFORCE_GE(k, 1, "k must >= 1"); diff --git a/paddle/fluid/platform/cpu_info.cc b/paddle/fluid/platform/cpu_info.cc index 7d53a684d6068c79659719159696ef5aebfeaa2b..fcd658d67cf4551dbdb9696ef49b5ab3cc58bf95 100644 --- a/paddle/fluid/platform/cpu_info.cc +++ b/paddle/fluid/platform/cpu_info.cc @@ -103,15 +103,16 @@ size_t CUDAPinnedMaxChunkSize() { return CUDAPinnedMaxAllocSize() / 256; } -#ifdef PADDLE_WITH_XBYAK namespace jit { - +#ifdef PADDLE_WITH_XBYAK static Xbyak::util::Cpu cpu; bool MayIUse(const cpu_isa_t cpu_isa) { using namespace Xbyak::util; // NOLINT switch (cpu_isa) { case sse42: return cpu.has(Cpu::tSSE42); + case avx: + return cpu.has(Cpu::tAVX); case avx2: return cpu.has(Cpu::tAVX2); case avx512_common: @@ -134,8 +135,16 @@ bool MayIUse(const cpu_isa_t cpu_isa) { } return false; } +#else +bool MayIUse(const cpu_isa_t cpu_isa) { + if (cpu_isa == isa_any) { + return true; + } else { + return false; + } +} +#endif } // namespace jit -#endif } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/cpu_info.h b/paddle/fluid/platform/cpu_info.h index f5f67667594f1ab80058533e4c5d5b04c2592b60..5d17978dd7946596c490dc465dab51e7cf53a044 100644 --- a/paddle/fluid/platform/cpu_info.h +++ b/paddle/fluid/platform/cpu_info.h @@ -37,12 +37,11 @@ size_t CUDAPinnedMinChunkSize(); //! Get the maximum chunk size for buddy allocator. size_t CUDAPinnedMaxChunkSize(); -#ifdef PADDLE_WITH_XBYAK namespace jit { - typedef enum { isa_any, sse42, + avx, avx2, avx512_common, avx512_core, @@ -55,7 +54,6 @@ typedef enum { inline bool MayIUse(const cpu_isa_t cpu_isa); } // namespace jit -#endif } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/dynload/mklml.h b/paddle/fluid/platform/dynload/mklml.h index 15ad4a3b40b1ad13a10dd37449c6f6f3e2029df6..aa20553ceffceded09447693c6e92f55fb48702d 100644 --- a/paddle/fluid/platform/dynload/mklml.h +++ b/paddle/fluid/platform/dynload/mklml.h @@ -66,10 +66,16 @@ extern void* mklml_dso_handle; __macro(cblas_dgemm_free); \ __macro(cblas_sgemm_batch); \ __macro(cblas_dgemm_batch); \ + __macro(cblas_sdot); \ + __macro(cblas_ddot); \ + __macro(cblas_sscal); \ + __macro(cblas_dscal); \ __macro(vsAdd); \ __macro(vdAdd); \ __macro(vsMul); \ __macro(vdMul); \ + __macro(vsExp); \ + __macro(vdExp); \ __macro(MKL_Set_Num_Threads) MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP); diff --git a/paddle/fluid/pybind/const_value.cc b/paddle/fluid/pybind/const_value.cc index a81715c3b317ac47d1aefd1752ef5b434a15494a..e4415ed15c791100a5b309e73d7deb5943f71b97 100644 --- a/paddle/fluid/pybind/const_value.cc +++ b/paddle/fluid/pybind/const_value.cc @@ -43,9 +43,6 @@ void BindConstValue(pybind11::module* m) { op_proto_and_checker_maker.def( "kOpRoleVarAttrName", framework::OpProtoAndCheckerMaker::OpRoleVarAttrName); - op_proto_and_checker_maker.def( - "kOpCreationCallstackAttrName", - framework::OpProtoAndCheckerMaker::OpCreationCallstackAttrName); } } // namespace pybind diff --git a/python/paddle/dataset/common.py b/python/paddle/dataset/common.py index 1d7ff582c86a40c8c2086e0de16e89d69c94da60..ece4046f5b7a7eff5be724d6f890665be7f3344e 100644 --- a/python/paddle/dataset/common.py +++ b/python/paddle/dataset/common.py @@ -19,6 +19,7 @@ import hashlib import os import errno import shutil +import six import sys import importlib import paddle.dataset @@ -94,6 +95,8 @@ def download(url, module_name, md5sum, save_name=None): dl = 0 total_length = int(total_length) for data in r.iter_content(chunk_size=4096): + if six.PY2: + data = six.b(data) dl += len(data) f.write(data) done = int(50 * dl / total_length) diff --git a/python/paddle/dataset/flowers.py b/python/paddle/dataset/flowers.py index aa73bbaf7024ec873d9e921205536f12e097ff32..0a1cdaceaf3be48a06b1c0b5b979e90f50e9000c 100644 --- a/python/paddle/dataset/flowers.py +++ b/python/paddle/dataset/flowers.py @@ -35,6 +35,7 @@ import itertools import functools from .common import download import tarfile +import six import scipy.io as scio from paddle.dataset.image import * from paddle.reader import * @@ -45,10 +46,10 @@ from six.moves import cPickle as pickle from six.moves import zip __all__ = ['train', 'test', 'valid'] -DATA_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz' -LABEL_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat' -SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat' -DATA_MD5 = '33bfc11892f1e405ca193ae9a9f2a118' +DATA_URL = 'http://paddlemodels.cdn.bcebos.com/flowers/102flowers.tgz' +LABEL_URL = 'http://paddlemodels.cdn.bcebos.com/flowers/imagelabels.mat' +SETID_URL = 'http://paddlemodels.cdn.bcebos.com/flowers/setid.mat' +DATA_MD5 = '52808999861908f626f3c1f4e79d11fa' LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d' SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c' # In official 'readme', tstid is the flag of test data @@ -120,7 +121,10 @@ def reader_creator(data_file, file = file.strip() batch = None with open(file, 'rb') as f: - batch = pickle.load(f) + if six.PY2: + batch = pickle.load(f) + else: + batch = pickle.load(f, encoding='bytes') data = batch['data'] labels = batch['label'] for sample, label in zip(data, batch['label']): diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index e0ddd3b5ffecfd5c49b2da7dfa739e2c3e93838f..febb750ee1af26c71e6c1ae1e4c97fb02fb27a04 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -18,7 +18,6 @@ import collections import contextlib import re import six -import traceback import numpy as np @@ -506,10 +505,6 @@ class Operator(object): if role_var_name in op_attrs and len(op_attrs[role_var_name]) == 0: del op_attrs[role_var_name] - callstack_var_name = op_maker.kOpCreationCallstackAttrName() - op_attrs[callstack_var_name] = list( - reversed(traceback.format_stack()))[1:] - if len(self.desc.type()) != 0: return if type is None: diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index a815ba0f2f4a946f37da6baaafcd56fbb880adda..4bd260a00503c57b7f67b2706b4c25e43271c3f6 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -27,7 +27,6 @@ from . import utils import random from .. import unique_name from functools import reduce -import warnings __all__ = [ 'fc', @@ -104,6 +103,7 @@ __all__ = [ 'rank_loss', 'prelu', 'flatten', + 'stack', ] @@ -2047,7 +2047,7 @@ def batch_norm(input, param_attr(ParamAttr): The parameter attribute for Parameter `scale`. bias_attr(ParamAttr): The parameter attribute for Parameter `bias`. data_layout(string, default NCHW): NCHW|NHWC - in_place(bool, Default False): This argument is deprecated since 0.15.0. + in_place(bool, Default False): Make the input and output of batch norm reuse memory. use_mkldnn(bool, Default false): ${use_mkldnn_comment} name(string, Default None): A name for this layer(optional). If set None, the layer will be named automatically. @@ -2069,10 +2069,6 @@ def batch_norm(input, helper = LayerHelper('batch_norm', **locals()) dtype = helper.input_dtype() - if in_place: - raise warnings.warn("The argument in_place is deprecated since 0.15.0, " - "please do not set it True.") - input_shape = input.shape if data_layout == 'NCHW': channel_num = input_shape[1] @@ -2122,7 +2118,7 @@ def batch_norm(input, saved_mean = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) saved_variance = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) - batch_norm_out = helper.create_tmp_variable(dtype) + batch_norm_out = input if in_place else helper.create_tmp_variable(dtype) helper.append_op( type="batch_norm", @@ -5522,3 +5518,17 @@ def flatten(x, axis=1, name=None): outputs={'Out': out}, attrs={"axis": axis}) return out + + +def stack(x, axis=0): + helper = LayerHelper('stack', **locals()) + axis = 0 if axis is None else axis + + if not isinstance(x, list) and not isinstance(x, tuple): + x = [x] + + out = helper.create_tmp_variable(x[0].dtype) + helper.append_op( + type='stack', inputs={'X': x}, outputs={'Y': out}, + attrs={'axis': axis}) + return out diff --git a/python/paddle/fluid/nets.py b/python/paddle/fluid/nets.py index 01563cbbb706d9a1c9c9d46ded71f7f48b5a9f04..051fe84364639ca6028326c0cb02b204a02531af 100644 --- a/python/paddle/fluid/nets.py +++ b/python/paddle/fluid/nets.py @@ -229,7 +229,7 @@ def img_conv_group(input, use_mkldnn=use_mkldnn) if conv_with_batchnorm[i]: - tmp = layers.batch_norm(input=tmp, act=conv_act) + tmp = layers.batch_norm(input=tmp, act=conv_act, in_place=True) drop_rate = conv_batchnorm_drop_rate[i] if abs(drop_rate) > 1e-5: tmp = layers.dropout(x=tmp, dropout_prob=drop_rate) diff --git a/python/paddle/fluid/tests/book/test_image_classification.py b/python/paddle/fluid/tests/book/test_image_classification.py index cd1e8cd682315ef4931e323536a57542f4b3bc26..9fe361425c128590da910128beaccb3336f8ba57 100644 --- a/python/paddle/fluid/tests/book/test_image_classification.py +++ b/python/paddle/fluid/tests/book/test_image_classification.py @@ -256,10 +256,7 @@ def main(net_type, use_cuda, is_local=True): save_dirname = "image_classification_" + net_type + ".inference.model" train(net_type, use_cuda, save_dirname, is_local) - - # There is bug in fluid.InferenceTranspiler for VGG. - if net_type == "resnet": - infer(use_cuda, save_dirname) + infer(use_cuda, save_dirname) class TestImageClassification(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py b/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py new file mode 100644 index 0000000000000000000000000000000000000000..a7382c2244ec3291c4e8f625cc2d15499e0acdac --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py @@ -0,0 +1,208 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +from test_fusion_lstm_op import fc, ACTIVATION +from test_softmax_op import stable_softmax + + +def attention_lstm( + x, # T x M + lod, # 1 x N + h0, # N x D + c0, # N x D + fcws, # (M+D) x 1, 1x1 + fcbs, # 1 x 1, 1x1 + w, # (M+D) x 4D + b, # 1 x 4D + act_gate, + act_cell, + act_cand): + + T = sum(lod[0]) + N = len(lod[0]) + M = x.shape[1] + D = b.shape[1] / 4 + assert T == x.shape[0] + assert len(fcws) == len(fcbs) + hidden = [] + cell = [] + + start_offset = 0 + for bid in range(N): + seq_len = lod[0][bid] + xi = np.copy(x[start_offset:start_offset + seq_len, :]).reshape(seq_len, + M) + prev_cell = np.copy(c0[bid]).reshape([1, D]) + prev_hidden = np.copy(h0[bid]).reshape([1, D]) + for step in range(seq_len): + expanded_cell = np.repeat(prev_cell, seq_len, axis=0) + tmp = np.concatenate((xi, expanded_cell), axis=1) + assert tmp.shape[0] == seq_len + assert tmp.shape[1] == M + D + for fcid in range(len(fcbs)): + tmp = fc(tmp, fcws[fcid], fcbs[fcid]) + tmp = ACTIVATION['relu'](tmp) + tmp = np.reshape(tmp, (1, seq_len)) + tmp = stable_softmax(tmp).reshape(seq_len, 1) + lstmx = xi * tmp # seq * M + lstmx = np.sum(lstmx.reshape(seq_len, M), axis=0).reshape([1, M]) + lstmin = np.concatenate((prev_hidden, lstmx), axis=1) + lstmout = fc(lstmin, w, b).reshape([1, 4 * D]) + + g_f, g_i, g_o, cand = np.split(lstmout, 4, axis=1) + g_f = act_gate(g_f).reshape([1, D]) + g_i = act_gate(g_i).reshape([1, D]) + g_o = act_gate(g_o).reshape([1, D]) + cand = act_cand(cand).reshape([1, D]) + + cell_t = (prev_cell * g_f) + (g_i * cand) + hidden_t = g_o * act_cell(cell_t) + + hidden.append(hidden_t.flatten()) + cell.append(cell_t.flatten()) + + prev_cell = cell_t.reshape([1, D]) + prev_hidden = hidden_t.reshape([1, D]) + + start_offset += seq_len + + hidden = np.array(hidden).astype('float32').reshape([T, D]) + cell = np.array(cell).astype('float32').reshape([T, D]) + return hidden, cell + + +class TestAttentionLSTMOp(OpTest): + def set_conf(self): + pass + + def setUp(self): + self.op_type = 'attention_lstm' + self.lod = [[3]] + self.M = 30 + self.D = 15 + self.has_initial_hidden = True + self.act_gate = 'sigmoid' + self.act_cell = 'tanh' + self.act_cand = 'tanh' + self.set_conf() + + T = sum(self.lod[0]) + bs = len(self.lod[0]) + + x = np.random.normal(size=(T, self.M)).astype('float32') + c0 = np.random.normal(size=(bs, self.D)).astype('float32') + if self.has_initial_hidden: + h0 = np.random.normal(size=(bs, self.D)).astype('float32') + else: + h0 = np.zeros((bs, self.D)).astype('float32') + + fcw1 = np.random.normal(size=(self.M + self.D, 1)).astype('float32') + fcb1 = np.random.normal(size=(1, 1)).astype('float32') + fcw2 = np.random.normal(size=(1, 1)).astype('float32') + fcb2 = np.random.normal(size=(1, 1)).astype('float32') + + # lstm weight and bias + w = np.random.normal(size=(self.M + self.D, + self.D * 4)).astype('float32') + b = np.random.normal(size=(1, self.D * 4)).astype('float32') + + h, c = attention_lstm(x, self.lod, h0, c0, [fcw1, fcw2], [fcb1, fcb2], + w, b, ACTIVATION[self.act_gate], + ACTIVATION[self.act_cell], + ACTIVATION[self.act_cand]) + + self.inputs = { + 'X': (x, self.lod), + 'C0': c0, + 'AttentionWeight': fcw1, + 'AttentionBias': fcb1, + 'AttentionScalar': fcw2, + 'AttentionScalarBias': fcb2, + 'LSTMWeight': w, + 'LSTMBias': b + } + + if self.has_initial_hidden: + self.inputs['H0'] = h0 + + self.outputs = { + 'Hidden': (h, self.lod), + 'Cell': (c, self.lod), + } + self.attrs = { + 'gate_activation': self.act_gate, + 'cell_activation': self.act_cell, + 'candidate_activation': self.act_cand + } + + def test_check_output(self): + self.check_output() + + +class TestAttentionOpNonInit(TestAttentionLSTMOp): + def set_conf(self): + self.has_initial_hidden = False + + +class TestAttentionOpAct(TestAttentionLSTMOp): + def set_conf(self): + self.M = 3 + self.D = 2 + self.act_gate = 'relu' + self.act_cell = 'tanh' + self.act_cand = 'sigmoid' + + +class TestAttentionOpMD1(TestAttentionLSTMOp): + def set_conf(self): + self.M = 36 + self.D = 8 + + +class TestAttentionOpMD2(TestAttentionLSTMOp): + def set_conf(self): + self.M = 8 + self.D = 8 + + +class TestAttentionOpMD3(TestAttentionLSTMOp): + def set_conf(self): + self.M = 15 + self.D = 30 + + +class TestAttentionOpBS1(TestAttentionLSTMOp): + def set_conf(self): + self.lod = [[5]] + self.M = 16 + self.D = 32 + + +class TestAttentionOpBS2(TestAttentionLSTMOp): + def set_conf(self): + self.lod = [[3, 6]] + + +class TestAttentionOpBS5(TestAttentionLSTMOp): + def set_conf(self): + self.lod = [[3, 2, 4, 7, 5]] + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_operator_desc.py b/python/paddle/fluid/tests/unittests/test_operator_desc.py index 3ac82680733feb4b82ab98669269160e4aad948f..6d01955993324498de42462b7f85ef6f8e444505 100644 --- a/python/paddle/fluid/tests/unittests/test_operator_desc.py +++ b/python/paddle/fluid/tests/unittests/test_operator_desc.py @@ -67,10 +67,7 @@ class TestOperator(unittest.TestCase): self.assertEqual(mul_op.output("Out"), ["mul.out"]) self.assertEqual( set(mul_op.attr_names), - set([ - "x_num_col_dims", "y_num_col_dims", "op_role", "op_role_var", - "op_callstack" - ])) + set(["x_num_col_dims", "y_num_col_dims", "op_role", "op_role_var"])) self.assertEqual(mul_op.has_attr("x_num_col_dims"), True) self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT) self.assertEqual(mul_op.attr("x_num_col_dims"), 1) diff --git a/python/paddle/fluid/tests/unittests/test_program_code.py b/python/paddle/fluid/tests/unittests/test_program_code.py new file mode 100644 index 0000000000000000000000000000000000000000..e9c2b928617dce3904ca119896ca81454256e82e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_program_code.py @@ -0,0 +1,81 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import unittest +from multiprocessing import Process +import signal + +import numpy + +import paddle.fluid as fluid +import paddle.fluid.layers as layers +from paddle.fluid.layers.io import ListenAndServ +from paddle.fluid.layers.io import Recv +from paddle.fluid.layers.io import Send + +from paddle.fluid.transpiler.details import program_to_code + + +class TestProgram2Code(unittest.TestCase): + def test_print(self): + place = fluid.CPUPlace() + self.init_serv(place) + self.init_client(place, 9123) + + def init_serv(self, place): + main = fluid.Program() + + with fluid.program_guard(main): + serv = ListenAndServ("127.0.0.1:0", ["X"], optimizer_mode=False) + with serv.do(): + out_var = main.global_block().create_var( + name="scale_0.tmp_0", + psersistable=True, + dtype="float32", + shape=[32, 32]) + x = layers.data( + shape=[32, 32], + dtype='float32', + name="X", + append_batch_size=False) + fluid.initializer.Constant(value=1.0)(x, main.global_block()) + layers.scale(x=x, scale=10.0, out=out_var) + + program_to_code(main) + + def init_client(self, place, port): + main = fluid.Program() + with fluid.program_guard(main): + x = layers.data( + shape=[32, 32], + dtype='float32', + name='X', + append_batch_size=False) + fluid.initializer.Constant(value=2.3)(x, main.global_block()) + get_var = main.global_block().create_var( + name="scale_0.tmp_0", # server side var + dtype="float32", + persistable=False, + shape=[32, 32]) + fluid.initializer.Constant(value=2.3)(get_var, main.global_block()) + Send("127.0.0.1:%d" % port, [x]) + o = Recv("127.0.0.1:%d" % port, [get_var]) + + program_to_code(main) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_stack_op.py b/python/paddle/fluid/tests/unittests/test_stack_op.py new file mode 100644 index 0000000000000000000000000000000000000000..defdeb5d70df4c39ed8e23247270e6eb3dd14a7a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_stack_op.py @@ -0,0 +1,92 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from op_test import OpTest +import numpy as np +import unittest + + +class TestStackOpBase(OpTest): + def initDefaultParameters(self): + self.num_inputs = 4 + self.input_dim = (5, 6, 7) + self.axis = 0 + self.dtype = 'float32' + + def initParameters(self): + pass + + def get_x_names(self): + x_names = [] + for i in range(self.num_inputs): + x_names.append('x{}'.format(i)) + return x_names + + def setUp(self): + self.initDefaultParameters() + self.initParameters() + self.op_type = 'stack' + self.x = [] + for i in range(self.num_inputs): + self.x.append( + np.random.random(size=self.input_dim).astype(self.dtype)) + + tmp = [] + x_names = self.get_x_names() + for i in range(self.num_inputs): + tmp.append((x_names[i], self.x[i])) + + self.inputs = {'X': tmp} + self.outputs = {'Y': np.stack(self.x, axis=self.axis)} + self.attrs = {'axis': self.axis} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(self.get_x_names(), 'Y') + + +class TestStackOp1(TestStackOpBase): + def initParameters(self): + self.num_inputs = 16 + + +class TestStackOp2(TestStackOpBase): + def initParameters(self): + self.num_inputs = 20 + + +class TestStackOp3(TestStackOpBase): + def initParameters(self): + self.axis = -1 + + +class TestStackOp4(TestStackOpBase): + def initParameters(self): + self.axis = -4 + + +class TestStackOp5(TestStackOpBase): + def initParameters(self): + self.axis = 1 + + +class TestStackOp6(TestStackOpBase): + def initParameters(self): + self.axis = 3 + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/transpiler/details/program_utils.py b/python/paddle/fluid/transpiler/details/program_utils.py index 640dbf4bbed58edf746456419af18c75241fa03c..420ae6dfd4b75b507dd01bb947fa707bca5cdb08 100644 --- a/python/paddle/fluid/transpiler/details/program_utils.py +++ b/python/paddle/fluid/transpiler/details/program_utils.py @@ -16,6 +16,9 @@ from __future__ import print_function import six +from paddle.fluid import core +import paddle + def delete_ops(block, ops): try: @@ -39,3 +42,133 @@ def find_op_by_output_arg(block, arg_name): if arg_name in op.output_arg_names: return index return -1 + + +def get_indent_space(indent, space_num=4): + ret = "" + for i in range(0, indent * space_num): + ret += " " + + return ret + + +def variable_to_code(var): + """ + Get readable codes of fluid variable. + + Args: + var: A fluid operator. + + Returns: + string: The formatted string. + """ + + var_str = "{name} : fluid.{type}.shape{shape}.astype({dtype})".\ + format(i="{", e="}", name=var.name, type=var.type, shape=var.shape, dtype=var.dtype) + + if type(var) == paddle.fluid.framework.Parameter: + if var.trainable: + var_str = "trainable parameter " + var_str + else: + var_str = "parameter " + var_str + else: + var_str = "var " + var_str + + if var.persistable: + var_str = "persist " + var_str + + return var_str + + +def op_to_code(op): + """ + Get readable codes of fluid operator. + + Args: + op: A fluid operator. + + Returns: + string: The foramtted string. + """ + + outputs_str = "{" + for i in range(0, len(op.output_names)): + outputs_str += "{name}=".format(name=op.output_names[i]) + o = op.output(op.output_names[i]) + outputs_str += "{value}".format(value=o) + if i != len(op.output_names) - 1: + outputs_str += ", " + outputs_str += "}" + + inputs_str = "{" + for i in range(0, len(op.input_names)): + inputs_str += "{name}=".format(name=op.input_names[i]) + o = op.input(op.input_names[i]) + inputs_str += "{value}".format(value=o) + + if i != len(op.input_names) - 1: + inputs_str += ", " + inputs_str += "}" + + attrs_str = "" + for i in range(0, len(op.attr_names)): + name = op.attr_names[i] + + attr_type = op.desc.attr_type(name) + if attr_type == core.AttrType.BLOCK: + a = "{name} = block[{value}]".format( + name=name, type=attr_type, value=op.block_attr_id(name)) + attrs_str += a + continue + + if attr_type == core.AttrType.BLOCKS: + a = "{name} = blocks{value}".format( + name=name, type=attr_type, value=op.blocks_attr_ids(name)) + attrs_str += a + continue + + a = "{name} = {value}".format( + name=name, type=attr_type, value=op.desc.attr(name)) + attrs_str += a + if i != len(op.attr_names) - 1: + attrs_str += ", " + + if outputs_str != "{}": + op_str = "{outputs} = {op_type}(inputs={inputs}, {attrs})".\ + format(outputs = outputs_str, op_type=op.type, inputs=inputs_str, attrs=attrs_str) + else: + op_str = "{op_type}(inputs={inputs}, {attrs})".\ + format(op_type=op.type, inputs=inputs_str, attrs=attrs_str) + return op_str + + +def program_to_code(prog): + """ + Print readable codes of fluid program. + + Args: + prog : A fluid program. + + An example result like bellow: + https://github.com/PaddlePaddle/Paddle/pull/12673 + """ + indent = 0 + block_idx = 0 + for block in prog.blocks: + print("{0}{1} // block {2}".format( + get_indent_space(indent), '{', block_idx)) + indent += 1 + # sort all vars + all_vars = sorted(block.vars.iteritems(), key=lambda x: x[0]) + for var in all_vars: + print("{}{}".format( + get_indent_space(indent), variable_to_code(var[1]))) + + if len(all_vars) > 0: + print("") + + for op in block.ops: + print("{}{}".format(get_indent_space(indent), op_to_code(op))) + indent -= 1 + print("{0}{1}".format(get_indent_space(indent), '}')) + block_idx += 1