diff --git a/doc/v2/api/fluid/optimizer.rst b/doc/v2/api/fluid/optimizer.rst index 9b165f870459b4f9ef2efe24f5604a3fcb96f7f3..2f820595c35c2bccd6a5c8a20c60d796c04c8e97 100644 --- a/doc/v2/api/fluid/optimizer.rst +++ b/doc/v2/api/fluid/optimizer.rst @@ -47,3 +47,10 @@ DecayedAdagrad :members: :noindex: +Adadelta +-------------- + +.. autoclass:: paddle.fluid.optimizer.AdadeltaOptimizer + :members: + :noindex: + diff --git a/doc/v2/howto/cluster/index_en.rst b/doc/v2/howto/cluster/index_en.rst index c965d30d54e71339cf10d4b05f25e740c81adbf9..31eda57c4fb3947d92df45ea8dbb9274c9814140 100644 --- a/doc/v2/howto/cluster/index_en.rst +++ b/doc/v2/howto/cluster/index_en.rst @@ -2,6 +2,9 @@ Distributed Training ==================== The effectiveness of the deep learning model is often directly related to the scale of the data: it can generally achieve better results after increasing the size of the dataset on the same model. However, it can not fit in one single computer when the amount of data increases to a certain extent. At this point, using multiple computers for distributed training is a natural solution. In distributed training, the training data is divided into multiple copies (sharding), and multiple machines participating in the training read their own data for training and collaboratively update the parameters of the overall model. + +Distributed training generally has framwork as shown below: + .. image:: src/ps_en.png :width: 500 diff --git a/paddle/fluid/operators/dropout_op.cc b/paddle/fluid/operators/dropout_op.cc index 1074ed6acc22a81f46c466d917ef973945a12898..e4436549f6185ba04a5f270893596a6dcb11e89b 100644 --- a/paddle/fluid/operators/dropout_op.cc +++ b/paddle/fluid/operators/dropout_op.cc @@ -35,7 +35,6 @@ class DropoutOp : public framework::OperatorWithKernel { } }; -template class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { public: DropoutOpMaker(OpProto* proto, OpAttrChecker* op_checker) @@ -73,7 +72,6 @@ are set equal to their corresponding inputs. } }; -template class DropoutOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -103,11 +101,10 @@ class DropoutOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker, dropout_grad, - ops::DropoutOpGrad); +REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker, dropout_grad, + ops::DropoutOpGrad); REGISTER_OP_CPU_KERNEL( - dropout, - ops::CPUDropoutKernel); + dropout, ops::CPUDropoutKernel); REGISTER_OP_CPU_KERNEL( dropout_grad, ops::DropoutGradKernel); diff --git a/paddle/fluid/operators/dropout_op.cu b/paddle/fluid/operators/dropout_op.cu index d6f9c04359d733cb4f3f0586e9239ee67deb7078..f6c85a2a537b37feb20e6d62729dc5075af2a5d9 100644 --- a/paddle/fluid/operators/dropout_op.cu +++ b/paddle/fluid/operators/dropout_op.cu @@ -18,17 +18,18 @@ limitations under the License. */ #include #include #include "paddle/fluid/operators/dropout_op.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { -template +template __global__ void RandomGenerator(const size_t n, const int seed, - const AttrType dropout_prob, const T* src, + const float dropout_prob, const T* src, T* mask_data, T* dst) { thrust::minstd_rand rng; rng.seed(seed); - thrust::uniform_real_distribution dist(0, 1); + thrust::uniform_real_distribution dist(0, 1); int idx = blockDim.x * blockIdx.x + threadIdx.x; for (; idx < n; idx += blockDim.x * gridDim.x) { @@ -44,14 +45,14 @@ __global__ void RandomGenerator(const size_t n, const int seed, // It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT. // Use std::random and thrust::random(thrust is a std library in CUDA) to // implement uniform random. -template +template class GPUDropoutKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* x = context.Input("X"); auto* y = context.Output("Out"); y->mutable_data(context.GetPlace()); - AttrType dropout_prob = context.Attr("dropout_prob"); + float dropout_prob = context.Attr("dropout_prob"); auto X = EigenMatrix::Reshape(*x, 1); auto Y = EigenMatrix::Reshape(*y, 1); @@ -70,11 +71,11 @@ class GPUDropoutKernel : public framework::OpKernel { int threads = 512; int grid = (x->numel() + threads - 1) / threads; - RandomGenerator<<>>( + RandomGenerator< + T><<>>( size, seed, dropout_prob, x_data, mask_data, y_data); } else { - Y.device(place) = X * (1.0f - dropout_prob); + Y.device(place) = X * static_cast(1.0f - dropout_prob); } } }; @@ -83,9 +84,9 @@ class GPUDropoutKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( - dropout, - ops::GPUDropoutKernel); -REGISTER_OP_CUDA_KERNEL( - dropout_grad, - ops::DropoutGradKernel); + dropout, ops::GPUDropoutKernel, + ops::GPUDropoutKernel); +REGISTER_OP_CUDA_KERNEL(dropout_grad, + ops::DropoutGradKernel); diff --git a/paddle/fluid/operators/dropout_op.h b/paddle/fluid/operators/dropout_op.h index 209e4dec1756dc65fbf147c4dbbf0913d3c6ef7e..b5ee86ae2d11dfc835e1a3a6826ce016baf38a29 100644 --- a/paddle/fluid/operators/dropout_op.h +++ b/paddle/fluid/operators/dropout_op.h @@ -25,7 +25,7 @@ template using EigenMatrix = framework::EigenMatrix; -template +template class CPUDropoutKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 4253300788462a3704076fc79241a864f2f130a0..a594de67e05acd28ffedc5407beecfaea1281444 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/proto_desc.h" +#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/operators/detail/grpc_server.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/detail/simple_block_queue.h" @@ -89,6 +90,10 @@ class ListenAndServOp : public framework::OperatorBase { auto *block = Attr(kOptimizeBlock); auto *program = block->Program(); + int num_blocks = program->Size(); + PADDLE_ENFORCE_GE(num_blocks, 2, + "server program should have at least 2 blocks"); + framework::Executor executor(dev_place); // TODO(typhoonzero): change this to a while_op for every cluster-batch. @@ -132,12 +137,36 @@ class ListenAndServOp : public framework::OperatorBase { rpc_service_->ShutDown(); break; } - try { - executor.Run(*program, &recv_scope, block->ID(), /*global_block*/ - false /*create_local_scope*/, false /*create_vars*/); - } catch (std::exception &e) { - LOG(ERROR) << "run sub program error " << e.what(); + + // put optimize blocks in the thread pool to start run, the last block + // should be global ops. + // NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads + // and this will still work. + std::vector> fs; + // block0 contains only listen_and_serv op, start run from block1. + for (int blkid = 1; blkid < num_blocks - 1; ++blkid) { + fs.push_back(framework::Async([&executor, &program, &recv_scope, + blkid]() { + int run_block = blkid; // thread local + try { + executor.Run(*program, &recv_scope, run_block, + false /*create_local_scope*/, false /*create_vars*/); + } catch (std::exception &e) { + LOG(ERROR) << "run sub program error " << e.what(); + } + })); + } + for (int i = 0; i < num_blocks - 2; ++i) fs[i].wait(); + // Run global block at final step, or block1 if there are only 2 blocks + if (num_blocks >= 2) { + try { + executor.Run(*program, &recv_scope, num_blocks - 1, + false /*create_local_scope*/, false /*create_vars*/); + } catch (std::exception &e) { + LOG(ERROR) << "run sub program error " << e.what(); + } } + // Reset the received sparse variables, the sum operator would not // sum the input sparse variables which rows is empty at the next // mini-batch. diff --git a/paddle/fluid/operators/math/math_function.cc b/paddle/fluid/operators/math/math_function.cc index 35d251f71a0cb631d5900498ea3188b5ddeae334..17e576a9d5c8f50fbe84b066a93460f03ae6bb08 100644 --- a/paddle/fluid/operators/math/math_function.cc +++ b/paddle/fluid/operators/math/math_function.cc @@ -371,6 +371,8 @@ template struct RowwiseAdd; template struct ColwiseSum; template struct ColwiseSum; +template struct ColwiseSum; +template struct ColwiseSum; template struct RowwiseSum; template struct RowwiseSum; diff --git a/paddle/fluid/operators/math/math_function.cu b/paddle/fluid/operators/math/math_function.cu index 3abbcdb71d03eaf6f8eba3d97150d27ac5a5405e..c6ca2693a053360ce5dc44765acf1520a11cce2c 100644 --- a/paddle/fluid/operators/math/math_function.cu +++ b/paddle/fluid/operators/math/math_function.cu @@ -422,6 +422,8 @@ struct RowwiseAdd { template struct RowwiseAdd; template struct RowwiseAdd; template struct ColwiseSum; +template struct ColwiseSum; +template struct ColwiseSum; // template struct ColwiseSum; // The ColwiseSum failed in debug mode, // and only failed for this case. So reimplemented it. diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index d0de092947eb04a1b7d06dedea919f6b1094dd06..bd0bb2ee3b0252f47318c59d9940d8dd478723de 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -48,20 +48,24 @@ class DoubleBufferReader : public framework::DecoratedReader { void start_thread() { buffer_ = framework::MakeChannel(kDoubleBufferSize); - std::thread prefetch([this] { PrefetchThreadFunc(); }); - prefetch.detach(); + prefetcher_ = std::thread([this] { PrefetchThreadFunc(); }); } void ReadNext(std::vector* out) override; void ReInit() override; - ~DoubleBufferReader() { buffer_->Close(); } + ~DoubleBufferReader() { + buffer_->Close(); + prefetcher_.join(); + delete buffer_; + } bool HasNext() const override; private: void PrefetchThreadFunc(); + std::thread prefetcher_; framework::Channel* buffer_; platform::Place place_; std::vector> ctxs_; @@ -134,6 +138,8 @@ void DoubleBufferReader::ReadNext(std::vector* out) { void DoubleBufferReader::ReInit() { reader_->ReInit(); buffer_->Close(); + prefetcher_.join(); + delete buffer_; start_thread(); } @@ -159,11 +165,12 @@ void DoubleBufferReader::PrefetchThreadFunc() { if (!buffer_->Send(&batch)) { VLOG(5) << "WARNING: The double buffer channel has been closed. The " - "prefetch thread terminates."; + "prefetch thread will terminate."; break; } } buffer_->Close(); + VLOG(5) << "Prefetch thread terminates."; } bool DoubleBufferReader::HasNext() const { diff --git a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc index 70e2f587dc414a850ddc341b98f26ae54636755c..3a1f3805a0483c2f5eabdc7432556051d8308964 100644 --- a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc +++ b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc @@ -34,6 +34,9 @@ class ShuffleReader : public framework::DecoratedReader { } void ReadNext(std::vector* out) override { + if (!HasNext()) { + PADDLE_THROW("There is no next data!"); + } if (iteration_pos_ >= buffer_.size()) { VLOG(10) << "Resetting shuffle buffer"; ReadIntoBuffers(); @@ -50,7 +53,6 @@ class ShuffleReader : public framework::DecoratedReader { buffer_.clear(); buffer_.reserve(buffer_size_); iteration_pos_ = 0; - PADDLE_ENFORCE(reader_->HasNext()); for (size_t i = 0; i < buffer_size_; ++i) { if (!reader_->HasNext()) { break; diff --git a/paddle/fluid/operators/sequence_expand_op.cc b/paddle/fluid/operators/sequence_expand_op.cc index a5d84d629b2e50763dac9bc571ac490414a8a406..786fe63e7580ce16b946d5049a490eed2c3c6ced 100644 --- a/paddle/fluid/operators/sequence_expand_op.cc +++ b/paddle/fluid/operators/sequence_expand_op.cc @@ -17,7 +17,7 @@ limitations under the License. */ namespace paddle { namespace operators { -using framework::Tensor; +using framework::LoDTensor; class SequenceExpandOp : public framework::OperatorWithKernel { public: @@ -25,15 +25,71 @@ class SequenceExpandOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X")); - PADDLE_ENFORCE(ctx->HasOutput("Out")); - PADDLE_ENFORCE(ctx->HasInput("Y")); - framework::DDim out_dim; - auto y_dim = ctx->GetInputDim("Y"); - out_dim = ctx->GetInputDim("X"); - out_dim[0] = y_dim[0]; - ctx->ShareLoD("Y", "Out"); - ctx->SetOutputDim("Out", out_dim); + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceExpandOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), + "Input(Y) of SequenceExpandOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SequenceExpandOp should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + auto out_dims = x_dims; + int ref_level = ctx->Attrs().Get("ref_level"); + + PADDLE_ENFORCE_GE(x_dims.size(), 2, + "Dimension number of Input(X) should be at least 2."); + + if (ctx->IsRuntime()) { + framework::Variable* x_var = + boost::get(ctx->GetInputVarPtrs("X")[0]); + framework::Variable* y_var = + boost::get(ctx->GetInputVarPtrs("Y")[0]); + + auto& x_lod = x_var->Get().lod(); + auto& y_lod = y_var->Get().lod(); + + PADDLE_ENFORCE_LE(x_lod.size(), 1, + "Level number of Input(X)'s lod should not be " + "greater than 1."); + PADDLE_ENFORCE_GT(y_lod.size(), 0, + "Level number of Input(Y)'s lod should be " + "greater than 0."); + PADDLE_ENFORCE( + ref_level == -1 || + (ref_level >= 0 && ref_level < static_cast(y_lod.size())), + "Invlid `ref_level`, which should be either equal to -1 " + "or in [0, %d)", + y_lod.size()); + + if (ref_level == -1) ref_level = y_lod.size() - 1; + + if (x_lod.size() > 0) { + PADDLE_ENFORCE(x_lod[0].size() == y_lod[ref_level].size(), + "Level number of Input(X)'s lod could be 0. Otherwise " + "size of Input(X)'s first level lod should be equal to " + "size of Input(Y)'s referred level lod."); + } + + int64_t out_first_dim = 0; + if (y_lod[ref_level].size() <= 1) { + out_first_dim = x_dims[0]; + } else { + for (size_t i = 1; i < y_lod[ref_level].size(); ++i) { + int x_seq_len = 1; + if (x_lod.size() == 1) { + x_seq_len = x_lod[0][i] - x_lod[0][i - 1]; + } + out_first_dim += + (y_lod[ref_level][i] - y_lod[ref_level][i - 1]) * x_seq_len; + } + } + out_dims[0] = out_first_dim; + ctx->SetOutputDim("Out", out_dims); + } else { + out_dims[0] = -1; + ctx->SetOutputDim("Out", out_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } } }; @@ -42,83 +98,81 @@ class SequenceExpandOpMaker : public framework::OpProtoAndCheckerMaker { SequenceExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", - "(Tensor or LoDTensor) The input(X) of this operator can be a " - "LoDTensor or a base Tensor."); + "(LoDTensor, default LoDTensor) A 2-D LoDTensor whose lod " + "level is at most 1."); AddInput("Y", - "(LoDTensor)The reference input(Y) of sequence_expand op." - "It must be a LoDTensor with k-level(k>0)." - "The input(X) will be expanded according to LOD of input(Y)." - "The element numbers of last level in input(Y) " - "must be equal to dims[0] of input(X)."); + "(LoDTensor, default LoDTensor) Referred LoDTensor whose " + "lod (specified level) is referred by Input(X)."); AddOutput("Out", - "(LodTensor)The output of sequence_expand op." - "The lod of output will be as same as input(Y)'s lod."); + "(LodTensor, default LoDTensor) Output LoDTensor which is " + "generated from Input(X) by referring lod of Input(Y)."); + AddAttr("ref_level", "Specify lod level of Input(Y).").SetDefault(-1); AddComment(R"DOC( Sequence Expand Operator. -This operator expands input(X) according to LOD of input(Y). +This operator expands `X` according to specified level lod of `Y`. Current +implementation constaints that lod level of `X` should be at most 1. Attribute +`ref_level` is used to specify which level lod of `Y` is referred to expand `X`. +If set `ref_level` to -1, then last level lod of `Y` would be referred. +Please note, rank of `X` should be at least 2, when the rank exceeds 2, `X` +would be viewed as a 2-D tensor. + Following are cases to better explain how this works: + Case 1: -Given a 2-level LoDTensor input(X) - X.lod = [[0, 2, 3], - [0, 1, 3, 4]] - X.data = [a, b, c, d] +Given a 1-level LoDTensor input(X) + X.lod = [[0, 2, 4]] + X.data = [[a], [b], [c], [d]] X.dims = [4, 1] and input(Y) Y.lod = [[0, 2, 4], [0, 3, 6, 7, 8]] -with condition len(Y.lod[-1]) -1 == X.dims[0] -then we get 2-level LoDTensor - Out.lod = [[0, 2, 4], - [0, 3, 6, 7, 8]] - Out.data = [a, a, a, b, b, b, c, d] +ref_level: 0 +then we get 1-level LoDTensor + Out.lod = [[0, 2, 4, 6, 8]] + Out.data = [[a], [b], [a], [b], [c], [d], [c], [d]] Out.dims = [8, 1] Case 2: +Given 1-level LoDTensor input(X) + X.lod = [[0, 1, 4]] + X.data = [[a], [b], [c], [d]] + X.dims = [4, 1] +and input(Y) + Y.lod = [[0, 2, 4], + [0, 3, 6, 6, 8]] +ref_level: 0 +then we get 1-level LoDTensor + Out.lod = [[0, 1, 2, 5, 8]] + Out.data = [[a], [a], [b], [c], [d], [b], [c], [d]] + Out.dims = [8, 1] + +Case 3: + Given a common Tensor input(X) - X.data = [a, b, c] + X.data = [[a], [b], [c]] X.dims = [3, 1] and input(Y) Y.lod = [[0, 2, 3, 6]] -with condition len(Y.lod[-1]) -1 == X.dims[0] -then we get 1-level LoDTensor - Out.lod = [[0, 2, 3, 6]] - Out.data = [a, a, b, c, c, c] +ref_level: -1 +then we get a common Tensor + Out.data = [[a], [a], [b], [c], [c], [c]] Out.dims = [6, 1] -Case 3: +Case 4: Given a common Tensor input(X) X.data = [[a, b], [c, d], [e, f]] X.dims = [3, 2] and input(Y) Y.lod = [[0, 2, 3, 6]] -with condition len(Y.lod[-1]) -1 == X.dims[0] -then we get 1-level LoDTensor - Out.lod = [[0, 2, 3, 6]] - Out.data = [[a,b], [a,b] [c,d], [e, f], [e, f], [e, f]] +ref_level: 0 +then we get a common LoDTensor + Out.data = [[a, b], [a, b] [c, d], [e, f], [e, f], [e, f]] Out.dims = [6, 2] -Case 4: - -Given 2-level a LoDTensor input(X) - X.lod = [[0, 2, 3], - [0, 1, 3, 4]] - X.data = [a, b, c, d] - X.dims = [4, 1] -and input(Y) - Y.lod = [[0, 2, 4], - [0, 3, 6, 6, 8]] -with condition len(Y.lod[-1]) -1 == X.dims[0] -then we get 2-level LoDTensor - Out.lod = [[0, 2, 4], - [0, 3, 6, 6, 8]] - Out.data = [a, a, a, b, b, b, d, d] - Out.dims = [8, 1] - - )DOC"); } }; @@ -129,12 +183,14 @@ class SequenceExpandOpGrad : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X")); - PADDLE_ENFORCE(ctx->HasInput("Out")); + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should not be null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), - "The input(Out@GRAD) should not be null"); + "Input(Out@GRAD) should not be null."); + auto x_dims = ctx->GetInputDim("X"); auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { ctx->SetOutputDim(x_grad_name, x_dims); } @@ -149,7 +205,13 @@ REGISTER_OP(sequence_expand, ops::SequenceExpandOp, ops::SequenceExpandOpMaker, sequence_expand_grad, ops::SequenceExpandOpGrad); REGISTER_OP_CPU_KERNEL( sequence_expand, - ops::SequenceExpandKernel); + ops::SequenceExpandKernel, + ops::SequenceExpandKernel, + ops::SequenceExpandKernel, + ops::SequenceExpandKernel); REGISTER_OP_CPU_KERNEL( sequence_expand_grad, - ops::SequenceExpandGradKernel); + ops::SequenceExpandGradKernel, + ops::SequenceExpandGradKernel, + ops::SequenceExpandGradKernel, + ops::SequenceExpandGradKernel); diff --git a/paddle/fluid/operators/sequence_expand_op.cu b/paddle/fluid/operators/sequence_expand_op.cu index 26622d23afa1c703e237628bcb11db8f1da73210..bb51bb2902eea797de3449fcb6c8b52b4f0e7fbf 100644 --- a/paddle/fluid/operators/sequence_expand_op.cu +++ b/paddle/fluid/operators/sequence_expand_op.cu @@ -18,7 +18,14 @@ limitations under the License. */ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( sequence_expand, - ops::SequenceExpandKernel); + ops::SequenceExpandKernel, + ops::SequenceExpandKernel, + ops::SequenceExpandKernel, + ops::SequenceExpandKernel); REGISTER_OP_CUDA_KERNEL( sequence_expand_grad, - ops::SequenceExpandGradKernel); + ops::SequenceExpandGradKernel, + ops::SequenceExpandGradKernel, + ops::SequenceExpandGradKernel, + ops::SequenceExpandGradKernel); diff --git a/paddle/fluid/operators/sequence_expand_op.h b/paddle/fluid/operators/sequence_expand_op.h index 76dde976db2d19e307ae7406be8280f9b4987187..db7d8bd6821fabd9714a160970558291ec47197f 100644 --- a/paddle/fluid/operators/sequence_expand_op.h +++ b/paddle/fluid/operators/sequence_expand_op.h @@ -16,45 +16,75 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/memcpy.h" -#include "unsupported/Eigen/CXX11/Tensor" +#include "paddle/fluid/operators/math/math_function.h" namespace paddle { namespace operators { using LoDTensor = framework::LoDTensor; +template +using EigenMatrix = framework::EigenMatrix; template class SequenceExpandKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* x = context.Input("X"); - auto* out = context.Output("Out"); - const T* x_data = x->data(); - auto x_dims = x->dims(); auto* y = context.Input("Y"); - PADDLE_ENFORCE(!y->lod().empty(), "y should have lod"); - PADDLE_ENFORCE_EQ(static_cast(x_dims[0]), - y->lod().back().size() - 1, - "The size of last lod level in Input(Y)" - "must be equal to dims[0] of Input(X)."); - out->set_lod(y->lod()); - auto* place = - context.template device_context().eigen_device(); - size_t element_len = framework::product(x_dims) / x_dims[0]; - T* out_data = out->mutable_data(context.GetPlace()); - auto out_starts = out->lod().back(); - - for (size_t i = 0; i < out_starts.size() - 1; i++) { - int scale = out_starts[i + 1] - out_starts[i]; - Eigen::TensorMap< - Eigen::Tensor> - x_t(x_data, 1, element_len); - Eigen::TensorMap> - out_t(out_data, scale, element_len); - Eigen::array cast({{scale, 1}}); - out_t.device(*place) = x_t.broadcast(cast); - x_data += element_len; - out_data += element_len * scale; + auto* out = context.Output("Out"); + + int ref_level = context.Attr("ref_level"); + auto& x_lod = x->lod(); + auto& y_lod = y->lod(); + + if (ref_level == -1) ref_level = y_lod.size() - 1; + + out->mutable_data(context.GetPlace()); + + if (y_lod[ref_level].size() <= 1) { + framework::TensorCopy(*x, context.GetPlace(), out); + return; + } + + auto& out_lod = *out->mutable_lod(); + if (x_lod.size() == 1) { + out_lod.resize(1); + out_lod[0] = {0}; + } + + int out_offset = 0; + auto& eigen_place = + *context.template device_context().eigen_device(); + for (size_t i = 1; i < y_lod[ref_level].size(); ++i) { + int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1]; + int x_start = i - 1; + int x_end = i; + if (x_lod.size() == 1) { + x_start = x_lod[0][i - 1]; + x_end = x_lod[0][i]; + } + int x_seq_len = x_end - x_start; + if (repeat_num > 0) { + auto x_sub_tensor = x->Slice(x_start, x_end); + x_sub_tensor.Resize({1, x_sub_tensor.numel()}); + int out_start = out_offset; + if (x_lod.size() == 1) { + out_start = out_lod[0][out_offset]; + } + auto out_sub_tensor = + out->Slice(out_start, out_start + x_seq_len * repeat_num); + out_sub_tensor.Resize({repeat_num, x_sub_tensor.dims()[1]}); + EigenMatrix::From(out_sub_tensor).device(eigen_place) = + EigenMatrix::From(x_sub_tensor) + .broadcast(Eigen::array({{repeat_num, 1}})); + } + for (int j = 0; j < repeat_num; ++j) { + if (x_lod.size() == 1) { + out_lod[0].push_back(out_lod[0].back() + x_seq_len); + } + out_offset++; + } } } }; @@ -75,27 +105,51 @@ template class SequenceExpandGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* d_out = context.Input(framework::GradVarName("Out")); + auto* g_out = context.Input(framework::GradVarName("Out")); auto* x = context.Input("X"); - auto* out = context.Input("Out"); - auto* d_x = context.Output(framework::GradVarName("X")); - auto out_last_level = out->lod().back(); - d_x->set_lod(x->lod()); - const T* d_out_data = d_out->data(); - T* d_x_data = d_x->mutable_data(context.GetPlace()); - size_t element_len = d_out->numel() / d_out->dims()[0]; - for (size_t i = 0; i < out_last_level.size() - 1; ++i) { - size_t repeat = out_last_level[i + 1] - out_last_level[i]; - Eigen::TensorMap< - Eigen::Tensor> - d_out_t(d_out_data, static_cast(repeat), element_len); - Eigen::TensorMap> - d_x_t(d_x_data, static_cast(element_len)); - auto place = - context.template device_context().eigen_device(); - d_x_t.device(*place) = d_out_t.sum(Eigen::array({{0}})); - d_out_data += (repeat * element_len); - d_x_data += element_len; + auto* y = context.Input("Y"); + auto* g_x = context.Output(framework::GradVarName("X")); + int ref_level = context.Attr("ref_level"); + + g_x->mutable_data(context.GetPlace()); + g_x->set_lod(x->lod()); + + auto& x_lod = x->lod(); + auto& y_lod = y->lod(); + + if (ref_level == -1) ref_level = y_lod.size() - 1; + + // just copy the gradient + if (y_lod[ref_level].size() <= 1) { + framework::TensorCopy(*g_out, context.GetPlace(), g_x); + return; + } + + auto& dev_ctx = context.template device_context(); + + math::SetConstant set_zero; + set_zero(dev_ctx, g_x, static_cast(0)); + + int g_out_offset = 0; + for (size_t i = 1; i < y_lod[ref_level].size(); ++i) { + int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1]; + if (repeat_num > 0) { + int x_start = i - 1; + int x_end = i; + if (x_lod.size() == 1) { + x_start = x_lod[0][i - 1]; + x_end = x_lod[0][i]; + } + int x_seq_len = x_end - x_start; + auto g_x_sub = g_x->Slice(x_start, x_end); + g_x_sub.Resize(flatten_to_1d(g_x_sub.dims())); + int g_out_end = g_out_offset + repeat_num * x_seq_len; + auto g_out_sub = g_out->Slice(g_out_offset, g_out_end); + g_out_sub.Resize({repeat_num, g_x_sub.dims()[0]}); + math::ColwiseSum col_sum; + col_sum(dev_ctx, g_out_sub, &g_x_sub); + g_out_offset += repeat_num * x_seq_len; + } } } }; diff --git a/paddle/fluid/platform/float16.h b/paddle/fluid/platform/float16.h index 52fb8c2531357ad7a2b2f8613e5c7fbcef52c6bb..d3312a47f479160439d720dd993ee25a56d732fe 100644 --- a/paddle/fluid/platform/float16.h +++ b/paddle/fluid/platform/float16.h @@ -483,8 +483,123 @@ DEVICE inline bool operator>=(const half& a, const half& b) { #endif // PADDLE_CUDA_FP16 -// Arithmetic operators on ARMv8.2-A CPU -#if defined(PADDLE_WITH_NATIVE_FP16) +// Arithmetic operators for float16 on GPU +#if defined(PADDLE_CUDA_FP16) +HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return float16(__hadd(half(a), half(b))); +#else + return float16(float(a) + float(b)); +#endif +} + +HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return float16(__hsub(half(a), half(b))); +#else + return float16(float(a) - float(b)); +#endif +} + +HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return float16(__hmul(half(a), half(b))); +#else + return float16(float(a) * float(b)); +#endif +} + +HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300 + // TODO(kexinzhao): check which cuda version starts to support __hdiv + float num = __half2float(half(a)); + float denom = __half2float(half(b)); + return float16(num / denom); +#else + return float16(float(a) / float(b)); +#endif +} + +HOSTDEVICE inline float16 operator-(const float16& a) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return float16(__hneg(half(a))); +#else + float16 res; + res.x = a.x ^ 0x8000; + return res; +#endif +} + +HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) { + a = a + b; + return a; +} + +HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) { + a = a - b; + return a; +} + +HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) { + a = a * b; + return a; +} + +HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) { + a = a / b; + return a; +} + +HOSTDEVICE inline bool operator==(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __heq(half(a), half(b)); +#else + return float(a) == float(b); +#endif +} + +HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hne(half(a), half(b)); +#else + return float(a) != float(b); +#endif +} + +HOSTDEVICE inline bool operator<(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hlt(half(a), half(b)); +#else + return float(a) < float(b); +#endif +} + +HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hle(half(a), half(b)); +#else + return float(a) <= float(b); +#endif +} + +HOSTDEVICE inline bool operator>(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hgt(half(a), half(b)); +#else + return float(a) > float(b); +#endif +} + +HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hge(half(a), half(b)); +#else + return float(a) >= float(b); +#endif +} + +// Arithmetic operators for float16 on ARMv8.2-A CPU +#elif defined(PADDLE_WITH_NATIVE_FP16) HOST inline float16 operator+(const float16& a, const float16& b) { float16 res; asm volatile( @@ -668,71 +783,71 @@ HOST inline bool operator>=(const float16& a, const float16& b) { return (res & 0xffff) != 0; } -// Arithmetic operators, software emulated on other CPU +// Arithmetic operators for float16, software emulated on other CPU #else -HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) { +HOST inline float16 operator+(const float16& a, const float16& b) { return float16(float(a) + float(b)); } -HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) { +HOST inline float16 operator-(const float16& a, const float16& b) { return float16(float(a) - float(b)); } -HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) { +HOST inline float16 operator*(const float16& a, const float16& b) { return float16(float(a) * float(b)); } -HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) { +HOST inline float16 operator/(const float16& a, const float16& b) { return float16(float(a) / float(b)); } -HOSTDEVICE inline float16 operator-(const float16& a) { +HOST inline float16 operator-(const float16& a) { float16 res; res.x = a.x ^ 0x8000; return res; } -HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) { +HOST inline float16& operator+=(float16& a, const float16& b) { a = float16(float(a) + float(b)); return a; } -HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) { +HOST inline float16& operator-=(float16& a, const float16& b) { a = float16(float(a) - float(b)); return a; } -HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) { +HOST inline float16& operator*=(float16& a, const float16& b) { a = float16(float(a) * float(b)); return a; } -HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) { +HOST inline float16& operator/=(float16& a, const float16& b) { a = float16(float(a) / float(b)); return a; } -HOSTDEVICE inline bool operator==(const float16& a, const float16& b) { +HOST inline bool operator==(const float16& a, const float16& b) { return float(a) == float(b); } -HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) { +HOST inline bool operator!=(const float16& a, const float16& b) { return float(a) != float(b); } -HOSTDEVICE inline bool operator<(const float16& a, const float16& b) { +HOST inline bool operator<(const float16& a, const float16& b) { return float(a) < float(b); } -HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) { +HOST inline bool operator<=(const float16& a, const float16& b) { return float(a) <= float(b); } -HOSTDEVICE inline bool operator>(const float16& a, const float16& b) { +HOST inline bool operator>(const float16& a, const float16& b) { return float(a) > float(b); } -HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) { +HOST inline bool operator>=(const float16& a, const float16& b) { return float(a) >= float(b); } #endif diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index 3d3a6c116eeb39fb7236d0e9707415cdd6b828bd..ad655ee96cee0744e7bedb17163faf7d8d1d8877 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -307,15 +307,57 @@ class DistributeTranspiler: # Iterate through the ops, and if an op and the optimize ops # which located on current pserver are in one set, then # append it into the sub program. - for _, op in enumerate(self.optimize_ops): - for _, opt_op in enumerate(opt_op_on_pserver): - if ufind.is_connected(op, opt_op): - if self._is_opt_op(op): - self._append_pserver_ops(optimize_block, op, endpoint, - default_main_program()) - else: - self._append_pserver_non_opt_ops(optimize_block, op) - break + + # We try to put optimization program run parallelly, assume + # optimization program always looks like: + # + # prevop -> prevop -> opt op -> following op -> following op; -> + # prevop -> prevop -> opt op -> following op -> following op; -> + # global op -> global op + # + # we put operators that can run parallelly to many program blocks. + # in above example, we seperate ops by the ";". Global ops must run + # after all the optimize ops finished. + + global_ops = [] + # HACK: optimization global ops only used to scale beta1 and beta2 + # replace it with dependency engine. + for op in self.optimize_ops: + if op.type == "scale": + for in_name in op.input_arg_names: + if in_name.startswith("beta1_pow_acc") or\ + in_name.startswith("beta2_pow_acc"): + global_ops.append(op) + + def __append_optimize_op__(op, block): + if self._is_opt_op(op): + self._append_pserver_ops(block, op, endpoint, + default_main_program()) + else: + self._append_pserver_non_opt_ops(block, op) + + # append op to the current block + per_opt_block = optimize_block + for _, opt_op in enumerate(opt_op_on_pserver): + for _, op in enumerate(self.optimize_ops): + # optimizer is connected to itself + if ufind.is_connected(op, opt_op) and \ + op not in global_ops: + __append_optimize_op__(op, per_opt_block) + per_opt_block = pserver_program.create_block(0) + + # append global ops + for glb_op in global_ops: + __append_optimize_op__(glb_op, per_opt_block) + + # NOT USED: single block version: + # + # for _, op in enumerate(self.optimize_ops): + # for _, opt_op in enumerate(opt_op_on_pserver): + # if ufind.is_connected(op, opt_op): + # __append_optimize_op__(glb_op, optimize_block) + # break + # step5 append the listen_and_serv op pserver_program.global_block().append_op( type="listen_and_serv", @@ -660,10 +702,22 @@ class DistributeTranspiler: # If one op's input is another op's output or # one op's output is another op's input, we say # the two operator is connected. - op1_input_names = op1.desc.input_arg_names() + def _append_inname_remove_beta(varname_list): + op_input_names = [] + for in_name in varname_list: + # HACK: remove beta1 and beta2 to avoid let all + # ops connected. + if in_name.startswith("beta2_pow_acc") or \ + in_name.startswith("beta1_pow_acc"): + continue + else: + op_input_names.append(in_name) + return op_input_names + + op1_input_names = _append_inname_remove_beta(op1.desc.input_arg_names()) op1_output_names = op1.desc.output_arg_names() - op2_input_names = op2.desc.input_arg_names() + op2_input_names = _append_inname_remove_beta(op2.desc.input_arg_names()) op2_output_names = op2.desc.output_arg_names() if set(op1_output_names) & set(op2_input_names) or \ diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 9656dcf94f14ad9250bb7e79c1330c9bdd44d9d6..75d3d895081e29e25fd5cf29d19e4b8459035ffb 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -1809,52 +1809,52 @@ def conv2d_transpose(input, return out -def sequence_expand(x, y, name=None): +def sequence_expand(x, y, ref_level=-1, name=None): """Sequence Expand Layer. This layer will expand the input variable **x** - according to LoD information of **y**. And the following examples will - explain how sequence_expand works: + according to specified level lod of **y**. Please note that lod level of + **x** is at most 1 and rank of **x** is at least 2. When rank of **x** + is greater than 2, then it would be viewed as a 2-D tensor. + Following examples will explain how sequence_expand works: .. code-block:: text * Case 1 x is a LoDTensor: - x.lod = [[0, 2, 3], - [0, 1, 3, 4]] - x.data = [a, b, c, d] + x.lod = [[0, 2, 4]] + x.data = [[a], [b], [c], [d]] x.dims = [4, 1] y is a LoDTensor: y.lod = [[0, 2, 4], [0, 3, 6, 7, 8]] - with condition len(y.lod[-1]) - 1 == x.dims[0] + ref_level: 0 - then output is a 2-level LoDTensor: - out.lod = [[0, 2, 4], - [0, 3, 6, 7, 8]] - out.data = [a, a, a, b, b, b, c, d] + then output is a 1-level LoDTensor: + out.lod = [[0, 2, 4, 6, 8]] + out.data = [[a], [b], [a], [b], [c], [d], [c], [d]] out.dims = [8, 1] * Case 2 x is a Tensor: - x.data = [a, b, c] + x.data = [[a], [b], [c]] x.dims = [3, 1] y is a LoDTensor: - y.lod = [[0, 2, 3, 6]] + y.lod = [[0, 2, 2, 5]] - with condition len(y.lod[-1]) - 1 == x.dims[0] - - then output is a 1-level LoDTensor: - out.lod = [[0, 2, 3, 6]] - out.data = [a, a, b, c, c, c] - out.dims = [6, 1] + ref_level: -1 + then output is a Tensor: + out.data = [[a], [a], [c], [c], [c]] + out.dims = [5, 1] Args: x (Variable): The input variable which is a Tensor or LoDTensor. y (Variable): The input variable which is a LoDTensor. + ref_level (int): Lod level of `y` to be referred by `x`. If set to -1, + refer the last level of lod. name(str|None): A name for this layer(optional). If set None, the layer - will be named automatically. + will be named automatically. Returns: Variable: The expanded variable which is a LoDTensor. @@ -1865,14 +1865,17 @@ def sequence_expand(x, y, name=None): x = fluid.layers.data(name='x', shape=[10], dtype='float32') y = fluid.layers.data(name='y', shape=[10, 20], dtype='float32', lod_level=1) - out = layers.sequence_expand(x=x, y=y) + out = layers.sequence_expand(x=x, y=y, ref_level=0) """ helper = LayerHelper('sequence_expand', input=x, **locals()) dtype = helper.input_dtype() tmp = helper.create_tmp_variable(dtype) helper.append_op( - type='sequence_expand', inputs={'X': x, - 'Y': y}, outputs={'Out': tmp}) + type='sequence_expand', + inputs={'X': x, + 'Y': y}, + outputs={'Out': tmp}, + attrs={'ref_level': ref_level}) return tmp diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 4993fe39e0cdda6908166b51682cd2b58a90ffac..badac5ca5e13dd77edd3c70fa6957cba1cfe903d 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -27,7 +27,7 @@ from contextlib import contextmanager __all__ = [ 'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', - 'ModelAverage' + 'Adadelta', 'ModelAverage' ] @@ -591,6 +591,88 @@ class DecayedAdagradOptimizer(Optimizer): return decayed_adagrad_op +class AdadeltaOptimizer(Optimizer): + """ + **Adadelta Optimizer** + Simple Adadelta optimizer with average squared grad state and + average squared update state. + The details of adadelta please refer to this + `ADADELTA: AN ADAPTIVE LEARNING RATE METHOD + `_. + + .. math:: + + E(g_t^2) &= \\rho * E(g_{t-1}^2) + (1-\\rho) * g^2 \\\\ + learning\\_rate &= sqrt( ( E(dx_{t-1}^2) + \\epsilon ) / ( \\ + E(g_t^2) + \\epsilon ) ) \\\\ + E(dx_t^2) &= \\rho * E(dx_{t-1}^2) + (1-\\rho) * (-g*learning\\_rate)^2 + + Args: + learning_rate(float): global leraning rate + rho(float): rho in equation + epsilon(float): epsilon in equation + + Examples: + .. code-block:: python + + optimizer = fluid.optimizer.Adadelta( + learning_rate=0.0003, epsilon=1.0e-6, rho=0.95) + _, params_grads = optimizer.minimize(cost) + """ + + _avg_squared_grad_acc_str = "_avg_squared_grad" + _avg_squared_update_acc_str = "_avg_squared_update" + + def __init__(self, learning_rate, epsilon=1.0e-6, rho=0.95, **kwargs): + if learning_rate is None: + raise ValueError("learning_rate is not set.") + if epsilon is None: + raise ValueError("epsilon is not set.") + if rho is None: + raise ValueError("rho is not set.") + super(AdadeltaOptimizer, self).__init__( + learning_rate=learning_rate, **kwargs) + self.type = "adadelta" + self._epsilon = epsilon + self._rho = rho + + def _create_accumulators(self, block, parameters): + if not isinstance(block, framework.Block): + raise TypeError("block is not instance of framework.Block.") + + for p in parameters: + self._add_accumulator(self._avg_squared_grad_acc_str, p) + self._add_accumulator(self._avg_squared_update_acc_str, p) + + def _append_optimize_op(self, block, param_and_grad): + if not isinstance(block, framework.Block): + raise TypeError("block is not instance of framework.Block.") + + avg_squared_grad_acc = self._get_accumulator( + self._avg_squared_grad_acc_str, param_and_grad[0]) + avg_squared_update_acc = self._get_accumulator( + self._avg_squared_update_acc_str, param_and_grad[0]) + + # Create the adadelta optimizer op + adadelta_op = block.append_op( + type=self.type, + inputs={ + "Param": param_and_grad[0], + "Grad": param_and_grad[1], + "AvgSquaredGrad": avg_squared_grad_acc, + "AvgSquaredUpdate": avg_squared_update_acc + }, + outputs={ + "ParamOut": param_and_grad[0], + "AvgSquaredGradOut": avg_squared_grad_acc, + "AvgSquaredUpdateOut": avg_squared_update_acc + }, + attrs={"epsilon": self._epsilon, + "rho": self._rho}) + + return adadelta_op + + # We short the class name, since users will use the optimizer with the package # name. The sample code: # @@ -605,6 +687,7 @@ Adagrad = AdagradOptimizer Adam = AdamOptimizer Adamax = AdamaxOptimizer DecayedAdagrad = DecayedAdagradOptimizer +Adadelta = AdadeltaOptimizer class ModelAverage(Optimizer): diff --git a/python/paddle/fluid/tests/book/test_machine_translation.py b/python/paddle/fluid/tests/book/test_machine_translation.py index fa38bd3762423497b82c3b421b3a1db4cd87525b..3a1a0859ecfd4ac5337e2112f8b22e32d8474f22 100644 --- a/python/paddle/fluid/tests/book/test_machine_translation.py +++ b/python/paddle/fluid/tests/book/test_machine_translation.py @@ -118,12 +118,12 @@ def decoder_decode(context, is_sparse): is_sparse=is_sparse) # use rnn unit to update rnn - current_state = pd.fc(input=[pre_ids_emb, pre_state_expanded], + current_state = pd.fc(input=[pre_state_expanded, pre_ids_emb], size=decoder_size, act='tanh') - + current_state_with_lod = pd.lod_reset(x=current_state, y=pre_score) # use score to do beam search - current_score = pd.fc(input=current_state, + current_score = pd.fc(input=current_state_with_lod, size=target_dict_dim, act='softmax') topk_scores, topk_indices = pd.topk(current_score, k=50) diff --git a/python/paddle/fluid/tests/unittests/test_dropout_op.py b/python/paddle/fluid/tests/unittests/test_dropout_op.py index 60930a612c128cbf18e89711b9246d148e41ec58..eaa3435a86462236a99489749abe877648677053 100644 --- a/python/paddle/fluid/tests/unittests/test_dropout_op.py +++ b/python/paddle/fluid/tests/unittests/test_dropout_op.py @@ -14,6 +14,7 @@ import unittest import numpy as np +import paddle.fluid.core as core from op_test import OpTest @@ -82,5 +83,37 @@ class TestDropoutOp5(OpTest): self.check_output() +class TestFP16DropoutOp(OpTest): + def setUp(self): + self.op_type = "dropout" + self.init_test_case() + + x = np.random.random(self.input_size).astype("float16") + out = x * (1.0 - self.prob) + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.attrs = { + 'dropout_prob': self.prob, + 'fix_seed': self.fix_seed, + 'is_test': True + } + self.outputs = {'Out': out} + + def init_test_case(self): + self.input_size = [32, 64] + self.prob = 0.35 + self.fix_seed = True + + def test_check_output(self): + if core.is_compiled_with_cuda() and core.op_support_gpu("dropout"): + self.check_output_with_place(core.CUDAPlace(0), atol=1e-3) + + +class TestFP16DropoutOp2(TestFP16DropoutOp): + def init_test_case(self): + self.input_size = [32, 64, 3] + self.prob = 0.75 + self.fix_seed = False + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 744a762ae7664f1f28713c505f9112ba712fd41d..b5fd59cf3a1bea50b799c3ace8f3b9cea088b9d5 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -181,8 +181,8 @@ class TestBook(unittest.TestCase): with program_guard(program): x = layers.data(name='x', shape=[10], dtype='float32') y = layers.data( - name='y', shape=[10, 20], dtype='float32', lod_level=1) - self.assertIsNotNone(layers.sequence_expand(x=x, y=y)) + name='y', shape=[10, 20], dtype='float32', lod_level=2) + self.assertIsNotNone(layers.sequence_expand(x=x, y=y, ref_level=1)) print(str(program)) def test_lstm_unit(self): diff --git a/python/paddle/fluid/tests/unittests/test_sequence_expand.py b/python/paddle/fluid/tests/unittests/test_sequence_expand.py index 957fa5d2c4a795cfd01047c1b7845674e4c1d549..7feb509c4d6f5768552fc2515081f7e68f420967 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_expand.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_expand.py @@ -27,12 +27,36 @@ class TestSequenceExpand(OpTest): def compute(self): x = self.inputs['X'] x_data, x_lod = x if type(x) == tuple else (x, None) - n = 1 + x_data.shape[0] if not x_lod else len(x_lod[0]) y_data, y_lod = self.inputs['Y'] - repeats = [((y_lod[-1][i + 1] - y_lod[-1][i])) - for i in range(len(y_lod[-1]) - 1)] - out = x_data.repeat(repeats, axis=0) - self.outputs = {'Out': out} + + if hasattr(self, 'attrs'): + ref_level = self.attrs['ref_level'] + else: + ref_level = len(y_lod) - 1 + + out = np.zeros(shape=((0, ) + x_data.shape[1:]), dtype=x_data.dtype) + + if x_lod is None: + x_idx = [i for i in xrange(x_data.shape[0] + 1)] + else: + x_idx = x_lod[0] + out_lod = [[0]] + + for i in xrange(1, len(y_lod[ref_level])): + repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1] + x_len = x_idx[i] - x_idx[i - 1] + if repeat_num > 0: + x_sub = x_data[x_idx[i - 1]:x_idx[i], :] + x_sub = np.repeat(x_sub, repeat_num, axis=0) + out = np.vstack((out, x_sub)) + if x_lod is not None: + for j in xrange(repeat_num): + out_lod[0].append(out_lod[0][-1] + x_len) + + if x_lod is None: + self.outputs = {'Out': out} + else: + self.outputs = {'Out': (out, out_lod)} def setUp(self): self.op_type = 'sequence_expand' @@ -52,7 +76,8 @@ class TestSequenceExpandCase1(TestSequenceExpand): x_lod = [[0, 2, 5]] y_data = np.random.uniform(0.1, 1, [13, 1]).astype('float32') y_lod = [[0, 2, 5], [0, 2, 4, 7, 10, 13]] - self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} + self.inputs = {'X': x_data, 'Y': (y_data, y_lod)} + self.attrs = {'ref_level': 0} class TestSequenceExpandCase2(TestSequenceExpand): @@ -60,8 +85,9 @@ class TestSequenceExpandCase2(TestSequenceExpand): x_data = np.random.uniform(0.1, 1, [1, 2, 2]).astype('float32') x_lod = [[0, 1]] y_data = np.random.uniform(0.1, 1, [2, 2, 2]).astype('float32') - y_lod = [[0, 2]] + y_lod = [[0, 2], [0, 2]] self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} + self.attrs = {'ref_level': 0} class TestSequenceExpandCase3(TestSequenceExpand): @@ -75,14 +101,9 @@ class TestSequenceExpandCase3(TestSequenceExpand): class TestSequenceExpandCase4(TestSequenceExpand): def set_data(self): - x_data = np.array( - [0.1, 0.3, 0.2, 0.15, 0.25, 0.2, 0.15, 0.25, 0.1, 0.3]).reshape( - [2, 5]).astype('float32') - x_lod = [[ - 0, - 1, - 2, - ]] + data = [0.1, 0.3, 0.2, 0.15, 0.25, 0.2, 0.15, 0.25, 0.1, 0.3] + x_data = np.array(data).reshape([5, 2]).astype('float32') + x_lod = [[0, 2, 5]] y_data = np.random.uniform(0.1, 1, [2, 1]).astype('float32') y_lod = [[0, 1, 2], [0, 1, 2]] self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)}