提交 0264ec39 编写于 作者: Q qiaolongfei

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add-async-listen-and-serv-op

Use different clusters Use different clusters
====================== ======================
PaddlePaddle supports running jobs on several platforms including: The user's cluster environment is not the same. To facilitate everyone's deployment, we provide a variety of cluster deployment methods to facilitate the submission of cluster training tasks, which will be introduced as follows:
- `Kubernetes <http://kubernetes.io>`_ open-source system for automating deployment, scaling, and management of containerized applications from Google.
- `OpenMPI <https://www.open-mpi.org>`_ Mature high performance parallel computing framework.
- `Fabric <http://www.fabfile.org>`_ A cluster management tool. Write scripts to submit jobs or manage the cluster.
We'll introduce cluster job management on these platforms. The examples can be found under `cluster_train_v2 <https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/scripts/cluster_train_v2>`_ . `Kubernetes <http://kubernetes.io>`_ is a scheduling framework of Google open source container cluster, supporting a complete cluster solution for large-scale cluster production environment. The following guidelines show PaddlePaddle's support for Kubernetes:
These cluster platforms provide API or environment variables for training processes, when the job is dispatched to different nodes. Like node ID, IP or total number of nodes etc. .. toctree::
:maxdepth: 1
k8s_cn.md
k8s_distributed_cn.md
`OpenMPI <https://www.open-mpi.org>`_ is a mature high-performance parallel computing framework, which is widely used in the field of HPC. The following guide describes how to use OpenMPI to build PaddlePaddle's cluster training task:
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1
fabric_en.md openmpi_cn.md
openmpi_en.md
k8s_en.md `Fabric <http://www.fabfile.org>`_ is a convenient tool for program deployment and management. We provide a way to deploy and manage with Fabric. If you want to know more about it, please read the following guidelines:
k8s_aws_en.md
.. toctree::
:maxdepth: 1
fabric_cn.md
We also support the deployment of PaddlePaddle on AWS. Learn more about:
.. toctree::
:maxdepth: 1
k8s_aws_cn.md
The examples can be found under `cluster_train_v2 <https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/scripts/cluster_train_v2>`_ .
\ No newline at end of file
...@@ -226,15 +226,15 @@ static bool has_fetch_operators( ...@@ -226,15 +226,15 @@ static bool has_fetch_operators(
} }
void Executor::Run(const ProgramDesc& program, Scope* scope, void Executor::Run(const ProgramDesc& program, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets, std::map<std::string, const LoDTensor*>* feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets, std::map<std::string, LoDTensor*>* fetch_targets,
bool create_vars, const std::string& feed_holder_name, bool create_vars, const std::string& feed_holder_name,
const std::string& fetch_holder_name) { const std::string& fetch_holder_name) {
platform::RecordBlock b(kProgramId); platform::RecordBlock b(kProgramId);
bool has_feed_ops = bool has_feed_ops =
has_feed_operators(program.Block(0), feed_targets, feed_holder_name); has_feed_operators(program.Block(0), *feed_targets, feed_holder_name);
bool has_fetch_ops = bool has_fetch_ops =
has_fetch_operators(program.Block(0), fetch_targets, fetch_holder_name); has_fetch_operators(program.Block(0), *fetch_targets, fetch_holder_name);
ProgramDesc* copy_program = const_cast<ProgramDesc*>(&program); ProgramDesc* copy_program = const_cast<ProgramDesc*>(&program);
if (!has_feed_ops || !has_fetch_ops) { if (!has_feed_ops || !has_fetch_ops) {
...@@ -250,7 +250,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -250,7 +250,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
feed_holder->SetPersistable(true); feed_holder->SetPersistable(true);
int i = 0; int i = 0;
for (auto& feed_target : feed_targets) { for (auto& feed_target : (*feed_targets)) {
std::string var_name = feed_target.first; std::string var_name = feed_target.first;
VLOG(3) << "feed target's name: " << var_name; VLOG(3) << "feed target's name: " << var_name;
...@@ -273,7 +273,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -273,7 +273,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
fetch_holder->SetPersistable(true); fetch_holder->SetPersistable(true);
int i = 0; int i = 0;
for (auto& fetch_target : fetch_targets) { for (auto& fetch_target : (*fetch_targets)) {
std::string var_name = fetch_target.first; std::string var_name = fetch_target.first;
VLOG(3) << "fetch target's name: " << var_name; VLOG(3) << "fetch target's name: " << var_name;
...@@ -361,16 +361,16 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -361,16 +361,16 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
void Executor::RunPreparedContext( void Executor::RunPreparedContext(
ExecutorPrepareContext* ctx, Scope* scope, ExecutorPrepareContext* ctx, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets, std::map<std::string, const LoDTensor*>* feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets, bool create_vars, std::map<std::string, LoDTensor*>* fetch_targets, bool create_vars,
const std::string& feed_holder_name, const std::string& fetch_holder_name) { const std::string& feed_holder_name, const std::string& fetch_holder_name) {
auto& global_block = ctx->prog_.Block(ctx->block_id_); auto& global_block = ctx->prog_.Block(ctx->block_id_);
PADDLE_ENFORCE( PADDLE_ENFORCE(
has_feed_operators(global_block, feed_targets, feed_holder_name), has_feed_operators(global_block, *feed_targets, feed_holder_name),
"Program in ExecutorPrepareContext should has feed_ops."); "Program in ExecutorPrepareContext should has feed_ops.");
PADDLE_ENFORCE( PADDLE_ENFORCE(
has_fetch_operators(global_block, fetch_targets, fetch_holder_name), has_fetch_operators(global_block, *fetch_targets, fetch_holder_name),
"Program in the prepared context should has fetch_ops."); "Program in the prepared context should has fetch_ops.");
// map the data of feed_targets to feed_holder // map the data of feed_targets to feed_holder
...@@ -378,8 +378,8 @@ void Executor::RunPreparedContext( ...@@ -378,8 +378,8 @@ void Executor::RunPreparedContext(
if (op->Type() == kFeedOpType) { if (op->Type() == kFeedOpType) {
std::string feed_target_name = op->Output("Out")[0]; std::string feed_target_name = op->Output("Out")[0];
int idx = boost::get<int>(op->GetAttr("col")); int idx = boost::get<int>(op->GetAttr("col"));
SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name, SetFeedVariable(scope, *(*feed_targets)[feed_target_name],
idx); feed_holder_name, idx);
} }
} }
...@@ -390,7 +390,7 @@ void Executor::RunPreparedContext( ...@@ -390,7 +390,7 @@ void Executor::RunPreparedContext(
if (op->Type() == kFetchOpType) { if (op->Type() == kFetchOpType) {
std::string fetch_target_name = op->Input("X")[0]; std::string fetch_target_name = op->Input("X")[0];
int idx = boost::get<int>(op->GetAttr("col")); int idx = boost::get<int>(op->GetAttr("col"));
*fetch_targets[fetch_target_name] = *(*fetch_targets)[fetch_target_name] =
GetFetchVariable(*scope, fetch_holder_name, idx); GetFetchVariable(*scope, fetch_holder_name, idx);
} }
} }
......
...@@ -55,8 +55,8 @@ class Executor { ...@@ -55,8 +55,8 @@ class Executor {
bool create_local_scope = true, bool create_vars = true); bool create_local_scope = true, bool create_vars = true);
void Run(const ProgramDesc& program, Scope* scope, void Run(const ProgramDesc& program, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets, std::map<std::string, const LoDTensor*>* feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets, std::map<std::string, LoDTensor*>* fetch_targets,
bool create_vars = true, bool create_vars = true,
const std::string& feed_holder_name = "feed", const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch"); const std::string& fetch_holder_name = "fetch");
...@@ -74,8 +74,8 @@ class Executor { ...@@ -74,8 +74,8 @@ class Executor {
bool create_vars = true); bool create_vars = true);
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets, std::map<std::string, const LoDTensor*>* feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets, std::map<std::string, LoDTensor*>* fetch_targets,
bool create_vars = true, bool create_vars = true,
const std::string& feed_holder_name = "feed", const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch"); const std::string& fetch_holder_name = "fetch");
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/framework.pb.h"
namespace paddle {
namespace inference {
/*
* EngineBase is the base class of all inference engines. An inference engine
* takes a paddle program as input, and outputs the result in fluid Tensor
* format. It can be used to optimize performance of computation sub-blocks, for
* example, break down the original block into sub-blocks and execute each
* sub-blocks in different engines.
*
* For example:
* When inference, the resnet50 model can put most of the model into subgraph
* and run it on a TensorRT engine.
*
* There are several engines such as TensorRT and other frameworks, so an
* EngineBase is put forward to give an unified interface for all the
* different engine implemention.
*/
class EngineBase {
public:
using DescType = ::paddle::framework::proto::BlockDesc;
// Build the model and do some preparation, for example, in TensorRT, run
// createInferBuilder, buildCudaEngine.
virtual void Build(const DescType& paddle_model) = 0;
// Execute the engine, that will run the inference network.
virtual void Execute(int batch_size) = 0;
virtual ~EngineBase() {}
}; // class EngineBase
} // namespace inference
} // namespace paddle
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader) if(WITH_TESTING)
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
nv_test(test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda)
endif()
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/engine.h"
#include <NvInfer.h>
#include <cuda.h>
#include <glog/logging.h>
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace inference {
namespace tensorrt {
void TensorRTEngine::Build(const DescType& paddle_model) {
PADDLE_ENFORCE(false, "not implemented");
}
void TensorRTEngine::Execute(int batch_size) {
infer_context_->enqueue(batch_size, buffers_.data(), *stream_, nullptr);
cudaStreamSynchronize(*stream_);
}
TensorRTEngine::~TensorRTEngine() {
// clean buffer
for (auto& buffer : buffers_) {
if (buffer != nullptr) {
PADDLE_ENFORCE_EQ(0, cudaFree(buffer));
buffer = nullptr;
}
}
}
void TensorRTEngine::FreezeNetwork() {
PADDLE_ENFORCE(infer_builder_ != nullptr,
"Call InitNetwork first to initialize network.");
PADDLE_ENFORCE(infer_network_ != nullptr,
"Call InitNetwork first to initialize network.");
// build engine.
infer_builder_->setMaxBatchSize(max_batch_);
infer_builder_->setMaxWorkspaceSize(max_workspace_);
infer_engine_.reset(infer_builder_->buildCudaEngine(*infer_network_));
PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!");
infer_context_.reset(infer_engine_->createExecutionContext());
// allocate GPU buffers.
buffers_.resize(buffer_sizes_.size(), nullptr);
for (auto& item : buffer_sizes_) {
if (item.second == 0) {
auto slot_offset = infer_engine_->getBindingIndex(item.first.c_str());
item.second = kDataTypeSize[static_cast<int>(
infer_engine_->getBindingDataType(slot_offset))] *
AccumDims(infer_engine_->getBindingDimensions(slot_offset));
}
PADDLE_ENFORCE_EQ(0, cudaMalloc(&buffer(item.first), item.second));
}
}
nvinfer1::ITensor* TensorRTEngine::DeclareInput(const std::string& name,
nvinfer1::DataType dtype,
const nvinfer1::Dims& dim) {
PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate input name %s",
name);
PADDLE_ENFORCE(infer_network_ != nullptr, "should initnetwork first");
auto* input = infer_network_->addInput(name.c_str(), dtype, dim);
PADDLE_ENFORCE(input, "infer network add input %s failed", name);
buffer_sizes_[name] = kDataTypeSize[static_cast<int>(dtype)] * AccumDims(dim);
return input;
}
void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer* layer, int offset,
const std::string& name) {
PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s",
name);
auto* output = layer->getOutput(offset);
PADDLE_ENFORCE(output != nullptr);
output->setName(name.c_str());
infer_network_->markOutput(*output);
// output buffers' size can only be decided latter, set zero here to mark this
// and will reset latter.
buffer_sizes_[name] = 0;
}
void* TensorRTEngine::GetOutputInGPU(const std::string& name) {
return buffer(name);
}
void TensorRTEngine::GetOutputInCPU(const std::string& name, void* dst,
size_t max_size) {
// determine data size
auto it = buffer_sizes_.find(name);
PADDLE_ENFORCE(it != buffer_sizes_.end());
PADDLE_ENFORCE_GT(it->second, 0);
PADDLE_ENFORCE_GE(max_size, it->second);
PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buffer(name), it->second,
cudaMemcpyDeviceToHost, *stream_));
}
void*& TensorRTEngine::buffer(const std::string& name) {
PADDLE_ENFORCE(infer_engine_ != nullptr, "call FreezeNetwork first.");
auto it = buffer_sizes_.find(name);
PADDLE_ENFORCE(it != buffer_sizes_.end());
auto slot_offset = infer_engine_->getBindingIndex(name.c_str());
return buffers_[slot_offset];
}
void TensorRTEngine::SetInputFromCPU(const std::string& name, void* data,
size_t size) {
void* buf = buffer(name);
PADDLE_ENFORCE_EQ(
0, cudaMemcpyAsync(buf, data, size, cudaMemcpyHostToDevice, *stream_));
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <NvInfer.h>
#include <memory>
#include <unordered_map>
#include "paddle/fluid/inference/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* TensorRT Engine.
*
* There are two alternative ways to use it, one is to build from a paddle
* protobuf model, another way is to manully construct the network.
*/
class TensorRTEngine : public EngineBase {
public:
// Weight is model parameter.
class Weight {
public:
Weight(nvinfer1::DataType dtype, void* value, int num_elem) {
w_.type = dtype;
w_.values = value;
w_.count = num_elem;
}
const nvinfer1::Weights& get() { return w_; }
private:
nvinfer1::Weights w_;
};
TensorRTEngine(int max_batch, int max_workspace, cudaStream_t* stream,
nvinfer1::ILogger& logger = NaiveLogger::Global())
: max_batch_(max_batch),
max_workspace_(max_workspace),
stream_(stream),
logger_(logger) {}
virtual ~TensorRTEngine();
// TODO(Superjomn) implement it later when graph segmentation is supported.
virtual void Build(const DescType& paddle_model) override;
virtual void Execute(int batch_size) override;
// Initialize the inference network, so that TensorRT layers can add to this
// network.
void InitNetwork() {
infer_builder_.reset(createInferBuilder(logger_));
infer_network_.reset(infer_builder_->createNetwork());
}
// After finishing adding ops, freeze this network and creates the executation
// environment.
void FreezeNetwork();
// Add an input and set its name, data type and dimention.
nvinfer1::ITensor* DeclareInput(const std::string& name,
nvinfer1::DataType dtype,
const nvinfer1::Dims& dim);
// Set the offset-th output from a layer as the network's output, and set its
// name.
void DeclareOutput(const nvinfer1::ILayer* layer, int offset,
const std::string& name);
// GPU memory address for an ITensor with specific name. One can operate on
// these memory directly for acceleration, for example, output the converted
// data directly to the buffer to save data copy overhead.
// NOTE this should be used after calling `FreezeNetwork`.
void*& buffer(const std::string& name);
// Fill an input from CPU memory with name and size.
void SetInputFromCPU(const std::string& name, void* data, size_t size);
// TODO(Superjomn) is this method necessary given that buffer(xxx) can be
// accessed directly. Fill an input from GPU memory with name and size.
void SetInputFromGPU(const std::string& name, void* data, size_t size);
// Get an output called name, the output of tensorrt is in GPU, so this method
// will just return the output's GPU memory address.
void* GetOutputInGPU(const std::string& name);
// LOW EFFICENCY! Get output to CPU, this will trigger a memory copy from GPU
// to CPU.
void GetOutputInCPU(const std::string& name, void* dst, size_t max_size);
nvinfer1::ICudaEngine* engine() { return infer_engine_.get(); }
nvinfer1::INetworkDefinition* network() { return infer_network_.get(); }
private:
// the max batch size
int max_batch_;
// the max memory size the engine uses
int max_workspace_;
cudaStream_t* stream_;
nvinfer1::ILogger& logger_;
std::vector<void*> buffers_;
// max data size for the buffers.
std::unordered_map<std::string /*name*/, size_t /*max size*/> buffer_sizes_;
// TensorRT related internal members
template <typename T>
struct Destroyer {
void operator()(T* x) { x->destroy(); }
};
template <typename T>
using infer_ptr = std::unique_ptr<T, Destroyer<T>>;
infer_ptr<nvinfer1::IBuilder> infer_builder_;
infer_ptr<nvinfer1::INetworkDefinition> infer_network_;
infer_ptr<nvinfer1::ICudaEngine> infer_engine_;
infer_ptr<nvinfer1::IExecutionContext> infer_context_;
}; // class TensorRTEngine
// Add an layer__ into engine__ with args ARGS.
// For example:
// TRT_ENGINE_ADD_LAYER(xxx, FullyConnected, input, dim, weights, bias)
//
// Reference
// https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#charRNN_define_network
//
// will add a fully connected layer into the engine.
// TensorRT has too many layers, so that is not wise to add member functions for
// them, and an macro like this is more extensible when underlying TensorRT
// library add new layer supports.
#define TRT_ENGINE_ADD_LAYER(engine__, layer__, ARGS...) \
engine__->network()->add##layer__(ARGS);
} // namespace tensorrt
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <NvInfer.h>
#include <cuda.h>
#include <glog/logging.h>
#include "paddle/fluid/platform/dynload/tensorrt.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace dy = paddle::platform::dynload;
static size_t AccumDims(nvinfer1::Dims dims) {
size_t num = dims.nbDims == 0 ? 0 : 1;
for (int i = 0; i < dims.nbDims; i++) {
PADDLE_ENFORCE_GT(dims.d[i], 0);
num *= dims.d[i];
}
return num;
}
// TensorRT data type to size
const int kDataTypeSize[] = {
4, // kFLOAT
2, // kHALF
1, // kINT8
4 // kINT32
};
// The following two API are implemented in TensorRT's header file, cannot load
// from the dynamic library. So create our own implementation and directly
// trigger the method from the dynamic library.
static nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger& logger) {
return static_cast<nvinfer1::IBuilder*>(
dy::createInferBuilder_INTERNAL(&logger, NV_TENSORRT_VERSION));
}
static nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger& logger) {
return static_cast<nvinfer1::IRuntime*>(
dy::createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION));
}
// A logger for create TensorRT infer builder.
class NaiveLogger : public nvinfer1::ILogger {
public:
void log(nvinfer1::ILogger::Severity severity, const char* msg) override {
switch (severity) {
case Severity::kINFO:
LOG(INFO) << msg;
break;
case Severity::kWARNING:
LOG(WARNING) << msg;
break;
case Severity::kINTERNAL_ERROR:
case Severity::kERROR:
LOG(ERROR) << msg;
break;
default:
break;
}
}
static nvinfer1::ILogger& Global() {
static nvinfer1::ILogger* x = new NaiveLogger;
return *x;
}
virtual ~NaiveLogger() override {}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/engine.h"
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class TensorRTEngineTest : public ::testing::Test {
protected:
void SetUp() override {
ASSERT_EQ(0, cudaStreamCreate(&stream_));
engine_ = new TensorRTEngine(1, 1 << 10, &stream_);
engine_->InitNetwork();
}
void TearDown() override {
delete engine_;
cudaStreamDestroy(stream_);
}
protected:
TensorRTEngine* engine_;
cudaStream_t stream_;
};
TEST_F(TensorRTEngineTest, add_layer) {
const int size = 1;
float raw_weight[size] = {2.}; // Weight in CPU memory.
float raw_bias[size] = {3.};
LOG(INFO) << "create weights";
TensorRTEngine::Weight weight(nvinfer1::DataType::kFLOAT, raw_weight, size);
TensorRTEngine::Weight bias(nvinfer1::DataType::kFLOAT, raw_bias, size);
auto* x = engine_->DeclareInput("x", nvinfer1::DataType::kFLOAT,
nvinfer1::DimsCHW{1, 1, 1});
auto* fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *x, size,
weight.get(), bias.get());
PADDLE_ENFORCE(fc_layer != nullptr);
engine_->DeclareOutput(fc_layer, 0, "y");
LOG(INFO) << "freeze network";
engine_->FreezeNetwork();
ASSERT_EQ(engine_->engine()->getNbBindings(), 2);
// fill in real data
float x_v = 1234;
engine_->SetInputFromCPU("x", (void*)&x_v, 1 * sizeof(float));
LOG(INFO) << "to execute";
engine_->Execute(1);
LOG(INFO) << "to get output";
// void* y_v =
float y_cpu;
engine_->GetOutputInCPU("y", &y_cpu, sizeof(float));
LOG(INFO) << "to checkout output";
ASSERT_EQ(y_cpu, x_v * 2 + 3);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <glog/logging.h> #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
......
...@@ -178,10 +178,10 @@ void TestInference(const std::string& dirname, ...@@ -178,10 +178,10 @@ void TestInference(const std::string& dirname,
std::unique_ptr<paddle::framework::ExecutorPrepareContext> ctx; std::unique_ptr<paddle::framework::ExecutorPrepareContext> ctx;
if (PrepareContext) { if (PrepareContext) {
ctx = executor.Prepare(*inference_program, 0); ctx = executor.Prepare(*inference_program, 0);
executor.RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets, executor.RunPreparedContext(ctx.get(), scope, &feed_targets,
CreateVars); &fetch_targets, CreateVars);
} else { } else {
executor.Run(*inference_program, scope, feed_targets, fetch_targets, executor.Run(*inference_program, scope, &feed_targets, &fetch_targets,
CreateVars); CreateVars);
} }
...@@ -197,10 +197,10 @@ void TestInference(const std::string& dirname, ...@@ -197,10 +197,10 @@ void TestInference(const std::string& dirname,
if (PrepareContext) { if (PrepareContext) {
// Note: if you change the inference_program, you need to call // Note: if you change the inference_program, you need to call
// executor.Prepare() again to get a new ExecutorPrepareContext. // executor.Prepare() again to get a new ExecutorPrepareContext.
executor.RunPreparedContext(ctx.get(), scope, feed_targets, executor.RunPreparedContext(ctx.get(), scope, &feed_targets,
fetch_targets, CreateVars); &fetch_targets, CreateVars);
} else { } else {
executor.Run(*inference_program, scope, feed_targets, fetch_targets, executor.Run(*inference_program, scope, &feed_targets, &fetch_targets,
CreateVars); CreateVars);
} }
} }
......
...@@ -223,8 +223,9 @@ void BeamSearchDecoder<T>::ConvertSentenceVectorToLodTensor( ...@@ -223,8 +223,9 @@ void BeamSearchDecoder<T>::ConvertSentenceVectorToLodTensor(
sentence_vector_list[src_idx].size()); sentence_vector_list[src_idx].size());
} }
auto cpu_place = new paddle::platform::CPUPlace(); auto cpu_place = std::unique_ptr<paddle::platform::CPUPlace>(
paddle::platform::CPUDeviceContext cpu_ctx(*cpu_place); new paddle::platform::CPUPlace());
paddle::platform::CPUDeviceContext cpu_ctx(*cpu_place.get());
framework::LoD lod; framework::LoD lod;
lod.push_back(source_level_lod); lod.push_back(source_level_lod);
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/math/math_function_impl.h" #include "paddle/fluid/operators/math/math_function_impl.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -161,7 +162,8 @@ void batched_gemm<platform::CPUDeviceContext, float16>( ...@@ -161,7 +162,8 @@ void batched_gemm<platform::CPUDeviceContext, float16>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA, const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float16 alpha, const float16* A, const float16* B, const float16 beta, const float16 alpha, const float16* A, const float16* B, const float16 beta,
float16* C, const int batchCount, const int strideA, const int strideB) { float16* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
PADDLE_THROW("float16 batched_gemm not supported on CPU"); PADDLE_THROW("float16 batched_gemm not supported on CPU");
} }
...@@ -172,7 +174,8 @@ void batched_gemm<platform::CPUDeviceContext, float>( ...@@ -172,7 +174,8 @@ void batched_gemm<platform::CPUDeviceContext, float>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA, const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float alpha, const float* A, const float* B, const float beta, const float alpha, const float* A, const float* B, const float beta,
float* C, const int batchCount, const int strideA, const int strideB) { float* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
int lda = (transA == CblasNoTrans) ? K : M; int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K; int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N; int ldc = N;
...@@ -194,7 +197,8 @@ void batched_gemm<platform::CPUDeviceContext, double>( ...@@ -194,7 +197,8 @@ void batched_gemm<platform::CPUDeviceContext, double>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA, const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const double alpha, const double* A, const double* B, const double beta, const double alpha, const double* A, const double* B, const double beta,
double* C, const int batchCount, const int strideA, const int strideB) { double* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
int lda = (transA == CblasNoTrans) ? K : M; int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K; int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N; int ldc = N;
...@@ -220,7 +224,8 @@ void batched_gemm<platform::CPUDeviceContext, float>( ...@@ -220,7 +224,8 @@ void batched_gemm<platform::CPUDeviceContext, float>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA, const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float alpha, const float* A, const float* B, const float beta, const float alpha, const float* A, const float* B, const float beta,
float* C, const int batchCount, const int strideA, const int strideB) { float* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
for (int k = 0; k < batchCount; ++k) { for (int k = 0; k < batchCount; ++k) {
const float* Ak = &A[k * strideA]; const float* Ak = &A[k * strideA];
const float* Bk = &B[k * strideB]; const float* Bk = &B[k * strideB];
...@@ -235,7 +240,8 @@ void batched_gemm<platform::CPUDeviceContext, double>( ...@@ -235,7 +240,8 @@ void batched_gemm<platform::CPUDeviceContext, double>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA, const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const double alpha, const double* A, const double* B, const double beta, const double alpha, const double* A, const double* B, const double beta,
double* C, const int batchCount, const int strideA, const int strideB) { double* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
for (int k = 0; k < batchCount; ++k) { for (int k = 0; k < batchCount; ++k) {
const double* Ak = &A[k * strideA]; const double* Ak = &A[k * strideA];
const double* Bk = &B[k * strideB]; const double* Bk = &B[k * strideB];
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function_impl.h" #include "paddle/fluid/operators/math/math_function_impl.h"
...@@ -267,7 +268,8 @@ void batched_gemm<platform::CUDADeviceContext, float16>( ...@@ -267,7 +268,8 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA, const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float16 alpha, const float16* A, const float16* B, const float16 beta, const float16 alpha, const float16* A, const float16* B, const float16 beta,
float16* C, const int batchCount, const int strideA, const int strideB) { float16* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
// Note that cublas follows fortran order, so the order is different from // Note that cublas follows fortran order, so the order is different from
// the cblas convention. // the cblas convention.
...@@ -278,7 +280,7 @@ void batched_gemm<platform::CUDADeviceContext, float16>( ...@@ -278,7 +280,7 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int strideC = M * N; const int64_t strideC = M * N;
const half h_alpha = static_cast<const half>(alpha); const half h_alpha = static_cast<const half>(alpha);
const half h_beta = static_cast<const half>(beta); const half h_beta = static_cast<const half>(beta);
...@@ -303,7 +305,8 @@ void batched_gemm<platform::CUDADeviceContext, float>( ...@@ -303,7 +305,8 @@ void batched_gemm<platform::CUDADeviceContext, float>(
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA, const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float alpha, const float* A, const float* B, const float beta, const float alpha, const float* A, const float* B, const float beta,
float* C, const int batchCount, const int strideA, const int strideB) { float* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
// Note that cublas follows fortran order, so the order is different from // Note that cublas follows fortran order, so the order is different from
// the cblas convention. // the cblas convention.
...@@ -314,7 +317,7 @@ void batched_gemm<platform::CUDADeviceContext, float>( ...@@ -314,7 +317,7 @@ void batched_gemm<platform::CUDADeviceContext, float>(
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int strideC = M * N; const int64_t strideC = M * N;
PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched( PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
...@@ -329,7 +332,8 @@ void batched_gemm<platform::CUDADeviceContext, double>( ...@@ -329,7 +332,8 @@ void batched_gemm<platform::CUDADeviceContext, double>(
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA, const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const double alpha, const double* A, const double* B, const double beta, const double alpha, const double* A, const double* B, const double beta,
double* C, const int batchCount, const int strideA, const int strideB) { double* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
// Note that cublas follows fortran order, so the order is different from // Note that cublas follows fortran order, so the order is different from
// the cblas convention. // the cblas convention.
...@@ -340,7 +344,7 @@ void batched_gemm<platform::CUDADeviceContext, double>( ...@@ -340,7 +344,7 @@ void batched_gemm<platform::CUDADeviceContext, double>(
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int strideC = M * N; const int64_t strideC = M * N;
PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched( PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
......
...@@ -26,7 +26,7 @@ limitations under the License. */ ...@@ -26,7 +26,7 @@ limitations under the License. */
#ifndef LAPACK_FOUND #ifndef LAPACK_FOUND
extern "C" { extern "C" {
#include <cblas.h> #include <cblas.h> // NOLINT
int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda, int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda,
int* ipiv); int* ipiv);
int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda, int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda,
...@@ -39,6 +39,7 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, ...@@ -39,6 +39,7 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda,
#endif #endif
#include <cmath> #include <cmath>
#include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
...@@ -78,8 +79,8 @@ template <typename DeviceContext, typename T> ...@@ -78,8 +79,8 @@ template <typename DeviceContext, typename T>
void batched_gemm(const DeviceContext& context, const CBLAS_TRANSPOSE transA, void batched_gemm(const DeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const CBLAS_TRANSPOSE transB, const int M, const int N,
const int K, const T alpha, const T* A, const T* B, const int K, const T alpha, const T* A, const T* B,
const T beta, T* C, const int batchCount, const int strideA, const T beta, T* C, const int batchCount,
const int strideB); const int64_t strideA, const int64_t strideB);
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void gemv(const DeviceContext& context, const bool trans_a, const int M, void gemv(const DeviceContext& context, const bool trans_a, const int M,
......
...@@ -77,7 +77,7 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -77,7 +77,7 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
if (!is_test) { if (!is_test) {
T threshold = exp(-64); T threshold = exp(-64);
for (size_t i = 0; i < dst_tz[0] * dst_tz[1]; ++i) { for (int i = 0; i < dst_tz[0] * dst_tz[1]; ++i) {
output_data[i] = output_data[i] =
output_data[i] < threshold ? threshold : output_data[i]; output_data[i] < threshold ? threshold : output_data[i];
} }
......
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
#pragma once #pragma once
#include <cublasXt.h>
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cuda.h> #include <cuda.h>
#include <dlfcn.h> #include <dlfcn.h>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include <type_traits>
#include "paddle/fluid/platform/dynload/dynamic_loader.h" #include "paddle/fluid/platform/dynload/dynamic_loader.h"
namespace paddle { namespace paddle {
...@@ -37,14 +39,14 @@ extern void *cublas_dso_handle; ...@@ -37,14 +39,14 @@ extern void *cublas_dso_handle;
#ifdef PADDLE_USE_DSO #ifdef PADDLE_USE_DSO
#define DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name) \ #define DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name) \
struct DynLoad__##__name { \ struct DynLoad__##__name { \
using FUNC_TYPE = decltype(&::__name); \
template <typename... Args> \ template <typename... Args> \
inline cublasStatus_t operator()(Args... args) { \ inline cublasStatus_t operator()(Args... args) { \
typedef cublasStatus_t (*cublasFunc)(Args...); \
std::call_once(cublas_dso_flag, []() { \ std::call_once(cublas_dso_flag, []() { \
cublas_dso_handle = paddle::platform::dynload::GetCublasDsoHandle(); \ cublas_dso_handle = paddle::platform::dynload::GetCublasDsoHandle(); \
}); \ }); \
void *p_##__name = dlsym(cublas_dso_handle, #__name); \ void *p_##__name = dlsym(cublas_dso_handle, #__name); \
return reinterpret_cast<cublasFunc>(p_##__name)(args...); \ return reinterpret_cast<FUNC_TYPE>(p_##__name)(args...); \
} \ } \
}; \ }; \
extern DynLoad__##__name __name extern DynLoad__##__name __name
...@@ -71,8 +73,8 @@ extern void *cublas_dso_handle; ...@@ -71,8 +73,8 @@ extern void *cublas_dso_handle;
__macro(cublasDgemm_v2); \ __macro(cublasDgemm_v2); \
__macro(cublasHgemm); \ __macro(cublasHgemm); \
__macro(cublasSgemmEx); \ __macro(cublasSgemmEx); \
__macro(cublasSgeam_v2); \ __macro(cublasSgeam); \
__macro(cublasDgeam_v2); \ __macro(cublasDgeam); \
__macro(cublasCreate_v2); \ __macro(cublasCreate_v2); \
__macro(cublasDestroy_v2); \ __macro(cublasDestroy_v2); \
__macro(cublasSetStream_v2); \ __macro(cublasSetStream_v2); \
......
...@@ -34,7 +34,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name); ...@@ -34,7 +34,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
struct DynLoad__##__name { \ struct DynLoad__##__name { \
template <typename... Args> \ template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \ auto operator()(Args... args) -> decltype(__name(args...)) { \
using cudnn_func = decltype(__name(args...)) (*)(Args...); \ using cudnn_func = decltype(&::__name); \
std::call_once(cudnn_dso_flag, []() { \ std::call_once(cudnn_dso_flag, []() { \
cudnn_dso_handle = paddle::platform::dynload::GetCUDNNDsoHandle(); \ cudnn_dso_handle = paddle::platform::dynload::GetCUDNNDsoHandle(); \
}); \ }); \
......
...@@ -41,7 +41,7 @@ extern void *cupti_dso_handle; ...@@ -41,7 +41,7 @@ extern void *cupti_dso_handle;
struct DynLoad__##__name { \ struct DynLoad__##__name { \
template <typename... Args> \ template <typename... Args> \
inline CUptiResult CUPTIAPI operator()(Args... args) { \ inline CUptiResult CUPTIAPI operator()(Args... args) { \
typedef CUptiResult CUPTIAPI (*cuptiFunc)(Args...); \ using cuptiFunc = decltype(&::__name); \
std::call_once(cupti_dso_flag, []() { \ std::call_once(cupti_dso_flag, []() { \
cupti_dso_handle = paddle::platform::dynload::GetCUPTIDsoHandle(); \ cupti_dso_handle = paddle::platform::dynload::GetCUPTIDsoHandle(); \
}); \ }); \
......
...@@ -30,7 +30,7 @@ extern void *curand_dso_handle; ...@@ -30,7 +30,7 @@ extern void *curand_dso_handle;
struct DynLoad__##__name { \ struct DynLoad__##__name { \
template <typename... Args> \ template <typename... Args> \
curandStatus_t operator()(Args... args) { \ curandStatus_t operator()(Args... args) { \
typedef curandStatus_t (*curandFunc)(Args...); \ using curandFunc = decltype(&::__name); \
std::call_once(curand_dso_flag, []() { \ std::call_once(curand_dso_flag, []() { \
curand_dso_handle = paddle::platform::dynload::GetCurandDsoHandle(); \ curand_dso_handle = paddle::platform::dynload::GetCurandDsoHandle(); \
}); \ }); \
......
...@@ -33,7 +33,7 @@ extern void* nccl_dso_handle; ...@@ -33,7 +33,7 @@ extern void* nccl_dso_handle;
struct DynLoad__##__name { \ struct DynLoad__##__name { \
template <typename... Args> \ template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \ auto operator()(Args... args) -> decltype(__name(args...)) { \
using nccl_func = decltype(__name(args...)) (*)(Args...); \ using nccl_func = decltype(&::__name); \
std::call_once(nccl_dso_flag, []() { \ std::call_once(nccl_dso_flag, []() { \
nccl_dso_handle = paddle::platform::dynload::GetNCCLDsoHandle(); \ nccl_dso_handle = paddle::platform::dynload::GetNCCLDsoHandle(); \
}); \ }); \
......
...@@ -36,7 +36,7 @@ extern void* warpctc_dso_handle; ...@@ -36,7 +36,7 @@ extern void* warpctc_dso_handle;
struct DynLoad__##__name { \ struct DynLoad__##__name { \
template <typename... Args> \ template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \ auto operator()(Args... args) -> decltype(__name(args...)) { \
using warpctcFunc = decltype(__name(args...)) (*)(Args...); \ using warpctcFunc = decltype(&::__name); \
std::call_once(warpctc_dso_flag, []() { \ std::call_once(warpctc_dso_flag, []() { \
warpctc_dso_handle = paddle::platform::dynload::GetWarpCTCDsoHandle(); \ warpctc_dso_handle = paddle::platform::dynload::GetWarpCTCDsoHandle(); \
}); \ }); \
......
...@@ -13,40 +13,49 @@ We want to make the building procedures: ...@@ -13,40 +13,49 @@ We want to make the building procedures:
1. Build docker images with PaddlePaddle pre-installed, so that we can run 1. Build docker images with PaddlePaddle pre-installed, so that we can run
PaddlePaddle applications directly in docker or on Kubernetes clusters. PaddlePaddle applications directly in docker or on Kubernetes clusters.
To achieve this, we created a repo: https://github.com/PaddlePaddle/buildtools To achieve this, we maintain a dockerhub repo:https://hub.docker.com/r/paddlepaddle/paddle
which gives several docker images that are `manylinux1` sufficient. Then we which provides pre-built environment images to build PaddlePaddle and generate corresponding `whl`
can build PaddlePaddle using these images to generate corresponding `whl` binaries.(**We strongly recommend building paddlepaddle in our pre-specified Docker environment.**)
binaries.
## Run The Build ## Development Workflow
Here we describe how the workflow goes on. We start from considering our daily development environment.
Developers work on a computer, which is usually a laptop or desktop:
<img src="doc/paddle-development-environment.png" width=500 />
or, they might rely on a more sophisticated box (like with GPUs):
<img src="doc/paddle-development-environment-gpu.png" width=500 />
A principle here is that source code lies on the development computer (host) so that editors like Eclipse can parse the source code to support auto-completion.
## Build With Docker
### Build Environments ### Build Environments
The pre-built build environment images are: The lastest pre-built build environment images are:
| Image | Tag | | Image | Tag |
| ----- | --- | | ----- | --- |
| paddlepaddle/paddle_manylinux_devel | cuda7.5_cudnn5 | | paddlepaddle/paddle | latest-dev |
| paddlepaddle/paddle_manylinux_devel | cuda8.0_cudnn5 | | paddlepaddle/paddle | latest-dev-android |
| paddlepaddle/paddle_manylinux_devel | cuda7.5_cudnn7 |
| paddlepaddle/paddle_manylinux_devel | cuda9.0_cudnn7 |
### Start Build ### Start Build
Choose one docker image that suit your environment and run the following
command to start a build:
```bash ```bash
git clone https://github.com/PaddlePaddle/Paddle.git git clone https://github.com/PaddlePaddle/Paddle.git
cd Paddle cd Paddle
docker run --rm -v $PWD:/paddle -e "WITH_GPU=OFF" -e "WITH_AVX=ON" -e "WITH_TESTING=OFF" -e "RUN_TEST=OFF" -e "PYTHON_ABI=cp27-cp27mu" paddlepaddle/paddle_manylinux_devel /paddle/paddle/scripts/docker/build.sh ./paddle/scripts/paddle_docker_build.sh build
``` ```
After the build finishes, you can get output `whl` package under After the build finishes, you can get output `whl` package under
`build/python/dist`. `build/python/dist`.
This command mounts the source directory on the host into `/paddle` in the container, then run the build script `/paddle/paddle/scripts/docker/build.sh` This command will download the most recent dev image from docker hub, start a container in the backend and then run the build script `/paddle/paddle/scripts/paddle_build.sh build` in the container.
in the container. When it writes to `/paddle/build` in the container, it writes to `$PWD/build` on the host indeed. The container mounts the source directory on the host into `/paddle`.
When it writes to `/paddle/build` in the container, it writes to `$PWD/build` on the host indeed.
### Build Options ### Build Options
...@@ -68,7 +77,6 @@ Users can specify the following Docker build arguments with either "ON" or "OFF" ...@@ -68,7 +77,6 @@ Users can specify the following Docker build arguments with either "ON" or "OFF"
| `WITH_DOC` | OFF | Build docs after build binaries. | | `WITH_DOC` | OFF | Build docs after build binaries. |
| `WOBOQ` | OFF | Generate WOBOQ code viewer under `build/woboq_out` | | `WOBOQ` | OFF | Generate WOBOQ code viewer under `build/woboq_out` |
## Docker Images ## Docker Images
You can get the latest PaddlePaddle docker images by You can get the latest PaddlePaddle docker images by
...@@ -144,59 +152,37 @@ docker push ...@@ -144,59 +152,37 @@ docker push
kubectl ... kubectl ...
``` ```
## Docker Images for Developers ### Reading source code with woboq codebrowser
We have a special docker image for developers:
`paddlepaddle/paddle:<version>-dev`. This image is also generated from
https://github.com/PaddlePaddle/buildtools
This a development image contains only the
development tools and standardizes the building procedure. Users include:
- developers -- no longer need to install development tools on the host, and can build their current work on the host (development computer).
- release engineers -- use this to build the official release from certain branch/tag on Github.com.
- document writers / Website developers -- Our documents are in the source repo in the form of .md/.rst files and comments in source code. We need tools to extract the information, typeset, and generate Web pages.
Of course, developers can install building tools on their development computers. But different versions of PaddlePaddle might require different set or version of building tools. Also, it makes collaborative debugging easier if all developers use a unified development environment.
The development image contains the following tools:
- gcc/clang
- nvcc
- Python
- sphinx
- woboq
- sshd
Many developers work on a remote computer with GPU; they could ssh into the computer and `docker exec` into the development container. However, running `sshd` in the container allows developers to ssh into the container directly.
### Development Workflow
Here we describe how the workflow goes on. We start from considering our daily development environment.
Developers work on a computer, which is usually a laptop or desktop: For developers who are interested in the C++ source code, you can build C++ source code into HTML pages using [Woboq codebrowser](https://github.com/woboq/woboq_codebrowser).
<img src="doc/paddle-development-environment.png" width=500 /> - The following command builds PaddlePaddle, generates HTML pages from C++ source code, and writes HTML pages into `$HOME/woboq_out` on the host:
or, they might rely on a more sophisticated box (like with GPUs): ```bash
./paddle/scripts/paddle_docker_build.sh html
```
<img src="doc/paddle-development-environment-gpu.png" width=500 /> - You can open the generated HTML files in your Web browser. Or, if you want to run a Nginx container to serve them for a wider audience, you can run:
A principle here is that source code lies on the development computer (host) so that editors like Eclipse can parse the source code to support auto-completion. ```
docker run -v $HOME/woboq_out:/usr/share/nginx/html -d -p 8080:80 nginx
```
### Reading source code with woboq codebrowser ## More Options
For developers who are interested in the C++ source code, please use -e "WOBOQ=ON" to enable the building of C++ source code into HTML pages using [Woboq codebrowser](https://github.com/woboq/woboq_codebrowser). ### Build Without Docker
- The following command builds PaddlePaddle, generates HTML pages from C++ source code, and writes HTML pages into `$HOME/woboq_out` on the host: Follow the *Dockerfile* in the paddlepaddle repo to set up your local dev environment and run:
```bash ```bash
docker run -v $PWD:/paddle -v $HOME/woboq_out:/woboq_out -e "WITH_GPU=OFF" -e "WITH_AVX=ON" -e "WITH_TESTING=ON" -e "WOBOQ=ON" paddlepaddle/paddle:latest-dev ./paddle/scripts/paddle_build.sh build
``` ```
- You can open the generated HTML files in your Web browser. Or, if you want to run a Nginx container to serve them for a wider audience, you can run: ### Additional Tasks
``` You can get the help menu for the build scripts by running with no options:
docker run -v $HOME/woboq_out:/usr/share/nginx/html -d -p 8080:80 nginx
```bash
./paddle/scripts/paddle_build.sh
or ./paddle/scripts/paddle_docker_build.sh
``` ```
#!/usr/bin/env bash
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#=================================================
# Utils
#=================================================
function print_usage() {
RED='\033[0;31m'
BLUE='\033[0;34m'
BOLD='\033[1m'
NONE='\033[0m'
echo -e "\n${RED}Usage${NONE}:
${BOLD}$0${NONE} [OPTION]"
echo -e "\n${RED}Options${NONE}:
${BLUE}build${NONE}: run build for x86 platform
${BLUE}build_android${NONE}: run build for android platform
${BLUE}build_ios${NONE}: run build for ios platform
${BLUE}test${NONE}: run all unit tests
${BLUE}bind_test${NONE}: parallel tests bind to different GPU
${BLUE}doc${NONE}: generate paddle documents
${BLUE}html${NONE}: convert C++ source code into HTML
${BLUE}dockerfile${NONE}: generate paddle release dockerfile
${BLUE}capi${NONE}: generate paddle CAPI package
${BLUE}fluid_inference_lib${NONE}: deploy fluid inference library
${BLUE}check_style${NONE}: run code style check
"
}
function init() {
PADDLE_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}")/../../" && pwd )"
}
function cmake_gen() {
mkdir -p ${PADDLE_ROOT}/build
cd ${PADDLE_ROOT}/build
# build script will not fail if *.deb does not exist
rm *.deb 2>/dev/null || true
# delete previous built whl packages
rm -rf python/dist 2>/dev/null || true
# Support build for all python versions, currently
# including cp27-cp27m and cp27-cp27mu.
PYTHON_FLAGS=""
if [ "$1" != "" ]; then
echo "using python abi: $1"
if [ "$1" == "cp27-cp27m" ]; then
export LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs2/lib:${LD_LIBRARY_PATH#/opt/_internal/cpython-2.7.11-ucs4/lib:}
export PATH=/opt/python/cp27-cp27m/bin/:${PATH}
PYTHON_FLAGS="-DPYTHON_EXECUTABLE:FILEPATH=/opt/python/cp27-cp27m/bin/python
-DPYTHON_INCLUDE_DIR:PATH=/opt/python/cp27-cp27m/include/python2.7
-DPYTHON_LIBRARIES:FILEPATH=/opt/_internal/cpython-2.7.11-ucs2/lib/libpython2.7.so"
elif [ "$1" == "cp27-cp27mu" ]; then
export LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs4/lib:${LD_LIBRARY_PATH#/opt/_internal/cpython-2.7.11-ucs2/lib:}
export PATH=/opt/python/cp27-cp27mu/bin/:${PATH}
PYTHON_FLAGS="-DPYTHON_EXECUTABLE:FILEPATH=/opt/python/cp27-cp27mu/bin/python
-DPYTHON_INCLUDE_DIR:PATH=/opt/python/cp27-cp27mu/include/python2.7
-DPYTHON_LIBRARIES:FILEPATH=/opt/_internal/cpython-2.7.11-ucs4/lib/libpython2.7.so"
fi
fi
cat <<EOF
========================================
Configuring cmake in /paddle/build ...
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE:-Release}
${PYTHON_FLAGS}
-DWITH_DSO=ON
-DWITH_DOC=${WITH_DOC:-OFF}
-DWITH_GPU=${WITH_GPU:-OFF}
-DWITH_AMD_GPU=${WITH_AMD_GPU:-OFF}
-DWITH_DISTRIBUTE=${WITH_DISTRIBUTE:-OFF}
-DWITH_MKL=${WITH_MKL:-ON}
-DWITH_AVX=${WITH_AVX:-OFF}
-DWITH_GOLANG=${WITH_GOLANG:-OFF}
-DCUDA_ARCH_NAME=${CUDA_ARCH_NAME:-All}
-DWITH_SWIG_PY=ON
-DWITH_C_API=${WITH_C_API:-OFF}
-DWITH_PYTHON=${WITH_PYTHON:-ON}
-DWITH_SWIG_PY=${WITH_SWIG_PY:-ON}
-DCUDNN_ROOT=/usr/
-DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-ON}
-DWITH_TESTING=${WITH_TESTING:-ON}
-DWITH_FAST_BUNDLE_TEST=ON
-DCMAKE_MODULE_PATH=/opt/rocm/hip/cmake
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON
-DWITH_FLUID_ONLY=${WITH_FLUID_ONLY:-OFF}
========================================
EOF
# Disable UNITTEST_USE_VIRTUALENV in docker because
# docker environment is fully controlled by this script.
# See /Paddle/CMakeLists.txt, UNITTEST_USE_VIRTUALENV option.
cmake .. \
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE:-Release} \
${PYTHON_FLAGS} \
-DWITH_DSO=ON \
-DWITH_DOC=${WITH_DOC:-OFF} \
-DWITH_GPU=${WITH_GPU:-OFF} \
-DWITH_AMD_GPU=${WITH_AMD_GPU:-OFF} \
-DWITH_DISTRIBUTE=${WITH_DISTRIBUTE:-OFF} \
-DWITH_MKL=${WITH_MKL:-ON} \
-DWITH_AVX=${WITH_AVX:-OFF} \
-DWITH_GOLANG=${WITH_GOLANG:-OFF} \
-DCUDA_ARCH_NAME=${CUDA_ARCH_NAME:-All} \
-DWITH_SWIG_PY=${WITH_SWIG_PY:-ON} \
-DWITH_C_API=${WITH_C_API:-OFF} \
-DWITH_PYTHON=${WITH_PYTHON:-ON} \
-DCUDNN_ROOT=/usr/ \
-DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-ON} \
-DWITH_TESTING=${WITH_TESTING:-ON} \
-DWITH_FAST_BUNDLE_TEST=ON \
-DCMAKE_MODULE_PATH=/opt/rocm/hip/cmake \
-DWITH_FLUID_ONLY=${WITH_FLUID_ONLY:-OFF} \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON
}
function abort(){
echo "Your change doesn't follow PaddlePaddle's code style." 1>&2
echo "Please use pre-commit to check what is wrong." 1>&2
exit 1
}
function check_style() {
trap 'abort' 0
set -e
# install glide
curl https://glide.sh/get | bash
eval "$(GIMME_GO_VERSION=1.8.3 gimme)"
# set up go environment for running gometalinter
mkdir -p $GOPATH/src/github.com/PaddlePaddle/
ln -sf ${PADDLE_ROOT} $GOPATH/src/github.com/PaddlePaddle/Paddle
cd $GOPATH/src/github.com/PaddlePaddle/Paddle/go; glide install; cd -
go get github.com/alecthomas/gometalinter
gometalinter --install
cd ${PADDLE_ROOT}
export PATH=/usr/bin:$PATH
pre-commit install
clang-format --version
if ! pre-commit run -a ; then
git diff
exit 1
fi
trap : 0
}
#=================================================
# Build
#=================================================
function build() {
mkdir -p ${PADDLE_ROOT}/build
cd ${PADDLE_ROOT}/build
cat <<EOF
============================================
Building in /paddle/build ...
============================================
EOF
make clean
make -j `nproc`
}
function build_android() {
if [ $ANDROID_ABI == "arm64-v8a" ]; then
ANDROID_ARCH=arm64
if [ $ANDROID_API -lt 21 ]; then
echo "Warning: arm64-v8a requires ANDROID_API >= 21."
ANDROID_API=21
fi
else # armeabi, armeabi-v7a
ANDROID_ARCH=arm
fi
ANDROID_STANDALONE_TOOLCHAIN=$ANDROID_TOOLCHAINS_DIR/$ANDROID_ARCH-android-$ANDROID_API
cat <<EOF
============================================
Generating the standalone toolchain ...
${ANDROID_NDK_HOME}/build/tools/make-standalone-toolchain.sh
--arch=$ANDROID_ARCH
--platform=android-$ANDROID_API
--install-dir=${ANDROID_STANDALONE_TOOLCHAIN}
============================================
EOF
${ANDROID_NDK_HOME}/build/tools/make-standalone-toolchain.sh \
--arch=$ANDROID_ARCH \
--platform=android-$ANDROID_API \
--install-dir=$ANDROID_STANDALONE_TOOLCHAIN
BUILD_ROOT=${PADDLE_ROOT}/build
DEST_ROOT={PADDLE_ROOT}/install
mkdir -p $BUILD_ROOT
cd $BUILD_ROOT
if [ $ANDROID_ABI == "armeabi-v7a" ]; then
cmake -DCMAKE_SYSTEM_NAME=Android \
-DANDROID_STANDALONE_TOOLCHAIN=$ANDROID_STANDALONE_TOOLCHAIN \
-DANDROID_ABI=$ANDROID_ABI \
-DANDROID_ARM_NEON=ON \
-DANDROID_ARM_MODE=ON \
-DHOST_C_COMPILER=/usr/bin/gcc \
-DHOST_CXX_COMPILER=/usr/bin/g++ \
-DCMAKE_INSTALL_PREFIX=$DEST_ROOT \
-DCMAKE_BUILD_TYPE=MinSizeRel \
-DUSE_EIGEN_FOR_BLAS=ON \
-DWITH_C_API=ON \
-DWITH_SWIG_PY=OFF \
-DWITH_STYLE_CHECK=OFF \
..
elif [ $ANDROID_ABI == "arm64-v8a" ]; then
cmake -DCMAKE_SYSTEM_NAME=Android \
-DANDROID_STANDALONE_TOOLCHAIN=$ANDROID_STANDALONE_TOOLCHAIN \
-DANDROID_ABI=$ANDROID_ABI \
-DANDROID_ARM_MODE=ON \
-DHOST_C_COMPILER=/usr/bin/gcc \
-DHOST_CXX_COMPILER=/usr/bin/g++ \
-DCMAKE_INSTALL_PREFIX=$DEST_ROOT \
-DCMAKE_BUILD_TYPE=MinSizeRel \
-DUSE_EIGEN_FOR_BLAS=OFF \
-DWITH_C_API=ON \
-DWITH_SWIG_PY=OFF \
-DWITH_STYLE_CHECK=OFF \
..
elif [ $ANDROID_ABI == "armeabi" ]; then
cmake -DCMAKE_SYSTEM_NAME=Android \
-DANDROID_STANDALONE_TOOLCHAIN=$ANDROID_STANDALONE_TOOLCHAIN \
-DANDROID_ABI=$ANDROID_ABI \
-DANDROID_ARM_MODE=ON \
-DHOST_C_COMPILER=/usr/bin/gcc \
-DHOST_CXX_COMPILER=/usr/bin/g++ \
-DCMAKE_INSTALL_PREFIX=$DEST_ROOT \
-DCMAKE_BUILD_TYPE=MinSizeRel \
-DWITH_C_API=ON \
-DWITH_SWIG_PY=OFF \
-DWITH_STYLE_CHECK=OFF \
..
else
echo "Invalid ANDROID_ABI: $ANDROID_ABI"
fi
cat <<EOF
============================================
Building in $BUILD_ROOT ...
============================================
EOF
make -j `nproc`
make install -j `nproc`
}
function build_ios() {
# Create the build directory for CMake.
mkdir -p ${PADDLE_ROOT}/build
cd ${PADDLE_ROOT}/build
# Compile paddle binaries
cmake .. \
-DCMAKE_SYSTEM_NAME=iOS \
-DIOS_PLATFORM=OS \
-DCMAKE_OSX_ARCHITECTURES="arm64" \
-DWITH_C_API=ON \
-DUSE_EIGEN_FOR_BLAS=ON \
-DWITH_TESTING=OFF \
-DWITH_SWIG_PY=OFF \
-DWITH_STYLE_CHECK=OFF \
-DCMAKE_BUILD_TYPE=Release
make -j 2
}
function run_test() {
mkdir -p ${PADDLE_ROOT}/build
cd ${PADDLE_ROOT}/build
if [ ${WITH_TESTING:-ON} == "ON" ] ; then
cat <<EOF
========================================
Running unit tests ...
========================================
EOF
ctest --output-on-failure
# make install should also be test when unittest
make install -j `nproc`
pip install /usr/local/opt/paddle/share/wheels/*.whl
if [[ ${WITH_FLUID_ONLY:-OFF} == "OFF" ]] ; then
paddle version
fi
fi
}
function bind_test() {
# the number of process to run tests
NUM_PROC=6
# calculate and set the memory usage for each process
MEM_USAGE=$(printf "%.2f" `echo "scale=5; 1.0 / $NUM_PROC" | bc`)
export FLAGS_fraction_of_gpu_memory_to_use=$MEM_USAGE
# get the CUDA device count
CUDA_DEVICE_COUNT=$(nvidia-smi -L | wc -l)
for (( i = 0; i < $NUM_PROC; i++ )); do
cuda_list=()
for (( j = 0; j < $CUDA_DEVICE_COUNT; j++ )); do
s=$[i+j]
n=$[s%CUDA_DEVICE_COUNT]
if [ $j -eq 0 ]; then
cuda_list=("$n")
else
cuda_list="$cuda_list,$n"
fi
done
echo $cuda_list
# CUDA_VISIBLE_DEVICES http://acceleware.com/blog/cudavisibledevices-masking-gpus
# ctest -I https://cmake.org/cmake/help/v3.0/manual/ctest.1.html?highlight=ctest
env CUDA_VISIBLE_DEVICES=$cuda_list ctest -I $i,,$NUM_PROC --output-on-failure &
done
wait
}
function gen_docs() {
mkdir -p ${PADDLE_ROOT}/build
cd ${PADDLE_ROOT}/build
cat <<EOF
========================================
Building documentation ...
In /paddle/build
========================================
EOF
cmake .. \
-DWITH_DOC=ON \
-DWITH_GPU=OFF \
-DWITH_AVX=${WITH_AVX:-ON} \
-DWITH_SWIG_PY=ON \
-DWITH_STYLE_CHECK=OFF
make -j `nproc` paddle_docs paddle_apis
}
function gen_html() {
cat <<EOF
========================================
Converting C++ source code into HTML ...
========================================
EOF
export WOBOQ_OUT=${PADDLE_ROOT}/build/woboq_out
mkdir -p $WOBOQ_OUT
cp -rv /woboq/data $WOBOQ_OUT/../data
/woboq/generator/codebrowser_generator \
-b ${PADDLE_ROOT}/build \
-a \
-o $WOBOQ_OUT \
-p paddle:${PADDLE_ROOT}
/woboq/indexgenerator/codebrowser_indexgenerator $WOBOQ_OUT
}
function gen_dockerfile() {
# Set BASE_IMAGE according to env variables
if [[ ${WITH_GPU} == "ON" ]]; then
BASE_IMAGE="nvidia/cuda:8.0-cudnn5-runtime-ubuntu16.04"
else
BASE_IMAGE="ubuntu:16.04"
fi
DOCKERFILE_GPU_ENV=""
DOCKERFILE_CUDNN_DSO=""
if [[ ${WITH_GPU:-OFF} == 'ON' ]]; then
DOCKERFILE_GPU_ENV="ENV LD_LIBRARY_PATH /usr/lib/x86_64-linux-gnu:\${LD_LIBRARY_PATH}"
DOCKERFILE_CUDNN_DSO="RUN ln -s /usr/lib/x86_64-linux-gnu/libcudnn.so.5 /usr/lib/x86_64-linux-gnu/libcudnn.so"
fi
cat <<EOF
========================================
Generate /paddle/build/Dockerfile ...
========================================
EOF
cat > ${PADDLE_ROOT}/build/Dockerfile <<EOF
FROM ${BASE_IMAGE}
MAINTAINER PaddlePaddle Authors <paddle-dev@baidu.com>
ENV HOME /root
EOF
if [[ ${WITH_GPU} == "ON" ]]; then
NCCL_DEPS="apt-get install -y libnccl2=2.1.2-1+cuda8.0 libnccl-dev=2.1.2-1+cuda8.0 &&"
else
NCCL_DEPS=""
fi
if [[ ${WITH_FLUID_ONLY:-OFF} == "OFF" ]]; then
PADDLE_VERSION="paddle version"
CMD='"paddle", "version"'
else
PADDLE_VERSION="true"
CMD='"true"'
fi
cat >> /paddle/build/Dockerfile <<EOF
ADD python/dist/*.whl /
# run paddle version to install python packages first
RUN apt-get update &&\
${NCCL_DEPS}\
apt-get install -y wget python-pip dmidecode python-tk && pip install -U pip==9.0.3 && \
pip install /*.whl; apt-get install -f -y && \
apt-get clean -y && \
rm -f /*.whl && \
${PADDLE_VERSION} && \
ldconfig
${DOCKERFILE_CUDNN_DSO}
${DOCKERFILE_GPU_ENV}
ENV NCCL_LAUNCH_MODE PARALLEL
ADD go/cmd/pserver/pserver /usr/bin/
ADD go/cmd/master/master /usr/bin/
# default command shows the paddle version and exit
CMD [${CMD}]
EOF
}
function gen_capi_package() {
if [[ ${WITH_C_API} == "ON" ]]; then
install_prefix="${PADDLE_ROOT}/build/capi_output"
rm -rf $install_prefix
make DESTDIR="$install_prefix" install
cd $install_prefix/usr/local
ls | egrep -v "^Found.*item$" | xargs tar -cf ${PADDLE_ROOT}/build/paddle.tgz
fi
}
function gen_fluid_inference_lib() {
if [ ${WITH_C_API:-OFF} == "OFF" ] ; then
cat <<EOF
========================================
Deploying fluid inference library ...
========================================
EOF
make inference_lib_dist
fi
}
function main() {
local CMD=$1
init
case $CMD in
build)
cmake_gen ${PYTHON_ABI:-""}
build
;;
build_android)
build_android
;;
build_ios)
build_ios
;;
test)
run_test
;;
bind_test)
bind_test
;;
doc)
gen_docs
;;
html)
gen_html
;;
dockerfile)
gen_dockerfile
;;
capi)
cmake_gen ${PYTHON_ABI:-""}
gen_capi_package
;;
fluid_inference_lib)
cmake_gen ${PYTHON_ABI:-""}
gen_fluid_inference_lib
;;
check_style)
check_style
;;
*)
print_usage
exit 0
;;
esac
}
main $@
#!/usr/bin/env bash
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
function container_running() {
name=$1
docker ps -a --format "{{.Names}}" | grep "${name}" > /dev/null
return $?
}
function start_build_docker() {
docker pull $IMG
if container_running "${CONTAINER_ID}"; then
docker stop "${CONTAINER_ID}" 1>/dev/null
docker rm -f "${CONTAINER_ID}" 1>/dev/null
fi
DOCKER_ENV=$(cat <<EOL
-e FLAGS_fraction_of_gpu_memory_to_use=0.15 \
-e CTEST_OUTPUT_ON_FAILURE=1 \
-e CTEST_PARALLEL_LEVEL=5 \
-e WITH_GPU=ON \
-e WITH_TESTING=ON \
-e WITH_C_API=OFF \
-e WITH_COVERAGE=ON \
-e COVERALLS_UPLOAD=ON \
-e WITH_DEB=OFF \
-e CMAKE_BUILD_TYPE=RelWithDebInfo \
-e PADDLE_FRACTION_GPU_MEMORY_TO_USE=0.15 \
-e CUDA_VISIBLE_DEVICES=0,1 \
-e WITH_DISTRIBUTE=ON \
-e RUN_TEST=ON
EOL
)
set -x
nvidia-docker run -it \
-d \
--name $CONTAINER_ID \
${DOCKER_ENV} \
-v $PADDLE_ROOT:/paddle \
-w /paddle \
$IMG \
/bin/bash
set +x
}
function main() {
DOCKER_REPO="paddlepaddle/paddle"
VERSION="latest-dev"
CONTAINER_ID="${USER}_paddle_dev"
PADDLE_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}")/../../" && pwd )"
if [ "$1" == "build_android" ]; then
CONTAINER_ID="${USER}_paddle_dev_android"
VERSION="latest-dev-android"
fi
IMG=${DOCKER_REPO}:${VERSION}
case $1 in
start)
start_build_docker
;;
build_android)
start_build_docker
docker exec ${CONTAINER_ID} bash -c "./paddle/scripts/paddle_build.sh $@"
;;
*)
if container_running "${CONTAINER_ID}"; then
docker exec ${CONTAINER_ID} bash -c "./paddle/scripts/paddle_build.sh $@"
else
echo "Please start container first, with command:"
echo "$0 start"
fi
;;
esac
}
main $@
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册