diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index c7a5a49dd02d0db022fabff5c3ae1c7800bac25c..6697952051c4b1997ca6b550da17a52e64cb3454 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -64,7 +64,8 @@ class OpConverter { (*it)(op, scope, test_mode); } - // convert fluid block to tensorrt network + // Convert a fluid block to tensorrt network, NOTE it just convert operators, + // the INetwork's inputs and outputs should specified in some other modules. void ConvertBlock(const framework::proto::BlockDesc& block, const std::unordered_set& parameters, const framework::Scope& scope, TensorRTEngine* engine) { diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index b60f00de9fa5fc8f8f4537379bf9ee9c8bb6f31c..b06a9bbc6758ae9410b2fce99ef2b1a9e7ab98c0 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -51,11 +51,12 @@ class TensorRTEngine : public EngineBase { nvinfer1::Weights w_; }; - TensorRTEngine(int max_batch, int max_workspace, cudaStream_t* stream, + TensorRTEngine(int max_batch, int max_workspace, + cudaStream_t* stream = nullptr, nvinfer1::ILogger& logger = NaiveLogger::Global()) : max_batch_(max_batch), max_workspace_(max_workspace), - stream_(stream), + stream_(stream ? stream : &default_stream_), logger_(logger) {} virtual ~TensorRTEngine(); @@ -121,6 +122,8 @@ class TensorRTEngine : public EngineBase { // the max memory size the engine uses int max_workspace_; cudaStream_t* stream_; + // If stream_ is not set from outside, hold its own stream. + cudaStream_t default_stream_; nvinfer1::ILogger& logger_; std::vector buffers_; @@ -165,20 +168,31 @@ class TensorRTEngine : public EngineBase { */ class TRT_EngineManager { public: - TensorRTEngine* Create(int max_batch, int max_workspace, - cudaStream_t* stream) { - engines_.emplace_back(new TensorRTEngine(max_batch, max_workspace, stream)); - return engines_.back().get(); + bool HasEngine(const std::string& name) const { + return engines_.count(name) != 0; + } + + // Get an engine called `name`. + TensorRTEngine* Get(const std::string& name) const { + return engines_.at(name).get(); + } + + // Create or get an engine called `name` + TensorRTEngine* Create(int max_batch, int max_workspace, cudaStream_t* stream, + const std::string& name) { + auto* p = new TensorRTEngine(max_batch, max_workspace, stream); + engines_[name].reset(p); + return p; } void DeleteALl() { - for (auto& ptr : engines_) { - ptr.reset(nullptr); + for (auto& item : engines_) { + item.second.reset(nullptr); } } private: - std::vector> engines_; + std::unordered_map> engines_; }; } // namespace tensorrt diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 790c012fdbe5764bb9fe374dd2fcbb44e6522c98..59f6e38d01d3c0c4348adf65f7219c18513457d8 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -252,15 +252,14 @@ class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "Output of Softshrink operator"); AddAttr("lambda", "non-negative offset").SetDefault(0.5f); AddComment(R"DOC( -Softshrink Activation Operator. +:strong:`Softshrink Activation Operator` -$$ -out = \begin{cases} - x - \lambda, \text{if } x > \lambda \\ - x + \lambda, \text{if } x < -\lambda \\ - 0, \text{otherwise} - \end{cases} -$$ +.. math:: + out = \begin{cases} + x - \lambda, \text{if } x > \lambda \\ + x + \lambda, \text{if } x < -\lambda \\ + 0, \text{otherwise} + \end{cases} )DOC"); } @@ -271,18 +270,18 @@ class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddInput("X", "Input of HardShrink operator"); AddOutput("Out", "Output of HardShrink operator"); - AddAttr("threshold", "The value of threshold for HardShrink") + AddAttr("threshold", + "The value of threshold for HardShrink. [default: 0.5]") .SetDefault(0.5f); AddComment(R"DOC( -HardShrink Activation Operator. +:strong:`HardShrink activation operator` -$$ -out = \begin{cases} - x, \text{if } x > \lambda \\ - x, \text{if } x < -\lambda \\ - 0, \text{otherwise} - \end{cases} -$$ +.. math:: + out = \begin{cases} + x, \text{if } x > \lambda \\ + x, \text{if } x < -\lambda \\ + 0, \text{otherwise} + \end{cases} )DOC"); } @@ -394,18 +393,18 @@ class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddInput("X", "Input of ThresholdedRelu operator"); AddOutput("Out", "Output of ThresholdedRelu operator"); - AddAttr("threshold", "The threshold location of activation") + AddAttr("threshold", + "The threshold location of activation. [default 1.0].") .SetDefault(1.0f); AddComment(R"DOC( -ThresholdedRelu Activation Operator. +:strong:`ThresholdedRelu activation operator` -$$ -out = \begin{cases} - x, \text{if } x > threshold \\ - 0, \text{otherwise} - \end{cases} -$$ +.. math:: + out = \begin{cases} + x, \text{if } x > threshold \\ + 0, \text{otherwise} + \end{cases} )DOC"); } }; diff --git a/paddle/fluid/operators/compare_op.cc b/paddle/fluid/operators/compare_op.cc index 3a4819f3dec9704a4a7c8910dd22e80fda082335..f40b1ba338d429c248103eeb930ac7e1bb690218 100644 --- a/paddle/fluid/operators/compare_op.cc +++ b/paddle/fluid/operators/compare_op.cc @@ -23,30 +23,26 @@ class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { OpComment comment; - AddInput("X", - string::Sprintf("(LoDTensor) the left hand operand of %s operator", - comment.type)); - AddInput("Y", string::Sprintf( - "(LoDTensor) the right hand operand of %s operator", - comment.type)); + AddInput("X", string::Sprintf("the left hand operand of %s operator", + comment.type)); + AddInput("Y", string::Sprintf("the right hand operand of %s operator", + comment.type)); AddAttr("force_cpu", - "(bool, default false) Force fill output variable to cpu " + "Force fill output variable to cpu " "memory. Otherwise, fill output variable to the running " - "device") - .SetDefault(false); - AddOutput("Out", string::Sprintf( - "(LoDTensor) n-dim bool tensor. Each element is %s", - comment.equation)); - AddComment(string::Sprintf(R"DOC(%s Operator - + "device [default true].") + .SetDefault(true); + AddOutput("Out", string::Sprintf("n-dim bool tensor. Each element is %s", + comment.equation)); + AddComment(string::Sprintf(R"DOC( It operates element-wise on X and Y, and returns the Out. Each of them is a N-dim tensor. X and Y could be any type. The each element of the Out tensor is -calculated by %s +calculated by $%s$ )DOC", - comment.type, comment.equation)); - AddAttr("axis", - "(int, default -1). The start dimension index " - "for broadcasting Y onto X.") + comment.equation)); + AddAttr( + "axis", + "The start dimension index for broadcasting Y onto X. [default -1]") .SetDefault(-1) .EqualGreaterThan(-1); } diff --git a/paddle/fluid/operators/cumsum_op.cc b/paddle/fluid/operators/cumsum_op.cc index 92bb835e8f18e17ae1355fdec29f43b8ffb70460..5302b822d6b9f232e9ccd0d03cc549d7d5044ebf 100644 --- a/paddle/fluid/operators/cumsum_op.cc +++ b/paddle/fluid/operators/cumsum_op.cc @@ -30,19 +30,19 @@ class CumOp : public framework::OperatorWithKernel { class CumsumOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("X", "Input of Cumsum operator"); - AddOutput("Out", "Output of Cumsum operator"); + AddInput("X", "Input of cumsum operator"); + AddOutput("Out", "Output of cumsum operator"); AddAttr("axis", - "(int, default -1). The dimenstion to accumulate along. " - "-1 means the last dimenstion") + "The dimenstion to accumulate along. -1 means the last " + "dimenstion [default -1].") .SetDefault(-1) .EqualGreaterThan(-1); AddAttr("exclusive", - "bool, default false). Whether to perform exclusive cumsum") + "Whether to perform exclusive cumsum. [default false].") .SetDefault(false); AddAttr("reverse", - "bool, default false). If true, the cumsum is performed in " - "the reversed direction") + "If true, the cumsum is performed in the reversed direction. " + "[default false].") .SetDefault(false); AddComment(R"DOC( The cumulative sum of the elements along a given axis. diff --git a/paddle/fluid/operators/layer_norm_op.cc b/paddle/fluid/operators/layer_norm_op.cc index ab097d31e9ab5eafa788539170e7e405df697625..14ce1da2e97186a50ed8bd52223a500c4c57b328 100644 --- a/paddle/fluid/operators/layer_norm_op.cc +++ b/paddle/fluid/operators/layer_norm_op.cc @@ -62,36 +62,33 @@ class LayerNormOp : public framework::OperatorWithKernel { class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("X", "(LoDTensor) The input tensor."); + AddInput("X", "The input tensor."); AddInput("Scale", - "(Tensor, optional) Scale is a 1-dimensional tensor of size " + "(optional) Scale is a 1-dimensional tensor of size " "H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])." "It is applied to the output.") .AsDispensable(); AddInput("Bias", - "(Tensor, optional) Bias is a 1-dimensional tensor of size " + "(optional) Bias is a 1-dimensional tensor of size " "H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])." "It is applied to the output.") .AsDispensable(); - AddOutput("Y", "(LoDTensor) Result after normalization."); - AddOutput("Mean", "(Tensor) Mean of the current mini batch.") - .AsIntermediate(); - AddOutput("Variance", "(Tensor) Variance of the current mini batch.") + AddOutput("Y", "Result after normalization."); + AddOutput("Mean", "Mean of the current mini batch.").AsIntermediate(); + AddOutput("Variance", "Variance of the current mini batch.") .AsIntermediate(); AddAttr("epsilon", - "(float, default 1e-5) Constant for " - "numerical stability") + "Constant for numerical stability [default 1e-5].") .SetDefault(1e-5) .AddCustomChecker([](const float &epsilon) { PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f, "'epsilon' should be between 0.0 and 0.001."); }); AddAttr("begin_norm_axis", - "(int default:1), the " - "axis of `begin_norm_axis ... Rank(X) - 1` will be " + "the axis of `begin_norm_axis ... Rank(X) - 1` will be " "normalized. `begin_norm_axis` splits the tensor(`X`) to a " - "matrix [N,H].") + "matrix [N,H]. [default 1].") .SetDefault(1) .AddCustomChecker([](const int &begin_norm_axis) { PADDLE_ENFORCE_GT(begin_norm_axis, 0, @@ -99,10 +96,14 @@ class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker { }); AddComment(R"DOC( -Layer Normalization. -Layer Norm has been implemented as discussed in the paper: -https://arxiv.org/abs/1607.06450 -... +Assume feature vectors exist on dimensions +:attr:`begin_norm_axis ... rank(input)` and calculate the moment statistics +along these dimensions for each feature vector :math:`a` with size +:math:`H`, then normalize each feature vector using the corresponding +statistics. After that, apply learnable gain and bias on the normalized +tensor to scale and shift if :attr:`scale` and :attr:`shift` are set. + +Refer to `Layer Normalization `_ )DOC"); } }; diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 4d12278799f66f2fb92b7580ba0c43e845aa4d3a..57c2ce457791d830e4230aa25e1c5b358f476782 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -348,7 +348,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { }; void SignalHandler::StopAndExit(int signal_num) { - VLOG(3) << "Catch interrupt signal: " << signal_num << ", program will exit"; + // Do not use VLOG here for the device for printing maybe already released. + // exit will release interal allocated resoureces. exit(0); } diff --git a/paddle/fluid/operators/mean_op.cc b/paddle/fluid/operators/mean_op.cc index 4881cff4a368ffae9b030f04b7fff01d6ee7d26e..9e0bebd17c02a3ce010b77142757b8789cfbcdd9 100644 --- a/paddle/fluid/operators/mean_op.cc +++ b/paddle/fluid/operators/mean_op.cc @@ -33,12 +33,10 @@ class MeanOp : public framework::OperatorWithKernel { class MeanOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("X", "The input of mean op"); - AddOutput("Out", "The output of mean op").Reuse("X"); + AddInput("X", "(Tensor) The input of mean op"); + AddOutput("Out", "(Tensor) The output of mean op").Reuse("X"); AddComment(R"DOC( -Mean Operator. - -Out is a scalar which is the mean of all elements in X. +Mean Operator calculates the mean of all elements in X. )DOC"); } diff --git a/paddle/fluid/operators/multiplex_op.cc b/paddle/fluid/operators/multiplex_op.cc index a4363fd25d57edb5c2509904a1f55634832613be..18ad46cb5eeeab2169136e40cebdaa53c0bfd587 100644 --- a/paddle/fluid/operators/multiplex_op.cc +++ b/paddle/fluid/operators/multiplex_op.cc @@ -62,26 +62,46 @@ class MultiplexOp : public framework::OperatorWithKernel { class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("Ids", "The index tensor of multiplex operator."); - AddInput("X", "The candidate tensors of multiplex operator.") + AddInput("Ids", + "Tensor, index variable which is a 2-D tensor with shape " + "[M, 1] where M is the batch size."); + AddInput("X", + "A list of variables to gather from. All variables have the same " + "shape and the rank is at least 2.") .AsDuplicable(); AddOutput("Out", "The output tensor of multiplex operator."); AddComment(R"DOC( -Multiplex Operator. - -Multiplex multiple tensors according to the index provided by the index tensor. - -Ids: the index tensor. -X[0 : N - 1]: the candidate tensors for output (N >= 2). -For each index i from 0 to batchSize - 1, the output is the i-th row of the +Referring to the given index variable, this layer selects rows from the +input variables to construct a multiplex variable. Assuming that there are +:math:`m` input variables and :math:`I_i` represents the i-th input +variable and :math:`i` is in [0, :math:`m`). All input variables are +tensors with same shape [:math:`d_0`, :math:`d_1`, ..., :math:`d_R`]. +Please note that rank of the input tensor should be at least 2. Each input +variable will be treated as a 2-D matrix with shape [:math:`M`, :math:`N`] +where :math:`M` for :math:`d_0` and :math:`N` for :math:`d_1` * :math:`d_2` +* ... * :math:`d_R`. Let :math:`I_i[j]` be the j-th row of the i-th input +variable. The given index variable should be a 2-D tensor with shape +[:math:`M`, 1]. Let `ID[i]` be the i-th index value of the index variable. +Then the output variable will be a tensor with shape [:math:`d_0`, +:math:`d_1`, ..., :math:`d_R`]. If we treat the output tensor as a 2-D +matrix with shape [:math:`M`, :math:`N`] and let :math:`O[i]` be the i-th +row of the matrix, then `O[i]` is equal to :math:`I_{ID[i]}[i]`. + +* Ids: the index tensor. + +* X[0 : N - 1]: the candidate tensors for output (N >= 2). + +* For each index i from 0 to batchSize - 1, the output is the i-th row of the the (Ids[i])-th tensor. For i-th row of the output tensor: -$$y[i] = x_{k}[i]$$ +$$ +y[i] = x_{k}[i] +$$ -where `y` is the output tensor, `x_{k}` is the k-th input tensor, -and `k = Ids[i]`. +where $y$ is the output tensor, $x_{k}$ is the k-th input tensor, +and $k = Ids[i]$. )DOC"); } diff --git a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc index 282ec3f36b98e7aa62d71fb04f72721a5464e21c..559827f08494af6730aafa1e67c46a47c21dedf6 100644 --- a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc +++ b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc @@ -78,11 +78,15 @@ class CreateRecordIOReaderOp : public framework::OperatorBase { class CreateRecordIOReaderOpMaker : public FileReaderMakerBase { protected: void Apply() override { - AddAttr("filename", "The filename of record io reader"); + AddAttr( + "filename", + "The filename of record file. This file will given to reader."); AddComment(R"DOC( - CreateRecordIOReader Operator +Open a recordio file and return the reader object. The returned reader object +is thread-safe. - Create a reader from a record io file +NOTE: This is a very low-level API. It is used for debugging data file or +training. Please use `open_files` instead of this API for production usage. )DOC"); } }; diff --git a/paddle/fluid/operators/reader/reader_op_registry.cc b/paddle/fluid/operators/reader/reader_op_registry.cc index 612e1f5eca3a4836db1fd167fc6bb63400d20177..e11256a49ffa6adc9410376cc8a71fa017df7e9c 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.cc +++ b/paddle/fluid/operators/reader/reader_op_registry.cc @@ -54,7 +54,7 @@ std::unique_ptr CreateReaderByFileName( } void FileReaderMakerBase::Make() { - AddOutput("Out", "(ReaderHolder) The created random reader.").AsDuplicable(); + AddOutput("Out", "(ReaderHolder): The created random reader.").AsDuplicable(); AddAttr>("shape_concat", "The concat of all data's shapes."); AddAttr>( "ranks", diff --git a/paddle/fluid/operators/row_conv_op.cc b/paddle/fluid/operators/row_conv_op.cc index 20f140f962c3aac364a1239a663d5f340bbeb6b2..10b1b0c899d833d70fa6afe51998fe210899e3c3 100644 --- a/paddle/fluid/operators/row_conv_op.cc +++ b/paddle/fluid/operators/row_conv_op.cc @@ -78,23 +78,23 @@ class RowConvOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", - "(LoDTensor), the input(X) is a LodTensor, which supports " + "the input(X) is a LodTensor, which supports " "variable time-length input sequences. The underlying tensor " "in this LoDTensor is a matrix with shape (T x N), where T " "is the total time steps in this mini-batch and N is the input " "data dimension."); AddInput("Filter", - "(Tensor), the input(Filter) is a learnable parameter. It " + "the input(Filter) is a learnable parameter. It " "is a 2-D tensor with shape (future_context x N), where, " "future_context is the future context length and N is the data " "dimension."); AddOutput("Out", - "(LoDTensor), the output(Out) is a LodTensor, which supports " + "the output(Out) is a LodTensor, which supports " "variable time-length input sequences. The underlying tensor " "in this LodTensor is a matrix with shape T x N, i.e., the " "same shape as X."); AddComment(R"DOC( -Row-convolution Operator. +:strong:`Row-convolution operator` The row convolution is called lookahead convolution. This operator was introduced in the following paper for DeepSpeech2: @@ -114,9 +114,23 @@ and a filter ($W$) of size $context \times d$, the output sequence is convolved as: $$ -out_{i, :} = \sum_{j=i}^{i + context} in_{j,:} \dot W_{i-j, :} +out_{i, :} = \\sum_{j=i}^{i + context} in_{j,:} \\cdot W_{i-j, :} $$ +In the above equation: + +* $Out_{i}$: The i-th row of output variable with shape [1, D]. + +* $\\tau$: Future context size. + +* $X_{j}$: The j-th row of input variable with shape [1, D]. + +* $W_{i-j}$: The (i-j)-th row of parameters with shape [1, D]. + +More details about row_conv please refer to +the design document +https://github.com/PaddlePaddle/Paddle/issues/2228#issuecomment-303903645 . + )DOC"); } }; diff --git a/paddle/fluid/operators/tensorrt_engine_op.cc b/paddle/fluid/operators/tensorrt_engine_op.cc index 4b1208c4376b48e25866fc510f3a6d2ea06e7610..0ea273af9d5a5c8f1ae112232a9187675031b360 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.cc +++ b/paddle/fluid/operators/tensorrt_engine_op.cc @@ -66,17 +66,25 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector &shape) { } // namespace template -void paddle::operators::TensorRTEngineKernel::Prepare( +void TensorRTEngineKernel::Prepare( const framework::ExecutionContext &context) const { VLOG(4) << "Prepare engine"; // Get the ProgramDesc and pass to convert. framework::proto::BlockDesc block_desc; block_desc.ParseFromString(context.Attr("subgraph")); - max_batch_ = context.Attr("max_batch"); + int max_batch = context.Attr("max_batch"); auto max_workspace = context.Attr("max_workspace"); - engine_ = Singleton::Global().Create( - max_batch_, max_workspace, &stream_); - engine_->InitNetwork(); + auto params = context.Attr>("parameters"); + std::unordered_set parameters; + for (const auto ¶m : params) { + parameters.insert(param); + } + + // TODO(Superjomn) replace this with a different stream + auto *engine = Singleton::Global().Create( + max_batch, max_workspace, nullptr /*engine hold its own stream*/, + context.Attr("engine_uniq_key")); + engine->InitNetwork(); framework::BlockDesc block(nullptr /*programdesc*/, &block_desc); // Add inputs @@ -87,24 +95,23 @@ void paddle::operators::TensorRTEngineKernel::Prepare( PADDLE_ENFORCE_EQ(var->GetType(), FluidDT::VarType_Type_LOD_TENSOR, "TensorRT engine only takes LoDTensor as input"); auto shape = var->GetShape(); - engine_->DeclareInput( + engine->DeclareInput( input, FluidDataType2TRT( var->Proto()->type().lod_tensor().tensor().data_type()), Vec2TRT_Dims(var->GetShape())); } - // TODO(Superjomn) parameters should be passed after analysised from outside. inference::Singleton::Global().ConvertBlock( - block_desc, {}, context.scope(), engine_); + block_desc, parameters, context.scope(), engine); // Add outputs VLOG(4) << "declare outputs"; for (auto &output : context.Outputs("Ys")) { VLOG(4) << "declare output " << output; - engine_->DeclareOutput(output); + engine->DeclareOutput(output); } - engine_->FreezeNetwork(); + engine->FreezeNetwork(); } class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker { @@ -113,6 +120,7 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Xs", "A list of inputs.").AsDuplicable(); AddOutput("Ys", "A list of outputs").AsDuplicable(); AddAttr("subgraph", "the subgraph."); + AddAttr("engine_uniq_key", "unique key for the TRT engine."); AddAttr("max_batch", "the maximum batch size."); AddAttr("max_workspace", "the maximum batch size."); AddComment("TensorRT engine operator."); diff --git a/paddle/fluid/operators/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt_engine_op.h index 4b089601ff76eedd87bb3a52a38c4d22d4a94bf6..8455d24ddf47382b235edda10cb9b2e8934c5f06 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt_engine_op.h @@ -19,10 +19,14 @@ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/engine.h" namespace paddle { namespace operators { +using inference::Singleton; +using inference::tensorrt::TRT_EngineManager; + class TensorRTEngineOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -47,16 +51,18 @@ template class TensorRTEngineKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - if (!engine_) { + auto engine_name = context.Attr("engine_uniq_key"); + if (!Singleton::Global().HasEngine(engine_name)) { Prepare(context); } + auto* engine = Singleton::Global().Get(engine_name); auto input_names = context.op().Inputs("Xs"); PADDLE_ENFORCE(!input_names.empty(), "should pass more than one inputs"); // Try to determine a batch_size auto& tensor0 = inference::analysis::GetFromScope( context.scope(), input_names.front()); int batch_size = tensor0.dims()[0]; - PADDLE_ENFORCE_LE(batch_size, max_batch_); + PADDLE_ENFORCE_LE(batch_size, context.Attr("max_batch")); // Convert input tensor from fluid to engine. for (const auto& x : context.Inputs("Xs")) { @@ -64,20 +70,20 @@ class TensorRTEngineKernel : public framework::OpKernel { auto& t = inference::analysis::GetFromScope( context.scope(), x); if (platform::is_cpu_place(t.place())) { - engine_->SetInputFromCPU(x, static_cast(t.data()), - t.memory_size()); + engine->SetInputFromCPU(x, static_cast(t.data()), + t.memory_size()); } else { - engine_->SetInputFromGPU(x, static_cast(t.data()), - t.memory_size()); + engine->SetInputFromGPU(x, static_cast(t.data()), + t.memory_size()); } } // Execute the engine. PADDLE_ENFORCE_GT(batch_size, 0); - engine_->Execute(batch_size); + engine->Execute(batch_size); // Convert output tensor from engine to fluid for (const auto& y : context.Outputs("Ys")) { // convert output and copy to fluid. - nvinfer1::ITensor* trt_t = engine_->GetITensor(y); + nvinfer1::ITensor* trt_t = engine->GetITensor(y); auto dims = trt_t->getDimensions(); // Use the output ITensor's dims to reshape the Fluid Tensor. std::vector ddim(dims.d, dims.d + dims.nbDims); @@ -89,27 +95,22 @@ class TensorRTEngineKernel : public framework::OpKernel { auto size = inference::analysis::AccuDims(dims.d, dims.nbDims); if (platform::is_cpu_place(fluid_t->place())) { // TODO(Superjomn) change this float to dtype size. - engine_->GetOutputInCPU( + engine->GetOutputInCPU( y, fluid_t->mutable_data(platform::CPUPlace()), size * sizeof(float)); } else { - engine_->GetOutputInGPU( + engine->GetOutputInGPU( y, fluid_t->mutable_data(platform::CUDAPlace()), size * sizeof(float)); } } - cudaStreamSynchronize(stream_); + cudaStreamSynchronize(*engine->stream()); } protected: // Build the engine. void Prepare(const framework::ExecutionContext& context) const; - - private: - mutable cudaStream_t stream_; - mutable inference::tensorrt::TensorRTEngine* engine_{nullptr}; - mutable int max_batch_{0}; }; } // namespace operators diff --git a/paddle/fluid/operators/tensorrt_engine_op_test.cc b/paddle/fluid/operators/tensorrt_engine_op_test.cc index 6f383de259b270038c32296b59007f6c7d895f12..85330958cdba94f6721e3132c36caca43064c0e3 100644 --- a/paddle/fluid/operators/tensorrt_engine_op_test.cc +++ b/paddle/fluid/operators/tensorrt_engine_op_test.cc @@ -79,6 +79,17 @@ void SetAttr(framework::proto::OpDesc* op, const std::string& name, attr->set_type(paddle::framework::proto::AttrType::LONG); attr->set_l(data); } +template <> +void SetAttr>(framework::proto::OpDesc* op, + const std::string& name, + const std::vector& data) { + auto* attr = op->add_attrs(); + attr->set_name(name); + attr->set_type(paddle::framework::proto::AttrType::STRINGS); + for (const auto& s : data) { + attr->add_strings(s.c_str()); + } +} } // namespace @@ -123,11 +134,15 @@ TEST(TensorRTEngineOp, manual) { engine_op_desc.SetOutput("Ys", std::vector({"z0"})); SetAttr(engine_op_desc.Proto(), "subgraph", block_->SerializeAsString()); - SetAttr(engine_op_desc.Proto(), "max_batch", 30); + SetAttr(engine_op_desc.Proto(), "max_batch", 100); SetAttr(engine_op_desc.Proto(), "max_workspace", 1 << 10); + SetAttr(engine_op_desc.Proto(), "engine_uniq_key", "a_engine"); + SetAttr>(engine_op_desc.Proto(), "parameters", + std::vector({})); LOG(INFO) << "create engine op"; auto engine_op = framework::OpRegistry::CreateOp(*engine_op_desc.Proto()); + LOG(INFO) << "engine_op " << engine_op.get(); framework::Scope scope; platform::CPUPlace place; @@ -145,6 +160,88 @@ TEST(TensorRTEngineOp, manual) { engine_op->Run(scope, place); } +void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { + framework::ProgramDesc program; + framework::Scope scope; + platform::CPUPlace place; + platform::CPUDeviceContext ctx(place); + + auto* block_ = program.Proto()->add_blocks(); + block_->set_idx(0); + block_->set_parent_idx(-1); + + using shape_t = std::vector; + + LOG(INFO) << "create block desc"; + framework::BlockDesc block_desc(&program, block_); + + auto AddFCLayer = [&](const std::string& x_name, const std::string& y_name, + const std::string& z_name, bool x_created, + const shape_t& x_shape, const shape_t& y_shape, + const shape_t& z_shape) { + + LOG(INFO) << "create fc op"; + auto* fc = block_desc.AppendOp(); + fc->SetType("mul"); + fc->SetInput("X", std::vector({x_name})); + fc->SetInput("Y", std::vector({y_name})); + fc->SetOutput("Out", std::vector({z_name})); + + // Set inputs' variable shape in BlockDesc + if (!x_created) { + AddTensorToBlockDesc(block_, x_name, + std::vector({batch_size, input_dim, 1, 1})); + } + AddTensorToBlockDesc(block_, y_name, + std::vector({input_dim, output_dim})); + AddTensorToBlockDesc(block_, z_name, + std::vector({batch_size, output_dim})); + + // Prepare variables. + if (!x_created) { + CreateCPUTensor(&scope, x_name, std::vector(x_shape)); + } + CreateCPUTensor(&scope, y_name, std::vector(y_shape)); + CreateCPUTensor(&scope, z_name, std::vector(z_shape)); + + // It is wired, need to copy manually. + *block_->add_ops() = *fc->Proto(); + }; + + // Test with 4 layer FC + AddFCLayer("x0", "y0", "z0", false, {batch_size, input_dim}, + {input_dim, output_dim}, {batch_size, output_dim}); + AddFCLayer("z0", "y1", "z1", true, {}, {output_dim, output_dim}, + {batch_size, output_dim}); + AddFCLayer("z1", "y2", "z2", true, {}, {output_dim, output_dim}, + {batch_size, output_dim}); + AddFCLayer("z2", "y3", "z3", true, {}, {output_dim, output_dim}, + {batch_size, output_dim}); + + LOG(INFO) << "create tensorrt desc"; + framework::OpDesc engine_op_desc(nullptr); + engine_op_desc.SetType("tensorrt_engine"); + engine_op_desc.SetInput("Xs", std::vector({"x0"})); + engine_op_desc.SetOutput("Ys", std::vector({"z3"})); + + SetAttr(engine_op_desc.Proto(), "subgraph", + block_->SerializeAsString()); + SetAttr(engine_op_desc.Proto(), "max_batch", batch_size); + SetAttr(engine_op_desc.Proto(), "max_workspace", 2 << 10); + SetAttr>( + engine_op_desc.Proto(), "parameters", + std::vector({"y0", "y1", "y2", "y3"})); + SetAttr(engine_op_desc.Proto(), "engine_uniq_key", "b_engine"); + + auto engine_op = framework::OpRegistry::CreateOp(*engine_op_desc.Proto()); + + // Execute them. + engine_op->Run(scope, place); +} + +// Test with a larger FC layer. +TEST(TensorRTEngineOp, fc) { Execute(40, 256, 256); } + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/uniform_random_op.cc b/paddle/fluid/operators/uniform_random_op.cc index 137ea91caedabc3167146d91b063dbe9e2e2b931..edd1baa4ace4e246190afcd12b0716f1dd38e243 100644 --- a/paddle/fluid/operators/uniform_random_op.cc +++ b/paddle/fluid/operators/uniform_random_op.cc @@ -86,32 +86,24 @@ class UniformRandomOp : public framework::OperatorWithKernel { class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddOutput("Out", "(Tensor) The output tensor of uniform random op"); + AddOutput("Out", "The output tensor of uniform random op"); AddComment(R"DOC( -Uniform random operator. - This operator initializes a tensor with random values sampled from a -uniform distribution. +uniform distribution. The random result is in set [min, max]. )DOC"); - AddAttr>("shape", - "(vector) The shape of the output tensor"); - AddAttr("min", - "(float, default -1.0) " - "Minimum value of uniform random") + AddAttr>("shape", "The shape of the output tensor"); + AddAttr("min", "Minimum value of uniform random. [default -1.0].") .SetDefault(-1.0f); - AddAttr("max", - "(float, default 1.0) " - "Maximun value of uniform random") + AddAttr("max", "Maximun value of uniform random. [default 1.0].") .SetDefault(1.0f); AddAttr("seed", - "(int, default 0) " "Random seed used for generating samples. " "0 means use a seed generated by the system." "Note that if seed is not 0, this operator will always " - "generate the same random numbers every time.") + "generate the same random numbers every time. [default 0].") .SetDefault(0); - AddAttr("dtype", "(int, default 5(FP32)) Output tensor data type") + AddAttr("dtype", "Output tensor data type. [default 5(FP32)].") .SetDefault(framework::proto::VarType::FP32); } }; diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index ba5cb19257871f6c3c69398c188c4f52f468b3d7..11b2b3269ab00b7646a371d2a1f7605637a802a5 100644 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -20,6 +20,7 @@ from ..framework import Program, Variable, Operator from ..layer_helper import LayerHelper, unique_name from ..initializer import force_init_on_cpu from ops import logical_and, logical_not, logical_or +import numpy __all__ = [ 'split_lod_tensor', @@ -706,7 +707,7 @@ def lod_rank_table(x, level=0): .. code-block:: python x = fluid.layers.data(name='x', shape=[10], - dtype='float32', lod_level=1) + dtype='float32', lod_level=1) out = layers.lod_rank_table(x=x, level=0) """ helper = LayerHelper("lod_rank_table", **locals()) @@ -917,37 +918,40 @@ def create_array(dtype): dtype=dtype) -def less_than(x, y, force_cpu=True, cond=None, **ignored): +@templatedoc() +def less_than(x, y, force_cpu=None, cond=None, **ignored): """ - **Less than** + ${comment} - This layer returns the truth value of :math:`x < y` elementwise. + >>> import paddle.fluid as fluid + >>> less = fluid.layers.less_than(x=label, y=limit) Args: - x(Variable): First operand of *less_than* - y(Variable): Second operand of *less_than* - force_cpu(Bool|True): The output data will be on CPU if set true. + x(${x_type}): ${x_comment}. + y(${y_type}): ${y_comment}. + force_cpu(${force_cpu_type}): ${force_cpu_comment}. cond(Variable|None): Optional output variable to store the result of *less_than* Returns: - Variable: The tensor variable storing the output of *less_than*. - - Examples: - .. code-block:: python - - less = fluid.layers.less_than(x=label, y=limit) + ${out_comment}. """ helper = LayerHelper("less_than", **locals()) if cond is None: cond = helper.create_tmp_variable(dtype='bool') cond.stop_gradient = True + attrs = dict() + if force_cpu is not None: + attrs['force_cpu'] = force_cpu + elif force_init_on_cpu(): + attrs['force_cpu'] = force_init_on_cpu() + helper.append_op( type='less_than', inputs={'X': [x], 'Y': [y]}, outputs={'Out': [cond]}, - attrs={'force_cpu': force_cpu or force_init_on_cpu()}) + attrs=attrs) return cond @@ -1012,8 +1016,28 @@ def array_read(array, i): def shrink_memory(x, i, table): """ - This function creates an operator to shrink_rnn_memory using the RankTable + This function creates an operator to shrink rnn memory using the RankTable as mentioned in the input parameter. + + NOTE: This API is very low-level API. It is used by DynamicRNN only. + + Since the Dynamic RNN uses no-padding way to implement RNN. The sequence + will be sorted by order, and the length of valid memory will be shrink after + each time step. + + Args: + x(Variable): The memory object in the previous time step. + i(Variable): The step count variable. A int scalar as LoDTensor. + table(Variable): The RNNRankTable object. + + Returns: + the memory variable after shrink. + + Examples: + + Since this API is very low level API. The example is not provided. + Please reference the implementation of class DynamicRNN for detail + usage. """ helper = LayerHelper('shrink_memory', **locals()) out = helper.create_tmp_variable(dtype=x.dtype) @@ -1354,6 +1378,38 @@ class IfElse(object): class DynamicRNN(object): + """ + The dynamic RNN can process a batch of sequence data. The length of each + sample sequence can be different. This API automatically process them in + batch. + + The input lod must be set. Please reference `lod_tensor` + + >>> import paddle.fluid as fluid + >>> data = fluid.layers.data(name='sentence', dtype='int64', lod_level=1) + >>> embedding = fluid.layers.embedding(input=data, size=[65535, 32], + >>> is_sparse=True) + >>> + >>> drnn = fluid.layers.DynamicRNN() + >>> with drnn.block(): + >>> word = drnn.step_input(embedding) + >>> prev = drnn.memory(shape=[200]) + >>> hidden = fluid.layers.fc(input=[word, prev], size=200, act='relu') + >>> drnn.update_memory(prev, hidden) # set prev to hidden + >>> drnn.output(hidden) + >>> + >>> # last is the last time step of rnn. It is the encoding result. + >>> last = fluid.layers.sequence_last_step(drnn()) + + The dynamic RNN will unfold sequence into timesteps. Users need to define + how to process each time step during the :code:`with` block. + + The `memory` is used staging data cross time step. The initial value of + memory can be zero or another variable. + + The dynamic RNN can mark multiple variables as its output. Use `drnn()` to + get the output sequence. + """ BEFORE_RNN = 0 IN_RNN = 1 AFTER_RNN = 2 @@ -1376,6 +1432,15 @@ class DynamicRNN(object): self.mem_link = [] def step_input(self, x): + """ + Mark a sequence as a dynamic RNN input. + Args: + x(Variable): The input sequence. + + Returns: + The current timestep in the input sequence. + + """ self._assert_in_rnn_block_("step_input") if not isinstance(x, Variable): raise TypeError( @@ -1419,6 +1484,15 @@ class DynamicRNN(object): return array_read(array=input_array, i=self.step_idx) def static_input(self, x): + """ + Mark a variable as a RNN input. The input will not be scattered into + time steps. + Args: + x(Variable): The input variable. + + Returns: + The input variable that can access in RNN. + """ self._assert_in_rnn_block_("static_input") if not isinstance(x, Variable): raise TypeError( @@ -1440,6 +1514,10 @@ class DynamicRNN(object): @contextlib.contextmanager def block(self): + """ + The block for user to define operators in RNN. See the class docstring + for more details. + """ if self.status != DynamicRNN.BEFORE_RNN: raise ValueError("rnn.block() can only be invoke once") self.step_idx = fill_constant( @@ -1466,6 +1544,9 @@ class DynamicRNN(object): x=each_array, table=self.lod_rank_table)) def __call__(self, *args, **kwargs): + """ + Get the output of RNN. This API should only be invoked after RNN.block() + """ if self.status != DynamicRNN.AFTER_RNN: raise ValueError(("Output of the dynamic RNN can only be visited " "outside the rnn block.")) @@ -1480,6 +1561,70 @@ class DynamicRNN(object): value=0.0, need_reorder=False, dtype='float32'): + """ + Create a memory variable for dynamic rnn. + + If the :code:`init` is not None, :code:`memory` will be initialized by + this variable. The :code:`need_reorder` is used to reorder the memory as + the input variable. It should be set to true when the initialized memory + depends on the input sample. + + For example, + + >>> import paddle.fluid as fluid + >>> sentence = fluid.layers.data( + >>> name='sentence', dtype='float32', shape=[32]) + >>> boot_memory = fluid.layers.data( + >>> name='boot', dtype='float32', shape=[10]) + >>> + >>> drnn = fluid.layers.DynamicRNN() + >>> with drnn.block(): + >>> word = drnn.step_input(sentence) + >>> memory = drnn.memory(init=boot_memory, need_reorder=True) + >>> hidden = fluid.layers.fc( + >>> input=[word, memory], size=10, act='tanh') + >>> drnn.update_memory(ex_mem=memory, new_mem=hidden) + >>> drnn.output(hidden) + >>> rnn_output = drnn() + + + Otherwise, if :code:`shape`, :code:`value`, :code:`dtype` are set, the + :code:`memory` will be initialized by this :code:`value`. + + For example, + + >>> import paddle.fluid as fluid + >>> sentence = fluid.layers.data( + >>> name='sentence', dtype='float32', shape=[32]) + >>> + >>> drnn = fluid.layers.DynamicRNN() + >>> with drnn.block(): + >>> word = drnn.step_input(sentence) + >>> memory = drnn.memory(shape=[10], dtype='float32', value=0) + >>> hidden = fluid.layers.fc( + >>> input=[word, memory], size=10, act='tanh') + >>> drnn.update_memory(ex_mem=memory, new_mem=hidden) + >>> drnn.output(hidden) + >>> rnn_output = drnn() + + + Args: + init(Variable|None): The initialized variable. + + shape(list|tuple): The memory shape. NOTE the shape does not contain + batch_size. + + value(float): the initalized value. + + need_reorder(bool): True if the initialized memory depends on the + input sample. + + dtype(str|numpy.dtype): The data type of the initialized memory. + + Returns: + the memory variable. + + """ self._assert_in_rnn_block_('memory') if init is not None: if not isinstance(init, Variable): @@ -1547,6 +1692,16 @@ class DynamicRNN(object): return self.memory(init=init) def update_memory(self, ex_mem, new_mem): + """ + Update the memory from ex_mem to new_mem. NOTE that the shape and data + type of :code:`ex_mem` and :code:`new_mem` must be same. + Args: + ex_mem(Variable): the memory variable. + new_mem(Variable): the plain variable generated in RNN block. + + Returns: + None + """ self._assert_in_rnn_block_('update_memory') if not isinstance(ex_mem, Variable): raise TypeError("The input arg `ex_mem` of update_memory() must " @@ -1564,6 +1719,15 @@ class DynamicRNN(object): self.mem_link.append((new_mem, mem_array)) def output(self, *outputs): + """ + mark the RNN output variables. + + Args: + outputs: The output variables. + + Returns: + None + """ self._assert_in_rnn_block_('output') parent_block = self._parent_block_() for each in outputs: diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 3a83db12fd13651578deeac6b562bac2f1e4e4b6..edf528a5950ee84be4a3e2097cee36cb5ad8c68e 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -210,53 +210,68 @@ def bipartite_match(dist_matrix, dist_threshold=None, name=None): """ - **Bipartite matchint operator** - - This operator is a greedy bipartite matching algorithm, which is used to - obtain the matching with the maximum distance based on the input + This operator implements a greedy bipartite matching algorithm, which is + used to obtain the matching with the maximum distance based on the input distance matrix. For input 2D matrix, the bipartite matching algorithm can - find the matched column for each row, also can find the matched row for - each column. And this operator only calculate matched indices from column - to row. For each instance, the number of matched indices is the number of - of columns of the input ditance matrix. - - There are two outputs to save matched indices and distance. - A simple description, this algothrim matched the best (maximum distance) + find the matched column for each row (matched means the largest distance), + also can find the matched row for each column. And this operator only + calculate matched indices from column to row. For each instance, + the number of matched indices is the column number of the input distance + matrix. + + There are two outputs, matched indices and distance. + A simple description, this algorithm matched the best (maximum distance) row entity to the column entity and the matched indices are not duplicated in each row of ColToRowMatchIndices. If the column entity is not matched any row entity, set -1 in ColToRowMatchIndices. - Please note that the input DistMat can be LoDTensor (with LoD) or Tensor. + NOTE: the input DistMat can be LoDTensor (with LoD) or Tensor. If LoDTensor with LoD, the height of ColToRowMatchIndices is batch size. If Tensor, the height of ColToRowMatchIndices is 1. + NOTE: This API is a very low level API. It is used by :code:`ssd_loss` + layer. Please consider to use :code:`ssd_loss` instead. + Args: dist_matrix(Variable): This input is a 2-D LoDTensor with shape [K, M]. It is pair-wise distance matrix between the entities represented by each row and each column. For example, assumed one entity is A with shape [K], another entity is B with shape [M]. The - dist_matirx[i][j] is the distance between A[i] and B[j]. The bigger - the distance is, the better macthing the pairs are. Please note, - This tensor can contain LoD information to represent a batch of - inputs. One instance of this batch can contain different numbers of - entities. + dist_matrix[i][j] is the distance between A[i] and B[j]. The bigger + the distance is, the better matching the pairs are. + + NOTE: This tensor can contain LoD information to represent a batch + of inputs. One instance of this batch can contain different numbers + of entities. match_type(string|None): The type of matching method, should be - 'bipartite' or 'per_prediction', 'bipartite' by defalut. + 'bipartite' or 'per_prediction'. [default 'bipartite']. dist_threshold(float|None): If `match_type` is 'per_prediction', this threshold is to determine the extra matching bboxes based - on the maximum distance, 0.5 by defalut. + on the maximum distance, 0.5 by default. Returns: - match_indices(Variable): A 2-D Tensor with shape [N, M] in int type. - N is the batch size. If match_indices[i][j] is -1, it - means B[j] does not match any entity in i-th instance. - Otherwise, it means B[j] is matched to row - match_indices[i][j] in i-th instance. The row number of - i-th instance is saved in match_indices[i][j]. - match_distance(Variable): A 2-D Tensor with shape [N, M] in float type. - N is batch size. If match_indices[i][j] is -1, - match_distance[i][j] is also -1.0. Otherwise, assumed - match_distance[i][j] = d, and the row offsets of each instance - are called LoD. Then match_distance[i][j] = dist_matrix[d+LoD[i]][j]. + tuple: a tuple with two elements is returned. The first is + matched_indices, the second is matched_distance. + + The matched_indices is a 2-D Tensor with shape [N, M] in int type. + N is the batch size. If match_indices[i][j] is -1, it + means B[j] does not match any entity in i-th instance. + Otherwise, it means B[j] is matched to row + match_indices[i][j] in i-th instance. The row number of + i-th instance is saved in match_indices[i][j]. + + The matched_distance is a 2-D Tensor with shape [N, M] in float type + . N is batch size. If match_indices[i][j] is -1, + match_distance[i][j] is also -1.0. Otherwise, assumed + match_distance[i][j] = d, and the row offsets of each instance + are called LoD. Then match_distance[i][j] = + dist_matrix[d+LoD[i]][j]. + + Examples: + + >>> x = fluid.layers.data(name='x', shape=[4], dtype='float32') + >>> y = fluid.layers.data(name='y', shape=[4], dtype='float32') + >>> iou = fluid.layers.iou_similarity(x=x, y=y) + >>> matched_indices, matched_dist = fluid.layers.bipartite_match(iou) """ helper = LayerHelper('bipartite_match', **locals()) match_indices = helper.create_tmp_variable(dtype='int32') @@ -364,7 +379,7 @@ def ssd_loss(location, normalize=True, sample_size=None): """ - **Multi-box loss layer for object dection algorithm of SSD** + **Multi-box loss layer for object detection algorithm of SSD** This layer is to compute dection loss for SSD given the location offset predictions, confidence predictions, prior boxes and ground-truth boudding @@ -372,21 +387,35 @@ def ssd_loss(location, is a weighted sum of the localization loss (or regression loss) and confidence loss (or classification loss) by performing the following steps: - 1. Find matched boundding box by bipartite matching algorithm. + 1. Find matched bounding box by bipartite matching algorithm. + 1.1 Compute IOU similarity between ground-truth boxes and prior boxes. + 1.2 Compute matched boundding box by bipartite matching algorithm. + 2. Compute confidence for mining hard examples + 2.1. Get the target label based on matched indices. + 2.2. Compute confidence loss. + 3. Apply hard example mining to get the negative example indices and update the matched indices. + 4. Assign classification and regression targets + 4.1. Encoded bbox according to the prior boxes. + 4.2. Assign regression targets. + 4.3. Assign classification targets. + 5. Compute the overall objective loss. + 5.1 Compute confidence loss. + 5.1 Compute localization loss. + 5.3 Compute the overall weighted loss. Args: @@ -421,39 +450,36 @@ def ssd_loss(location, mining_type (str): The hard example mining type, should be 'hard_example' or 'max_negative', now only support `max_negative`. normalize (bool): Whether to normalize the SSD loss by the total number - of output locations, True by defalut. + of output locations, True by default. sample_size (int): The max sample size of negative box, used only when mining_type is 'hard_example'. Returns: - Variable: The weighted sum of the localization loss and confidence loss, - with shape [N * Np, 1], N and Np are the same as they are - in `location`. + The weighted sum of the localization loss and confidence loss, with \ + shape [N * Np, 1], N and Np are the same as they are in `location`. Raises: - ValueError: If mining_type is 'hard_example', now only support - mining type of `max_negative`. + ValueError: If mining_type is 'hard_example', now only support mining \ + type of `max_negative`. Examples: - .. code-block:: python - - pb = layers.data( - name='prior_box', - shape=[10, 4], - append_batch_size=False, - dtype='float32') - pbv = layers.data( - name='prior_box_var', - shape=[10, 4], - append_batch_size=False, - dtype='float32') - loc = layers.data(name='target_box', shape=[10, 4], dtype='float32') - scores = layers.data(name='scores', shape=[10, 21], dtype='float32') - gt_box = layers.data( - name='gt_box', shape=[4], lod_level=1, dtype='float32') - gt_label = layers.data( - name='gt_label', shape=[1], lod_level=1, dtype='float32') - loss = layers.ssd_loss(loc, scores, gt_box, gt_label, pb, pbv) + >>> pb = fluid.layers.data( + >>> name='prior_box', + >>> shape=[10, 4], + >>> append_batch_size=False, + >>> dtype='float32') + >>> pbv = fluid.layers.data( + >>> name='prior_box_var', + >>> shape=[10, 4], + >>> append_batch_size=False, + >>> dtype='float32') + >>> loc = fluid.layers.data(name='target_box', shape=[10, 4], dtype='float32') + >>> scores = fluid.layers.data(name='scores', shape=[10, 21], dtype='float32') + >>> gt_box = fluid.layers.data( + >>> name='gt_box', shape=[4], lod_level=1, dtype='float32') + >>> gt_label = fluid.layers.data( + >>> name='gt_label', shape=[1], lod_level=1, dtype='float32') + >>> loss = fluid.layers.ssd_loss(loc, scores, gt_box, gt_label, pb, pbv) """ helper = LayerHelper('ssd_loss', **locals()) diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index b004312d2ddf7588107d01d99e17a1c4a0749518..2dd82e35320ef79ba754070b6369e00413647b53 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -22,9 +22,9 @@ from ..executor import global_scope from layer_function_generator import generate_layer_fn, templatedoc __all__ = [ - 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file', - 'open_files', 'read_file', 'shuffle', 'batch', 'double_buffer', - 'random_data_generator', 'Preprocessor', 'load' + 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'Recv', + 'open_recordio_file', 'open_files', 'read_file', 'shuffle', 'batch', + 'double_buffer', 'random_data_generator', 'Preprocessor', 'load' ] @@ -177,18 +177,17 @@ class ListenAndServ(object): }) -def Send(endpoints, send_vars, get_vars=None): +def Send(endpoints, send_vars, sync=True): """ - Send layer + Send variables to the server side, and get vars from server + side when server have finished running server side program. Args: - endpoints: comma seperated IP:PORT pairs in the order + endpoints (str): comma seperated IP:PORT pairs in the order of send_vars to send - send_vars: vars to send - get_vars: vars to get from server after send completes. - - Send variables to the server side, and get vars from server - side when server have finished running server side program. + send_vars (list): variables to send to server + sync (bool): whether to wait the request finish + """ assert (type(send_vars) == list) @@ -196,40 +195,33 @@ def Send(endpoints, send_vars, get_vars=None): endpoints = list(set(epmap)) helper = LayerHelper("Send", **locals()) - if not get_vars: - get_vars = [] - for s in send_vars: - v = helper.create_tmp_variable(dtype=s.dtype, stop_gradient=True) - get_vars.append(v) rpc_op_role_name = core.op_proto_and_checker_maker.kOpRoleAttrName() helper.append_op( type="send", inputs={"X": send_vars}, - outputs={"Out": get_vars}, attrs={ "endpoints": endpoints, "epmap": epmap, rpc_op_role_name: core.op_proto_and_checker_maker.OpRole.RPC }) - - return get_vars + if sync: + helper.append_op(type="send_barrier", attrs={"endpoints": endpoints}) -def Recv(endpoints, get_vars): +def Recv(endpoints, get_vars, sync=True): """ - Recv layer + Receive variables from server side Args: - endpoints: comma seperated IP:PORT pairs in the order + endpoints (str): comma seperated IP:PORT pairs in the order of send_vars to send - send_vars: vars to send - get_vars: vars to get from server after send completes. + get_vars (list): vars to get from server after send completes. + sync (bool): whether to wait the request finish - Send variables to the server side, and get vars from server - side when server have finished running server side program. + Returns: + list: list of received variables """ - assert (type(send_vars) == list) assert (type(get_vars) == list) epmap = endpoints.split(",") @@ -242,6 +234,9 @@ def Recv(endpoints, get_vars): outputs={"Out": get_vars}, attrs={"endpoints": endpoints, "epmap": epmap}) + if sync: + helper.append_op(type="fetch_barrier", attrs={"endpoints": endpoints}) + return get_vars def monkey_patch_reader_methods(reader): @@ -292,6 +287,7 @@ def _copy_reader_create_op_(block, op): return new_op +@templatedoc(op_type='create_recordio_file_reader') def open_recordio_file(filename, shapes, lod_levels, @@ -299,34 +295,30 @@ def open_recordio_file(filename, pass_num=1, for_parallel=True): """ - Open a RecordIO file - - This layer takes a RecordIO file to read from and returns a Reader Variable. - Via the Reader Variable, we can get data from the given RecordIO file. + ${comment} Args: - filename(str): The RecordIO file's name. + filename(${filename_type}): ${filename_comment}. shapes(list): List of tuples which declaring data shapes. - lod_levels(list): List of ints which declaring data lod_level. + lod_levels(${lod_levels_type}): ${lod_levels_comment}. dtypes(list): List of strs which declaring data type. pass_num(int): Number of passes to run. for_parallel(Bool): Set it as True if you are going to run subsequent operators in parallel. Returns: - Variable: A Reader Variable via which we can get RecordIO file data. + ${out_comment}. Examples: - .. code-block:: python - reader = fluid.layers.io.open_recordio_file( - filename='./data.recordio', - shapes=[(3,224,224), (1)], - lod_levels=[0, 0], - dtypes=['float32', 'int64']) - - # Via the reader, we can use 'read_file' layer to get data: - image, label = fluid.layers.io.read_file(reader) + >>> import paddle.fluid as fluid + >>> reader = fluid.layers.io.open_recordio_file( + >>> filename='./data.recordio', + >>> shapes=[(3,224,224), (1)], + >>> lod_levels=[0, 0], + >>> dtypes=['float32', 'int64']) + >>> # Via the reader, we can use 'read_file' layer to get data: + >>> image, label = fluid.layers.io.read_file(reader) """ dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes] shape_concat = [] @@ -544,6 +536,9 @@ def __create_unshared_decorated_reader__(op_type, reader, attrs, name=None): def shuffle(reader, buffer_size): + """ + Shuffle the reader. + """ return __create_unshared_decorated_reader__( 'create_shuffle_reader', reader, {'buffer_size': int(buffer_size)}) @@ -589,6 +584,29 @@ def batch(reader, batch_size): def double_buffer(reader, place=None, name=None): + """ + Wrap a double buffer reader. The data will copy to target place with a + double buffer queue. If the target place is None, the place that executor + perform on will be used. + + Args: + reader(Variable): the reader variable need to be wrapped. + place(Place): the place of target data. Default is the sample place of + executor perform. + + name(str): Variable name. None if the user does not care. + + Returns: + wrapped reader with double buffer. + + Examples: + + >>> reader = fluid.layers.open_files(filenames=['somefile'], + >>> shapes=[[-1, 784], [-1, 1]], + >>> dtypes=['float32', 'int64']) + >>> reader = fluid.layers.double_buffer(reader) + >>> img, label = fluid.layers.read_file(reader) + """ attrs = dict() if place is not None: attrs['place'] = str(place).upper() diff --git a/python/paddle/fluid/layers/layer_function_generator.py b/python/paddle/fluid/layers/layer_function_generator.py index cb60a3aec9a5a69f1eed281eb017384a621c66a8..7a95afa9a6c370adbc13f4cb77526e316033899a 100644 --- a/python/paddle/fluid/layers/layer_function_generator.py +++ b/python/paddle/fluid/layers/layer_function_generator.py @@ -44,6 +44,11 @@ def _type_to_str_(tp): return framework_pb2.AttrType.Name(tp) +_two_dollar_pattern_ = re.compile(r"\$\$([^\$]+)\$\$") +_single_dollar_pattern_ = re.compile(r"\$([^\$]+)\$") +_two_bang_pattern_ = re.compile(r"!!([^!]+)!!") + + def _generate_doc_string_(op_proto): """ Generate docstring by OpProto @@ -55,22 +60,26 @@ def _generate_doc_string_(op_proto): str: the document string """ + def escape_math(text): + return _two_bang_pattern_.sub( + r'$$\1$$', + _single_dollar_pattern_.sub( + r':math:`\1`', _two_dollar_pattern_.sub(r"!!\1!!", text))) + if not isinstance(op_proto, framework_pb2.OpProto): raise TypeError("OpProto should be `framework_pb2.OpProto`") buf = cStringIO.StringIO() - buf.write(op_proto.comment) + buf.write(escape_math(op_proto.comment)) buf.write('\nArgs:\n') for each_input in op_proto.inputs: line_begin = ' {0}: '.format(_convert_(each_input.name)) buf.write(line_begin) - buf.write(each_input.comment) - buf.write('\n') - buf.write(' ' * len(line_begin)) - buf.write('Duplicable: ') - buf.write(str(each_input.duplicable)) - buf.write(' Optional: ') - buf.write(str(each_input.dispensable)) + buf.write(escape_math(each_input.comment)) + if each_input.duplicable: + buf.write(" Duplicatable.") + if each_input.dispensable: + buf.write(" Optional.") buf.write('\n') skip_attrs = OpProtoHolder.generated_op_attr_names() @@ -83,7 +92,7 @@ def _generate_doc_string_(op_proto): buf.write(' (') buf.write(_type_to_str_(each_attr.type)) buf.write('): ') - buf.write(each_attr.comment) + buf.write(escape_math(each_attr.comment)) buf.write('\n') if len(op_proto.outputs) != 0: @@ -92,7 +101,7 @@ def _generate_doc_string_(op_proto): for each_opt in op_proto.outputs: if not each_opt.intermediate: break - buf.write(each_opt.comment) + buf.write(escape_math(each_opt.comment)) return buf.getvalue() diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 56a230914677adbdd624d8bd243a26085e79e808..d768aa4b625b4602ea4d7894fc70cf3529f495de 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -225,11 +225,11 @@ def embedding(input, have two elements which indicate the size of the dictionary of embeddings and the size of each embedding vector respectively. is_sparse(bool): The flag indicating whether to use sparse update. - is_distributed (bool): Whether to run lookup table from remote parameter server. + is_distributed(bool): Whether to run lookup table from remote parameter server. padding_idx(int|long|None): If :attr:`None`, it makes no effect to lookup. Otherwise the given :attr:`padding_idx` indicates padding the output with zeros whenever lookup encounters it in :attr:`input`. If - :math:`padding_idx < 0`, the padding_idx to use in lookup is + :math:`padding_idx < 0`, the :attr:`padding_idx` to use in lookup is :math:`size[0] + dim`. param_attr(ParamAttr): Parameters for this layer dtype(np.dtype|core.VarDesc.VarType|str): The type of data : float32, float_16, int etc @@ -1235,14 +1235,17 @@ def conv2d(input, act=None, name=None): """ - **Convlution2D Layer** - The convolution2D layer calculates the output based on the input, filter - and strides, paddings, dilations, groups parameters. Input(Input) and - Output(Output) are in NCHW format. Where N is batch size, C is the number of + and strides, paddings, dilations, groups parameters. Input and + Output are in NCHW format, where N is batch size, C is the number of channels, H is the height of the feature, and W is the width of the feature. - The details of convolution layer, please refer UFLDL's `convolution, - `_ . + Filter is in MCHW format, where M is the number of output image channels, + C is the number of input image channels, H is the height of the filter, + and W is the width of the filter. If the groups is greater than 1, + C will equal the number of input image channels divided by the groups. + Please refer to UFLDL's `convolution + `_ + for more detials. If bias attribution and activation type are provided, bias is added to the output of the convolution, and the corresponding activation function is applied to the final result. @@ -1253,15 +1256,14 @@ def conv2d(input, Out = \sigma (W \\ast X + b) - In the above equation: + Where: * :math:`X`: Input value, a tensor with NCHW format. * :math:`W`: Filter value, a tensor with MCHW format. * :math:`\\ast`: Convolution operation. * :math:`b`: Bias value, a 2-D tensor with shape [M, 1]. * :math:`\\sigma`: Activation function. - * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be - different. + * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different. Example: @@ -1272,6 +1274,7 @@ def conv2d(input, Filter shape: :math:`(C_{out}, C_{in}, H_f, W_f)` - Output: + Output shape: :math:`(N, C_{out}, H_{out}, W_{out})` Where @@ -1283,7 +1286,7 @@ def conv2d(input, Args: input (Variable): The input image with [N, C, H, W] format. - num_filters(int): The number of filter. It is as same as the output + num_filters(int): The number of filter. It is as same as the output image channel. filter_size (int|tuple|None): The filter size. If filter_size is a tuple, it must contain two integers, (filter_size_H, filter_size_W). @@ -1306,7 +1309,8 @@ def conv2d(input, bias_attr (ParamAttr): Bias parameter for the Conv2d layer. Default: None use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn library is installed. Default: True - use_mkldnn (bool): Use mkldnn kernels or not. + use_mkldnn (bool): Use mkldnn kernels or not, it is valid only when compiled + with mkldnn library. Default: False act (str): Activation type. Default: None name (str|None): A name for this layer(optional). If set None, the layer will be named automatically. @@ -1974,6 +1978,7 @@ def batch_norm(input, return helper.append_activation(batch_norm_out) +@templatedoc() def layer_norm(input, scale=True, shift=True, @@ -1984,20 +1989,11 @@ def layer_norm(input, act=None, name=None): """ - **Layer Normalization** - - Assume feature vectors exist on dimensions - :attr:`begin_norm_axis ... rank(input)` and calculate the moment statistics - along these dimensions for each feature vector :math:`a` with size - :math:`H`, then normalize each feature vector using the corresponding - statistics. After that, apply learnable gain and bias on the normalized - tensor to scale and shift if :attr:`scale` and :attr:`shift` are set. - - Refer to `Layer Normalization `_ + ${comment} The formula is as follows: - .. math:: + .. math:: \\mu & = \\frac{1}{H}\\sum_{i=1}^{H} a_i @@ -2005,6 +2001,15 @@ def layer_norm(input, h & = f(\\frac{g}{\\sigma}(a - \\mu) + b) + * :math:`a`: the vector representation of the summed inputs to the neurons + in that layer. + + * :math:`H`: the number of hidden units in a layers + + * :math:`g`: the trainable scale parameter. + + * :math:`b`: the trainable bias parameter. + Args: input(Variable): The input tensor variable. scale(bool): Whether to learn the adaptive gain :math:`g` after @@ -2023,14 +2028,13 @@ def layer_norm(input, name (str): The name of this layer. It is optional. Returns: - Variable: A tensor variable with the same shape as the input. + ${y_comment} Examples: - .. code-block:: python - data = fluid.layers.data( - name='data', shape=[3, 32, 32], dtype='float32') - x = fluid.layers.layer_norm(input=data, begin_norm_axis=1) + >>> data = fluid.layers.data(name='data', shape=[3, 32, 32], + >>> dtype='float32') + >>> x = fluid.layers.layer_norm(input=data, begin_norm_axis=1) """ helper = LayerHelper('layer_norm', **locals()) dtype = helper.input_dtype() @@ -3739,29 +3743,13 @@ def im2sequence(input, filter_size=1, stride=1, padding=0, name=None): return out +@templatedoc() def row_conv(input, future_context_size, param_attr=None, act=None): - """Row Conv Operator. This layer will apply lookahead convolution to - **input**. The input variable should be a 2D LoDTensor with shape [T, D]. - Parameters with shape [future_context_size + 1, D] will be created. The math - equation of row convolution is as follows: - - .. math:: - Out_{i} = \sum_{j = i} ^ {i + \\tau} X_{j} \odot W_{i - j} - - In the above equation: - - * :math:`Out_{i}`: The i-th row of output variable with shape [1, D]. - * :math:`\\tau`: Future context size. - * :math:`X_{j}`: The j-th row of input variable with shape [1, D]. - * :math:`W_{i-j}`: The (i-j)-th row of parameters with shape [1, D]. - - More details about row_conv please refer to the paper \ - (http://www.cs.cmu.edu/~dyogatam/papers/wang+etal.iclrworkshop2016.pdf) and - the design document \ - (https://github.com/PaddlePaddle/Paddle/issues/2228#issuecomment-303903645). + """ + ${comment} Args: - input (Variable): Input variable, a 2D LoDTensor with shape [T, D]. + input (${x_type}): ${x_comment}. future_context_size (int): Future context size. Please note, the shape of convolution kernel is [future_context_size + 1, D]. param_attr (ParamAttr): Attributes of parameters, including @@ -3769,14 +3757,13 @@ def row_conv(input, future_context_size, param_attr=None, act=None): act (str): Non-linear activation to be applied to output variable. Returns: - Variable: The output tensor with same shape as input tensor. + ${out_comment}. Examples: - .. code-block:: python - - x = fluid.layers.data(name='x', shape=[16], - dtype='float32', lod_level=1) - out = fluid.layers.row_conv(input=x, future_context_size=2) + >>> import paddle.fluid as fluid + >>> x = fluid.layers.data(name='x', shape=[16], + >>> dtype='float32', lod_level=1) + >>> out = fluid.layers.row_conv(input=x, future_context_size=2) """ helper = LayerHelper('row_conv', **locals()) dtype = helper.input_dtype() @@ -3792,42 +3779,23 @@ def row_conv(input, future_context_size, param_attr=None, act=None): return helper.append_activation(out) +@templatedoc() def multiplex(inputs, index): """ - **Multiplex Layer** - - Referring to the given index variable, this layer selects rows from the - input variables to construct a multiplex variable. Assuming that there are - :math:`m` input variables and :math:`I_i` represents the i-th input - variable and :math:`i` is in [0, :math:`m`). All input variables are - tensors with same shape [:math:`d_0`, :math:`d_1`, ..., :math:`d_R`]. - Please note that rank of the input tensor should be at least 2. Each input - variable will be treated as a 2-D matrix with shape [:math:`M`, :math:`N`] - where :math:`M` for :math:`d_0` and :math:`N` for :math:`d_1` * :math:`d_2` - * ... * :math:`d_R`. Let :math:`I_i[j]` be the j-th row of the i-th input - variable. The given index variable should be a 2-D tensor with shape - [:math:`M`, 1]. Let `ID[i]` be the i-th index value of the index variable. - Then the output variable will be a tensor with shape [:math:`d_0`, - :math:`d_1`, ..., :math:`d_R`]. If we treat the output tensor as a 2-D - matrix with shape [:math:`M`, :math:`N`] and let :math:`O[i]` be the i-th - row of the matrix, then `O[i]` is equal to :math:`I_{ID[i]}[i]`. + ${comment} + + >>> import paddle.fluid as fluid + >>> x1 = fluid.layers.data(name='x1', shape=[4], dtype='float32') + >>> x2 = fluid.layers.data(name='x2', shape=[4], dtype='float32') + >>> index = fluid.layers.data(name='index', shape=[1], dtype='int32') + >>> out = fluid.layers.multiplex(inputs=[x1, x2], index=index) Args: - inputs (list): A list of variables to gather from. All variables have the - same shape and the rank is at least 2. - index (Variable): Tensor, index variable which is a 2-D tensor - with shape [M, 1] where M is the batch size. + inputs (list): ${x_comment}. + index (${ids_type}): ${ids_comment}. Returns: - Variable: Multiplex variable gathered from input variables. - - Examples: - .. code-block:: python - - x1 = fluid.layers.data(name='x1', shape=[4], dtype='float32') - x2 = fluid.layers.data(name='x2', shape=[4], dtype='float32') - index = fluid.layers.data(name='index', shape=[1], dtype='int32') - out = fluid.layers.multiplex(inputs=[x1, x2], index=index) + ${out_comment}. """ helper = LayerHelper('multiplex', **locals()) diff --git a/python/paddle/fluid/layers/ops.py b/python/paddle/fluid/layers/ops.py index 98f169e8f0881fbba6aecb45b43a52c8fd51132d..6f404c5cc608abda91c1d042d405f109dedc55c9 100644 --- a/python/paddle/fluid/layers/ops.py +++ b/python/paddle/fluid/layers/ops.py @@ -40,8 +40,6 @@ __activations__ = [ 'relu6', 'pow', 'stanh', - 'hard_shrink', - 'thresholded_relu', 'hard_sigmoid', 'swish', ] @@ -64,11 +62,9 @@ __all__ = [ 'logical_or', 'logical_xor', 'logical_not', - 'uniform_random', 'uniform_random_batch_size_like', 'gaussian_random', 'gaussian_random_batch_size_like', - 'cumsum', 'scatter', 'sum', 'slice', @@ -79,3 +75,88 @@ __all__ = [ for _OP in set(__all__): globals()[_OP] = generate_layer_fn(_OP) + +__all__ += ["uniform_random"] + +_uniform_random_ = generate_layer_fn('uniform_random') + + +def uniform_random(shape, dtype=None, min=None, max=None, seed=None): + kwargs = dict() + for name in locals(): + val = locals()[name] + if val is not None: + kwargs[name] = val + return _uniform_random_(**kwargs) + + +uniform_random.__doc__ = _uniform_random_.__doc__ + """ +Examples: + + >>> result = fluid.layers.uniform_random(shape=[32, 784]) +""" + +__all__ += ['hard_shrink'] + +_hard_shrink_ = generate_layer_fn('hard_shrink') + + +def hard_shrink(x, threshold=None): + kwargs = dict() + for name in locals(): + val = locals()[name] + if val is not None: + kwargs[name] = val + return _hard_shrink_(**kwargs) + + +hard_shrink.__doc__ = _hard_shrink_.__doc__ + """ +Examples: + + >>> data = fluid.layers.data(name="input", shape=[784]) + >>> result = fluid.layers.hard_shrink(x=data, threshold=0.3) +""" + +__all__ += ['cumsum'] + +_cum_sum_ = generate_layer_fn('cumsum') + + +def cumsum(x, axis=None, exclusive=None, reverse=None): + kwargs = dict() + for name in locals(): + val = locals()[name] + if val is not None: + kwargs[name] = val + + return _cum_sum_(**kwargs) + + +cumsum.__doc__ = _cum_sum_.__doc__ + """ +Examples: + + >>> data = fluid.layers.data(name="input", shape=[32, 784]) + >>> result = fluid.layers.cumsum(data, axis=0) +""" + +__all__ += ['thresholded_relu'] + +_thresholded_relu_ = generate_layer_fn('thresholded_relu') + + +def thresholded_relu(x, threshold=None): + kwargs = dict() + for name in locals(): + val = locals()[name] + if val is not None: + kwargs[name] = val + + _thresholded_relu_(**kwargs) + + +thresholded_relu.__doc__ = _thresholded_relu_.__doc__ + """ +Examples: + + >>> data = fluid.layers.data(name="input", shape=[1]) + >>> result = fluid.layers.thresholded_relu(data, threshold=0.4) +""" diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 81f42ff47032c7e5b6f356720c938a0797d60a40..2637dfe5e5ca413e74bb8a2ecd795b8e63fd71ce 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -6,7 +6,7 @@ # # http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software +# Unlessf required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and @@ -51,7 +51,12 @@ def create_parameter(shape, is_bias=False, default_initializer=None): """ - Create a parameter + Create a parameter. The parameter is a learnable variable, which can have + gradient, and can be optimized. + + NOTE: this is a very low-level API. This API is useful when you create + operator by your self. instead of using layers. + Args: shape(list[int]): shape of the parameter dtype(string): element type of the parameter @@ -63,7 +68,12 @@ def create_parameter(shape, default_initializer(Initializer): initializer for the parameter Returns: - Parameter: the created parameter + the created parameter. + + Examples: + >>> W = fluid.layers.create_parameter(shape=[784, 200], dtype='float32') + >>> data = fluid.layers.data(name="img", shape=[64, 784], append_batch_size=False) + >>> hidden = fluid.layers.matmul(x=data, y=W) """ helper = LayerHelper("create_parameter", **locals()) if attr is None: @@ -207,6 +217,7 @@ def assign(input, output): Examples: .. code-block:: python + out = fluid.layers.create_tensor(dtype='float32') hidden = fluid.layers.fc(input=data, size=10) fluid.layers.assign(hidden, out) diff --git a/python/paddle/fluid/tests/unittests/test_dist_train.py b/python/paddle/fluid/tests/unittests/test_dist_train.py index 2314bb2ed8a4eeb34752fd5d040f8a8476798aa6..562e66b0625083fe840d64967249f0215cfda1f9 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_train.py +++ b/python/paddle/fluid/tests/unittests/test_dist_train.py @@ -16,6 +16,7 @@ import os import time import unittest from multiprocessing import Process +import signal import numpy @@ -24,9 +25,6 @@ import paddle.fluid.layers as layers class TestSendOp(unittest.TestCase): - @unittest.skip( - "This test is buggy. We cannot use time.sleep to sync processes, the connection may fail in unittest." - ) def test_send(self): # Run init_serv in a thread place = fluid.CPUPlace() @@ -35,7 +33,9 @@ class TestSendOp(unittest.TestCase): p.daemon = True p.start() - time.sleep(10) + self.ps_timeout = 5 + self._wait_ps_ready(p.pid) + with open("/tmp/paddle.%d.port" % p.pid, "r") as fn: selected_port = int(fn.readlines()[0]) self.init_client(place, selected_port) @@ -44,9 +44,23 @@ class TestSendOp(unittest.TestCase): self.assertTrue(numpy.allclose(self.local_out, self.dist_out)) # FIXME(typhoonzero): find a way to gracefully shutdown the server. - os.system("kill -9 %d" % p.pid) + os.kill(p.pid, signal.SIGKILL) p.join() + def _wait_ps_ready(self, pid): + start_left_time = self.ps_timeout + sleep_time = 0.5 + while True: + assert start_left_time >= 0, "wait ps ready failed" + time.sleep(sleep_time) + try: + # the listen_and_serv_op would touch a file which contains the listen port + # on the /tmp directory until it was ready to process all the RPC call. + os.stat("/tmp/paddle.%d.port" % pid) + return + except os.error: + start_left_time -= sleep_time + def init_serv(self, place): main = fluid.Program() @@ -84,7 +98,10 @@ class TestSendOp(unittest.TestCase): dtype="float32", persistable=False, shape=[32, 32]) - o = layers.Send("127.0.0.1:%d" % port, [x], [get_var]) + fluid.initializer.Constant(value=2.3)(get_var, main.global_block()) + layers.Send("127.0.0.1:%d" % port, [x]) + o = layers.Recv("127.0.0.1:%d" % port, [get_var]) + exe = fluid.Executor(place) self.dist_out = exe.run(main, fetch_list=o) # o is a list diff --git a/python/paddle/fluid/tests/unittests/test_listen_and_serv_op.py b/python/paddle/fluid/tests/unittests/test_listen_and_serv_op.py index d1d709551c77908db88be6fda7ac74d4e922138e..9dec2acb1d7101f8f00565c56e0469edb143d0c6 100644 --- a/python/paddle/fluid/tests/unittests/test_listen_and_serv_op.py +++ b/python/paddle/fluid/tests/unittests/test_listen_and_serv_op.py @@ -57,17 +57,18 @@ class TestListenAndServOp(OpTest): def setUp(self): self.ps_timeout = 5 self.ip = "127.0.0.1" - self.port = "6173" + self.port = "0" self.trainers = 1 - self.trainer_id = 1 + self.trainer_id = 0 def _start_pserver(self, use_cuda, sync_mode): p = Process( target=run_pserver, args=(use_cuda, sync_mode, self.ip, self.port, self.trainers, self.trainer_id)) + p.daemon = True p.start() - return p.pid + return p def _wait_ps_ready(self, pid): start_left_time = self.ps_timeout @@ -89,18 +90,20 @@ class TestListenAndServOp(OpTest): def test_handle_signal_in_serv_op(self): # run pserver on CPU in sync mode - pid = self._start_pserver(False, True) - self._wait_ps_ready(pid) + p1 = self._start_pserver(False, True) + self._wait_ps_ready(p1.pid) # raise SIGTERM to pserver - os.kill(pid, signal.SIGTERM) + os.kill(p1.pid, signal.SIGKILL) + p1.join() # run pserver on CPU in async mode - pid = self._start_pserver(False, False) - self._wait_ps_ready(pid) + p2 = self._start_pserver(False, False) + self._wait_ps_ready(p2.pid) # raise SIGTERM to pserver - os.kill(pid, signal.SIGTERM) + os.kill(p2.pid, signal.SIGKILL) + p2.join() if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_crf.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_crf.py index 163975555ec2cea5c169cc1da3c4324d91ba3616..1ea7a6a5682318fb5f4ef8b3a08911df3cd44acf 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_crf.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_crf.py @@ -173,6 +173,7 @@ class TestCRFModel(unittest.TestCase): pe.run(feed=feeder.feed(cur_batch), fetch_list=[avg_cost.name]))[0] + @unittest.skip(reason="CI hangs") def test_update_sparse_parameter_all_reduce(self): build_strategy = fluid.BuildStrategy() build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce @@ -181,6 +182,7 @@ class TestCRFModel(unittest.TestCase): self.check_network_convergence( is_sparse=True, build_strategy=build_strategy, use_cuda=False) + @unittest.skip(reason="CI hangs") def test_update_dense_parameter_all_reduce(self): build_strategy = fluid.BuildStrategy() build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce @@ -189,6 +191,7 @@ class TestCRFModel(unittest.TestCase): self.check_network_convergence( is_sparse=False, build_strategy=build_strategy, use_cuda=False) + @unittest.skip(reason="CI hangs") def test_update_sparse_parameter_reduce(self): build_strategy = fluid.BuildStrategy() build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce @@ -197,6 +200,7 @@ class TestCRFModel(unittest.TestCase): self.check_network_convergence( is_sparse=True, build_strategy=build_strategy, use_cuda=False) + @unittest.skip(reason="CI hangs") def test_update_dense_parameter_reduce(self): build_strategy = fluid.BuildStrategy() build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce diff --git a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py index 9ff0ae6fca27d4681891b2033e2f8f95bd825942..8bfb554845d9b128f000d6c90cf626416a198eef 100644 --- a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py +++ b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py @@ -157,9 +157,11 @@ class ControlFlowGraph(object): if op.type() == "fill_constant" and op.attr("force_cpu") == True: self._skip_opt.update(op.output_arg_names()) - def release_memory(self): + def release_memory(self, skip_opt_set=None): self._dataflow_analyze() self._update_skip_opt_set() + if skip_opt_set: + self._skip_opt.update(skip_opt_set) fwd_id = 0 bwd_id = 0 for i in range(self.op_size): @@ -183,7 +185,7 @@ class ControlFlowGraph(object): else: bwd_id += 1 - def memory_optimize(self, level=0): + def memory_optimize(self, skip_opt_set=None, level=0): def compare_shape(x_shape, cache_shape, opt_level): if opt_level == 0: return x_shape == cache_shape @@ -200,6 +202,9 @@ class ControlFlowGraph(object): self._dataflow_analyze() self._update_skip_opt_set() + # update skip set to meet users' demand + if skip_opt_set: + self._skip_opt.update(skip_opt_set) self.pool = [] for i in range(self.op_size): op = self._ops[i] @@ -358,7 +363,7 @@ def _get_cfgs(input_program): return cfgs -def memory_optimize(input_program, print_log=False, level=0): +def memory_optimize(input_program, skip_opt_set=None, print_log=False, level=0): """Optimize memory by reusing var memory. Note: it doesn't not support subblock nested in subblock. @@ -374,10 +379,10 @@ def memory_optimize(input_program, print_log=False, level=0): PRINT_LOG = print_log cfgs = _get_cfgs(input_program) for cfg in cfgs: - cfg.memory_optimize(level) + cfg.memory_optimize(skip_opt_set=skip_opt_set, level=level) -def release_memory(input_program): +def release_memory(input_program, skip_opt_set=None): cfgs = _get_cfgs(input_program) for cfg in cfgs: - cfg.release_memory() + cfg.release_memory(skip_opt_set=skip_opt_set)