提交 bf5ce626 编写于 作者: Y Yibing Liu

Merge branch 'develop' of upstream into fix_docs

...@@ -64,7 +64,8 @@ class OpConverter { ...@@ -64,7 +64,8 @@ class OpConverter {
(*it)(op, scope, test_mode); (*it)(op, scope, test_mode);
} }
// convert fluid block to tensorrt network // Convert a fluid block to tensorrt network, NOTE it just convert operators,
// the INetwork's inputs and outputs should specified in some other modules.
void ConvertBlock(const framework::proto::BlockDesc& block, void ConvertBlock(const framework::proto::BlockDesc& block,
const std::unordered_set<std::string>& parameters, const std::unordered_set<std::string>& parameters,
const framework::Scope& scope, TensorRTEngine* engine) { const framework::Scope& scope, TensorRTEngine* engine) {
......
...@@ -51,11 +51,12 @@ class TensorRTEngine : public EngineBase { ...@@ -51,11 +51,12 @@ class TensorRTEngine : public EngineBase {
nvinfer1::Weights w_; nvinfer1::Weights w_;
}; };
TensorRTEngine(int max_batch, int max_workspace, cudaStream_t* stream, TensorRTEngine(int max_batch, int max_workspace,
cudaStream_t* stream = nullptr,
nvinfer1::ILogger& logger = NaiveLogger::Global()) nvinfer1::ILogger& logger = NaiveLogger::Global())
: max_batch_(max_batch), : max_batch_(max_batch),
max_workspace_(max_workspace), max_workspace_(max_workspace),
stream_(stream), stream_(stream ? stream : &default_stream_),
logger_(logger) {} logger_(logger) {}
virtual ~TensorRTEngine(); virtual ~TensorRTEngine();
...@@ -121,6 +122,8 @@ class TensorRTEngine : public EngineBase { ...@@ -121,6 +122,8 @@ class TensorRTEngine : public EngineBase {
// the max memory size the engine uses // the max memory size the engine uses
int max_workspace_; int max_workspace_;
cudaStream_t* stream_; cudaStream_t* stream_;
// If stream_ is not set from outside, hold its own stream.
cudaStream_t default_stream_;
nvinfer1::ILogger& logger_; nvinfer1::ILogger& logger_;
std::vector<Buffer> buffers_; std::vector<Buffer> buffers_;
...@@ -165,20 +168,31 @@ class TensorRTEngine : public EngineBase { ...@@ -165,20 +168,31 @@ class TensorRTEngine : public EngineBase {
*/ */
class TRT_EngineManager { class TRT_EngineManager {
public: public:
TensorRTEngine* Create(int max_batch, int max_workspace, bool HasEngine(const std::string& name) const {
cudaStream_t* stream) { return engines_.count(name) != 0;
engines_.emplace_back(new TensorRTEngine(max_batch, max_workspace, stream)); }
return engines_.back().get();
// Get an engine called `name`.
TensorRTEngine* Get(const std::string& name) const {
return engines_.at(name).get();
}
// Create or get an engine called `name`
TensorRTEngine* Create(int max_batch, int max_workspace, cudaStream_t* stream,
const std::string& name) {
auto* p = new TensorRTEngine(max_batch, max_workspace, stream);
engines_[name].reset(p);
return p;
} }
void DeleteALl() { void DeleteALl() {
for (auto& ptr : engines_) { for (auto& item : engines_) {
ptr.reset(nullptr); item.second.reset(nullptr);
} }
} }
private: private:
std::vector<std::unique_ptr<TensorRTEngine>> engines_; std::unordered_map<std::string, std::unique_ptr<TensorRTEngine>> engines_;
}; };
} // namespace tensorrt } // namespace tensorrt
......
...@@ -112,7 +112,7 @@ $$out = \frac{1}{1 + e^{-x}}$$ ...@@ -112,7 +112,7 @@ $$out = \frac{1}{1 + e^{-x}}$$
__attribute__((unused)) constexpr char LogSigmoidDoc[] = R"DOC( __attribute__((unused)) constexpr char LogSigmoidDoc[] = R"DOC(
Logsigmoid Activation Operator Logsigmoid Activation Operator
$$out = \log \frac{1}{1 + e^{-x}}$$ $$out = \\log \\frac{1}{1 + e^{-x}}$$
)DOC"; )DOC";
...@@ -252,15 +252,14 @@ class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -252,15 +252,14 @@ class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "Output of Softshrink operator"); AddOutput("Out", "Output of Softshrink operator");
AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f); AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
AddComment(R"DOC( AddComment(R"DOC(
Softshrink Activation Operator. :strong:`Softshrink Activation Operator`
$$ .. math::
out = \begin{cases} out = \begin{cases}
x - \lambda, \text{if } x > \lambda \\ x - \lambda, \text{if } x > \lambda \\
x + \lambda, \text{if } x < -\lambda \\ x + \lambda, \text{if } x < -\lambda \\
0, \text{otherwise} 0, \text{otherwise}
\end{cases} \end{cases}
$$
)DOC"); )DOC");
} }
......
...@@ -106,23 +106,36 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -106,23 +106,36 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker {
"and M represents the number of deocded boxes."); "and M represents the number of deocded boxes.");
AddComment(R"DOC( AddComment(R"DOC(
Bounding Box Coder Operator.
Bounding Box Coder.
Encode/Decode the target bounding box with the priorbox information. Encode/Decode the target bounding box with the priorbox information.
The Encoding schema described below: The Encoding schema described below:
ox = (tx - px) / pw / pxv
oy = (ty - py) / ph / pyv ox = (tx - px) / pw / pxv
ow = log(abs(tw / pw)) / pwv
oh = log(abs(th / ph)) / phv oy = (ty - py) / ph / pyv
ow = log(abs(tw / pw)) / pwv
oh = log(abs(th / ph)) / phv
The Decoding schema described below: The Decoding schema described below:
ox = (pw * pxv * tx * + px) - tw / 2
oy = (ph * pyv * ty * + py) - th / 2 ox = (pw * pxv * tx * + px) - tw / 2
ow = exp(pwv * tw) * pw + tw / 2
oh = exp(phv * th) * ph + th / 2 oy = (ph * pyv * ty * + py) - th / 2
where tx, ty, tw, th denote the target box's center coordinates, width and
height respectively. Similarly, px, py, pw, ph denote the priorbox's(anchor) ow = exp(pwv * tw) * pw + tw / 2
center coordinates, width and height. pxv, pyv, pwv, phv denote the variance
of the priorbox and ox, oy, ow, oh denote the encoded/decoded coordinates, oh = exp(phv * th) * ph + th / 2
width and height.
where `tx`, `ty`, `tw`, `th` denote the target box's center coordinates, width
and height respectively. Similarly, `px`, `py`, `pw`, `ph` denote the
priorbox's (anchor) center coordinates, width and height. `pxv`, `pyv`, `pwv`,
`phv` denote the variance of the priorbox and `ox`, `oy`, `ow`, `oh` denote the
encoded/decoded coordinates, width and height.
)DOC"); )DOC");
} }
}; };
......
...@@ -36,11 +36,12 @@ class GaussianRandomBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker { ...@@ -36,11 +36,12 @@ class GaussianRandomBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
void Apply() override { void Apply() override {
AddAttr<float>("mean", AddAttr<float>("mean",
"(float, default 0.0) " "(float, default 0.0) "
"mean of random tensor.") "The mean (or center) of the gaussian distribution.")
.SetDefault(.0f); .SetDefault(.0f);
AddAttr<float>("std", AddAttr<float>("std",
"(float, default 1.0) " "(float, default 1.0) "
"std of random tensor.") "The standard deviation (std, or spread) of the "
"gaussian distribution.")
.SetDefault(1.0f); .SetDefault(1.0f);
AddAttr<int>("seed", AddAttr<int>("seed",
"(int, default 0) " "(int, default 0) "
...@@ -55,9 +56,11 @@ class GaussianRandomBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker { ...@@ -55,9 +56,11 @@ class GaussianRandomBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
.SetDefault(framework::proto::VarType::FP32); .SetDefault(framework::proto::VarType::FP32);
AddComment(R"DOC( AddComment(R"DOC(
GaussianRandom Operator.
Used to initialize tensors with gaussian random generator. Used to initialize tensors with gaussian random generator.
The defalut mean of the distribution is 0. and defalut standard
deviation (std) of the distribution is 1.. Uers can set mean and std
by input arguments.
)DOC"); )DOC");
} }
}; };
......
...@@ -348,7 +348,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -348,7 +348,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
}; };
void SignalHandler::StopAndExit(int signal_num) { void SignalHandler::StopAndExit(int signal_num) {
VLOG(3) << "Catch interrupt signal: " << signal_num << ", program will exit"; // Do not use VLOG here for the device for printing maybe already released.
// exit will release interal allocated resoureces.
exit(0); exit(0);
} }
......
...@@ -33,12 +33,10 @@ class MeanOp : public framework::OperatorWithKernel { ...@@ -33,12 +33,10 @@ class MeanOp : public framework::OperatorWithKernel {
class MeanOpMaker : public framework::OpProtoAndCheckerMaker { class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "The input of mean op"); AddInput("X", "(Tensor) The input of mean op");
AddOutput("Out", "The output of mean op").Reuse("X"); AddOutput("Out", "(Tensor) The output of mean op").Reuse("X");
AddComment(R"DOC( AddComment(R"DOC(
Mean Operator. Mean Operator calculates the mean of all elements in X.
Out is a scalar which is the mean of all elements in X.
)DOC"); )DOC");
} }
......
...@@ -66,17 +66,25 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<int64_t> &shape) { ...@@ -66,17 +66,25 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<int64_t> &shape) {
} // namespace } // namespace
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void paddle::operators::TensorRTEngineKernel<DeviceContext, T>::Prepare( void TensorRTEngineKernel<DeviceContext, T>::Prepare(
const framework::ExecutionContext &context) const { const framework::ExecutionContext &context) const {
VLOG(4) << "Prepare engine"; VLOG(4) << "Prepare engine";
// Get the ProgramDesc and pass to convert. // Get the ProgramDesc and pass to convert.
framework::proto::BlockDesc block_desc; framework::proto::BlockDesc block_desc;
block_desc.ParseFromString(context.Attr<std::string>("subgraph")); block_desc.ParseFromString(context.Attr<std::string>("subgraph"));
max_batch_ = context.Attr<int>("max_batch"); int max_batch = context.Attr<int>("max_batch");
auto max_workspace = context.Attr<int>("max_workspace"); auto max_workspace = context.Attr<int>("max_workspace");
engine_ = Singleton<TRT_EngineManager>::Global().Create( auto params = context.Attr<std::vector<std::string>>("parameters");
max_batch_, max_workspace, &stream_); std::unordered_set<std::string> parameters;
engine_->InitNetwork(); for (const auto &param : params) {
parameters.insert(param);
}
// TODO(Superjomn) replace this with a different stream
auto *engine = Singleton<TRT_EngineManager>::Global().Create(
max_batch, max_workspace, nullptr /*engine hold its own stream*/,
context.Attr<std::string>("engine_uniq_key"));
engine->InitNetwork();
framework::BlockDesc block(nullptr /*programdesc*/, &block_desc); framework::BlockDesc block(nullptr /*programdesc*/, &block_desc);
// Add inputs // Add inputs
...@@ -87,24 +95,23 @@ void paddle::operators::TensorRTEngineKernel<DeviceContext, T>::Prepare( ...@@ -87,24 +95,23 @@ void paddle::operators::TensorRTEngineKernel<DeviceContext, T>::Prepare(
PADDLE_ENFORCE_EQ(var->GetType(), FluidDT::VarType_Type_LOD_TENSOR, PADDLE_ENFORCE_EQ(var->GetType(), FluidDT::VarType_Type_LOD_TENSOR,
"TensorRT engine only takes LoDTensor as input"); "TensorRT engine only takes LoDTensor as input");
auto shape = var->GetShape(); auto shape = var->GetShape();
engine_->DeclareInput( engine->DeclareInput(
input, FluidDataType2TRT( input, FluidDataType2TRT(
var->Proto()->type().lod_tensor().tensor().data_type()), var->Proto()->type().lod_tensor().tensor().data_type()),
Vec2TRT_Dims(var->GetShape())); Vec2TRT_Dims(var->GetShape()));
} }
// TODO(Superjomn) parameters should be passed after analysised from outside.
inference::Singleton<inference::tensorrt::OpConverter>::Global().ConvertBlock( inference::Singleton<inference::tensorrt::OpConverter>::Global().ConvertBlock(
block_desc, {}, context.scope(), engine_); block_desc, parameters, context.scope(), engine);
// Add outputs // Add outputs
VLOG(4) << "declare outputs"; VLOG(4) << "declare outputs";
for (auto &output : context.Outputs("Ys")) { for (auto &output : context.Outputs("Ys")) {
VLOG(4) << "declare output " << output; VLOG(4) << "declare output " << output;
engine_->DeclareOutput(output); engine->DeclareOutput(output);
} }
engine_->FreezeNetwork(); engine->FreezeNetwork();
} }
class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker { class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -113,6 +120,7 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -113,6 +120,7 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Xs", "A list of inputs.").AsDuplicable(); AddInput("Xs", "A list of inputs.").AsDuplicable();
AddOutput("Ys", "A list of outputs").AsDuplicable(); AddOutput("Ys", "A list of outputs").AsDuplicable();
AddAttr<std::string>("subgraph", "the subgraph."); AddAttr<std::string>("subgraph", "the subgraph.");
AddAttr<std::string>("engine_uniq_key", "unique key for the TRT engine.");
AddAttr<int>("max_batch", "the maximum batch size."); AddAttr<int>("max_batch", "the maximum batch size.");
AddAttr<int>("max_workspace", "the maximum batch size."); AddAttr<int>("max_workspace", "the maximum batch size.");
AddComment("TensorRT engine operator."); AddComment("TensorRT engine operator.");
......
...@@ -19,10 +19,14 @@ ...@@ -19,10 +19,14 @@
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using inference::Singleton;
using inference::tensorrt::TRT_EngineManager;
class TensorRTEngineOp : public framework::OperatorWithKernel { class TensorRTEngineOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -47,16 +51,18 @@ template <typename DeviceContext, typename T> ...@@ -47,16 +51,18 @@ template <typename DeviceContext, typename T>
class TensorRTEngineKernel : public framework::OpKernel<T> { class TensorRTEngineKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
if (!engine_) { auto engine_name = context.Attr<std::string>("engine_uniq_key");
if (!Singleton<TRT_EngineManager>::Global().HasEngine(engine_name)) {
Prepare(context); Prepare(context);
} }
auto* engine = Singleton<TRT_EngineManager>::Global().Get(engine_name);
auto input_names = context.op().Inputs("Xs"); auto input_names = context.op().Inputs("Xs");
PADDLE_ENFORCE(!input_names.empty(), "should pass more than one inputs"); PADDLE_ENFORCE(!input_names.empty(), "should pass more than one inputs");
// Try to determine a batch_size // Try to determine a batch_size
auto& tensor0 = inference::analysis::GetFromScope<framework::LoDTensor>( auto& tensor0 = inference::analysis::GetFromScope<framework::LoDTensor>(
context.scope(), input_names.front()); context.scope(), input_names.front());
int batch_size = tensor0.dims()[0]; int batch_size = tensor0.dims()[0];
PADDLE_ENFORCE_LE(batch_size, max_batch_); PADDLE_ENFORCE_LE(batch_size, context.Attr<int>("max_batch"));
// Convert input tensor from fluid to engine. // Convert input tensor from fluid to engine.
for (const auto& x : context.Inputs("Xs")) { for (const auto& x : context.Inputs("Xs")) {
...@@ -64,20 +70,20 @@ class TensorRTEngineKernel : public framework::OpKernel<T> { ...@@ -64,20 +70,20 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
auto& t = inference::analysis::GetFromScope<framework::LoDTensor>( auto& t = inference::analysis::GetFromScope<framework::LoDTensor>(
context.scope(), x); context.scope(), x);
if (platform::is_cpu_place(t.place())) { if (platform::is_cpu_place(t.place())) {
engine_->SetInputFromCPU(x, static_cast<const void*>(t.data<void>()), engine->SetInputFromCPU(x, static_cast<const void*>(t.data<void>()),
t.memory_size()); t.memory_size());
} else { } else {
engine_->SetInputFromGPU(x, static_cast<const void*>(t.data<void>()), engine->SetInputFromGPU(x, static_cast<const void*>(t.data<void>()),
t.memory_size()); t.memory_size());
} }
} }
// Execute the engine. // Execute the engine.
PADDLE_ENFORCE_GT(batch_size, 0); PADDLE_ENFORCE_GT(batch_size, 0);
engine_->Execute(batch_size); engine->Execute(batch_size);
// Convert output tensor from engine to fluid // Convert output tensor from engine to fluid
for (const auto& y : context.Outputs("Ys")) { for (const auto& y : context.Outputs("Ys")) {
// convert output and copy to fluid. // convert output and copy to fluid.
nvinfer1::ITensor* trt_t = engine_->GetITensor(y); nvinfer1::ITensor* trt_t = engine->GetITensor(y);
auto dims = trt_t->getDimensions(); auto dims = trt_t->getDimensions();
// Use the output ITensor's dims to reshape the Fluid Tensor. // Use the output ITensor's dims to reshape the Fluid Tensor.
std::vector<int> ddim(dims.d, dims.d + dims.nbDims); std::vector<int> ddim(dims.d, dims.d + dims.nbDims);
...@@ -89,27 +95,22 @@ class TensorRTEngineKernel : public framework::OpKernel<T> { ...@@ -89,27 +95,22 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
auto size = inference::analysis::AccuDims(dims.d, dims.nbDims); auto size = inference::analysis::AccuDims(dims.d, dims.nbDims);
if (platform::is_cpu_place(fluid_t->place())) { if (platform::is_cpu_place(fluid_t->place())) {
// TODO(Superjomn) change this float to dtype size. // TODO(Superjomn) change this float to dtype size.
engine_->GetOutputInCPU( engine->GetOutputInCPU(
y, fluid_t->mutable_data<float>(platform::CPUPlace()), y, fluid_t->mutable_data<float>(platform::CPUPlace()),
size * sizeof(float)); size * sizeof(float));
} else { } else {
engine_->GetOutputInGPU( engine->GetOutputInGPU(
y, fluid_t->mutable_data<float>(platform::CUDAPlace()), y, fluid_t->mutable_data<float>(platform::CUDAPlace()),
size * sizeof(float)); size * sizeof(float));
} }
} }
cudaStreamSynchronize(stream_); cudaStreamSynchronize(*engine->stream());
} }
protected: protected:
// Build the engine. // Build the engine.
void Prepare(const framework::ExecutionContext& context) const; void Prepare(const framework::ExecutionContext& context) const;
private:
mutable cudaStream_t stream_;
mutable inference::tensorrt::TensorRTEngine* engine_{nullptr};
mutable int max_batch_{0};
}; };
} // namespace operators } // namespace operators
......
...@@ -79,6 +79,17 @@ void SetAttr<int64_t>(framework::proto::OpDesc* op, const std::string& name, ...@@ -79,6 +79,17 @@ void SetAttr<int64_t>(framework::proto::OpDesc* op, const std::string& name,
attr->set_type(paddle::framework::proto::AttrType::LONG); attr->set_type(paddle::framework::proto::AttrType::LONG);
attr->set_l(data); attr->set_l(data);
} }
template <>
void SetAttr<std::vector<std::string>>(framework::proto::OpDesc* op,
const std::string& name,
const std::vector<std::string>& data) {
auto* attr = op->add_attrs();
attr->set_name(name);
attr->set_type(paddle::framework::proto::AttrType::STRINGS);
for (const auto& s : data) {
attr->add_strings(s.c_str());
}
}
} // namespace } // namespace
...@@ -123,11 +134,15 @@ TEST(TensorRTEngineOp, manual) { ...@@ -123,11 +134,15 @@ TEST(TensorRTEngineOp, manual) {
engine_op_desc.SetOutput("Ys", std::vector<std::string>({"z0"})); engine_op_desc.SetOutput("Ys", std::vector<std::string>({"z0"}));
SetAttr<std::string>(engine_op_desc.Proto(), "subgraph", SetAttr<std::string>(engine_op_desc.Proto(), "subgraph",
block_->SerializeAsString()); block_->SerializeAsString());
SetAttr<int>(engine_op_desc.Proto(), "max_batch", 30); SetAttr<int>(engine_op_desc.Proto(), "max_batch", 100);
SetAttr<int>(engine_op_desc.Proto(), "max_workspace", 1 << 10); SetAttr<int>(engine_op_desc.Proto(), "max_workspace", 1 << 10);
SetAttr<std::string>(engine_op_desc.Proto(), "engine_uniq_key", "a_engine");
SetAttr<std::vector<std::string>>(engine_op_desc.Proto(), "parameters",
std::vector<std::string>({}));
LOG(INFO) << "create engine op"; LOG(INFO) << "create engine op";
auto engine_op = framework::OpRegistry::CreateOp(*engine_op_desc.Proto()); auto engine_op = framework::OpRegistry::CreateOp(*engine_op_desc.Proto());
LOG(INFO) << "engine_op " << engine_op.get();
framework::Scope scope; framework::Scope scope;
platform::CPUPlace place; platform::CPUPlace place;
...@@ -145,6 +160,88 @@ TEST(TensorRTEngineOp, manual) { ...@@ -145,6 +160,88 @@ TEST(TensorRTEngineOp, manual) {
engine_op->Run(scope, place); engine_op->Run(scope, place);
} }
void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
framework::ProgramDesc program;
framework::Scope scope;
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
auto* block_ = program.Proto()->add_blocks();
block_->set_idx(0);
block_->set_parent_idx(-1);
using shape_t = std::vector<int64_t>;
LOG(INFO) << "create block desc";
framework::BlockDesc block_desc(&program, block_);
auto AddFCLayer = [&](const std::string& x_name, const std::string& y_name,
const std::string& z_name, bool x_created,
const shape_t& x_shape, const shape_t& y_shape,
const shape_t& z_shape) {
LOG(INFO) << "create fc op";
auto* fc = block_desc.AppendOp();
fc->SetType("mul");
fc->SetInput("X", std::vector<std::string>({x_name}));
fc->SetInput("Y", std::vector<std::string>({y_name}));
fc->SetOutput("Out", std::vector<std::string>({z_name}));
// Set inputs' variable shape in BlockDesc
if (!x_created) {
AddTensorToBlockDesc(block_, x_name,
std::vector<int64_t>({batch_size, input_dim, 1, 1}));
}
AddTensorToBlockDesc(block_, y_name,
std::vector<int64_t>({input_dim, output_dim}));
AddTensorToBlockDesc(block_, z_name,
std::vector<int64_t>({batch_size, output_dim}));
// Prepare variables.
if (!x_created) {
CreateCPUTensor(&scope, x_name, std::vector<int64_t>(x_shape));
}
CreateCPUTensor(&scope, y_name, std::vector<int64_t>(y_shape));
CreateCPUTensor(&scope, z_name, std::vector<int64_t>(z_shape));
// It is wired, need to copy manually.
*block_->add_ops() = *fc->Proto();
};
// Test with 4 layer FC
AddFCLayer("x0", "y0", "z0", false, {batch_size, input_dim},
{input_dim, output_dim}, {batch_size, output_dim});
AddFCLayer("z0", "y1", "z1", true, {}, {output_dim, output_dim},
{batch_size, output_dim});
AddFCLayer("z1", "y2", "z2", true, {}, {output_dim, output_dim},
{batch_size, output_dim});
AddFCLayer("z2", "y3", "z3", true, {}, {output_dim, output_dim},
{batch_size, output_dim});
LOG(INFO) << "create tensorrt desc";
framework::OpDesc engine_op_desc(nullptr);
engine_op_desc.SetType("tensorrt_engine");
engine_op_desc.SetInput("Xs", std::vector<std::string>({"x0"}));
engine_op_desc.SetOutput("Ys", std::vector<std::string>({"z3"}));
SetAttr<std::string>(engine_op_desc.Proto(), "subgraph",
block_->SerializeAsString());
SetAttr<int>(engine_op_desc.Proto(), "max_batch", batch_size);
SetAttr<int>(engine_op_desc.Proto(), "max_workspace", 2 << 10);
SetAttr<std::vector<std::string>>(
engine_op_desc.Proto(), "parameters",
std::vector<std::string>({"y0", "y1", "y2", "y3"}));
SetAttr<std::string>(engine_op_desc.Proto(), "engine_uniq_key", "b_engine");
auto engine_op = framework::OpRegistry::CreateOp(*engine_op_desc.Proto());
// Execute them.
engine_op->Run(scope, place);
}
// Test with a larger FC layer.
TEST(TensorRTEngineOp, fc) { Execute(40, 256, 256); }
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -15,11 +15,13 @@ ...@@ -15,11 +15,13 @@
import framework import framework
import numpy as np import numpy as np
import contextlib import contextlib
from framework import convert_np_dtype_to_dtype_
from core import VarDesc
__all__ = [ __all__ = [
'Constant', 'Uniform', 'Normal', 'Xavier', 'force_init_on_cpu', 'Constant', 'Uniform', 'Normal', 'Xavier', 'Bilinear', 'force_init_on_cpu',
'init_on_cpu', 'ConstantInitializer', 'UniformInitializer', 'init_on_cpu', 'ConstantInitializer', 'UniformInitializer',
'NormalInitializer', 'XavierInitializer' 'NormalInitializer', 'XavierInitializer', 'BilinearInitializer'
] ]
_force_init_on_cpu_ = False _force_init_on_cpu_ = False
...@@ -422,6 +424,101 @@ class MSRAInitializer(Initializer): ...@@ -422,6 +424,101 @@ class MSRAInitializer(Initializer):
return op return op
class BilinearInitializer(Initializer):
"""Implements the bilinear initializer.
This initializer can be used in transposed convolution operator to
act as upsampling. Users can upsample a feature map with shape of
(B, C, H, W) by any integer factor. The usage is:
>>> factor = 2
>>> w_attr = ParamAttr(learning_rate=0., regularizer=L2Decay(0.),
>>> initializer=Bilinear())
>>> conv_up = fluid.layers.conv2d_transpose(
>>> input,
>>> num_filters=C,
>>> output_size=None,
>>> filter_size=2 * factor - factor % 2,
>>> padding=ceil((factor - 1) / 2.),
>>> stride=factor,
>>> groups=C,
>>> param_attr=w_attr,
>>> bias_attr=False)
Where, `num_filters=C` and `groups=C` means this is channel-wise tranposed
convolution. The filter shape will be (C, 1, K, K) where K is `filer_size`,
This initializer will set a (K, K) interpolation kernel for every channel
of the filter identically. The resulting shape of the output feature map
will be (B, C, factor * H, factor * W). Note that the learning rate and the
weight decay are set to 0 in order to keep coefficient values of bilinear
interpolation unchanged during training.
"""
def __init__(self):
"""Constructor for BilinearInitializer.
"""
super(BilinearInitializer, self).__init__()
def __call__(self, var, block):
"""Add biliear initialization ops for a variable
Args:
var (Variable): Variable that needs to be initialized.
block (Block): The block in which initialization ops should
be added.
Returns:
the initialization op
Raises:
ValueError: If type of `var` and `block` is not right.
If the shape of `var` size is not 4 and
var.shape[2] != var.shape[3].
"""
if not isinstance(var, framework.Variable):
raise ValueError("var must be framework.Variable.")
if not isinstance(block, framework.Block):
raise ValueError("block must be framework.Block.")
shape = var.shape
if len(shape) != 4:
raise ValueError("the length of shape must be 4.")
if shape[2] != shape[3]:
raise ValueError("shape[2] must be equal to shape[3].")
weight = np.zeros(np.prod(var.shape), dtype='float32')
size = shape[3]
# factor
f = np.ceil(size / 2.)
# center
c = (2 * f - 1 - f % 2) / (2. * f)
for i in range(np.prod(shape)):
x = i % size
y = (i / size) % size
weight[i] = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
weight = np.reshape(weight, shape)
if var.dtype == VarDesc.VarType.FP32:
value_name = "fp32_values"
values = [float(v) for v in weight.flat]
else:
raise ValueError("Unsupported dtype %s", input.dtype)
if np.prod(shape) > 1024 * 1024:
raise ValueError("The size of input is too big. ")
op = block.append_op(
type='assign_value',
outputs={'Out': [var]},
attrs={
'dtype': var.dtype,
'shape': list(shape),
value_name: values
})
var.op = op
return op
# We short the class name, since users will use the initializer with the package # We short the class name, since users will use the initializer with the package
# name. The sample code: # name. The sample code:
# #
...@@ -436,3 +533,4 @@ Uniform = UniformInitializer ...@@ -436,3 +533,4 @@ Uniform = UniformInitializer
Normal = NormalInitializer Normal = NormalInitializer
Xavier = XavierInitializer Xavier = XavierInitializer
MSRA = MSRAInitializer MSRA = MSRAInitializer
Bilinear = BilinearInitializer
...@@ -707,7 +707,7 @@ def lod_rank_table(x, level=0): ...@@ -707,7 +707,7 @@ def lod_rank_table(x, level=0):
.. code-block:: python .. code-block:: python
x = fluid.layers.data(name='x', shape=[10], x = fluid.layers.data(name='x', shape=[10],
dtype='float32', lod_level=1) dtype='float32', lod_level=1)
out = layers.lod_rank_table(x=x, level=0) out = layers.lod_rank_table(x=x, level=0)
""" """
helper = LayerHelper("lod_rank_table", **locals()) helper = LayerHelper("lod_rank_table", **locals())
......
...@@ -22,9 +22,9 @@ from ..executor import global_scope ...@@ -22,9 +22,9 @@ from ..executor import global_scope
from layer_function_generator import generate_layer_fn, templatedoc from layer_function_generator import generate_layer_fn, templatedoc
__all__ = [ __all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file', 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'Recv',
'open_files', 'read_file', 'shuffle', 'batch', 'double_buffer', 'open_recordio_file', 'open_files', 'read_file', 'shuffle', 'batch',
'random_data_generator', 'Preprocessor', 'load' 'double_buffer', 'random_data_generator', 'Preprocessor', 'load'
] ]
...@@ -177,18 +177,17 @@ class ListenAndServ(object): ...@@ -177,18 +177,17 @@ class ListenAndServ(object):
}) })
def Send(endpoints, send_vars, get_vars=None): def Send(endpoints, send_vars, sync=True):
""" """
Send layer Send variables to the server side, and get vars from server
side when server have finished running server side program.
Args: Args:
endpoints: comma seperated IP:PORT pairs in the order endpoints (str): comma seperated IP:PORT pairs in the order
of send_vars to send of send_vars to send
send_vars: vars to send send_vars (list): variables to send to server
get_vars: vars to get from server after send completes. sync (bool): whether to wait the request finish
Send variables to the server side, and get vars from server
side when server have finished running server side program.
""" """
assert (type(send_vars) == list) assert (type(send_vars) == list)
...@@ -196,40 +195,33 @@ def Send(endpoints, send_vars, get_vars=None): ...@@ -196,40 +195,33 @@ def Send(endpoints, send_vars, get_vars=None):
endpoints = list(set(epmap)) endpoints = list(set(epmap))
helper = LayerHelper("Send", **locals()) helper = LayerHelper("Send", **locals())
if not get_vars:
get_vars = []
for s in send_vars:
v = helper.create_tmp_variable(dtype=s.dtype, stop_gradient=True)
get_vars.append(v)
rpc_op_role_name = core.op_proto_and_checker_maker.kOpRoleAttrName() rpc_op_role_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
helper.append_op( helper.append_op(
type="send", type="send",
inputs={"X": send_vars}, inputs={"X": send_vars},
outputs={"Out": get_vars},
attrs={ attrs={
"endpoints": endpoints, "endpoints": endpoints,
"epmap": epmap, "epmap": epmap,
rpc_op_role_name: core.op_proto_and_checker_maker.OpRole.RPC rpc_op_role_name: core.op_proto_and_checker_maker.OpRole.RPC
}) })
if sync:
return get_vars helper.append_op(type="send_barrier", attrs={"endpoints": endpoints})
def Recv(endpoints, get_vars): def Recv(endpoints, get_vars, sync=True):
""" """
Recv layer Receive variables from server side
Args: Args:
endpoints: comma seperated IP:PORT pairs in the order endpoints (str): comma seperated IP:PORT pairs in the order
of send_vars to send of send_vars to send
send_vars: vars to send get_vars (list): vars to get from server after send completes.
get_vars: vars to get from server after send completes. sync (bool): whether to wait the request finish
Send variables to the server side, and get vars from server Returns:
side when server have finished running server side program. list: list of received variables
""" """
assert (type(send_vars) == list)
assert (type(get_vars) == list) assert (type(get_vars) == list)
epmap = endpoints.split(",") epmap = endpoints.split(",")
...@@ -242,6 +234,9 @@ def Recv(endpoints, get_vars): ...@@ -242,6 +234,9 @@ def Recv(endpoints, get_vars):
outputs={"Out": get_vars}, outputs={"Out": get_vars},
attrs={"endpoints": endpoints, attrs={"endpoints": endpoints,
"epmap": epmap}) "epmap": epmap})
if sync:
helper.append_op(type="fetch_barrier", attrs={"endpoints": endpoints})
return get_vars
def monkey_patch_reader_methods(reader): def monkey_patch_reader_methods(reader):
...@@ -383,16 +378,16 @@ def random_data_generator(low, high, shapes, lod_levels, for_parallel=True): ...@@ -383,16 +378,16 @@ def random_data_generator(low, high, shapes, lod_levels, for_parallel=True):
Variable: A Reader Variable from which we can get random data. Variable: A Reader Variable from which we can get random data.
Examples: Examples:
.. code-block:: python
reader = fluid.layers.io.random_data_generator( .. code-block:: python
low=0.0,
high=1.0,
shapes=[(3,224,224), (1)],
lod_levels=[0, 0])
# Via the reader, we can use 'read_file' layer to get data: reader = fluid.layers.random_data_generator(
image, label = fluid.layers.io.read_file(reader) low=0.0,
high=1.0,
shapes=[[3,224,224], [1]],
lod_levels=[0, 0])
# Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.read_file(reader)
""" """
dtypes = [core.VarDesc.VarType.FP32] * len(shapes) dtypes = [core.VarDesc.VarType.FP32] * len(shapes)
shape_concat = [] shape_concat = []
...@@ -541,6 +536,9 @@ def __create_unshared_decorated_reader__(op_type, reader, attrs, name=None): ...@@ -541,6 +536,9 @@ def __create_unshared_decorated_reader__(op_type, reader, attrs, name=None):
def shuffle(reader, buffer_size): def shuffle(reader, buffer_size):
"""
Shuffle the reader.
"""
return __create_unshared_decorated_reader__( return __create_unshared_decorated_reader__(
'create_shuffle_reader', reader, {'buffer_size': int(buffer_size)}) 'create_shuffle_reader', reader, {'buffer_size': int(buffer_size)})
......
...@@ -44,6 +44,11 @@ def _type_to_str_(tp): ...@@ -44,6 +44,11 @@ def _type_to_str_(tp):
return framework_pb2.AttrType.Name(tp) return framework_pb2.AttrType.Name(tp)
_two_dollar_pattern_ = re.compile(r"\$\$([^\$]+)\$\$")
_single_dollar_pattern_ = re.compile(r"\$([^\$]+)\$")
_two_bang_pattern_ = re.compile(r"!!([^!]+)!!")
def _generate_doc_string_(op_proto): def _generate_doc_string_(op_proto):
""" """
Generate docstring by OpProto Generate docstring by OpProto
...@@ -55,22 +60,26 @@ def _generate_doc_string_(op_proto): ...@@ -55,22 +60,26 @@ def _generate_doc_string_(op_proto):
str: the document string str: the document string
""" """
def escape_math(text):
return _two_bang_pattern_.sub(
r'$$\1$$',
_single_dollar_pattern_.sub(
r':math:`\1`', _two_dollar_pattern_.sub(r"!!\1!!", text)))
if not isinstance(op_proto, framework_pb2.OpProto): if not isinstance(op_proto, framework_pb2.OpProto):
raise TypeError("OpProto should be `framework_pb2.OpProto`") raise TypeError("OpProto should be `framework_pb2.OpProto`")
buf = cStringIO.StringIO() buf = cStringIO.StringIO()
buf.write(op_proto.comment) buf.write(escape_math(op_proto.comment))
buf.write('\nArgs:\n') buf.write('\nArgs:\n')
for each_input in op_proto.inputs: for each_input in op_proto.inputs:
line_begin = ' {0}: '.format(_convert_(each_input.name)) line_begin = ' {0}: '.format(_convert_(each_input.name))
buf.write(line_begin) buf.write(line_begin)
buf.write(each_input.comment) buf.write(escape_math(each_input.comment))
buf.write('\n') if each_input.duplicable:
buf.write(' ' * len(line_begin)) buf.write(" Duplicatable.")
buf.write('Duplicable: ') if each_input.dispensable:
buf.write(str(each_input.duplicable)) buf.write(" Optional.")
buf.write(' Optional: ')
buf.write(str(each_input.dispensable))
buf.write('\n') buf.write('\n')
skip_attrs = OpProtoHolder.generated_op_attr_names() skip_attrs = OpProtoHolder.generated_op_attr_names()
...@@ -83,7 +92,7 @@ def _generate_doc_string_(op_proto): ...@@ -83,7 +92,7 @@ def _generate_doc_string_(op_proto):
buf.write(' (') buf.write(' (')
buf.write(_type_to_str_(each_attr.type)) buf.write(_type_to_str_(each_attr.type))
buf.write('): ') buf.write('): ')
buf.write(each_attr.comment) buf.write(escape_math(each_attr.comment))
buf.write('\n') buf.write('\n')
if len(op_proto.outputs) != 0: if len(op_proto.outputs) != 0:
...@@ -92,7 +101,7 @@ def _generate_doc_string_(op_proto): ...@@ -92,7 +101,7 @@ def _generate_doc_string_(op_proto):
for each_opt in op_proto.outputs: for each_opt in op_proto.outputs:
if not each_opt.intermediate: if not each_opt.intermediate:
break break
buf.write(each_opt.comment) buf.write(escape_math(each_opt.comment))
return buf.getvalue() return buf.getvalue()
......
...@@ -225,11 +225,11 @@ def embedding(input, ...@@ -225,11 +225,11 @@ def embedding(input,
have two elements which indicate the size of the dictionary of have two elements which indicate the size of the dictionary of
embeddings and the size of each embedding vector respectively. embeddings and the size of each embedding vector respectively.
is_sparse(bool): The flag indicating whether to use sparse update. is_sparse(bool): The flag indicating whether to use sparse update.
is_distributed (bool): Whether to run lookup table from remote parameter server. is_distributed(bool): Whether to run lookup table from remote parameter server.
padding_idx(int|long|None): If :attr:`None`, it makes no effect to lookup. padding_idx(int|long|None): If :attr:`None`, it makes no effect to lookup.
Otherwise the given :attr:`padding_idx` indicates padding the output Otherwise the given :attr:`padding_idx` indicates padding the output
with zeros whenever lookup encounters it in :attr:`input`. If with zeros whenever lookup encounters it in :attr:`input`. If
:math:`padding_idx < 0`, the padding_idx to use in lookup is :math:`padding_idx < 0`, the :attr:`padding_idx` to use in lookup is
:math:`size[0] + dim`. :math:`size[0] + dim`.
param_attr(ParamAttr): Parameters for this layer param_attr(ParamAttr): Parameters for this layer
dtype(np.dtype|core.VarDesc.VarType|str): The type of data : float32, float_16, int etc dtype(np.dtype|core.VarDesc.VarType|str): The type of data : float32, float_16, int etc
...@@ -364,8 +364,7 @@ def dynamic_lstm(input, ...@@ -364,8 +364,7 @@ def dynamic_lstm(input,
cell_activation(str): The activation for cell output. Choices = ["sigmoid", cell_activation(str): The activation for cell output. Choices = ["sigmoid",
"tanh", "relu", "identity"], default "tanh". "tanh", "relu", "identity"], default "tanh".
candidate_activation(str): The activation for candidate hidden state. candidate_activation(str): The activation for candidate hidden state.
Choices = ["sigmoid", "tanh", Choices = ["sigmoid", "tanh", "relu", "identity"],
"relu", "identity"],
default "tanh". default "tanh".
dtype(str): Data type. Choices = ["float32", "float64"], default "float32". dtype(str): Data type. Choices = ["float32", "float64"], default "float32".
name(str|None): A name for this layer(optional). If set None, the layer name(str|None): A name for this layer(optional). If set None, the layer
...@@ -540,27 +539,31 @@ def dynamic_lstmp(input, ...@@ -540,27 +539,31 @@ def dynamic_lstmp(input,
cell_activation(str): The activation for cell output. Choices = ["sigmoid", cell_activation(str): The activation for cell output. Choices = ["sigmoid",
"tanh", "relu", "identity"], default "tanh". "tanh", "relu", "identity"], default "tanh".
candidate_activation(str): The activation for candidate hidden state. candidate_activation(str): The activation for candidate hidden state.
Choices = ["sigmoid", "tanh", Choices = ["sigmoid", "tanh", "relu", "identity"],
"relu", "identity"],
default "tanh". default "tanh".
proj_activation(str): The activation for projection output. proj_activation(str): The activation for projection output.
Choices = ["sigmoid", "tanh", Choices = ["sigmoid", "tanh", "relu", "identity"],
"relu", "identity"],
default "tanh". default "tanh".
dtype(str): Data type. Choices = ["float32", "float64"], default "float32". dtype(str): Data type. Choices = ["float32", "float64"], default "float32".
name(str|None): A name for this layer(optional). If set None, the layer name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically. will be named automatically.
Returns: Returns:
tuple: The projection of hidden state, and cell state of LSTMP. The \ tuple: A tuple of two output variable: the projection of hidden state, \
shape of projection is (T x P), for the cell state which is \ and cell state of LSTMP. The shape of projection is (T x P), \
(T x D), and both LoD is the same with the `input`. for the cell state which is (T x D), and both LoD is the same \
with the `input`.
Examples: Examples:
.. code-block:: python .. code-block:: python
dict_dim, emb_dim = 128, 64
data = fluid.layers.data(name='sequence', shape=[1],
dtype='int32', lod_level=1)
emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim])
hidden_dim, proj_dim = 512, 256 hidden_dim, proj_dim = 512, 256
fc_out = fluid.layers.fc(input=input_seq, size=hidden_dim * 4, fc_out = fluid.layers.fc(input=emb, size=hidden_dim * 4,
act=None, bias_attr=None) act=None, bias_attr=None)
proj_out, _ = fluid.layers.dynamic_lstmp(input=fc_out, proj_out, _ = fluid.layers.dynamic_lstmp(input=fc_out,
size=hidden_dim * 4, size=hidden_dim * 4,
...@@ -626,10 +629,10 @@ def dynamic_gru(input, ...@@ -626,10 +629,10 @@ def dynamic_gru(input,
candidate_activation='tanh', candidate_activation='tanh',
h_0=None): h_0=None):
""" """
**Dynamic GRU Layer** **Gated Recurrent Unit (GRU) Layer**
Refer to `Empirical Evaluation of Gated Recurrent Neural Networks on Refer to `Empirical Evaluation of Gated Recurrent Neural Networks on
Sequence Modeling <https://arxiv.org/abs/1412.3555>`_ Sequence Modeling <https://arxiv.org/abs/1412.3555>`_ .
The formula is as follows: The formula is as follows:
...@@ -676,17 +679,25 @@ def dynamic_gru(input, ...@@ -676,17 +679,25 @@ def dynamic_gru(input,
Choices = ["sigmoid", "tanh", "relu", "identity"], default "sigmoid". Choices = ["sigmoid", "tanh", "relu", "identity"], default "sigmoid".
candidate_activation(str): The activation for candidate hidden state. candidate_activation(str): The activation for candidate hidden state.
Choices = ["sigmoid", "tanh", "relu", "identity"], default "tanh". Choices = ["sigmoid", "tanh", "relu", "identity"], default "tanh".
h_0 (Variable): The hidden output of the first time step. h_0 (Variable): This is initial hidden state. If not set, default is
zero. This is a tensor with shape (N x D), where N is the number of
total time steps of input mini-batch feature and D is the hidden
size.
Returns: Returns:
Variable: The hidden state of GRU. The shape is :math:`(T \\times D)`, \ Variable: The hidden state of GRU. The shape is :math:`(T \\times D)`, \
and lod is the same with the input. and sequence length is the same with the input.
Examples: Examples:
.. code-block:: python .. code-block:: python
dict_dim, emb_dim = 128, 64
data = fluid.layers.data(name='sequence', shape=[1],
dtype='int32', lod_level=1)
emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim])
hidden_dim = 512 hidden_dim = 512
x = fluid.layers.fc(input=data, size=hidden_dim * 3) x = fluid.layers.fc(input=emb, size=hidden_dim * 3)
hidden = fluid.layers.dynamic_gru(input=x, dim=hidden_dim) hidden = fluid.layers.dynamic_gru(input=x, dim=hidden_dim)
""" """
...@@ -924,13 +935,13 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None): ...@@ -924,13 +935,13 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
Drop or keep each element of `x` independently. Dropout is a regularization Drop or keep each element of `x` independently. Dropout is a regularization
technique for reducing overfitting by preventing neuron co-adaption during technique for reducing overfitting by preventing neuron co-adaption during
training. The dropout operator randomly set (according to the given dropout training. The dropout operator randomly sets (according to the given dropout
probability) the outputs of some units to zero, while others are remain probability) the outputs of some units to zero, while others are remain
unchanged. unchanged.
Args: Args:
x (Variable): The input tensor. x (Variable): The input tensor variable.
dropout_prob (float): Probability of setting units to zero. dropout_prob (float): Probability of setting units to zero.
is_test (bool): A flag indicating whether it is in test phrase or not. is_test (bool): A flag indicating whether it is in test phrase or not.
seed (int): A Python integer used to create random seeds. If this seed (int): A Python integer used to create random seeds. If this
parameter is set to None, a random seed is used. parameter is set to None, a random seed is used.
...@@ -940,13 +951,14 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None): ...@@ -940,13 +951,14 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
will be named automatically. will be named automatically.
Returns: Returns:
Variable: A tensor variable. Variable: A tensor variable is the shape with `x`.
Examples: Examples:
.. code-block:: python .. code-block:: python
x = fluid.layers.data(name="data", shape=[32, 32], dtype="float32") x = fluid.layers.data(name="data", shape=[32, 32], dtype="float32")
droped = fluid.layers.dropout(input=x, dropout_rate=0.5) droped = fluid.layers.dropout(x, dropout_prob=0.5)
""" """
helper = LayerHelper('dropout', **locals()) helper = LayerHelper('dropout', **locals())
...@@ -1235,14 +1247,17 @@ def conv2d(input, ...@@ -1235,14 +1247,17 @@ def conv2d(input,
act=None, act=None,
name=None): name=None):
""" """
**Convlution2D Layer**
The convolution2D layer calculates the output based on the input, filter The convolution2D layer calculates the output based on the input, filter
and strides, paddings, dilations, groups parameters. Input(Input) and and strides, paddings, dilations, groups parameters. Input and
Output(Output) are in NCHW format. Where N is batch size, C is the number of Output are in NCHW format, where N is batch size, C is the number of
channels, H is the height of the feature, and W is the width of the feature. channels, H is the height of the feature, and W is the width of the feature.
The details of convolution layer, please refer UFLDL's `convolution, Filter is in MCHW format, where M is the number of output image channels,
<http://ufldl.stanford.edu/tutorial/supervised/FeatureExtractionUsingConvolution/>`_ . C is the number of input image channels, H is the height of the filter,
and W is the width of the filter. If the groups is greater than 1,
C will equal the number of input image channels divided by the groups.
Please refer to UFLDL's `convolution
<http://ufldl.stanford.edu/tutorial/supervised/FeatureExtractionUsingConvolution/>`_
for more detials.
If bias attribution and activation type are provided, bias is added to the If bias attribution and activation type are provided, bias is added to the
output of the convolution, and the corresponding activation function is output of the convolution, and the corresponding activation function is
applied to the final result. applied to the final result.
...@@ -1253,15 +1268,14 @@ def conv2d(input, ...@@ -1253,15 +1268,14 @@ def conv2d(input,
Out = \sigma (W \\ast X + b) Out = \sigma (W \\ast X + b)
In the above equation: Where:
* :math:`X`: Input value, a tensor with NCHW format. * :math:`X`: Input value, a tensor with NCHW format.
* :math:`W`: Filter value, a tensor with MCHW format. * :math:`W`: Filter value, a tensor with MCHW format.
* :math:`\\ast`: Convolution operation. * :math:`\\ast`: Convolution operation.
* :math:`b`: Bias value, a 2-D tensor with shape [M, 1]. * :math:`b`: Bias value, a 2-D tensor with shape [M, 1].
* :math:`\\sigma`: Activation function. * :math:`\\sigma`: Activation function.
* :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
different.
Example: Example:
...@@ -1272,6 +1286,7 @@ def conv2d(input, ...@@ -1272,6 +1286,7 @@ def conv2d(input,
Filter shape: :math:`(C_{out}, C_{in}, H_f, W_f)` Filter shape: :math:`(C_{out}, C_{in}, H_f, W_f)`
- Output: - Output:
Output shape: :math:`(N, C_{out}, H_{out}, W_{out})` Output shape: :math:`(N, C_{out}, H_{out}, W_{out})`
Where Where
...@@ -1283,7 +1298,7 @@ def conv2d(input, ...@@ -1283,7 +1298,7 @@ def conv2d(input,
Args: Args:
input (Variable): The input image with [N, C, H, W] format. input (Variable): The input image with [N, C, H, W] format.
num_filters(int): The number of filter. It is as same as the output num_filters(int): The number of filter. It is as same as the output
image channel. image channel.
filter_size (int|tuple|None): The filter size. If filter_size is a tuple, filter_size (int|tuple|None): The filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_H, filter_size_W). it must contain two integers, (filter_size_H, filter_size_W).
...@@ -1306,7 +1321,8 @@ def conv2d(input, ...@@ -1306,7 +1321,8 @@ def conv2d(input,
bias_attr (ParamAttr): Bias parameter for the Conv2d layer. Default: None bias_attr (ParamAttr): Bias parameter for the Conv2d layer. Default: None
use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True library is installed. Default: True
use_mkldnn (bool): Use mkldnn kernels or not. use_mkldnn (bool): Use mkldnn kernels or not, it is valid only when compiled
with mkldnn library. Default: False
act (str): Activation type. Default: None act (str): Activation type. Default: None
name (str|None): A name for this layer(optional). If set None, the layer name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. will be named automatically.
...@@ -2987,32 +3003,33 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None): ...@@ -2987,32 +3003,33 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None):
norm. For a 1-D tensor (`dim` is fixed to 0), this layer computes norm. For a 1-D tensor (`dim` is fixed to 0), this layer computes
.. math:: .. math::
y = \frac{x}{ \sqrt{\sum {x^2} + epsion }}
y = \\frac{x}{ \sqrt{\sum {x^2} + epsion }}
For `x` with more dimensions, this layer independently normalizes each 1-D For `x` with more dimensions, this layer independently normalizes each 1-D
slice along dimension `axis`. slice along dimension `axis`.
Args: Args:
x(Variable|list): The input tensor to l2_normalize layer. x(Variable|list): The input tensor to l2_normalize layer.
axis(int): The axis on which to apply normalization. If `axis < 0`, axis(int): The axis on which to apply normalization. If `axis < 0`, \
the dimension to normalization is rank(X) + axis. -1 is the the dimension to normalization is rank(X) + axis. -1 is the
last dimension. last dimension.
epsilon(float): The epsilon value is used to avoid division by zero, epsilon(float): The epsilon value is used to avoid division by zero, \
the defalut value is 1e-10. the defalut value is 1e-10.
name(str|None): A name for this layer(optional). If set None, the layer name(str|None): A name for this layer(optional). If set None, the layer \
will be named automatically. will be named automatically.
Returns: Returns:
Variable: The output tensor variable. Variable: The output tensor variable is the same shape with `x`.
Examples: Examples:
.. code-block:: python .. code-block:: python
data = fluid.layers.data(name="data", data = fluid.layers.data(name="data",
shape=(3, 17, 13), shape=(3, 17, 13),
dtype="float32") dtype="float32")
normed = fluid.layers.l2_normalize(x=data, axis=1) normed = fluid.layers.l2_normalize(x=data, axis=1)
""" """
if len(x.shape) == 1: if len(x.shape) == 1:
......
...@@ -214,6 +214,7 @@ def assign(input, output): ...@@ -214,6 +214,7 @@ def assign(input, output):
Examples: Examples:
.. code-block:: python .. code-block:: python
out = fluid.layers.create_tensor(dtype='float32') out = fluid.layers.create_tensor(dtype='float32')
hidden = fluid.layers.fc(input=data, size=10) hidden = fluid.layers.fc(input=data, size=10)
fluid.layers.assign(hidden, out) fluid.layers.assign(hidden, out)
...@@ -509,11 +510,27 @@ def save_combine(x, file_path, overwrite=True): ...@@ -509,11 +510,27 @@ def save_combine(x, file_path, overwrite=True):
Saves a list of variables into a single file. Saves a list of variables into a single file.
Args: Args:
x(list): A list of Tensor/LoDTensor to be saved together in a single file. x(list): A list of Tensor/LoDTensor variables to be saved together in
a single file.
file_path(str): The file path where variables will be saved. file_path(str): The file path where variables will be saved.
overwrite(bool): Whether or not cover the given file when it has already overwrite(bool): Whether or not cover the given file when it has already
existed. If it's set 'False' and the file is existed, a runtime existed. If it's set 'False' and the file is existed, a runtime
error will be thrown. error will be thrown.
Returns:
There is no return value.
Examples:
.. code-block:: python
v1 = fluid.layers.data(name="data",
shape=(4, 6),
dtype="float32")
v2 = fluid.layers.data(name="data",
shape=(6, 8, 4),
dtype="float32")
normed = fluid.layers.save_combine([v1, v2], file_path="output")
""" """
helper = LayerHelper("save_combine", **locals()) helper = LayerHelper("save_combine", **locals())
helper.append_op( helper.append_op(
......
...@@ -16,6 +16,7 @@ import os ...@@ -16,6 +16,7 @@ import os
import time import time
import unittest import unittest
from multiprocessing import Process from multiprocessing import Process
import signal
import numpy import numpy
...@@ -24,9 +25,6 @@ import paddle.fluid.layers as layers ...@@ -24,9 +25,6 @@ import paddle.fluid.layers as layers
class TestSendOp(unittest.TestCase): class TestSendOp(unittest.TestCase):
@unittest.skip(
"This test is buggy. We cannot use time.sleep to sync processes, the connection may fail in unittest."
)
def test_send(self): def test_send(self):
# Run init_serv in a thread # Run init_serv in a thread
place = fluid.CPUPlace() place = fluid.CPUPlace()
...@@ -35,7 +33,9 @@ class TestSendOp(unittest.TestCase): ...@@ -35,7 +33,9 @@ class TestSendOp(unittest.TestCase):
p.daemon = True p.daemon = True
p.start() p.start()
time.sleep(10) self.ps_timeout = 5
self._wait_ps_ready(p.pid)
with open("/tmp/paddle.%d.port" % p.pid, "r") as fn: with open("/tmp/paddle.%d.port" % p.pid, "r") as fn:
selected_port = int(fn.readlines()[0]) selected_port = int(fn.readlines()[0])
self.init_client(place, selected_port) self.init_client(place, selected_port)
...@@ -44,9 +44,23 @@ class TestSendOp(unittest.TestCase): ...@@ -44,9 +44,23 @@ class TestSendOp(unittest.TestCase):
self.assertTrue(numpy.allclose(self.local_out, self.dist_out)) self.assertTrue(numpy.allclose(self.local_out, self.dist_out))
# FIXME(typhoonzero): find a way to gracefully shutdown the server. # FIXME(typhoonzero): find a way to gracefully shutdown the server.
os.system("kill -9 %d" % p.pid) os.kill(p.pid, signal.SIGKILL)
p.join() p.join()
def _wait_ps_ready(self, pid):
start_left_time = self.ps_timeout
sleep_time = 0.5
while True:
assert start_left_time >= 0, "wait ps ready failed"
time.sleep(sleep_time)
try:
# the listen_and_serv_op would touch a file which contains the listen port
# on the /tmp directory until it was ready to process all the RPC call.
os.stat("/tmp/paddle.%d.port" % pid)
return
except os.error:
start_left_time -= sleep_time
def init_serv(self, place): def init_serv(self, place):
main = fluid.Program() main = fluid.Program()
...@@ -84,7 +98,10 @@ class TestSendOp(unittest.TestCase): ...@@ -84,7 +98,10 @@ class TestSendOp(unittest.TestCase):
dtype="float32", dtype="float32",
persistable=False, persistable=False,
shape=[32, 32]) shape=[32, 32])
o = layers.Send("127.0.0.1:%d" % port, [x], [get_var]) fluid.initializer.Constant(value=2.3)(get_var, main.global_block())
layers.Send("127.0.0.1:%d" % port, [x])
o = layers.Recv("127.0.0.1:%d" % port, [get_var])
exe = fluid.Executor(place) exe = fluid.Executor(place)
self.dist_out = exe.run(main, fetch_list=o) # o is a list self.dist_out = exe.run(main, fetch_list=o) # o is a list
......
...@@ -364,5 +364,22 @@ class TestMSRAInitializer(unittest.TestCase): ...@@ -364,5 +364,22 @@ class TestMSRAInitializer(unittest.TestCase):
self.assertEqual(init_op.attr('seed'), 134) self.assertEqual(init_op.attr('seed'), 134)
class TestMSRAInitializer(unittest.TestCase):
def test_bilinear_initializer(self):
"""Test the bilinear initializer with supplied arguments
"""
program = framework.Program()
block = program.global_block()
block.create_parameter(
dtype="float32",
shape=[8, 1, 3, 3],
lod_level=0,
name="param",
initializer=initializer.BilinearInitializer())
self.assertEqual(len(block.ops), 1)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'assign_value')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -57,17 +57,18 @@ class TestListenAndServOp(OpTest): ...@@ -57,17 +57,18 @@ class TestListenAndServOp(OpTest):
def setUp(self): def setUp(self):
self.ps_timeout = 5 self.ps_timeout = 5
self.ip = "127.0.0.1" self.ip = "127.0.0.1"
self.port = "6173" self.port = "0"
self.trainers = 1 self.trainers = 1
self.trainer_id = 1 self.trainer_id = 0
def _start_pserver(self, use_cuda, sync_mode): def _start_pserver(self, use_cuda, sync_mode):
p = Process( p = Process(
target=run_pserver, target=run_pserver,
args=(use_cuda, sync_mode, self.ip, self.port, self.trainers, args=(use_cuda, sync_mode, self.ip, self.port, self.trainers,
self.trainer_id)) self.trainer_id))
p.daemon = True
p.start() p.start()
return p.pid return p
def _wait_ps_ready(self, pid): def _wait_ps_ready(self, pid):
start_left_time = self.ps_timeout start_left_time = self.ps_timeout
...@@ -89,18 +90,20 @@ class TestListenAndServOp(OpTest): ...@@ -89,18 +90,20 @@ class TestListenAndServOp(OpTest):
def test_handle_signal_in_serv_op(self): def test_handle_signal_in_serv_op(self):
# run pserver on CPU in sync mode # run pserver on CPU in sync mode
pid = self._start_pserver(False, True) p1 = self._start_pserver(False, True)
self._wait_ps_ready(pid) self._wait_ps_ready(p1.pid)
# raise SIGTERM to pserver # raise SIGTERM to pserver
os.kill(pid, signal.SIGTERM) os.kill(p1.pid, signal.SIGKILL)
p1.join()
# run pserver on CPU in async mode # run pserver on CPU in async mode
pid = self._start_pserver(False, False) p2 = self._start_pserver(False, False)
self._wait_ps_ready(pid) self._wait_ps_ready(p2.pid)
# raise SIGTERM to pserver # raise SIGTERM to pserver
os.kill(pid, signal.SIGTERM) os.kill(p2.pid, signal.SIGKILL)
p2.join()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -173,6 +173,7 @@ class TestCRFModel(unittest.TestCase): ...@@ -173,6 +173,7 @@ class TestCRFModel(unittest.TestCase):
pe.run(feed=feeder.feed(cur_batch), pe.run(feed=feeder.feed(cur_batch),
fetch_list=[avg_cost.name]))[0] fetch_list=[avg_cost.name]))[0]
@unittest.skip(reason="CI hangs")
def test_update_sparse_parameter_all_reduce(self): def test_update_sparse_parameter_all_reduce(self):
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
...@@ -181,6 +182,7 @@ class TestCRFModel(unittest.TestCase): ...@@ -181,6 +182,7 @@ class TestCRFModel(unittest.TestCase):
self.check_network_convergence( self.check_network_convergence(
is_sparse=True, build_strategy=build_strategy, use_cuda=False) is_sparse=True, build_strategy=build_strategy, use_cuda=False)
@unittest.skip(reason="CI hangs")
def test_update_dense_parameter_all_reduce(self): def test_update_dense_parameter_all_reduce(self):
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
...@@ -189,6 +191,7 @@ class TestCRFModel(unittest.TestCase): ...@@ -189,6 +191,7 @@ class TestCRFModel(unittest.TestCase):
self.check_network_convergence( self.check_network_convergence(
is_sparse=False, build_strategy=build_strategy, use_cuda=False) is_sparse=False, build_strategy=build_strategy, use_cuda=False)
@unittest.skip(reason="CI hangs")
def test_update_sparse_parameter_reduce(self): def test_update_sparse_parameter_reduce(self):
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
...@@ -197,6 +200,7 @@ class TestCRFModel(unittest.TestCase): ...@@ -197,6 +200,7 @@ class TestCRFModel(unittest.TestCase):
self.check_network_convergence( self.check_network_convergence(
is_sparse=True, build_strategy=build_strategy, use_cuda=False) is_sparse=True, build_strategy=build_strategy, use_cuda=False)
@unittest.skip(reason="CI hangs")
def test_update_dense_parameter_reduce(self): def test_update_dense_parameter_reduce(self):
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册