diff --git a/doc/v2/howto/cluster/multi_cluster/index_en.rst b/doc/v2/howto/cluster/multi_cluster/index_en.rst index dac7aaef085c80851c1bbb89250faf2151de4ca6..b69bd5b2dbf1967d65558da06812d76f431c1d5a 100644 --- a/doc/v2/howto/cluster/multi_cluster/index_en.rst +++ b/doc/v2/howto/cluster/multi_cluster/index_en.rst @@ -1,19 +1,35 @@ Use different clusters ====================== -PaddlePaddle supports running jobs on several platforms including: -- `Kubernetes `_ open-source system for automating deployment, scaling, and management of containerized applications from Google. -- `OpenMPI `_ Mature high performance parallel computing framework. -- `Fabric `_ A cluster management tool. Write scripts to submit jobs or manage the cluster. +The user's cluster environment is not the same. To facilitate everyone's deployment, we provide a variety of cluster deployment methods to facilitate the submission of cluster training tasks, which will be introduced as follows: -We'll introduce cluster job management on these platforms. The examples can be found under `cluster_train_v2 `_ . +`Kubernetes `_ is a scheduling framework of Google open source container cluster, supporting a complete cluster solution for large-scale cluster production environment. The following guidelines show PaddlePaddle's support for Kubernetes: -These cluster platforms provide API or environment variables for training processes, when the job is dispatched to different nodes. Like node ID, IP or total number of nodes etc. +.. toctree:: + :maxdepth: 1 + + k8s_cn.md + k8s_distributed_cn.md + +`OpenMPI `_ is a mature high-performance parallel computing framework, which is widely used in the field of HPC. The following guide describes how to use OpenMPI to build PaddlePaddle's cluster training task: .. toctree:: :maxdepth: 1 - fabric_en.md - openmpi_en.md - k8s_en.md - k8s_aws_en.md + openmpi_cn.md + +`Fabric `_ is a convenient tool for program deployment and management. We provide a way to deploy and manage with Fabric. If you want to know more about it, please read the following guidelines: + +.. toctree:: + :maxdepth: 1 + + fabric_cn.md + +We also support the deployment of PaddlePaddle on AWS. Learn more about: + +.. toctree:: + :maxdepth: 1 + + k8s_aws_cn.md + +The examples can be found under `cluster_train_v2 `_ . \ No newline at end of file diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 513e720fd099bcd898a6c73afd1a3a16f6f53aab..766bf0ab0c1c50146ad3f6e048738209428707b9 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -226,15 +226,15 @@ static bool has_fetch_operators( } void Executor::Run(const ProgramDesc& program, Scope* scope, - std::map& feed_targets, - std::map& fetch_targets, + std::map* feed_targets, + std::map* fetch_targets, bool create_vars, const std::string& feed_holder_name, const std::string& fetch_holder_name) { platform::RecordBlock b(kProgramId); bool has_feed_ops = - has_feed_operators(program.Block(0), feed_targets, feed_holder_name); + has_feed_operators(program.Block(0), *feed_targets, feed_holder_name); bool has_fetch_ops = - has_fetch_operators(program.Block(0), fetch_targets, fetch_holder_name); + has_fetch_operators(program.Block(0), *fetch_targets, fetch_holder_name); ProgramDesc* copy_program = const_cast(&program); if (!has_feed_ops || !has_fetch_ops) { @@ -250,7 +250,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, feed_holder->SetPersistable(true); int i = 0; - for (auto& feed_target : feed_targets) { + for (auto& feed_target : (*feed_targets)) { std::string var_name = feed_target.first; VLOG(3) << "feed target's name: " << var_name; @@ -273,7 +273,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, fetch_holder->SetPersistable(true); int i = 0; - for (auto& fetch_target : fetch_targets) { + for (auto& fetch_target : (*fetch_targets)) { std::string var_name = fetch_target.first; VLOG(3) << "fetch target's name: " << var_name; @@ -361,16 +361,16 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, void Executor::RunPreparedContext( ExecutorPrepareContext* ctx, Scope* scope, - std::map& feed_targets, - std::map& fetch_targets, bool create_vars, + std::map* feed_targets, + std::map* fetch_targets, bool create_vars, const std::string& feed_holder_name, const std::string& fetch_holder_name) { auto& global_block = ctx->prog_.Block(ctx->block_id_); PADDLE_ENFORCE( - has_feed_operators(global_block, feed_targets, feed_holder_name), + has_feed_operators(global_block, *feed_targets, feed_holder_name), "Program in ExecutorPrepareContext should has feed_ops."); PADDLE_ENFORCE( - has_fetch_operators(global_block, fetch_targets, fetch_holder_name), + has_fetch_operators(global_block, *fetch_targets, fetch_holder_name), "Program in the prepared context should has fetch_ops."); // map the data of feed_targets to feed_holder @@ -378,8 +378,8 @@ void Executor::RunPreparedContext( if (op->Type() == kFeedOpType) { std::string feed_target_name = op->Output("Out")[0]; int idx = boost::get(op->GetAttr("col")); - SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name, - idx); + SetFeedVariable(scope, *(*feed_targets)[feed_target_name], + feed_holder_name, idx); } } @@ -390,7 +390,7 @@ void Executor::RunPreparedContext( if (op->Type() == kFetchOpType) { std::string fetch_target_name = op->Input("X")[0]; int idx = boost::get(op->GetAttr("col")); - *fetch_targets[fetch_target_name] = + *(*fetch_targets)[fetch_target_name] = GetFetchVariable(*scope, fetch_holder_name, idx); } } diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 43defdacf2a1c2f59cf3af2461ae6cfc4c61f5be..4a3d637e2d79f8cbd83412eea2d73e4b497ef1e7 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -55,8 +55,8 @@ class Executor { bool create_local_scope = true, bool create_vars = true); void Run(const ProgramDesc& program, Scope* scope, - std::map& feed_targets, - std::map& fetch_targets, + std::map* feed_targets, + std::map* fetch_targets, bool create_vars = true, const std::string& feed_holder_name = "feed", const std::string& fetch_holder_name = "fetch"); @@ -74,8 +74,8 @@ class Executor { bool create_vars = true); void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, - std::map& feed_targets, - std::map& fetch_targets, + std::map* feed_targets, + std::map* fetch_targets, bool create_vars = true, const std::string& feed_holder_name = "feed", const std::string& fetch_holder_name = "fetch"); diff --git a/paddle/fluid/inference/engine.h b/paddle/fluid/inference/engine.h new file mode 100644 index 0000000000000000000000000000000000000000..0633c052e4dc61d495b9acacb9a7ab6c941a267a --- /dev/null +++ b/paddle/fluid/inference/engine.h @@ -0,0 +1,53 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/framework.pb.h" + +namespace paddle { +namespace inference { + +/* + * EngineBase is the base class of all inference engines. An inference engine + * takes a paddle program as input, and outputs the result in fluid Tensor + * format. It can be used to optimize performance of computation sub-blocks, for + * example, break down the original block into sub-blocks and execute each + * sub-blocks in different engines. + * + * For example: + * When inference, the resnet50 model can put most of the model into subgraph + * and run it on a TensorRT engine. + * + * There are several engines such as TensorRT and other frameworks, so an + * EngineBase is put forward to give an unified interface for all the + * different engine implemention. + */ +class EngineBase { + public: + using DescType = ::paddle::framework::proto::BlockDesc; + + // Build the model and do some preparation, for example, in TensorRT, run + // createInferBuilder, buildCudaEngine. + virtual void Build(const DescType& paddle_model) = 0; + + // Execute the engine, that will run the inference network. + virtual void Execute(int batch_size) = 0; + + virtual ~EngineBase() {} + +}; // class EngineBase + +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/CMakeLists.txt b/paddle/fluid/inference/tensorrt/CMakeLists.txt index e39c0daac76e0993382868289f66351da3d16f8f..4b5866ad5dd5c2e33830b61e575ab92954cecb39 100644 --- a/paddle/fluid/inference/tensorrt/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/CMakeLists.txt @@ -1 +1,4 @@ -nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader) +if(WITH_TESTING) + nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader) + nv_test(test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda) +endif() diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc new file mode 100644 index 0000000000000000000000000000000000000000..276502e4999df3aecf246fcb9591c59ea9d3f02b --- /dev/null +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -0,0 +1,134 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/engine.h" + +#include +#include +#include +#include "paddle/fluid/inference/tensorrt/helper.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +void TensorRTEngine::Build(const DescType& paddle_model) { + PADDLE_ENFORCE(false, "not implemented"); +} + +void TensorRTEngine::Execute(int batch_size) { + infer_context_->enqueue(batch_size, buffers_.data(), *stream_, nullptr); + cudaStreamSynchronize(*stream_); +} + +TensorRTEngine::~TensorRTEngine() { + // clean buffer + for (auto& buffer : buffers_) { + if (buffer != nullptr) { + PADDLE_ENFORCE_EQ(0, cudaFree(buffer)); + buffer = nullptr; + } + } +} + +void TensorRTEngine::FreezeNetwork() { + PADDLE_ENFORCE(infer_builder_ != nullptr, + "Call InitNetwork first to initialize network."); + PADDLE_ENFORCE(infer_network_ != nullptr, + "Call InitNetwork first to initialize network."); + // build engine. + infer_builder_->setMaxBatchSize(max_batch_); + infer_builder_->setMaxWorkspaceSize(max_workspace_); + + infer_engine_.reset(infer_builder_->buildCudaEngine(*infer_network_)); + PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!"); + + infer_context_.reset(infer_engine_->createExecutionContext()); + + // allocate GPU buffers. + buffers_.resize(buffer_sizes_.size(), nullptr); + for (auto& item : buffer_sizes_) { + if (item.second == 0) { + auto slot_offset = infer_engine_->getBindingIndex(item.first.c_str()); + item.second = kDataTypeSize[static_cast( + infer_engine_->getBindingDataType(slot_offset))] * + AccumDims(infer_engine_->getBindingDimensions(slot_offset)); + } + PADDLE_ENFORCE_EQ(0, cudaMalloc(&buffer(item.first), item.second)); + } +} + +nvinfer1::ITensor* TensorRTEngine::DeclareInput(const std::string& name, + nvinfer1::DataType dtype, + const nvinfer1::Dims& dim) { + PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate input name %s", + name); + + PADDLE_ENFORCE(infer_network_ != nullptr, "should initnetwork first"); + auto* input = infer_network_->addInput(name.c_str(), dtype, dim); + PADDLE_ENFORCE(input, "infer network add input %s failed", name); + + buffer_sizes_[name] = kDataTypeSize[static_cast(dtype)] * AccumDims(dim); + return input; +} + +void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer* layer, int offset, + const std::string& name) { + PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s", + name); + + auto* output = layer->getOutput(offset); + PADDLE_ENFORCE(output != nullptr); + output->setName(name.c_str()); + infer_network_->markOutput(*output); + // output buffers' size can only be decided latter, set zero here to mark this + // and will reset latter. + buffer_sizes_[name] = 0; +} + +void* TensorRTEngine::GetOutputInGPU(const std::string& name) { + return buffer(name); +} + +void TensorRTEngine::GetOutputInCPU(const std::string& name, void* dst, + size_t max_size) { + // determine data size + auto it = buffer_sizes_.find(name); + PADDLE_ENFORCE(it != buffer_sizes_.end()); + PADDLE_ENFORCE_GT(it->second, 0); + PADDLE_ENFORCE_GE(max_size, it->second); + + PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buffer(name), it->second, + cudaMemcpyDeviceToHost, *stream_)); +} + +void*& TensorRTEngine::buffer(const std::string& name) { + PADDLE_ENFORCE(infer_engine_ != nullptr, "call FreezeNetwork first."); + auto it = buffer_sizes_.find(name); + PADDLE_ENFORCE(it != buffer_sizes_.end()); + auto slot_offset = infer_engine_->getBindingIndex(name.c_str()); + return buffers_[slot_offset]; +} + +void TensorRTEngine::SetInputFromCPU(const std::string& name, void* data, + size_t size) { + void* buf = buffer(name); + PADDLE_ENFORCE_EQ( + 0, cudaMemcpyAsync(buf, data, size, cudaMemcpyHostToDevice, *stream_)); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h new file mode 100644 index 0000000000000000000000000000000000000000..ff853455b8bf0c9db36353f3cbe6be0ed4948fb3 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -0,0 +1,144 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include "paddle/fluid/inference/engine.h" +#include "paddle/fluid/inference/tensorrt/helper.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * TensorRT Engine. + * + * There are two alternative ways to use it, one is to build from a paddle + * protobuf model, another way is to manully construct the network. + */ +class TensorRTEngine : public EngineBase { + public: + // Weight is model parameter. + class Weight { + public: + Weight(nvinfer1::DataType dtype, void* value, int num_elem) { + w_.type = dtype; + w_.values = value; + w_.count = num_elem; + } + const nvinfer1::Weights& get() { return w_; } + + private: + nvinfer1::Weights w_; + }; + + TensorRTEngine(int max_batch, int max_workspace, cudaStream_t* stream, + nvinfer1::ILogger& logger = NaiveLogger::Global()) + : max_batch_(max_batch), + max_workspace_(max_workspace), + stream_(stream), + logger_(logger) {} + + virtual ~TensorRTEngine(); + + // TODO(Superjomn) implement it later when graph segmentation is supported. + virtual void Build(const DescType& paddle_model) override; + + virtual void Execute(int batch_size) override; + + // Initialize the inference network, so that TensorRT layers can add to this + // network. + void InitNetwork() { + infer_builder_.reset(createInferBuilder(logger_)); + infer_network_.reset(infer_builder_->createNetwork()); + } + // After finishing adding ops, freeze this network and creates the executation + // environment. + void FreezeNetwork(); + + // Add an input and set its name, data type and dimention. + nvinfer1::ITensor* DeclareInput(const std::string& name, + nvinfer1::DataType dtype, + const nvinfer1::Dims& dim); + // Set the offset-th output from a layer as the network's output, and set its + // name. + void DeclareOutput(const nvinfer1::ILayer* layer, int offset, + const std::string& name); + + // GPU memory address for an ITensor with specific name. One can operate on + // these memory directly for acceleration, for example, output the converted + // data directly to the buffer to save data copy overhead. + // NOTE this should be used after calling `FreezeNetwork`. + void*& buffer(const std::string& name); + + // Fill an input from CPU memory with name and size. + void SetInputFromCPU(const std::string& name, void* data, size_t size); + // TODO(Superjomn) is this method necessary given that buffer(xxx) can be + // accessed directly. Fill an input from GPU memory with name and size. + void SetInputFromGPU(const std::string& name, void* data, size_t size); + // Get an output called name, the output of tensorrt is in GPU, so this method + // will just return the output's GPU memory address. + void* GetOutputInGPU(const std::string& name); + // LOW EFFICENCY! Get output to CPU, this will trigger a memory copy from GPU + // to CPU. + void GetOutputInCPU(const std::string& name, void* dst, size_t max_size); + + nvinfer1::ICudaEngine* engine() { return infer_engine_.get(); } + nvinfer1::INetworkDefinition* network() { return infer_network_.get(); } + + private: + // the max batch size + int max_batch_; + // the max memory size the engine uses + int max_workspace_; + cudaStream_t* stream_; + nvinfer1::ILogger& logger_; + + std::vector buffers_; + // max data size for the buffers. + std::unordered_map buffer_sizes_; + + // TensorRT related internal members + template + struct Destroyer { + void operator()(T* x) { x->destroy(); } + }; + template + using infer_ptr = std::unique_ptr>; + infer_ptr infer_builder_; + infer_ptr infer_network_; + infer_ptr infer_engine_; + infer_ptr infer_context_; +}; // class TensorRTEngine + +// Add an layer__ into engine__ with args ARGS. +// For example: +// TRT_ENGINE_ADD_LAYER(xxx, FullyConnected, input, dim, weights, bias) +// +// Reference +// https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#charRNN_define_network +// +// will add a fully connected layer into the engine. +// TensorRT has too many layers, so that is not wise to add member functions for +// them, and an macro like this is more extensible when underlying TensorRT +// library add new layer supports. +#define TRT_ENGINE_ADD_LAYER(engine__, layer__, ARGS...) \ + engine__->network()->add##layer__(ARGS); + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/helper.h b/paddle/fluid/inference/tensorrt/helper.h new file mode 100644 index 0000000000000000000000000000000000000000..796283d325ceb84c733eff5c119b808300bca069 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/helper.h @@ -0,0 +1,88 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include +#include +#include +#include "paddle/fluid/platform/dynload/tensorrt.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +namespace dy = paddle::platform::dynload; + +static size_t AccumDims(nvinfer1::Dims dims) { + size_t num = dims.nbDims == 0 ? 0 : 1; + for (int i = 0; i < dims.nbDims; i++) { + PADDLE_ENFORCE_GT(dims.d[i], 0); + num *= dims.d[i]; + } + return num; +} + +// TensorRT data type to size +const int kDataTypeSize[] = { + 4, // kFLOAT + 2, // kHALF + 1, // kINT8 + 4 // kINT32 +}; + +// The following two API are implemented in TensorRT's header file, cannot load +// from the dynamic library. So create our own implementation and directly +// trigger the method from the dynamic library. +static nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger& logger) { + return static_cast( + dy::createInferBuilder_INTERNAL(&logger, NV_TENSORRT_VERSION)); +} +static nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger& logger) { + return static_cast( + dy::createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION)); +} + +// A logger for create TensorRT infer builder. +class NaiveLogger : public nvinfer1::ILogger { + public: + void log(nvinfer1::ILogger::Severity severity, const char* msg) override { + switch (severity) { + case Severity::kINFO: + LOG(INFO) << msg; + break; + case Severity::kWARNING: + LOG(WARNING) << msg; + break; + case Severity::kINTERNAL_ERROR: + case Severity::kERROR: + LOG(ERROR) << msg; + break; + default: + break; + } + } + + static nvinfer1::ILogger& Global() { + static nvinfer1::ILogger* x = new NaiveLogger; + return *x; + } + + virtual ~NaiveLogger() override {} +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/test_engine.cc b/paddle/fluid/inference/tensorrt/test_engine.cc new file mode 100644 index 0000000000000000000000000000000000000000..f3dbdf11f217ca35bc541b9bdb95bf0c06377fd3 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/test_engine.cc @@ -0,0 +1,83 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/engine.h" + +#include +#include +#include +#include + +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class TensorRTEngineTest : public ::testing::Test { + protected: + void SetUp() override { + ASSERT_EQ(0, cudaStreamCreate(&stream_)); + engine_ = new TensorRTEngine(1, 1 << 10, &stream_); + engine_->InitNetwork(); + } + + void TearDown() override { + delete engine_; + cudaStreamDestroy(stream_); + } + + protected: + TensorRTEngine* engine_; + cudaStream_t stream_; +}; + +TEST_F(TensorRTEngineTest, add_layer) { + const int size = 1; + + float raw_weight[size] = {2.}; // Weight in CPU memory. + float raw_bias[size] = {3.}; + + LOG(INFO) << "create weights"; + TensorRTEngine::Weight weight(nvinfer1::DataType::kFLOAT, raw_weight, size); + TensorRTEngine::Weight bias(nvinfer1::DataType::kFLOAT, raw_bias, size); + auto* x = engine_->DeclareInput("x", nvinfer1::DataType::kFLOAT, + nvinfer1::DimsCHW{1, 1, 1}); + auto* fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *x, size, + weight.get(), bias.get()); + PADDLE_ENFORCE(fc_layer != nullptr); + + engine_->DeclareOutput(fc_layer, 0, "y"); + LOG(INFO) << "freeze network"; + engine_->FreezeNetwork(); + ASSERT_EQ(engine_->engine()->getNbBindings(), 2); + + // fill in real data + float x_v = 1234; + engine_->SetInputFromCPU("x", (void*)&x_v, 1 * sizeof(float)); + LOG(INFO) << "to execute"; + engine_->Execute(1); + + LOG(INFO) << "to get output"; + // void* y_v = + float y_cpu; + engine_->GetOutputInCPU("y", &y_cpu, sizeof(float)); + + LOG(INFO) << "to checkout output"; + ASSERT_EQ(y_cpu, x_v * 2 + 3); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/test_tensorrt.cc b/paddle/fluid/inference/tensorrt/test_tensorrt.cc index a81a708e7a79225fd52c4b8e081afdcd8fe7e9ad..aed5b5e1a22cbed1256d4f28d0a8a4c29c6cc744 100644 --- a/paddle/fluid/inference/tensorrt/test_tensorrt.cc +++ b/paddle/fluid/inference/tensorrt/test_tensorrt.cc @@ -1,16 +1,16 @@ /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 +http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #include #include diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h index 117472599f7c4874ab05e29c6ecb46fd61d0db9c..af2a7a5620487a10c1df6152fc4e4bf67b150752 100644 --- a/paddle/fluid/inference/tests/test_helper.h +++ b/paddle/fluid/inference/tests/test_helper.h @@ -178,10 +178,10 @@ void TestInference(const std::string& dirname, std::unique_ptr ctx; if (PrepareContext) { ctx = executor.Prepare(*inference_program, 0); - executor.RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets, - CreateVars); + executor.RunPreparedContext(ctx.get(), scope, &feed_targets, + &fetch_targets, CreateVars); } else { - executor.Run(*inference_program, scope, feed_targets, fetch_targets, + executor.Run(*inference_program, scope, &feed_targets, &fetch_targets, CreateVars); } @@ -197,10 +197,10 @@ void TestInference(const std::string& dirname, if (PrepareContext) { // Note: if you change the inference_program, you need to call // executor.Prepare() again to get a new ExecutorPrepareContext. - executor.RunPreparedContext(ctx.get(), scope, feed_targets, - fetch_targets, CreateVars); + executor.RunPreparedContext(ctx.get(), scope, &feed_targets, + &fetch_targets, CreateVars); } else { - executor.Run(*inference_program, scope, feed_targets, fetch_targets, + executor.Run(*inference_program, scope, &feed_targets, &fetch_targets, CreateVars); } } diff --git a/paddle/fluid/operators/beam_search_decode_op.h b/paddle/fluid/operators/beam_search_decode_op.h index 4cb0457d9285e20d4b6a2f9987b7fdb1c6ac157f..3c01f81c83555b985bb6b7a9e3330ab594a62863 100644 --- a/paddle/fluid/operators/beam_search_decode_op.h +++ b/paddle/fluid/operators/beam_search_decode_op.h @@ -223,8 +223,9 @@ void BeamSearchDecoder::ConvertSentenceVectorToLodTensor( sentence_vector_list[src_idx].size()); } - auto cpu_place = new paddle::platform::CPUPlace(); - paddle::platform::CPUDeviceContext cpu_ctx(*cpu_place); + auto cpu_place = std::unique_ptr( + new paddle::platform::CPUPlace()); + paddle::platform::CPUDeviceContext cpu_ctx(*cpu_place.get()); framework::LoD lod; lod.push_back(source_level_lod); diff --git a/paddle/fluid/operators/math/math_function.cc b/paddle/fluid/operators/math/math_function.cc index 44fd739fb1d161c6c7d6ab1cc611c59220280a4e..b5ae41c8f9d7aeb8e410b795fb9fbbd57ec69d4b 100644 --- a/paddle/fluid/operators/math/math_function.cc +++ b/paddle/fluid/operators/math/math_function.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/math_function.h" +#include #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/operators/math/math_function_impl.h" #include "paddle/fluid/platform/float16.h" @@ -161,7 +162,8 @@ void batched_gemm( const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const float16 alpha, const float16* A, const float16* B, const float16 beta, - float16* C, const int batchCount, const int strideA, const int strideB) { + float16* C, const int batchCount, const int64_t strideA, + const int64_t strideB) { PADDLE_THROW("float16 batched_gemm not supported on CPU"); } @@ -172,7 +174,8 @@ void batched_gemm( const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const float alpha, const float* A, const float* B, const float beta, - float* C, const int batchCount, const int strideA, const int strideB) { + float* C, const int batchCount, const int64_t strideA, + const int64_t strideB) { int lda = (transA == CblasNoTrans) ? K : M; int ldb = (transB == CblasNoTrans) ? N : K; int ldc = N; @@ -194,7 +197,8 @@ void batched_gemm( const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const double alpha, const double* A, const double* B, const double beta, - double* C, const int batchCount, const int strideA, const int strideB) { + double* C, const int batchCount, const int64_t strideA, + const int64_t strideB) { int lda = (transA == CblasNoTrans) ? K : M; int ldb = (transB == CblasNoTrans) ? N : K; int ldc = N; @@ -220,7 +224,8 @@ void batched_gemm( const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const float alpha, const float* A, const float* B, const float beta, - float* C, const int batchCount, const int strideA, const int strideB) { + float* C, const int batchCount, const int64_t strideA, + const int64_t strideB) { for (int k = 0; k < batchCount; ++k) { const float* Ak = &A[k * strideA]; const float* Bk = &B[k * strideB]; @@ -235,7 +240,8 @@ void batched_gemm( const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const double alpha, const double* A, const double* B, const double beta, - double* C, const int batchCount, const int strideA, const int strideB) { + double* C, const int batchCount, const int64_t strideA, + const int64_t strideB) { for (int k = 0; k < batchCount; ++k) { const double* Ak = &A[k * strideA]; const double* Bk = &B[k * strideB]; diff --git a/paddle/fluid/operators/math/math_function.cu b/paddle/fluid/operators/math/math_function.cu index 9badf26c9bb80acad029be3d1b63377cef63d929..2aa819625e0f5213a6001908e715bcc73d4747c3 100644 --- a/paddle/fluid/operators/math/math_function.cu +++ b/paddle/fluid/operators/math/math_function.cu @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #define EIGEN_USE_GPU +#include #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function_impl.h" @@ -267,7 +268,8 @@ void batched_gemm( const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const float16 alpha, const float16* A, const float16* B, const float16 beta, - float16* C, const int batchCount, const int strideA, const int strideB) { + float16* C, const int batchCount, const int64_t strideA, + const int64_t strideB) { #if CUDA_VERSION >= 8000 // Note that cublas follows fortran order, so the order is different from // the cblas convention. @@ -278,7 +280,7 @@ void batched_gemm( (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; cublasOperation_t cuTransB = (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const int strideC = M * N; + const int64_t strideC = M * N; const half h_alpha = static_cast(alpha); const half h_beta = static_cast(beta); @@ -303,7 +305,8 @@ void batched_gemm( const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const float alpha, const float* A, const float* B, const float beta, - float* C, const int batchCount, const int strideA, const int strideB) { + float* C, const int batchCount, const int64_t strideA, + const int64_t strideB) { #if CUDA_VERSION >= 8000 // Note that cublas follows fortran order, so the order is different from // the cblas convention. @@ -314,7 +317,7 @@ void batched_gemm( (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; cublasOperation_t cuTransB = (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const int strideC = M * N; + const int64_t strideC = M * N; PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched( context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, @@ -329,7 +332,8 @@ void batched_gemm( const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const double alpha, const double* A, const double* B, const double beta, - double* C, const int batchCount, const int strideA, const int strideB) { + double* C, const int batchCount, const int64_t strideA, + const int64_t strideB) { #if CUDA_VERSION >= 8000 // Note that cublas follows fortran order, so the order is different from // the cblas convention. @@ -340,7 +344,7 @@ void batched_gemm( (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; cublasOperation_t cuTransB = (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const int strideC = M * N; + const int64_t strideC = M * N; PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched( context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, diff --git a/paddle/fluid/operators/math/math_function.h b/paddle/fluid/operators/math/math_function.h index cdbc7bfb37e83c6c2b696ba010277c9eec49f2a8..cdd02974722045457aacdfa517c147751185f332 100644 --- a/paddle/fluid/operators/math/math_function.h +++ b/paddle/fluid/operators/math/math_function.h @@ -26,7 +26,7 @@ limitations under the License. */ #ifndef LAPACK_FOUND extern "C" { -#include +#include // NOLINT int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda, int* ipiv); int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda, @@ -39,6 +39,7 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, #endif #include +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/tensor.h" @@ -78,8 +79,8 @@ template void batched_gemm(const DeviceContext& context, const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const T alpha, const T* A, const T* B, - const T beta, T* C, const int batchCount, const int strideA, - const int strideB); + const T beta, T* C, const int batchCount, + const int64_t strideA, const int64_t strideB); template void gemv(const DeviceContext& context, const bool trans_a, const int M, diff --git a/paddle/fluid/operators/softmax_mkldnn_op.cc b/paddle/fluid/operators/softmax_mkldnn_op.cc index d00bd1447e6114b6000b65799abb566a2a510127..71b541d98f6e0d3e12601c9988ca6ffb8bb7554d 100644 --- a/paddle/fluid/operators/softmax_mkldnn_op.cc +++ b/paddle/fluid/operators/softmax_mkldnn_op.cc @@ -77,7 +77,7 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel { const bool is_test = ctx.Attr("is_test"); if (!is_test) { T threshold = exp(-64); - for (size_t i = 0; i < dst_tz[0] * dst_tz[1]; ++i) { + for (int i = 0; i < dst_tz[0] * dst_tz[1]; ++i) { output_data[i] = output_data[i] < threshold ? threshold : output_data[i]; } diff --git a/paddle/fluid/platform/dynload/cublas.h b/paddle/fluid/platform/dynload/cublas.h index 1ab55d6b9bf8fdbd14c9c2bd978e3e99dba3e73e..81acaff87d3c2025cf0d6185a1590b018bfbd83c 100644 --- a/paddle/fluid/platform/dynload/cublas.h +++ b/paddle/fluid/platform/dynload/cublas.h @@ -14,10 +14,12 @@ #pragma once +#include #include #include #include #include // NOLINT +#include #include "paddle/fluid/platform/dynload/dynamic_loader.h" namespace paddle { @@ -37,14 +39,14 @@ extern void *cublas_dso_handle; #ifdef PADDLE_USE_DSO #define DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name) \ struct DynLoad__##__name { \ + using FUNC_TYPE = decltype(&::__name); \ template \ inline cublasStatus_t operator()(Args... args) { \ - typedef cublasStatus_t (*cublasFunc)(Args...); \ std::call_once(cublas_dso_flag, []() { \ cublas_dso_handle = paddle::platform::dynload::GetCublasDsoHandle(); \ }); \ void *p_##__name = dlsym(cublas_dso_handle, #__name); \ - return reinterpret_cast(p_##__name)(args...); \ + return reinterpret_cast(p_##__name)(args...); \ } \ }; \ extern DynLoad__##__name __name @@ -71,8 +73,8 @@ extern void *cublas_dso_handle; __macro(cublasDgemm_v2); \ __macro(cublasHgemm); \ __macro(cublasSgemmEx); \ - __macro(cublasSgeam_v2); \ - __macro(cublasDgeam_v2); \ + __macro(cublasSgeam); \ + __macro(cublasDgeam); \ __macro(cublasCreate_v2); \ __macro(cublasDestroy_v2); \ __macro(cublasSetStream_v2); \ diff --git a/paddle/fluid/platform/dynload/cudnn.h b/paddle/fluid/platform/dynload/cudnn.h index 24475b62ca2825c45ff7edb39328dece3b822b25..34d83e395694f55eafca74d63ebf363169ab30e8 100644 --- a/paddle/fluid/platform/dynload/cudnn.h +++ b/paddle/fluid/platform/dynload/cudnn.h @@ -34,7 +34,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name); struct DynLoad__##__name { \ template \ auto operator()(Args... args) -> decltype(__name(args...)) { \ - using cudnn_func = decltype(__name(args...)) (*)(Args...); \ + using cudnn_func = decltype(&::__name); \ std::call_once(cudnn_dso_flag, []() { \ cudnn_dso_handle = paddle::platform::dynload::GetCUDNNDsoHandle(); \ }); \ diff --git a/paddle/fluid/platform/dynload/cupti.h b/paddle/fluid/platform/dynload/cupti.h index d0d676b9d8ac462900b48246bec43166d04ef97b..e64de7c20fc9d145e51cfc4528e321b3c4ec86c8 100644 --- a/paddle/fluid/platform/dynload/cupti.h +++ b/paddle/fluid/platform/dynload/cupti.h @@ -41,7 +41,7 @@ extern void *cupti_dso_handle; struct DynLoad__##__name { \ template \ inline CUptiResult CUPTIAPI operator()(Args... args) { \ - typedef CUptiResult CUPTIAPI (*cuptiFunc)(Args...); \ + using cuptiFunc = decltype(&::__name); \ std::call_once(cupti_dso_flag, []() { \ cupti_dso_handle = paddle::platform::dynload::GetCUPTIDsoHandle(); \ }); \ diff --git a/paddle/fluid/platform/dynload/curand.h b/paddle/fluid/platform/dynload/curand.h index 4697fb6cd96770127206bdabeea77e43eb09d1f5..46ad4379d5f9572d415ef1d747077217ae29391e 100644 --- a/paddle/fluid/platform/dynload/curand.h +++ b/paddle/fluid/platform/dynload/curand.h @@ -30,7 +30,7 @@ extern void *curand_dso_handle; struct DynLoad__##__name { \ template \ curandStatus_t operator()(Args... args) { \ - typedef curandStatus_t (*curandFunc)(Args...); \ + using curandFunc = decltype(&::__name); \ std::call_once(curand_dso_flag, []() { \ curand_dso_handle = paddle::platform::dynload::GetCurandDsoHandle(); \ }); \ diff --git a/paddle/fluid/platform/dynload/nccl.h b/paddle/fluid/platform/dynload/nccl.h index c5a10a78a4f432b431680c089f255fea777277cb..37902ae20c5d9d64486232bbd468375c4a50a615 100644 --- a/paddle/fluid/platform/dynload/nccl.h +++ b/paddle/fluid/platform/dynload/nccl.h @@ -33,7 +33,7 @@ extern void* nccl_dso_handle; struct DynLoad__##__name { \ template \ auto operator()(Args... args) -> decltype(__name(args...)) { \ - using nccl_func = decltype(__name(args...)) (*)(Args...); \ + using nccl_func = decltype(&::__name); \ std::call_once(nccl_dso_flag, []() { \ nccl_dso_handle = paddle::platform::dynload::GetNCCLDsoHandle(); \ }); \ diff --git a/paddle/fluid/platform/dynload/warpctc.h b/paddle/fluid/platform/dynload/warpctc.h index 7fa468370463a51c486b80317f401612930bc72e..7c70649d21c547beb824576d4a8ecf6219a9bddf 100644 --- a/paddle/fluid/platform/dynload/warpctc.h +++ b/paddle/fluid/platform/dynload/warpctc.h @@ -36,7 +36,7 @@ extern void* warpctc_dso_handle; struct DynLoad__##__name { \ template \ auto operator()(Args... args) -> decltype(__name(args...)) { \ - using warpctcFunc = decltype(__name(args...)) (*)(Args...); \ + using warpctcFunc = decltype(&::__name); \ std::call_once(warpctc_dso_flag, []() { \ warpctc_dso_handle = paddle::platform::dynload::GetWarpCTCDsoHandle(); \ }); \ diff --git a/paddle/scripts/docker/README.md b/paddle/scripts/README.md similarity index 66% rename from paddle/scripts/docker/README.md rename to paddle/scripts/README.md index 78c0cc378231f763597556cc5450f6f03ab2b291..9e8b135c1bc7fc05d88fe6f3bed17dd3b48e9615 100644 --- a/paddle/scripts/docker/README.md +++ b/paddle/scripts/README.md @@ -13,40 +13,49 @@ We want to make the building procedures: 1. Build docker images with PaddlePaddle pre-installed, so that we can run PaddlePaddle applications directly in docker or on Kubernetes clusters. -To achieve this, we created a repo: https://github.com/PaddlePaddle/buildtools -which gives several docker images that are `manylinux1` sufficient. Then we -can build PaddlePaddle using these images to generate corresponding `whl` -binaries. +To achieve this, we maintain a dockerhub repo:https://hub.docker.com/r/paddlepaddle/paddle +which provides pre-built environment images to build PaddlePaddle and generate corresponding `whl` +binaries.(**We strongly recommend building paddlepaddle in our pre-specified Docker environment.**) -## Run The Build +## Development Workflow + +Here we describe how the workflow goes on. We start from considering our daily development environment. + +Developers work on a computer, which is usually a laptop or desktop: + + + +or, they might rely on a more sophisticated box (like with GPUs): + + + +A principle here is that source code lies on the development computer (host) so that editors like Eclipse can parse the source code to support auto-completion. + +## Build With Docker ### Build Environments -The pre-built build environment images are: +The lastest pre-built build environment images are: | Image | Tag | | ----- | --- | -| paddlepaddle/paddle_manylinux_devel | cuda7.5_cudnn5 | -| paddlepaddle/paddle_manylinux_devel | cuda8.0_cudnn5 | -| paddlepaddle/paddle_manylinux_devel | cuda7.5_cudnn7 | -| paddlepaddle/paddle_manylinux_devel | cuda9.0_cudnn7 | +| paddlepaddle/paddle | latest-dev | +| paddlepaddle/paddle | latest-dev-android | ### Start Build -Choose one docker image that suit your environment and run the following -command to start a build: - ```bash git clone https://github.com/PaddlePaddle/Paddle.git cd Paddle -docker run --rm -v $PWD:/paddle -e "WITH_GPU=OFF" -e "WITH_AVX=ON" -e "WITH_TESTING=OFF" -e "RUN_TEST=OFF" -e "PYTHON_ABI=cp27-cp27mu" paddlepaddle/paddle_manylinux_devel /paddle/paddle/scripts/docker/build.sh +./paddle/scripts/paddle_docker_build.sh build ``` After the build finishes, you can get output `whl` package under `build/python/dist`. -This command mounts the source directory on the host into `/paddle` in the container, then run the build script `/paddle/paddle/scripts/docker/build.sh` -in the container. When it writes to `/paddle/build` in the container, it writes to `$PWD/build` on the host indeed. +This command will download the most recent dev image from docker hub, start a container in the backend and then run the build script `/paddle/paddle/scripts/paddle_build.sh build` in the container. +The container mounts the source directory on the host into `/paddle`. +When it writes to `/paddle/build` in the container, it writes to `$PWD/build` on the host indeed. ### Build Options @@ -68,7 +77,6 @@ Users can specify the following Docker build arguments with either "ON" or "OFF" | `WITH_DOC` | OFF | Build docs after build binaries. | | `WOBOQ` | OFF | Generate WOBOQ code viewer under `build/woboq_out` | - ## Docker Images You can get the latest PaddlePaddle docker images by @@ -144,59 +152,37 @@ docker push kubectl ... ``` -## Docker Images for Developers - -We have a special docker image for developers: -`paddlepaddle/paddle:-dev`. This image is also generated from -https://github.com/PaddlePaddle/buildtools - -This a development image contains only the -development tools and standardizes the building procedure. Users include: - -- developers -- no longer need to install development tools on the host, and can build their current work on the host (development computer). -- release engineers -- use this to build the official release from certain branch/tag on Github.com. -- document writers / Website developers -- Our documents are in the source repo in the form of .md/.rst files and comments in source code. We need tools to extract the information, typeset, and generate Web pages. - -Of course, developers can install building tools on their development computers. But different versions of PaddlePaddle might require different set or version of building tools. Also, it makes collaborative debugging easier if all developers use a unified development environment. - -The development image contains the following tools: - - - gcc/clang - - nvcc - - Python - - sphinx - - woboq - - sshd - -Many developers work on a remote computer with GPU; they could ssh into the computer and `docker exec` into the development container. However, running `sshd` in the container allows developers to ssh into the container directly. - - -### Development Workflow - -Here we describe how the workflow goes on. We start from considering our daily development environment. +### Reading source code with woboq codebrowser -Developers work on a computer, which is usually a laptop or desktop: +For developers who are interested in the C++ source code, you can build C++ source code into HTML pages using [Woboq codebrowser](https://github.com/woboq/woboq_codebrowser). - +- The following command builds PaddlePaddle, generates HTML pages from C++ source code, and writes HTML pages into `$HOME/woboq_out` on the host: -or, they might rely on a more sophisticated box (like with GPUs): +```bash +./paddle/scripts/paddle_docker_build.sh html +``` - +- You can open the generated HTML files in your Web browser. Or, if you want to run a Nginx container to serve them for a wider audience, you can run: -A principle here is that source code lies on the development computer (host) so that editors like Eclipse can parse the source code to support auto-completion. +``` +docker run -v $HOME/woboq_out:/usr/share/nginx/html -d -p 8080:80 nginx +``` -### Reading source code with woboq codebrowser +## More Options -For developers who are interested in the C++ source code, please use -e "WOBOQ=ON" to enable the building of C++ source code into HTML pages using [Woboq codebrowser](https://github.com/woboq/woboq_codebrowser). +### Build Without Docker -- The following command builds PaddlePaddle, generates HTML pages from C++ source code, and writes HTML pages into `$HOME/woboq_out` on the host: +Follow the *Dockerfile* in the paddlepaddle repo to set up your local dev environment and run: ```bash -docker run -v $PWD:/paddle -v $HOME/woboq_out:/woboq_out -e "WITH_GPU=OFF" -e "WITH_AVX=ON" -e "WITH_TESTING=ON" -e "WOBOQ=ON" paddlepaddle/paddle:latest-dev +./paddle/scripts/paddle_build.sh build ``` -- You can open the generated HTML files in your Web browser. Or, if you want to run a Nginx container to serve them for a wider audience, you can run: +### Additional Tasks -``` -docker run -v $HOME/woboq_out:/usr/share/nginx/html -d -p 8080:80 nginx +You can get the help menu for the build scripts by running with no options: + +```bash +./paddle/scripts/paddle_build.sh +or ./paddle/scripts/paddle_docker_build.sh ``` diff --git a/paddle/scripts/docker/doc/paddle-development-environment-gpu.graffle b/paddle/scripts/doc/paddle-development-environment-gpu.graffle similarity index 100% rename from paddle/scripts/docker/doc/paddle-development-environment-gpu.graffle rename to paddle/scripts/doc/paddle-development-environment-gpu.graffle diff --git a/paddle/scripts/docker/doc/paddle-development-environment-gpu.png b/paddle/scripts/doc/paddle-development-environment-gpu.png similarity index 100% rename from paddle/scripts/docker/doc/paddle-development-environment-gpu.png rename to paddle/scripts/doc/paddle-development-environment-gpu.png diff --git a/paddle/scripts/docker/doc/paddle-development-environment.graffle b/paddle/scripts/doc/paddle-development-environment.graffle similarity index 100% rename from paddle/scripts/docker/doc/paddle-development-environment.graffle rename to paddle/scripts/doc/paddle-development-environment.graffle diff --git a/paddle/scripts/docker/doc/paddle-development-environment.png b/paddle/scripts/doc/paddle-development-environment.png similarity index 100% rename from paddle/scripts/docker/doc/paddle-development-environment.png rename to paddle/scripts/doc/paddle-development-environment.png diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh new file mode 100755 index 0000000000000000000000000000000000000000..654c8272a18e5adb01e75be94985a80502ba2c8d --- /dev/null +++ b/paddle/scripts/paddle_build.sh @@ -0,0 +1,508 @@ +#!/usr/bin/env bash + +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +#================================================= +# Utils +#================================================= + +function print_usage() { + RED='\033[0;31m' + BLUE='\033[0;34m' + BOLD='\033[1m' + NONE='\033[0m' + + echo -e "\n${RED}Usage${NONE}: + ${BOLD}$0${NONE} [OPTION]" + + echo -e "\n${RED}Options${NONE}: + ${BLUE}build${NONE}: run build for x86 platform + ${BLUE}build_android${NONE}: run build for android platform + ${BLUE}build_ios${NONE}: run build for ios platform + ${BLUE}test${NONE}: run all unit tests + ${BLUE}bind_test${NONE}: parallel tests bind to different GPU + ${BLUE}doc${NONE}: generate paddle documents + ${BLUE}html${NONE}: convert C++ source code into HTML + ${BLUE}dockerfile${NONE}: generate paddle release dockerfile + ${BLUE}capi${NONE}: generate paddle CAPI package + ${BLUE}fluid_inference_lib${NONE}: deploy fluid inference library + ${BLUE}check_style${NONE}: run code style check + " +} + +function init() { + PADDLE_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}")/../../" && pwd )" +} + +function cmake_gen() { + mkdir -p ${PADDLE_ROOT}/build + cd ${PADDLE_ROOT}/build + + # build script will not fail if *.deb does not exist + rm *.deb 2>/dev/null || true + # delete previous built whl packages + rm -rf python/dist 2>/dev/null || true + + # Support build for all python versions, currently + # including cp27-cp27m and cp27-cp27mu. + PYTHON_FLAGS="" + if [ "$1" != "" ]; then + echo "using python abi: $1" + if [ "$1" == "cp27-cp27m" ]; then + export LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs2/lib:${LD_LIBRARY_PATH#/opt/_internal/cpython-2.7.11-ucs4/lib:} + export PATH=/opt/python/cp27-cp27m/bin/:${PATH} + PYTHON_FLAGS="-DPYTHON_EXECUTABLE:FILEPATH=/opt/python/cp27-cp27m/bin/python + -DPYTHON_INCLUDE_DIR:PATH=/opt/python/cp27-cp27m/include/python2.7 + -DPYTHON_LIBRARIES:FILEPATH=/opt/_internal/cpython-2.7.11-ucs2/lib/libpython2.7.so" + elif [ "$1" == "cp27-cp27mu" ]; then + export LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs4/lib:${LD_LIBRARY_PATH#/opt/_internal/cpython-2.7.11-ucs2/lib:} + export PATH=/opt/python/cp27-cp27mu/bin/:${PATH} + PYTHON_FLAGS="-DPYTHON_EXECUTABLE:FILEPATH=/opt/python/cp27-cp27mu/bin/python + -DPYTHON_INCLUDE_DIR:PATH=/opt/python/cp27-cp27mu/include/python2.7 + -DPYTHON_LIBRARIES:FILEPATH=/opt/_internal/cpython-2.7.11-ucs4/lib/libpython2.7.so" + fi + fi + + cat <&2 + echo "Please use pre-commit to check what is wrong." 1>&2 + exit 1 +} + +function check_style() { + trap 'abort' 0 + set -e + + # install glide + curl https://glide.sh/get | bash + eval "$(GIMME_GO_VERSION=1.8.3 gimme)" + + # set up go environment for running gometalinter + mkdir -p $GOPATH/src/github.com/PaddlePaddle/ + ln -sf ${PADDLE_ROOT} $GOPATH/src/github.com/PaddlePaddle/Paddle + cd $GOPATH/src/github.com/PaddlePaddle/Paddle/go; glide install; cd - + + go get github.com/alecthomas/gometalinter + gometalinter --install + + cd ${PADDLE_ROOT} + export PATH=/usr/bin:$PATH + pre-commit install + clang-format --version + + if ! pre-commit run -a ; then + git diff + exit 1 + fi + + trap : 0 +} + +#================================================= +# Build +#================================================= + +function build() { + mkdir -p ${PADDLE_ROOT}/build + cd ${PADDLE_ROOT}/build + cat <= 21." + ANDROID_API=21 + fi + else # armeabi, armeabi-v7a + ANDROID_ARCH=arm + fi + + ANDROID_STANDALONE_TOOLCHAIN=$ANDROID_TOOLCHAINS_DIR/$ANDROID_ARCH-android-$ANDROID_API + + cat < ${PADDLE_ROOT}/build/Dockerfile < + ENV HOME /root +EOF + + if [[ ${WITH_GPU} == "ON" ]]; then + NCCL_DEPS="apt-get install -y libnccl2=2.1.2-1+cuda8.0 libnccl-dev=2.1.2-1+cuda8.0 &&" + else + NCCL_DEPS="" + fi + + if [[ ${WITH_FLUID_ONLY:-OFF} == "OFF" ]]; then + PADDLE_VERSION="paddle version" + CMD='"paddle", "version"' + else + PADDLE_VERSION="true" + CMD='"true"' + fi + + cat >> /paddle/build/Dockerfile < /dev/null + return $? +} + +function start_build_docker() { + docker pull $IMG + + if container_running "${CONTAINER_ID}"; then + docker stop "${CONTAINER_ID}" 1>/dev/null + docker rm -f "${CONTAINER_ID}" 1>/dev/null + fi + + DOCKER_ENV=$(cat <