diff --git a/CMakeLists.txt b/CMakeLists.txt index c2fa5420e916fd5958f6198d6e97c9b1092b5aa1..d43df124bdee2d568a0c09d5acd35d5ff96f4654 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -69,6 +69,7 @@ option(WITH_ANAKIN "Compile with Anakin library" OFF) option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE}) option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF) option(WITH_INFERENCE "Compile fluid inference library" ON) +option(WITH_INFERENCE_API_TEST "Test fluid inference high-level api interface" OFF) option(WITH_SYSTEM_BLAS "Use system blas library" OFF) option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION}) diff --git a/doc/fluid/dev/releasing_process_en.md b/doc/fluid/dev/releasing_process_en.md index b810dc941d27fdb5004812ab58e105502e83280f..00650946ff2e658cfad0e63a8f1e008902a2d36e 100644 --- a/doc/fluid/dev/releasing_process_en.md +++ b/doc/fluid/dev/releasing_process_en.md @@ -1,6 +1,6 @@ # PaddlePaddle Releasing Process -PaddlePaddle manages its branches using "git-flow branching model", and [Semantic Versioning](http://semver.org/) as it's version number semantics. +PaddlePaddle manages its branches using Trunk Based Development, and [Semantic Versioning](http://semver.org/) as it's version number semantics. Each time we release a new PaddlePaddle version, we should follow the below steps: diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index abd5459f6d47da6d1341284916b419325dc5977c..a8e0c4a3fedfd56e38de7568be6b3f2e76a4b25f 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -28,10 +28,20 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope) +if(WITH_GPU) + cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle + all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass) +endif() + cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle) -cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto) +if(WITH_GPU) + cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass) +else() + cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto) +endif() + cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) diff --git a/paddle/fluid/framework/details/computation_op_handle.h b/paddle/fluid/framework/details/computation_op_handle.h index d9fcd92427ef38b131b4ce782c0ada37765682db..e98f1ab148db083ac63a1afd43e334fbfae62539 100644 --- a/paddle/fluid/framework/details/computation_op_handle.h +++ b/paddle/fluid/framework/details/computation_op_handle.h @@ -32,6 +32,10 @@ struct ComputationOpHandle : public OpHandleBase { std::string Name() const override; + const Scope *GetScope() const { return scope_; } + + const platform::Place &GetPlace() const { return place_; } + protected: void RunImpl() override; diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 250e093a5f789dba6b06df4889c060c294d469fe..8f319116ab80b75c624f35b0e1315e7362e88d9a 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -127,6 +127,9 @@ static const char kLocalScopes[] = "local_scopes"; static const char kStrategy[] = "strategy"; void MultiDevSSAGraphBuilder::Init() const { + all_vars_.clear(); + balance_vars_.clear(); + loss_var_name_ = Get(kLossVarName); places_ = Get>(kPlaces); local_scopes_ = Get>(kLocalScopes); diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h index 1ca8c4b855f9468589e537245380451a91a50b14..47aaa80f4d66a48b729d0638badcab885a50585c 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -40,12 +40,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass { size_t device_id) const; void Init() const; - private: - mutable std::string loss_var_name_; - mutable std::vector places_; - mutable std::vector local_scopes_; - mutable std::unordered_set grad_names_; - #ifdef PADDLE_WITH_CUDA mutable platform::NCCLContextMap *nccl_ctxs_; #endif @@ -95,13 +89,17 @@ class MultiDevSSAGraphBuilder : public ir::Pass { size_t GetAppropriateDeviceID( const std::vector &var_names) const; - private: + void SetCommunicationContext(OpHandleBase *op_handle, + const platform::Place &p) const; + + mutable std::string loss_var_name_; + mutable std::vector places_; + mutable std::vector local_scopes_; + mutable std::unordered_set grad_names_; + mutable BuildStrategy strategy_; mutable std::unordered_map all_vars_; mutable std::vector balance_vars_; - - void SetCommunicationContext(OpHandleBase *op_handle, - const platform::Place &p) const; }; } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/reference_count_op_handle.h b/paddle/fluid/framework/details/reference_count_op_handle.h new file mode 100644 index 0000000000000000000000000000000000000000..71db8d952f4c205b875ad254dc19c0c1f74e61b3 --- /dev/null +++ b/paddle/fluid/framework/details/reference_count_op_handle.h @@ -0,0 +1,123 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/garbage_collector.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor.h" + +namespace paddle { +namespace framework { +namespace details { + +using ReferenceCountMap = std::unordered_map; +using AtomicReferenceCountMap = + std::unordered_map>; +using DeviceReferenceCountMap = + std::unordered_map>; +using AtomicDeviceReferenceCountMap = + std::unordered_map>; +using DeviceGarbageCollectorMap = + std::unordered_map>>; + +class ReferenceCountOpHandle : public OpHandleBase { + public: + ReferenceCountOpHandle(ir::Node *node, const Scope *scope, + const platform::CUDAPlace &place, + const std::vector &var_names, + GarbageCollector *gc, + AtomicReferenceCountMap *ref_cnts) + : OpHandleBase(node), + scope_(scope), + var_names_(var_names), + gc_(gc), + ref_cnts_(ref_cnts) { + dev_ctx_ = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + if (IsStreamGarabageCollector()) { + PADDLE_ENFORCE(cudaSetDevice(place.device)); + PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); + } + } + + ~ReferenceCountOpHandle() { + if (IsStreamGarabageCollector()) { + auto gpu_place = boost::get(dev_ctx_->GetPlace()); + PADDLE_ENFORCE(cudaSetDevice(gpu_place.device)); + PADDLE_ENFORCE(cudaEventDestroy(event_)); + } + } + + std::string Name() const override { return "reference_count"; } + + protected: + void RunImpl() override { + auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get(); + std::vector tensors; + for (auto &name : var_names_) { + auto it = ref_cnts_->find(name); + if (it == ref_cnts_->end()) continue; + + auto *var = exec_scope->FindVar(name); + if (var == nullptr || !var->IsType()) continue; + + if (it->second.fetch_sub(1) <= 1) { + tensors.emplace_back(var->GetMutable()); + } + } + + if (!tensors.empty()) { + ClearTensors(tensors); + } + } + + private: + void ClearTensors(const std::vector &tensors) { + auto *gc = dynamic_cast *>(gc_); + if (gc != nullptr) { + auto compute_stream = dev_ctx_->stream(); + auto callback_stream = gc->stream(); + auto callback_func = [=]() { + PADDLE_ENFORCE(cudaEventRecord(event_, compute_stream)); + PADDLE_ENFORCE(cudaStreamWaitEvent(callback_stream, event_, 0)); + }; + gc_->Add(tensors, callback_func); + } else { + gc_->Add(tensors); + } + } + + bool IsStreamGarabageCollector() const { + return dynamic_cast *>(gc_) != nullptr; + } + + const Scope *scope_; + platform::CUDADeviceContext *dev_ctx_; + std::vector var_names_; + GarbageCollector *gc_; // not own + AtomicReferenceCountMap *ref_cnts_; // not own + cudaEvent_t event_; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/reference_count_pass.cc b/paddle/fluid/framework/details/reference_count_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..344754d5a1e119c04cae08ad50126924b5824315 --- /dev/null +++ b/paddle/fluid/framework/details/reference_count_pass.cc @@ -0,0 +1,150 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/details/reference_count_pass.h" + +namespace paddle { +namespace framework { +namespace details { + +std::unique_ptr ReferenceCountPass::ApplyImpl( + std::unique_ptr graph) const { + auto &ref_cnts = Get(kGlobalReferenceCount); + auto &cur_ref_cnts = Get(kCurReferenceCount); + auto &gcs = Get(kGarbageCollector); + + // It is not easy to find the right reference counts of varaibles in graph + // Step 1: Find all variables in computation ops + // Step 2: Find all variables in non-computation ops which refers to variables + // in computation ops + std::unordered_set names; + auto get_ref_cnts_from_compute_op = [&]( + const std::unique_ptr &op, + const std::vector &vars) { + std::vector var_names_in_op; + auto *compute_op = dynamic_cast(op.get()); + if (compute_op == nullptr || + !platform::is_gpu_place(compute_op->GetPlace())) + return var_names_in_op; + auto place = boost::get(compute_op->GetPlace()); + for (VarHandleBase *var_handle_base : vars) { + auto *var_handle = dynamic_cast(var_handle_base); + if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue; + + if (!platform::is_gpu_place(var_handle->place_) || + boost::get(var_handle->place_) != place) + continue; + + VarDesc *var_desc = var_handle->Node()->Var(); + auto var_name = var_handle->Node()->Name(); + + // This is wierd but there is really some variables without var_desc + // in computation_op + if (var_desc == nullptr) { + if (compute_op->Node()->Op()->Block()->FindVar(var_name) == nullptr) + continue; + } else { + if (var_desc->Persistable() || + var_desc->Proto()->type().type() != proto::VarType::LOD_TENSOR) + continue; + } + + // compute op only runs in one device + if (ref_cnts[place.device]->count(var_name)) + ++(*ref_cnts[place.device])[var_name]; + else + (*ref_cnts[place.device])[var_name] = 1; + + names.insert(var_name); + var_names_in_op.push_back(var_name); + } + return var_names_in_op; + }; + + auto update_ref_cnts_from_non_compute_op = [&]( + const std::unique_ptr &op, + const std::vector &vars) { + if (dynamic_cast(op.get()) != nullptr) return; + for (VarHandleBase *var_handle_base : vars) { + auto *var_handle = dynamic_cast(var_handle_base); + if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue; + + auto var_name = var_handle->Node()->Name(); + auto var_place = var_handle->place_; + if (!platform::is_gpu_place(var_place)) continue; + auto place = boost::get(var_place); + if (names.count(var_name) == 0) continue; + if (ref_cnts.count(place.device) && + ref_cnts[place.device]->count(var_name)) { + ++(*ref_cnts[place.device])[var_name]; + } + } + }; + + std::unordered_map + compute_ref_cnt_map; + auto &all_ops = graph->Get(kGraphOps); + for (auto &op : all_ops) { + auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs()); + auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs()); + if (in_var_names.empty() && out_var_names.empty()) continue; + in_var_names.insert(in_var_names.end(), out_var_names.begin(), + out_var_names.end()); + auto *compute_op = dynamic_cast(op.get()); + auto place = boost::get(compute_op->GetPlace()); + ir::Node *ref_cnt_node = + graph->CreateEmptyNode("reference_count", ir::Node::Type::kOperation); + auto *ref_cnt_handle = new ReferenceCountOpHandle( + ref_cnt_node, compute_op->GetScope(), place, in_var_names, + gcs[place.device].get(), cur_ref_cnts[place.device].get()); + auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); + compute_op->AddOutput(dep_var); + ref_cnt_handle->AddInput(dep_var); + graph->Get(kGraphDepVars).emplace(dep_var); + compute_ref_cnt_map[compute_op] = ref_cnt_handle; + } + + for (auto &op : all_ops) { + update_ref_cnts_from_non_compute_op(op, op->Inputs()); + update_ref_cnts_from_non_compute_op(op, op->Outputs()); + } + + std::vector> new_all_ops; + new_all_ops.reserve(compute_ref_cnt_map.size() + all_ops.size()); + for (auto &op : all_ops) { + new_all_ops.emplace_back(std::move(op)); + auto it = compute_ref_cnt_map.find(new_all_ops.back().get()); + if (it != compute_ref_cnt_map.end()) { + new_all_ops.emplace_back(it->second); + } + } + + all_ops.swap(new_all_ops); + return graph; +} + +} // namespace details +} // namespace framework +} // namespace paddle + +REGISTER_PASS(reference_count_pass, + paddle::framework::details::ReferenceCountPass) + .RequirePassAttr(paddle::framework::details::kGlobalReferenceCount) + .RequirePassAttr(paddle::framework::details::kCurReferenceCount) + .RequirePassAttr(paddle::framework::details::kGarbageCollector); diff --git a/paddle/fluid/framework/details/reference_count_pass.h b/paddle/fluid/framework/details/reference_count_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..7081280b0600b9c1985987d02d679c298ad4b8bd --- /dev/null +++ b/paddle/fluid/framework/details/reference_count_pass.h @@ -0,0 +1,37 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/details/reference_count_op_handle.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace details { + +constexpr char kGlobalReferenceCount[] = "reference_count"; +constexpr char kCurReferenceCount[] = "current_reference_count"; +constexpr char kGarbageCollector[] = "garbage_collector"; + +class ReferenceCountPass : public ir::Pass { + protected: + std::unique_ptr ApplyImpl( + std::unique_ptr graph) const override; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc index 5bd974d6b789a2f085c0a69de5e133187342f587..e5b1eaa7318aecde1dbf89de8fe242a3008db97c 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc @@ -18,6 +18,9 @@ #include #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/platform/profiler.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/framework/details/reference_count_op_handle.h" +#endif namespace paddle { namespace framework { @@ -65,12 +68,28 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( platform::RecordEvent e("ScopeBufferedSSAGraphExecutorAfterRun", nullptr); drop_scope_counter_ += 1; + +#ifdef PADDLE_WITH_CUDA + const std::string gc_name = "garbage_collector"; + DeviceGarbageCollectorMap *gc = + Graph().Has(gc_name) ? &(Graph().Get(gc_name)) + : nullptr; +#endif + if (!fetch_tensors.empty() || drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) { drop_scope_counter_ = 0; // Wait All computational streams for (auto p : places_) { platform::DeviceContextPool::Instance().Get(p)->Wait(); +#ifdef PADDLE_WITH_CUDA + if (gc != nullptr && platform::is_gpu_place(p)) { + auto gpu_place = boost::get(p); + auto &gc_at_place = gc->at(gpu_place.device); + gc_at_place->Wait(); + gc_at_place->Reset(); + } +#endif } for (auto &scope : local_scopes_) { auto &local_scope = diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index dad170ed78c64202b5c812bd8682887fe3b736d6..8d8042a0563a21dad216ffd53a474322c378ace6 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -37,7 +37,11 @@ int kProgramId = -1; ExecutorPrepareContext::ExecutorPrepareContext( const framework::ProgramDesc& prog, size_t block_id) - : prog_(prog), block_id_(block_id) {} + : prog_(prog), block_id_(block_id) { + if (GetEagerDeletionThreshold() >= 0) { + ref_cnts_ = GetNonPersistableReferenceCount(prog_, block_id_); + } +} ExecutorPrepareContext::~ExecutorPrepareContext() { VLOG(5) << "destroy ExecutorPrepareContext"; @@ -329,15 +333,81 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, CreateVariables(ctx->prog_, local_scope, ctx->block_id_); } + int64_t max_memory_size = GetEagerDeletionThreshold(); + + std::unique_ptr> gc; + if (max_memory_size >= 0) { + ctx->ResetReferenceCount(); +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(place_)) { + gc.reset(new DefaultStreamGarbageCollector( + boost::get(place_), max_memory_size)); + } else { +#endif + gc.reset(new CPUGarbageCollector( + boost::get(place_), max_memory_size)); +#ifdef PADDLE_WITH_CUDA + } +#endif + } + for (auto& op : ctx->ops_) { op->Run(*local_scope, place_); + if (gc != nullptr) { + std::vector erase_vars; + for (auto& input : op->Inputs()) { + for (auto& input_name : input.second) { + auto it = ctx->cur_ref_cnts_.find(input_name); + if (it == ctx->cur_ref_cnts_.end()) continue; + if (it->second == 1) { // should delete it + erase_vars.emplace_back(input_name); + ctx->cur_ref_cnts_.erase(input_name); + } else { + --(it->second); + } + } + } + + for (auto& output : op->Outputs()) { + for (auto& output_name : output.second) { + auto it = ctx->cur_ref_cnts_.find(output_name); + if (it == ctx->cur_ref_cnts_.end()) continue; + if (it->second == 1) { + erase_vars.emplace_back(output_name); + ctx->cur_ref_cnts_.erase(output_name); + } else { + --(it->second); + } + } + } + + if (!erase_vars.empty()) { + std::vector erase_tensors; + for (auto& name : erase_vars) { + auto* var = local_scope->FindVar(name); + if (var == nullptr) continue; + if (var->IsType()) { + auto* tensor = var->GetMutable(); + erase_tensors.push_back(tensor); + } + } + if (!erase_tensors.empty()) gc->Add(erase_tensors); + } + } + if (FLAGS_benchmark) { VLOG(2) << "Memory used after operator " + op->Type() + " running: " << memory::memory_usage(place_); } } - platform::DeviceContextPool::Instance().Get(place_)->Wait(); + + if (gc != nullptr) { + gc->Wait(); + } else { + platform::DeviceContextPool::Instance().Get(place_)->Wait(); + } + if (local_scope != scope) { scope->DeleteScope(local_scope); } else { diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index f95808c199b9de693ec653c29374c9130be7fd59..f0cc1338a8af50030a70a9797cbcd1b0567272b5 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" @@ -27,13 +28,58 @@ namespace paddle { namespace framework { extern void InitializeVariable(Variable* var, proto::VarType::Type var_type); +template +std::unordered_map GetNonPersistableReferenceCount( + const ProgramDesc& prog, size_t block_id) { + auto& block = prog.Block(block_id); + std::unordered_set ignored_vars; + std::unordered_map ref_cnts; + + for (auto var_desc : block.AllVars()) { + auto type = var_desc->Proto()->type().type(); + if (type != proto::VarType::LOD_TENSOR || var_desc->Persistable()) { + ignored_vars.insert(var_desc->Name()); // ignore persistable vars + } + } + + for (auto op_desc : block.AllOps()) { + for (auto& input : op_desc->Inputs()) { + for (auto& input_name : input.second) { + if (!ignored_vars.count(input_name)) { + if (ref_cnts.count(input_name)) + ++ref_cnts[input_name]; + else + ref_cnts[input_name] = 1; + } + } + } + + for (auto& output : op_desc->Outputs()) { + for (auto output_name : output.second) { + if (!ignored_vars.count(output_name)) { + if (ref_cnts.count(output_name)) + ++ref_cnts[output_name]; + else + ref_cnts[output_name] = 1; + } + } + } + } + return ref_cnts; +} + struct ExecutorPrepareContext { ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id); ~ExecutorPrepareContext(); + void ResetReferenceCount() { cur_ref_cnts_ = ref_cnts_; } + const framework::ProgramDesc& prog_; size_t block_id_; std::vector> ops_; + + std::unordered_map ref_cnts_; + std::unordered_map cur_ref_cnts_; }; class Executor { diff --git a/paddle/fluid/framework/garbage_collector.h b/paddle/fluid/framework/garbage_collector.h new file mode 100644 index 0000000000000000000000000000000000000000..b403252c972d26da6deeca54ce88a9547ffe7afa --- /dev/null +++ b/paddle/fluid/framework/garbage_collector.h @@ -0,0 +1,163 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include // NOLINT +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace framework { + +// T should have memory_size() and clear() method +template +class GarbageCollector { + public: + GarbageCollector(const platform::Place &place, size_t max_memory_size) + : max_memory_size_(std::max(max_memory_size, static_cast(1))) { + garbages_.reset(new std::deque()); + dev_ctx_ = platform::DeviceContextPool::Instance().Get(place); + } + + virtual ~GarbageCollector() {} + + void Reset() { + std::lock_guard guard(mutex_); + garbages_.reset(new std::deque()); + cur_memory_size_ = 0; + } + + template + void Add(const Container &objs) { + Add(objs, []() {}); + } + + template + void Add(const Container &objs, Callback &&callback) { + std::shared_ptr> clear_deque; + { + std::lock_guard guard(mutex_); + for (auto *obj : objs) { + garbages_->push_back(obj); + cur_memory_size_ += obj->memory_size(); + } + if (cur_memory_size_ >= max_memory_size_) { + cur_memory_size_ = 0; + clear_deque = garbages_; + garbages_.reset(new std::deque()); + } + } + + if (clear_deque != nullptr) { + callback(); + ClearCallback([=]() { + for (auto *obj : *clear_deque) obj->clear(); + }); + } + } + + virtual void Wait() const {} + + protected: + virtual void ClearCallback(const std::function &callback) = 0; + + platform::DeviceContext *dev_ctx_; + std::shared_ptr> garbages_; + mutable std::mutex mutex_; + const size_t max_memory_size_; + size_t cur_memory_size_ = 0; +}; + +template +class CPUGarbageCollector : public GarbageCollector { + public: + CPUGarbageCollector(const platform::CPUPlace &place, size_t max_memory_size) + : GarbageCollector(place, max_memory_size) {} + + protected: + void ClearCallback(const std::function &callback) override { + callback(); + } +}; + +#ifdef PADDLE_WITH_CUDA +template +class DefaultStreamGarbageCollector : public GarbageCollector { + public: + DefaultStreamGarbageCollector(const platform::CUDAPlace &place, + size_t max_memory_size) + : GarbageCollector(place, max_memory_size) {} + + cudaStream_t stream() const { + return static_cast(this->dev_ctx_) + ->stream(); + } + + void Wait() const override { + this->dev_ctx_->Wait(); + static_cast(this->dev_ctx_) + ->WaitStreamCallback(); + } + + protected: + void ClearCallback(const std::function &callback) override { + static_cast(this->dev_ctx_) + ->AddStreamCallback(callback); + } +}; + +template +class StreamGarbageCollector : public GarbageCollector { + public: + StreamGarbageCollector(const platform::CUDAPlace &place, + size_t max_memory_size) + : GarbageCollector(place, max_memory_size) { + PADDLE_ENFORCE(cudaSetDevice(place.device)); + PADDLE_ENFORCE(cudaStreamCreate(&stream_)); + callback_manager_.reset(new platform::StreamCallbackManager(stream_)); + } + + ~StreamGarbageCollector() { + auto place = boost::get(this->dev_ctx_->GetPlace()); + PADDLE_ENFORCE(cudaSetDevice(place.device)); + PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); + PADDLE_ENFORCE(cudaStreamDestroy(stream_)); + } + + void Wait() const override { + PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); + std::lock_guard guard(this->mutex_); + callback_manager_->Wait(); + } + + cudaStream_t stream() const { return stream_; } + + protected: + void ClearCallback(const std::function &callback) override { + std::lock_guard guard(this->mutex_); + callback_manager_->AddCallback(callback); + } + + private: + cudaStream_t stream_; + std::unique_ptr callback_manager_; +}; +#endif + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index ae8496204d4aeb88c04154d571325d440274e821..ab687e760a761d4e445726bd5149966adc2403d0 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -94,6 +94,14 @@ class Graph { }; } + template + void SetNotOwned(const std::string &attr_name, AttrType *attr) { + PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the graph", + attr_name); + attrs_[attr_name] = attr; + attr_dels_[attr_name] = []() {}; + } + const std::unordered_set &Nodes() const { return node_set_; } // Create a normal variable with non-null VarDesc. diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 5b8c75a93de2ddd8f7260d2191c22a5945b3d2d9..dbc3ff8657a1f2238951a791fb5ac3356c885770 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -188,6 +188,30 @@ ParallelExecutor::ParallelExecutor( main_program, member_->places_, loss_var_name, params, member_->local_scopes_, member_->use_cuda_, build_strategy, member_->nccl_ctxs_.get()); + + auto max_memory_size = GetEagerDeletionThreshold(); + if (max_memory_size >= 0) { + for (auto &place : member_->places_) { + if (!platform::is_gpu_place(place)) continue; + auto gpu_place = boost::get(place); + if (gcs_[gpu_place.device] == nullptr) { + ref_cnts_[gpu_place.device].reset(new details::ReferenceCountMap()); + cur_ref_cnts_[gpu_place.device].reset( + new details::AtomicReferenceCountMap()); + gcs_[gpu_place.device].reset( + new StreamGarbageCollector(gpu_place, max_memory_size)); + } + } + if (!gcs_.empty()) { + auto ref_cnt_pass = + ir::PassRegistry::Instance().Get("reference_count_pass"); + ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount, &ref_cnts_); + ref_cnt_pass->SetNotOwned(details::kCurReferenceCount, &cur_ref_cnts_); + ref_cnt_pass->SetNotOwned(details::kGarbageCollector, &gcs_); + graph = ref_cnt_pass->Apply(std::move(graph)); + graph->SetNotOwned("garbage_collector", &gcs_); + } + } #else std::unique_ptr graph = ApplyParallelExecutorPass( main_program, member_->places_, loss_var_name, params, @@ -209,30 +233,9 @@ ParallelExecutor::ParallelExecutor( void ParallelExecutor::BCastParamsToDevices( const std::unordered_set &vars) const { - // the initializing bcast, all vars would be bcast from device(0), - // otherwise - // bcast from the specified device. - bool initializing = member_->executor_ ? false : true; + // the initializing bcast, all vars would be bcast from device(0). for (auto &var : vars) { - int var_dev_id = -1; - if (member_->executor_) { - auto &sharded_var_device = - member_->executor_->Graph().Get( - details::kShardedVarDevice); - if (sharded_var_device.find(var) != sharded_var_device.end()) { - var_dev_id = sharded_var_device.at(var); - } - } - - if (!initializing && var_dev_id == -1) continue; - - framework::Variable *main_var = nullptr; - if (initializing) { - main_var = member_->local_scopes_[0]->FindVar(var); - } else { - main_var = member_->local_scopes_[var_dev_id]->FindVar(var); - } - + framework::Variable *main_var = member_->local_scopes_[0]->FindVar(var); if (main_var == nullptr || !main_var->IsType()) { continue; } @@ -248,8 +251,7 @@ void ParallelExecutor::BCastParamsToDevices( auto place = member_->places_[i]; void *buffer; - if ((initializing && i == 0) || - (!initializing && static_cast(i) == var_dev_id)) { + if (i == 0) { buffer = const_cast(main_tensor.data()); } else { auto local_scope = member_->local_scopes_[i]; @@ -266,29 +268,18 @@ void ParallelExecutor::BCastParamsToDevices( platform::NCCLGroupGuard guard; for (size_t i = 0; i < member_->places_.size(); ++i) { auto &nccl_ctx = member_->nccl_ctxs_->at(member_->places_[i]); - if (initializing) { - platform::dynload::ncclBcast(buffers[i], numel, data_type, 0, - nccl_ctx.comm_, nccl_ctx.stream()); - } else { - if (var_dev_id >= 0) { - platform::dynload::ncclBcast(buffers[i], numel, data_type, - var_dev_id, nccl_ctx.comm_, - nccl_ctx.stream()); - } - } + platform::dynload::ncclBcast(buffers[i], numel, data_type, 0, + nccl_ctx.comm_, nccl_ctx.stream()); } member_->nccl_ctxs_->WaitAll(); } - #else PADDLE_THROW("Not compiled with CUDA"); #endif } else { platform::CPUPlace cpu; for (size_t i = 0; i < member_->places_.size(); ++i) { - if ((initializing && i == 0) || - (!initializing && static_cast(i) == var_dev_id)) - continue; + if (i == 0) continue; auto local_scope = member_->local_scopes_[i]; auto *t = local_scope->Var(var)->GetMutable(); @@ -310,6 +301,11 @@ void ParallelExecutor::BCastParamsToDevices( void ParallelExecutor::Run(const std::vector &fetch_tensors, const std::string &fetched_var_name) { platform::RecordBlock b(0); +#ifdef PADDLE_WITH_CUDA + if (!gcs_.empty()) { + ResetReferenceCount(); + } +#endif auto fetch_data = member_->executor_->Run(fetch_tensors); *member_->global_scope_->Var(fetched_var_name)->GetMutable() = fetch_data; @@ -367,3 +363,6 @@ USE_PASS(graph_viz_pass); USE_PASS(multi_devices_pass); USE_PASS(multi_devices_check_pass); USE_PASS(multi_devices_print_pass); +#ifdef PADDLE_WITH_CUDA +USE_PASS(reference_count_pass); +#endif diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index 5fb748fa205d5e9dbd2943b615c69aedd0e7a26f..c64906ff230df5f2b7cc9f5c6b29d68956ab8f33 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -15,7 +15,9 @@ limitations under the License. */ #pragma once #include +#include #include +#include #include #include #include "paddle/fluid/framework/details/execution_strategy.h" @@ -27,6 +29,10 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/device_context.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/framework/details/reference_count_pass.h" +#endif + namespace paddle { namespace framework { @@ -66,10 +72,27 @@ class ParallelExecutor { void Run(const std::vector &fetch_tensors, const std::string &fetched_var_name); + private: void BCastParamsToDevices(const std::unordered_set &vars) const; - private: ParallelExecutorPrivate *member_; + +#ifdef PADDLE_WITH_CUDA + // ref_cnts_ is only initialized when ParallelExecutor constructs, and then + // keeps unchanged + // Before each iteration, cur_ref_cnts_ is reset to ref_cnts_ + details::DeviceReferenceCountMap ref_cnts_; + details::AtomicDeviceReferenceCountMap cur_ref_cnts_; + details::DeviceGarbageCollectorMap gcs_; + + void ResetReferenceCount() { + for (auto &pair1 : ref_cnts_) { + for (auto &pair2 : *(pair1.second)) { + (*(cur_ref_cnts_[pair1.first]))[pair2.first] = pair2.second; + } + } + } +#endif }; } // namespace framework diff --git a/paddle/fluid/framework/scope.cc b/paddle/fluid/framework/scope.cc index 2be655b89a4caf2bf9874dcab6bc0bdb2856a026..1a727a2c8c759d010606d5b605823b7252b35c69 100644 --- a/paddle/fluid/framework/scope.cc +++ b/paddle/fluid/framework/scope.cc @@ -31,9 +31,21 @@ DEFINE_bool( "Delete local scope eagerly. It will reduce GPU memory usage but " "slow down the destruction of variables.(around 1% performance harm)"); +DEFINE_double( + eager_delete_tensor_gb, -1.0, + "Memory size threshold (GB) when the garbage collector clear tensors." + "Disabled when this value is less than 0"); + namespace paddle { namespace framework { +int64_t GetEagerDeletionThreshold() { + return FLAGS_eager_delete_tensor_gb < 0 + ? -1 + : static_cast(FLAGS_eager_delete_tensor_gb * + (static_cast(1) << 30)); +} + Scope::~Scope() { DropKids(); } Scope& Scope::NewScope() const { diff --git a/paddle/fluid/framework/scope.h b/paddle/fluid/framework/scope.h index b6165a595d537c314a95685e8b1edbc42e387ab7..e42fff1d79d92fb7ed61768a614d8cd98f6775a0 100644 --- a/paddle/fluid/framework/scope.h +++ b/paddle/fluid/framework/scope.h @@ -26,6 +26,8 @@ limitations under the License. */ namespace paddle { namespace framework { +int64_t GetEagerDeletionThreshold(); + class Scope; /** diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index 4cf95fa0ae07823289fbf337062190f05e6c6bcf..f1d268548578fea12082e2edb213a3749eccbfaf 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -151,6 +151,8 @@ class Tensor { void set_layout(const DataLayout layout) { layout_ = layout; } + void clear() { holder_ = nullptr; } + private: /** * @note Placeholder hides type T, so it doesn't appear as a template diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index efb91bcf75a3cb99a67d5a3251b1d42fc4b04170..6698efd1fa773127a84b4bcb28f57f4226dd7ae2 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -17,9 +17,7 @@ get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES) # paddle_fluid_origin exclude inference api interface cc_library(paddle_fluid_origin DEPS ${fluid_modules} paddle_fluid_api) -#if(APPLE) - add_subdirectory(api) -#endif() +add_subdirectory(api) # Create static library cc_library(paddle_fluid DEPS ${fluid_modules} paddle_fluid_api paddle_inference_api analysis_predictor) @@ -57,5 +55,7 @@ endif() if(WITH_TESTING) # tests/book depends the models that generated by python/paddle/fluid/tests/book add_subdirectory(tests/book) - add_subdirectory(tests/api) + if(WITH_INFERENCE_API_TEST) + add_subdirectory(tests/api) + endif() endif() diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc index 6dc39cae0522efd48c2e2921611adebd6937ddf7..8a8aeb5e09a0d9a6746f6d6d61c547363e0e2d30 100644 --- a/paddle/fluid/inference/analysis/analyzer.cc +++ b/paddle/fluid/inference/analysis/analyzer.cc @@ -69,8 +69,9 @@ class DfgPassManagerImpl final : public DfgPassManager { if (FLAGS_IA_enable_tensorrt_subgraph_engine) { auto trt_teller = [&](const Node* node) { std::unordered_set teller_set( - {"elementwise_add", "mul", "conv2d", "pool2d", "relu", "softmax", - "depthwise_conv2d", "batch_norm", "concat"}); + {"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", + "depthwise_conv2d", "batch_norm", "concat", "tanh", + "elementwise_add", "dropout"}); if (!node->IsFunction()) return false; const auto* func = static_cast(node); diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt index 5df486f345a98d7737d326c94e4854d24535ff61..e569df94c54c304852dab7c7496804c1b08d665c 100644 --- a/paddle/fluid/inference/api/CMakeLists.txt +++ b/paddle/fluid/inference/api/CMakeLists.txt @@ -69,25 +69,4 @@ if (WITH_ANAKIN AND WITH_MKL) # only needed in CI endfunction() anakin_target(inference_anakin_api) anakin_target(inference_anakin_api_shared) - if (WITH_TESTING) - # TODO(luotao): ANAKIN_MODLE_URL etc will move to demo ci later. - set(INFERENCE_URL "http://paddle-inference-dist.bj.bcebos.com") - set(ANAKIN_RNN_MODLE_URL "${INFERENCE_URL}/anakin_test%2Fditu_rnn.anakin2.model.bin") - set(ANAKIN_RNN_DATA_URL "${INFERENCE_URL}/anakin_test%2Fditu_rnn_data.txt") - execute_process(COMMAND bash -c "mkdir -p ${ANAKIN_SOURCE_DIR}") - execute_process(COMMAND bash -c "cd ${ANAKIN_SOURCE_DIR}; wget -q --no-check-certificate ${ANAKIN_RNN_MODLE_URL} -N") - execute_process(COMMAND bash -c "cd ${ANAKIN_SOURCE_DIR}; wget -q --no-check-certificate ${ANAKIN_RNN_DATA_URL} -N") - if(WITH_GPU) - set(anakin_test_extra_deps dynload_cuda) - set(ANAKIN_MODLE_URL "${INFERENCE_URL}/mobilenet_v2.anakin.bin") - execute_process(COMMAND bash -c "cd ${ANAKIN_SOURCE_DIR}; wget -q --no-check-certificate ${ANAKIN_MODLE_URL} -N") - cc_test(api_anakin_engine_tester SRCS api_anakin_engine_tester.cc - ARGS --model=${ANAKIN_SOURCE_DIR}/mobilenet_v2.anakin.bin - DEPS inference_anakin_api_shared ${anakin_test_extra_deps} SERIAL) - endif() - cc_test(api_anakin_engine_rnn_tester SRCS api_anakin_engine_rnn_tester.cc - ARGS --model=${ANAKIN_SOURCE_DIR}/anakin_test%2Fditu_rnn.anakin2.model.bin - --datapath=${ANAKIN_SOURCE_DIR}/anakin_test%2Fditu_rnn_data.txt - DEPS inference_anakin_api_shared ${anakin_test_extra_deps} SERIAL) - endif(WITH_TESTING) endif() diff --git a/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc b/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc index abee375313850f1490bacec11f737706c061a5e9..d9d6e139b8735c8f07c52f63c70b6b9805e03642 100644 --- a/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc +++ b/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc @@ -153,11 +153,21 @@ CreatePaddlePredictor( } // namespace paddle USE_TRT_CONVERTER(elementwise_add_weight); +USE_TRT_CONVERTER(elementwise_add_tensor); +USE_TRT_CONVERTER(elementwise_sub_tensor); +USE_TRT_CONVERTER(elementwise_div_tensor); +USE_TRT_CONVERTER(elementwise_mul_tensor); +USE_TRT_CONVERTER(elementwise_max_tensor); +USE_TRT_CONVERTER(elementwise_min_tensor); +USE_TRT_CONVERTER(elementwise_pow_tensor); USE_TRT_CONVERTER(mul); USE_TRT_CONVERTER(conv2d); USE_TRT_CONVERTER(relu); +USE_TRT_CONVERTER(sigmoid); +USE_TRT_CONVERTER(tanh); USE_TRT_CONVERTER(fc); USE_TRT_CONVERTER(pool2d); USE_TRT_CONVERTER(softmax); USE_TRT_CONVERTER(batch_norm); USE_TRT_CONVERTER(concat); +USE_TRT_CONVERTER(dropout); diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 9d7be2d03cf7bb12afe7e52d9630f184d689dc25..fac1babf6ec6131f84d3e3b9fc6efedd9f9f6cfc 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -1,7 +1,7 @@ # Add TRT tests nv_library(tensorrt_converter SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc -batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc +batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc DEPS tensorrt_engine operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS @@ -24,6 +24,8 @@ nv_test(test_trt_softmax_op SRCS test_softmax_op.cc softmax_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine softmax_op SERIAL) nv_test(test_trt_batch_norm_op SRCS test_batch_norm_op.cc batch_norm_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine batch_norm_op SERIAL) - nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine concat_op SERIAL) + +nv_test(test_trt_dropout_op SRCS test_dropout_op.cc dropout_op.cc + DEPS ${FLUID_CORE_MODULES} tensorrt_engine dropout_op SERIAL) diff --git a/paddle/fluid/inference/tensorrt/convert/activation_op.cc b/paddle/fluid/inference/tensorrt/convert/activation_op.cc index 8168cdff1b85fc05d22fbec7fac6ab8892f3a907..e73c5bbf57501e4ff3c080a46d91685035652bfa 100644 --- a/paddle/fluid/inference/tensorrt/convert/activation_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/activation_op.cc @@ -19,23 +19,31 @@ namespace paddle { namespace inference { namespace tensorrt { -class ReluOpConverter : public OpConverter { +class ActivationOpConverter : public OpConverter { public: - ReluOpConverter() {} + ActivationOpConverter() {} void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { // Here the two nullptr looks strange, that's because the // framework::OpDesc's constructor is strange. framework::OpDesc op_desc(op, nullptr); - LOG(INFO) << "convert a fluid relu op to tensorrt activation layer whose " - "type is Relu"; + LOG(INFO) + << "convert a fluid Activation op to tensorrt activation layer whose " + "type is " + << op_type_; const nvinfer1::ITensor* input_tensor = engine_->GetITensor(op_desc.Input("X")[0]); + + auto op_pair = ops.find(op_type_); + if (op_pair == ops.end()) { + PADDLE_THROW("Wrong activation op type!"); + } + nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER( engine_, Activation, *const_cast(input_tensor), - nvinfer1::ActivationType::kRELU); + op_pair->second); auto output_name = op_desc.Output("Out")[0]; - layer->setName(("relu (Output: " + output_name + ")").c_str()); + layer->setName((op_type_ + " (Output: " + output_name + ")").c_str()); layer->getOutput(0)->setName(output_name.c_str()); engine_->SetITensor(output_name, layer->getOutput(0)); if (test_mode) { // the test framework can not determine which is the @@ -43,6 +51,32 @@ class ReluOpConverter : public OpConverter { engine_->DeclareOutput(output_name); } } + + protected: + std::string op_type_; + static const std::unordered_map ops; +}; + +const std::unordered_map + ActivationOpConverter::ops = { + {"relu", nvinfer1::ActivationType::kRELU}, + {"sigmoid", nvinfer1::ActivationType::kSIGMOID}, + {"tanh", nvinfer1::ActivationType::kTANH}, +}; + +class ReluOpConverter : public ActivationOpConverter { + public: + ReluOpConverter() { op_type_ = "relu"; } +}; + +class SigmoidOpConverter : public ActivationOpConverter { + public: + SigmoidOpConverter() { op_type_ = "sigmoid"; } +}; + +class TanhOpConverter : public ActivationOpConverter { + public: + TanhOpConverter() { op_type_ = "tanh"; } }; } // namespace tensorrt @@ -50,3 +84,5 @@ class ReluOpConverter : public OpConverter { } // namespace paddle REGISTER_TRT_OP_CONVERTER(relu, ReluOpConverter); +REGISTER_TRT_OP_CONVERTER(sigmoid, SigmoidOpConverter); +REGISTER_TRT_OP_CONVERTER(tanh, TanhOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/dropout_op.cc b/paddle/fluid/inference/tensorrt/convert/dropout_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..9533ecbcfda4e2500fd201d8efc64fc5bd97169a --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/dropout_op.cc @@ -0,0 +1,71 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * DropoutOp. This Layer doesn't has weights. + */ +class DropoutOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(4) << "convert a fluid dropout op to tensorrt dropout layer"; + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]); + float dropout_prob = boost::get(op_desc.GetAttr("dropout_prob")); + + platform::CPUPlace cpu_place; + std::unique_ptr weight_tensor( + new framework::LoDTensor()); + weight_tensor->Resize(framework::make_ddim({1})); + auto* weight_data = + weight_tensor->mutable_data(platform::CPUPlace()); + weight_data[0] = 1 - dropout_prob; + + TensorRTEngine::Weight scale_weights{ + nvinfer1::DataType::kFLOAT, static_cast(weight_data), + weight_tensor->memory_size() / sizeof(float)}; + TensorRTEngine::Weight shift_weights{nvinfer1::DataType::kFLOAT, nullptr, + 0}; + TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr, + 0}; + + auto* layer = TRT_ENGINE_ADD_LAYER( + engine_, Scale, *const_cast(input1), + nvinfer1::ScaleMode::kUNIFORM, shift_weights.get(), scale_weights.get(), + power_weights.get()); + + engine_->weight_map[op_desc.Output("Out").front() + "_dropout"] = + std::move(weight_tensor); + auto output_name = op_desc.Output("Out")[0]; + layer->setName(("dropout (Output: " + output_name + ")").c_str()); + engine_->SetITensor(output_name, layer->getOutput(0)); + if (test_mode) { + engine_->DeclareOutput(output_name); + } + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +USE_OP(dropout); +REGISTER_TRT_OP_CONVERTER(dropout, DropoutOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc b/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc index e82762ea03ecd00bce7cfb83b130a3436ccbfed3..dd3dfb0bc7b609e28462954835a0d40e0a63b6cd 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc @@ -20,18 +20,18 @@ namespace paddle { namespace inference { namespace tensorrt { -TEST(ReluOpConverter, main) { +void test_activation(std::string act_type) { framework::Scope scope; std::unordered_set parameters; TRTConvertValidation validator(10, parameters, scope, 1000); - validator.DeclInputVar("relu-X", nvinfer1::Dims2(10, 6)); - validator.DeclOutputVar("relu-Out", nvinfer1::Dims2(10, 6)); + validator.DeclInputVar("act-X", nvinfer1::Dims2(10, 6)); + validator.DeclOutputVar("act-Out", nvinfer1::Dims2(10, 6)); // Prepare Op description framework::OpDesc desc; - desc.SetType("relu"); - desc.SetInput("X", {"relu-X"}); - desc.SetOutput("Out", {"relu-Out"}); + desc.SetType(act_type); + desc.SetInput("X", {"act-X"}); + desc.SetOutput("Out", {"act-Out"}); LOG(INFO) << "set OP"; validator.SetOp(*desc.Proto()); @@ -40,8 +40,16 @@ TEST(ReluOpConverter, main) { validator.Execute(5); } +TEST(ReluOpConverter, main) { test_activation("relu"); } + +TEST(SigmoidOpConverter, main) { test_activation("sigmoid"); } + +TEST(TanhOpConverter, main) { test_activation("tanh"); } + } // namespace tensorrt } // namespace inference } // namespace paddle USE_OP(relu); +USE_OP(sigmoid); +USE_OP(tanh); diff --git a/paddle/fluid/inference/tensorrt/convert/test_dropout_op.cc b/paddle/fluid/inference/tensorrt/convert/test_dropout_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..6b8e621b702d977f5868766a6eafb98c8522c3cd --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/test_dropout_op.cc @@ -0,0 +1,58 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +TEST(DropoutOpConverter, main) { + framework::Scope scope; + std::unordered_set parameters; + TRTConvertValidation validator(8, parameters, scope, 1000); + + std::vector tensor_shape{8, 10}; + validator.DeclInputVar("dropout-X", tensor_shape, + nvinfer1::DimsCHW(10, 1, 1)); + validator.DeclOutputVar("dropout-Out", nvinfer1::DimsCHW(10, 1, 1)); + validator.DeclOutputVar("mask-Out", nvinfer1::DimsCHW(10, 1, 1)); + + // Prepare Op description + framework::OpDesc desc; + int is_test = 1; + float dropout_prob = 0.4; + + desc.SetType("dropout"); + desc.SetInput("X", {"dropout-X"}); + desc.SetOutput("Mask", {"mask-Out"}); + desc.SetOutput("Out", {"dropout-Out"}); + desc.SetAttr("is_test", is_test); + desc.SetAttr("dropout_prob", dropout_prob); + + LOG(INFO) << "set OP"; + validator.SetOp(*desc.Proto()); + LOG(INFO) << "execute"; + + std::unordered_set neglected_output = {"mask-Out"}; + + validator.Execute(8, neglected_output); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +USE_OP(dropout); diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index e397457061662c8afb9760ef52406c22caaeb213..508ef1ce40aa0882a0f39a85f97511fd9ea2a8a5 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -1,73 +1,87 @@ -set(INFERENCE_URL "http://paddle-inference-dist.bj.bcebos.com") -set(INFERENCE_DEMO_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo") +set(INFERENCE_URL "http://paddle-inference-dist.cdn.bcebos.com") +set(INFERENCE_DEMO_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo" CACHE STRING + "A path setting inference demo download directories.") set(INFERENCE_EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis_predictor) -function (inference_download_and_uncompress install_dir filename) - message(STATUS "Download inference test stuff from ${INFERENCE_URL}/${filename}") +function (inference_download install_dir url filename) + message(STATUS "Download inference test stuff from ${url}/${filename}") execute_process(COMMAND bash -c "mkdir -p ${install_dir}") - execute_process(COMMAND bash -c "cd ${install_dir} && wget -q ${INFERENCE_URL}/${filename}") - execute_process(COMMAND bash -c "cd ${install_dir} && tar xzf ${filename}") + execute_process(COMMAND bash -c "cd ${install_dir} && wget -q ${url}/${filename}") message(STATUS "finish downloading ${filename}") -endfunction(inference_download_and_uncompress) +endfunction() + +function (inference_download_and_uncompress install_dir url filename) + inference_download(${install_dir} ${url} ${filename}) + execute_process(COMMAND bash -c "cd ${install_dir} && tar xzf ${filename}") +endfunction() function(download_model_and_data install_dir model_name data_name) - if (NOT EXISTS ${install_dir} AND WITH_INFERENCE) - inference_download_and_uncompress(${install_dir} ${model_name}) - inference_download_and_uncompress(${install_dir} ${data_name}) + if (NOT EXISTS ${install_dir}) + inference_download_and_uncompress(${install_dir} ${INFERENCE_URL} ${model_name}) + inference_download_and_uncompress(${install_dir} ${INFERENCE_URL} ${data_name}) endif() endfunction() +function(inference_analysis_api_test target install_dir filename) + inference_analysis_test(${target} SRCS ${filename} + EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} + ARGS --infer_model=${install_dir}/model --infer_data=${install_dir}/data.txt) +endfunction() + # RNN1 -set(RNN1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/rnn1") -download_model_and_data(${RNN1_INSTALL_DIR} "rnn1%2Fmodel.tar.gz" "rnn1%2Fdata.txt.tar.gz") -inference_analysis_test(test_analyzer_rnn1 SRCS analyzer_rnn1_tester.cc - EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} - ARGS --infer_model=${RNN1_INSTALL_DIR}/model - --infer_data=${RNN1_INSTALL_DIR}/data.txt) +if(NOT APPLE) + set(RNN1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/rnn1") + download_model_and_data(${RNN1_INSTALL_DIR} "rnn1%2Fmodel.tar.gz" "rnn1%2Fdata.txt.tar.gz") + inference_analysis_api_test(test_analyzer_rnn1 ${RNN1_INSTALL_DIR} analyzer_rnn1_tester.cc) +else() + # TODO: fix this test on MACOS, the reason is that + # fusion_seqexpand_concat_fc_op is not supported on MACOS + message(WARNING "These tests has been disabled in OSX before being fixed: \n test_analyzer_rnn1") +endif() # RNN2 set(RNN2_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/rnn2") download_model_and_data(${RNN2_INSTALL_DIR} "rnn2_model.tar.gz" "rnn2_data.txt.tar.gz") -inference_analysis_test(test_analyzer_rnn2 SRCS analyzer_rnn2_tester.cc - EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} - ARGS --infer_model=${RNN2_INSTALL_DIR}/model - --infer_data=${RNN2_INSTALL_DIR}/data.txt) +inference_analysis_api_test(test_analyzer_rnn2 ${RNN2_INSTALL_DIR} analyzer_rnn2_tester.cc) # chinese_ner set(CHINESE_NER_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/chinese_ner") download_model_and_data(${CHINESE_NER_INSTALL_DIR} "chinese_ner_model.tar.gz" "chinese_ner-data.txt.tar.gz") -inference_analysis_test(test_analyzer_ner SRCS analyzer_ner_tester.cc - EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} - ARGS --infer_model=${CHINESE_NER_INSTALL_DIR}/model - --infer_data=${CHINESE_NER_INSTALL_DIR}/data.txt) +inference_analysis_api_test(test_analyzer_ner ${CHINESE_NER_INSTALL_DIR} analyzer_ner_tester.cc) # lac set(LAC_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/lac") download_model_and_data(${LAC_INSTALL_DIR} "lac_model.tar.gz" "lac_data.txt.tar.gz") -inference_analysis_test(test_analyzer_lac SRCS analyzer_lac_tester.cc - EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} - ARGS --infer_model=${LAC_INSTALL_DIR}/model - --infer_data=${LAC_INSTALL_DIR}/data.txt) +inference_analysis_api_test(test_analyzer_lac ${LAC_INSTALL_DIR} analyzer_lac_tester.cc) # text_classification set(TEXT_CLASSIFICATION_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/text_classification") download_model_and_data(${TEXT_CLASSIFICATION_INSTALL_DIR} "text-classification-Senta.tar.gz" "text_classification_data.txt.tar.gz") -inference_analysis_test(test_analyzer_text_classification SRCS analyzer_text_classification_tester.cc - EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} - ARGS --infer_model=${TEXT_CLASSIFICATION_INSTALL_DIR}/model - --infer_data=${TEXT_CLASSIFICATION_INSTALL_DIR}/data.txt) +inference_analysis_api_test(test_analyzer_text_classification ${TEXT_CLASSIFICATION_INSTALL_DIR} analyzer_text_classification_tester.cc) # ocr -set(OCR_MODEL_URL "http://paddlemodels.cdn.bcebos.com/inference-vis-demos%2Focr.tar.gz") -set(OCR_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo/ocr") -if (NOT EXISTS ${OCR_INSTALL_DIR} AND WITH_INFERENCE) - get_filename_component(filename ${OCR_MODEL_URL} NAME) - message(STATUS "Download inference test stuff ${filename} from ${OCR_MODEL_URL}") - execute_process(COMMAND bash -c "mkdir -p ${OCR_INSTALL_DIR}") - execute_process(COMMAND bash -c "cd ${OCR_INSTALL_DIR} && wget -q ${OCR_MODEL_URL}") - execute_process(COMMAND bash -c "cd ${OCR_INSTALL_DIR} && tar xzf ${filename}") - message(STATUS "finish downloading ${filename}") +set(OCR_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/ocr") +if (NOT EXISTS ${OCR_INSTALL_DIR}) + inference_download_and_uncompress(${OCR_INSTALL_DIR} "http://paddlemodels.cdn.bcebos.com/" "inference-vis-demos%2Focr.tar.gz") +endif() +inference_analysis_api_test(test_analyzer_ocr ${OCR_INSTALL_DIR} analyzer_vis_tester.cc) + +# anakin +if (WITH_ANAKIN AND WITH_MKL) # only needed in CI + # anakin rnn1 + set(ANAKIN_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/anakin") + set(ANAKIN_RNN1_INSTALL_DIR "${ANAKIN_INSTALL_DIR}/rnn1") + inference_download(${ANAKIN_RNN1_INSTALL_DIR} ${INFERENCE_URL} "anakin_test%2Fditu_rnn.anakin2.model.bin") + inference_download(${ANAKIN_RNN1_INSTALL_DIR} ${INFERENCE_URL} "anakin_test%2Fditu_rnn_data.txt") + cc_test(test_anakin_rnn1 SRCS anakin_rnn1_tester.cc + ARGS --model=${ANAKIN_RNN1_INSTALL_DIR}/anakin_test%2Fditu_rnn.anakin2.model.bin + --datapath=${ANAKIN_RNN1_INSTALL_DIR}/anakin_test%2Fditu_rnn_data.txt + DEPS inference_anakin_api_shared SERIAL) + # anakin mobilenet + if(WITH_GPU) + set(ANAKIN_MOBILENET_INSTALL_DIR "${ANAKIN_INSTALL_DIR}/mobilenet") + inference_download(${ANAKIN_MOBILENET_INSTALL_DIR} ${INFERENCE_URL} "mobilenet_v2.anakin.bin") + cc_test(test_anakin_mobilenet SRCS anakin_mobilenet_tester.cc + ARGS --model=${ANAKIN_MOBILENET_INSTALL_DIR}/mobilenet_v2.anakin.bin + DEPS inference_anakin_api_shared dynload_cuda SERIAL) + endif() endif() -inference_analysis_test(test_analyzer_ocr SRCS analyzer_vis_tester.cc - EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} - ARGS --infer_model=${OCR_INSTALL_DIR}/model - --infer_data=${OCR_INSTALL_DIR}/data.txt) diff --git a/paddle/fluid/inference/api/api_anakin_engine_tester.cc b/paddle/fluid/inference/tests/api/anakin_mobilenet_tester.cc similarity index 100% rename from paddle/fluid/inference/api/api_anakin_engine_tester.cc rename to paddle/fluid/inference/tests/api/anakin_mobilenet_tester.cc diff --git a/paddle/fluid/inference/api/api_anakin_engine_rnn_tester.cc b/paddle/fluid/inference/tests/api/anakin_rnn1_tester.cc similarity index 100% rename from paddle/fluid/inference/api/api_anakin_engine_rnn_tester.cc rename to paddle/fluid/inference/tests/api/anakin_rnn1_tester.cc diff --git a/paddle/fluid/memory/detail/buddy_allocator.cc b/paddle/fluid/memory/detail/buddy_allocator.cc index c2f45fdc99b87bc12c2aadf1985de6e98a24fce7..26ef27c3caafadb4801b0ae52133f6175655ce0a 100644 --- a/paddle/fluid/memory/detail/buddy_allocator.cc +++ b/paddle/fluid/memory/detail/buddy_allocator.cc @@ -167,6 +167,8 @@ void BuddyAllocator::Free(void* p) { } size_t BuddyAllocator::Used() { return total_used_; } +size_t BuddyAllocator::GetMinChunkSize() { return min_chunk_size_; } +size_t BuddyAllocator::GetMaxChunkSize() { return max_chunk_size_; } void* BuddyAllocator::SystemAlloc(size_t size) { size_t index = 0; diff --git a/paddle/fluid/memory/detail/buddy_allocator.h b/paddle/fluid/memory/detail/buddy_allocator.h index f0c83efc23ce39c4fc89296d672e1e55751851bf..3f86a51f0d0b8504bbc4b0477f123093b343e9cf 100644 --- a/paddle/fluid/memory/detail/buddy_allocator.h +++ b/paddle/fluid/memory/detail/buddy_allocator.h @@ -42,6 +42,8 @@ class BuddyAllocator { void* Alloc(size_t unaligned_size); void Free(void* ptr); size_t Used(); + size_t GetMinChunkSize(); + size_t GetMaxChunkSize(); public: // Disable copy and assignment diff --git a/paddle/fluid/memory/malloc.cc b/paddle/fluid/memory/malloc.cc index 7c800b3c164049244770ceb2070b177d8307e85e..283745e977533358ef52521b36e67f0ada950e61 100644 --- a/paddle/fluid/memory/malloc.cc +++ b/paddle/fluid/memory/malloc.cc @@ -119,8 +119,8 @@ void* Alloc(platform::CUDAPlace place, size_t size) { LOG(WARNING) << "Cannot allocate " << size << " bytes in GPU " << place.device << ", available " << avail << " bytes"; LOG(WARNING) << "total " << total; - LOG(WARNING) << "GpuMinChunkSize " << platform::GpuMinChunkSize(); - LOG(WARNING) << "GpuMaxChunkSize " << platform::GpuMaxChunkSize(); + LOG(WARNING) << "GpuMinChunkSize " << buddy_allocator->GetMinChunkSize(); + LOG(WARNING) << "GpuMaxChunkSize " << buddy_allocator->GetMaxChunkSize(); LOG(WARNING) << "GPU memory used: " << Used(place); platform::SetDeviceId(cur_dev); } diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 7ec1e78da4ec642cb1e6248edfbcfed748fa11b8..ccb7fa1f8cce8cc757038904bce762af3b5ff30b 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -296,6 +296,7 @@ op_library(flatten_op DEPS reshape_op) op_library(sequence_pad_op DEPS sequence_padding) op_library(unstack_op DEPS stack_op) op_library(fake_quantize_op DEPS memory) +op_library(fusion_lstm_op DEPS cpu_lstm_compute) if (WITH_GPU) op_library(conv_op DEPS vol2col depthwise_conv im2col) diff --git a/paddle/fluid/operators/distributed/request_handler.h b/paddle/fluid/operators/distributed/request_handler.h index 3c3f9d17c871ac1cb4df83db17cf489d5b9e0563..3dbbd75b1e945208395c42ace3235db7891936c5 100644 --- a/paddle/fluid/operators/distributed/request_handler.h +++ b/paddle/fluid/operators/distributed/request_handler.h @@ -56,7 +56,7 @@ class VarHandle { const std::string& name, const platform::DeviceContext* p_ctx = nullptr, const framework::Scope* p_scope = nullptr) - : ok_(kVarHandleDefaultState) { + : status_(kDefaultState) { ep_ = ep; ctx_ = p_ctx; scope_ = p_scope; @@ -68,18 +68,20 @@ class VarHandle { public: bool Wait() { + int ret = kDefaultState; { std::unique_lock lk(sync_mutex_); - wait_cond_.wait(lk, [this] { return ok_ != kVarHandleDefaultState; }); + wait_cond_.wait(lk, [this] { return status_ != kDefaultState; }); + ret = status_; } - VLOG(7) << "VarHandle wait:" << ok_; - return ok_ != 0; + VLOG(7) << "VarHandle wait:" << ret; + return ret != kErrorState; } void Finish(bool ok) { { std::unique_lock lk(sync_mutex_); - ok_ = ok; + status_ = ok ? kFinishState : kErrorState; } VLOG(7) << "VarHandle finish:" << ok; wait_cond_.notify_all(); @@ -87,8 +89,8 @@ class VarHandle { std::string String() const { std::ostringstream s; - s << method_ << " name:[" << name_ << "], ep:[" << ep_ << "], ok:[" << ok_ - << "]"; + s << method_ << " name:[" << name_ << "], ep:[" << ep_ << "], status:[" + << status_ << "]"; return s.str(); } @@ -111,9 +113,13 @@ class VarHandle { protected: std::mutex sync_mutex_; std::condition_variable wait_cond_; - int ok_; - static const int kVarHandleDefaultState = -1; + enum VarHandleStatus { + kDefaultState = -1, + kErrorState = 0, + kFinishState = 1, + }; + VarHandleStatus status_; private: DISABLE_COPY_AND_ASSIGN(VarHandle); diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index 55e465e3af08c012b8cff7714452ed32b32a5556..8ca79d20ec4f6412b00dbf3990068f81b65e2efd 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/fluid/operators/fusion_lstm_op.h" #include #include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/cpu_lstm_compute.h" #include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/sequence2batch.h" @@ -269,7 +270,6 @@ class FuisonLSTMKernel : public framework::OpKernel { blas.GEMM(CblasNoTrans, CblasNoTrans, bs, D4, D, static_cast(1), prev, D, \ wh_data, D4, static_cast(1), out, D4) -// gates: W_ch, W_ih, W_fh, W_oh #define GET_Ct(ct_1, gates, ct) \ /* C_t = C_t-1 * fgated + cand_gated * igated*/ \ act_cand(D, gates, gates); \ @@ -395,11 +395,22 @@ class FuisonLSTMKernel : public framework::OpKernel { } } } else { + // TODO(TJ): unly workaround, clean me + std::function compute_ctht; + if (platform::jit::MayIUse(platform::jit::avx) && + act_gate_str == "sigmoid" && act_cand_str == "tanh" && + act_cell_str == "tanh" && D == 8) { + compute_ctht = math::lstm_compute_ctht; + } else { + compute_ctht = [&](T* gates, const T* ct_1, T* ct, T* ht) { + COMPUTE_CtHt(gates, ct_1, ct, ht); + }; + } for (int i = 0; i < N; ++i) { PROCESS_H0C0 for (int step = tstart; step < seq_len; ++step) { GEMM_WH_ADDON(1, prev_h_data, xx_data); - COMPUTE_CtHt(xx_data, prev_c_data, c_out_data, h_out_data); + compute_ctht(xx_data, prev_c_data, c_out_data, h_out_data); MOVE_ONE_STEP; } } @@ -532,12 +543,23 @@ class FuisonLSTMKernel : public framework::OpKernel { MOVE_ONE_STEP; } } else { + // TODO(TJ): unly workaround, clean me + std::function compute_ctht; + if (platform::jit::MayIUse(platform::jit::avx) && + act_gate_str == "sigmoid" && act_cand_str == "tanh" && + act_cell_str == "tanh" && D == 8) { + compute_ctht = math::lstm_compute_ctht; + } else { + compute_ctht = [&](T* gates, const T* ct_1, T* ct, T* ht) { + COMPUTE_CtHt(gates, ct_1, ct, ht); + }; + } for (int step = tstart; step < max_seq_len; ++step) { const int cur_bs = batch_starts[step + 1] - batch_starts[step]; GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data); DEFINE_CUR; for (int i = 0; i < cur_bs; ++i) { - COMPUTE_CtHt(cur_in_data, cur_prev_c_data, cur_c_out_data, + compute_ctht(cur_in_data, cur_prev_c_data, cur_c_out_data, cur_h_out_data); MOVE_ONE_BATCH; } diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index d7f0f3c6280db7d121bf8821ec6d578e22a33da6..91101356436c26171eaca2fe01dfd4d937e71717 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -45,6 +45,8 @@ math_library(im2col) if (NOT WIN32) # windows do not support avx functions yet. math_library(gru_compute DEPS activation_functions math_function) math_library(lstm_compute DEPS activation_functions) +# TODO(TJ): ugly workaround, clean me +cc_library(cpu_lstm_compute SRCS cpu_lstm_compute.cc DEPS activation_functions cblas cpu_info) endif (NOT WIN32) cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context) diff --git a/paddle/fluid/operators/math/cpu_lstm_compute.cc b/paddle/fluid/operators/math/cpu_lstm_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..58e6512021203664573a0478dade052f92dd70bb --- /dev/null +++ b/paddle/fluid/operators/math/cpu_lstm_compute.cc @@ -0,0 +1,18 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/cpu_lstm_compute.h" + +namespace paddle { +namespace operators { +namespace math {} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/cpu_lstm_compute.h b/paddle/fluid/operators/math/cpu_lstm_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..28b6f71729edf1b8cc5d610d76af78dea213313e --- /dev/null +++ b/paddle/fluid/operators/math/cpu_lstm_compute.h @@ -0,0 +1,81 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/operators/math/cpu_vec.h" +#include "paddle/fluid/platform/cpu_info.h" +#ifdef __AVX__ +#include +#endif + +namespace paddle { +namespace operators { +namespace math { + +// TODO(TJ): ugly workaround, clean me +template +void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) { + // gates: W_ch, W_ih, W_fh, W_oh + vec_sigmoid(24, gates + 8, gates + 8); + vec_tanh(8, gates, gates); + const T *i = gates + 8, *f = gates + 16, *o = gates + 24; + const T min = SIGMOID_THRESHOLD_MIN; + const T max = SIGMOID_THRESHOLD_MAX; + for (int d = 0; d < 8; ++d) { + // C_t = C_t-1 * fgated + cand_gated * igated + ct[d] = ct_1[d] * f[d] + gates[d] * i[d]; + // H_t = act_cell(C_t) * ogated + T tmp = ct[d] * 2; + tmp = static_cast(0) - ((tmp < min) ? min : ((tmp > max) ? max : tmp)); + vec_exp(1, &tmp, &tmp); + tmp = static_cast(2) / (static_cast(1) + tmp) - static_cast(1); + ht[d] = tmp * o[d]; + } +} + +#ifdef __AVX__ +namespace detail { +namespace forward { +namespace avx { +__m256 Sigmoid(const __m256 a); +__m256 Tanh(const __m256 a); +} // namespace avx +} // namespace forward +} // namespace detail + +template <> +void lstm_compute_ctht(float* gates, const float* ct_1, float* ct, + float* ht) { + namespace act = detail::forward::avx; + // gates: W_ch, W_ih, W_fh, W_oh + __m256 c, i, f, o; + c = _mm256_loadu_ps(gates); + i = _mm256_loadu_ps(gates + 8); + f = _mm256_loadu_ps(gates + 16); + o = _mm256_loadu_ps(gates + 24); + + /* C_t = C_t-1 * fgated + cand_gated * igated*/ + c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i)); + i = _mm256_loadu_ps(ct_1); + f = _mm256_mul_ps(i, act::Sigmoid(f)); + f = _mm256_add_ps(c, f); + _mm256_storeu_ps(ct, f); + + /* H_t = act_cell(C_t) * ogated */ + o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o)); + _mm256_storeu_ps(ht, o); +} +#endif + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h index 9560e3a3c15ca63892fbe3552679a22f027f11e2..6a059968b79189458349e466079cc7a663a8e5ff 100644 --- a/paddle/fluid/operators/math/cpu_vec.h +++ b/paddle/fluid/operators/math/cpu_vec.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/platform/cpu_info.h" +#include "paddle/fluid/platform/enforce.h" #ifdef __AVX__ #include #endif @@ -476,7 +477,7 @@ class VecActivations { } else if (type == "identity" || type == "") { return vec_identity; } - LOG(FATAL) << "Not support type: " << type; + PADDLE_THROW("Not support type: %s", type); } }; diff --git a/paddle/fluid/operators/maxout_op.cc b/paddle/fluid/operators/maxout_op.cc index 058115cb624627d81b31d0903f7d615d19708c77..b2543d3d0d80f0573f2cbc755318c1b5a0982324 100644 --- a/paddle/fluid/operators/maxout_op.cc +++ b/paddle/fluid/operators/maxout_op.cc @@ -71,8 +71,7 @@ class MaxOutOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of MaxoutOp" - "should not be null."); + "Input(X) of MaxoutOpshould not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of MaxoutOp should not be null."); auto in_x_dims = ctx->GetInputDim("X"); @@ -90,9 +89,10 @@ class MaxOutOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of MaxOutOpGrad must not be null."); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), - "Input(X@GRAD) should not be null."); + "Output(Grad@X) of MaxOutOpGrad should not be null."); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } }; diff --git a/paddle/fluid/operators/rnn_memory_helper_op.cc b/paddle/fluid/operators/rnn_memory_helper_op.cc index 23e5fc1112d0b1e634d0ab288721cbba57b3ffe5..13df1d4b4bb6c240610f96ccc8f223fc984d63f7 100644 --- a/paddle/fluid/operators/rnn_memory_helper_op.cc +++ b/paddle/fluid/operators/rnn_memory_helper_op.cc @@ -42,7 +42,7 @@ class RNNMemoryHelperOp : public framework::OperatorBase { auto *out_tensor = out_var->GetMutable(); auto &mem_tensor = mem_var->Get(); - out_tensor->ShareDataWith(mem_tensor); + framework::TensorCopySync(mem_tensor, dev_place, out_tensor); out_tensor->set_lod(mem_tensor.lod()); } }; @@ -50,8 +50,10 @@ class RNNMemoryHelperOp : public framework::OperatorBase { class RNNMemoryHelperOpShapeInference : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), ""); - PADDLE_ENFORCE(ctx->HasOutput("Out"), ""); + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of rnn_memory_helper op should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output of rnn_memory_helper op should not be null."); ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->ShareLoD("X", /*->*/ "Out"); } @@ -107,7 +109,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase { } else { auto &out_grad_tensor = out_grad_var->Get(); auto *in_grad_tensor = in_grad_var->GetMutable(); - in_grad_tensor->ShareDataWith(out_grad_tensor); + framework::TensorCopySync(out_grad_tensor, dev_place, in_grad_tensor); in_grad_tensor->set_lod(out_grad_tensor.lod()); } } @@ -133,8 +135,11 @@ class RNNMemoryHelperGradOpShapeInference : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext *ctx) const override { auto x_grad_name = framework::GradVarName("X"); - PADDLE_ENFORCE(ctx->HasOutput(x_grad_name), ""); - PADDLE_ENFORCE(ctx->HasInput("X"), ""); + PADDLE_ENFORCE(ctx->HasOutput(x_grad_name), + "Gradient of Input(X) in rnn_memory_helper_grad of should " + "not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of rnn_memory_helper_grad of should not be null."); ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X")); ctx->ShareLoD("X", /*->*/ x_grad_name); } diff --git a/paddle/fluid/operators/slice_op.cc b/paddle/fluid/operators/slice_op.cc index 4bd23d594134f227e86b01fd75b7e202dd76c11b..e55462d6cfe389033a9c24a464fbf5b5d699f34f 100644 --- a/paddle/fluid/operators/slice_op.cc +++ b/paddle/fluid/operators/slice_op.cc @@ -25,7 +25,7 @@ class SliceOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Input"), "Input (Input) of slice op should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), @@ -58,7 +58,7 @@ class SliceOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { + const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( framework::ToDataType(ctx.Input("Input")->type()), ctx.GetPlace()); @@ -87,13 +87,13 @@ Slice Operator. Produces a slice of the input tensor along multiple axes. Similar to numpy: https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html -Slice uses `axes`, `starts` and `ends` attributes to specify the start and +Slice uses `axes`, `starts` and `ends` attributes to specify the start and end dimension for each axis in the list of axes, it uses this information -to slice the input data tensor. If a negative value is passed for any of -the start or end indices, it represents number of elements before the end +to slice the input data tensor. If a negative value is passed for any of +the start or end indices, it represents number of elements before the end of that dimension. If the value passed to start or end is larger than -the n (the number of elements in this dimension), it represents n. -For slicing to the end of a dimension with unknown size, it is recommended +the n (the number of elements in this dimension), it represents n. +For slicing to the end of a dimension with unknown size, it is recommended to pass in INT_MAX. If axes are omitted, they are set to [0, ..., ndim-1]. Following examples will explain how slice works: @@ -119,15 +119,54 @@ Following examples will explain how slice works: } }; +class SliceOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx->GetInputDim("Input"); + auto x_grad_name = framework::GradVarName("Input"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + } +}; + +class SliceOpGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* bind = new framework::OpDesc(); + bind->SetInput("Input", Input("Input")); + bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + bind->SetOutput(framework::GradVarName("Input"), InputGrad("Input")); + bind->SetAttrMap(Attrs()); + bind->SetType("slice_grad"); + return std::unique_ptr(bind); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(slice, ops::SliceOp, ops::SliceOpMaker, - paddle::framework::EmptyGradOpMaker); + ops::SliceOpGradMaker); +REGISTER_OPERATOR(slice_grad, ops::SliceOpGrad); REGISTER_OP_CPU_KERNEL( slice, ops::SliceKernel, ops::SliceKernel, ops::SliceKernel, ops::SliceKernel); + +REGISTER_OP_CPU_KERNEL( + slice_grad, ops::SliceGradKernel, + ops::SliceGradKernel, + ops::SliceGradKernel, + ops::SliceGradKernel); diff --git a/paddle/fluid/operators/slice_op.cu b/paddle/fluid/operators/slice_op.cu index 8c1767c70b19d1386af9610ef3405eb487a39878..5efecb78d1a4eaffc3a9c62e1e82a9bcb5922748 100644 --- a/paddle/fluid/operators/slice_op.cu +++ b/paddle/fluid/operators/slice_op.cu @@ -20,3 +20,10 @@ REGISTER_OP_CUDA_KERNEL( ops::SliceKernel, ops::SliceKernel, ops::SliceKernel); + +REGISTER_OP_CUDA_KERNEL( + slice_grad, + ops::SliceGradKernel, + ops::SliceGradKernel, + ops::SliceGradKernel, + ops::SliceGradKernel); diff --git a/paddle/fluid/operators/slice_op.h b/paddle/fluid/operators/slice_op.h index ba231aee176564b91a642912ce0b32bcdef8cfc1..f38d08d7640794bd9a456a6c4ee1da2e04e96b37 100644 --- a/paddle/fluid/operators/slice_op.h +++ b/paddle/fluid/operators/slice_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include +#include #include #include "paddle/fluid/framework/op_registry.h" @@ -84,5 +85,79 @@ class SliceKernel : public framework::OpKernel { out_t.device(place) = in_t.slice(offsets, extents); } }; + +template +class SliceGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + size_t rank = ctx.Input(framework::GradVarName("Out")) + ->dims() + .size(); + switch (rank) { + case 1: + SliceCompute<1>(ctx); + break; + case 2: + SliceCompute<2>(ctx); + break; + case 3: + SliceCompute<3>(ctx); + break; + case 4: + SliceCompute<4>(ctx); + break; + case 5: + SliceCompute<5>(ctx); + break; + case 6: + SliceCompute<6>(ctx); + break; + } + } + + private: + template + void SliceCompute(const framework::ExecutionContext& context) const { + auto& place = + *context.template device_context().eigen_device(); + auto* d_out = + context.Input(framework::GradVarName("Out")); + auto* d_input = + context.Output(framework::GradVarName("Input")); + d_input->mutable_data(context.GetPlace()); + auto out_dims = d_out->dims(); + auto in_dims = d_input->dims(); + auto axes = context.Attr>("axes"); + auto starts = context.Attr>("starts"); + + auto offsets = Eigen::array(); + auto extents = Eigen::array(); + for (size_t i = 0; i < D; ++i) { + offsets[i] = 0; + extents[i] = out_dims[i]; + } + int start; + for (size_t i = 0; i < axes.size(); ++i) { + start = starts[i]; + if (start < 0) { + start = (start + in_dims[axes[i]]); + } + start = std::max(start, 0); + offsets[axes[i]] = start; + } + Eigen::array, D> paddings; + for (size_t i = 0; i < paddings.size(); ++i) { + paddings[i].first = offsets[i]; + paddings[i].second = (in_dims[i] - out_dims[i]) - offsets[i]; + } + auto d_in_t = + framework::EigenTensor::From( + *d_input); + auto d_out_t = + framework::EigenTensor::From( + *d_out); + d_in_t.device(place) = d_out_t.pad(paddings, 0); + } +}; } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/while_op.cc b/paddle/fluid/operators/while_op.cc index 65a3bc928e47ac60f06e7efc75f42703e45acbb4..791138a8c0eb3c477942a8b723206a8f8a3eac77 100644 --- a/paddle/fluid/operators/while_op.cc +++ b/paddle/fluid/operators/while_op.cc @@ -63,7 +63,7 @@ class WhileOp : public framework::OperatorBase { while (cond.data()[0]) { auto ¤t_scope = scope.NewScope(); step_scopes->push_back(¤t_scope); - executor.RunPreparedContext(ctx.get(), ¤t_scope, false); + executor.RunPreparedContext(ctx.get(), ¤t_scope, false, true, true); if (is_test) { scope.DeleteScope(¤t_scope); } @@ -169,7 +169,8 @@ class WhileGradOp : public framework::OperatorBase { } } } - executor.RunPreparedContext(ctx.get(), *cur_scope_iter, false); + executor.RunPreparedContext(ctx.get(), *cur_scope_iter, false, true, + true); auto &pg_names = Outputs(kXGRAD); auto &p_names = Inputs(kX); diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index e25efebe6c3555958f4f75e2b87b7dc45d4a4177..5af8af640e43a5b2e5ee9856f09f66a9fdf4463c 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -51,7 +51,7 @@ ENDIF() # memcpy depends on device_context, here add deps individually for # avoiding cycle dependencies -cc_library(device_context SRCS device_context.cc init.cc DEPS malloc +cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}) nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info) diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index c6f1d1f3d544117311821d980300dffea03891a5..dfc079e986e93c7f02f17b299e5d6293edbedd05 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -210,11 +210,14 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) if (dynload::HasCUDNN()) { cudnn_holder_.reset(new CudnnHolder(&stream_, place)); } + + callback_manager_.reset(new StreamCallbackManager(stream_)); } CUDADeviceContext::~CUDADeviceContext() { SetDeviceId(place_.device); Wait(); + WaitStreamCallback(); PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_)); eigen_stream_.reset(); eigen_device_.reset(); diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 3ed49fc4233d4c0cd6cc16319eda08480ab9b434..79539195157d74d4d757edee5e008cbb76c93ee2 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -31,6 +31,9 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/place.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/stream_callback_manager.h" +#endif #include "unsupported/Eigen/CXX11/Tensor" namespace paddle { @@ -112,6 +115,17 @@ class CUDADeviceContext : public DeviceContext { PADDLE_ENFORCE(cudaEventRecord(ev, stream_)); } + template + void AddStreamCallback(Callback&& callback) const { + std::lock_guard guard(callback_mtx_); + callback_manager_->AddCallback(callback); + } + + void WaitStreamCallback() const { + std::lock_guard guard(callback_mtx_); + callback_manager_->Wait(); + } + private: CUDAPlace place_; @@ -125,7 +139,12 @@ class CUDADeviceContext : public DeviceContext { int multi_process; int max_threads_per_mp; - std::mutex mtx_; + mutable std::mutex mtx_; + + // This lock is only used by callback + // If we use mtx_ for StreamCallbackManager, deadlock may occur sometimes + mutable std::mutex callback_mtx_; + std::unique_ptr callback_manager_; }; template <> diff --git a/paddle/fluid/platform/stream_callback_manager.h b/paddle/fluid/platform/stream_callback_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..6c984065aa5fa1a8875aebe84051ab396bc417ec --- /dev/null +++ b/paddle/fluid/platform/stream_callback_manager.h @@ -0,0 +1,82 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "ThreadPool.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { + +using StreamCallback = std::function; + +class StreamCallbackManager; + +struct StreamCallbackContext { + template + inline StreamCallbackContext(const StreamCallbackManager *manager, + Callback &&callback) + : manager_(manager), callback_(callback) {} + + const StreamCallbackManager *manager_; // do not own + StreamCallback callback_; +}; + +class StreamCallbackManager { + public: + explicit inline StreamCallbackManager(cudaStream_t stream = nullptr) + : stream_(stream), thread_pool_(new ThreadPool(1)) {} + + template + inline void AddCallback(Callback &&callback) const { + AddCallbackWithStreamAndErrorInfo( + [=](cudaStream_t, cudaError_t) { callback(); }); + } + + template + inline void AddCallbackWithStreamAndErrorInfo(Callback &&callback) const { + auto *stream_callback_context = new StreamCallbackContext(this, callback); + PADDLE_ENFORCE(cudaStreamAddCallback( + stream_, StreamCallbackManager::StreamCallbackFunc, + stream_callback_context, 0)); + } + + void Wait() const { thread_pool_.reset(new ThreadPool(1)); } + + private: + const cudaStream_t stream_; + mutable std::unique_ptr thread_pool_; + + // cudaStreamCallback cannot call CUDA API inside, so we have to use + // thread_pool here + static void CUDART_CB StreamCallbackFunc(cudaStream_t stream, + cudaError_t status, + void *user_data) { + auto *callback_context_ptr = + reinterpret_cast(user_data); + callback_context_ptr->manager_->thread_pool_->enqueue([=]() { + std::unique_ptr callback_context( + callback_context_ptr); + callback_context->callback_(stream, status); + }); + } +}; + +} // namespace platform +} // namespace paddle diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index ba5065f468376d83488d7eade5dc2041d86dfd39..77b9b36e68c88eab35bcc1a88ce08a7b5940d55f 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -68,26 +68,44 @@ function cmake_gen() { # Support build for all python versions, currently # including cp27-cp27m and cp27-cp27mu. PYTHON_FLAGS="" - if [ "$1" != "" ]; then - echo "using python abi: $1" - if [ "$1" == "cp27-cp27m" ]; then - export LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs2/lib:${LD_LIBRARY_PATH#/opt/_internal/cpython-2.7.11-ucs4/lib:} - export PATH=/opt/python/cp27-cp27m/bin/:${PATH} - PYTHON_FLAGS="-DPYTHON_EXECUTABLE:FILEPATH=/opt/python/cp27-cp27m/bin/python - -DPYTHON_INCLUDE_DIR:PATH=/opt/python/cp27-cp27m/include/python2.7 - -DPYTHON_LIBRARIES:FILEPATH=/opt/_internal/cpython-2.7.11-ucs2/lib/libpython2.7.so" - elif [ "$1" == "cp27-cp27mu" ]; then - export LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs4/lib:${LD_LIBRARY_PATH#/opt/_internal/cpython-2.7.11-ucs2/lib:} - export PATH=/opt/python/cp27-cp27mu/bin/:${PATH} - PYTHON_FLAGS="-DPYTHON_EXECUTABLE:FILEPATH=/opt/python/cp27-cp27mu/bin/python - -DPYTHON_INCLUDE_DIR:PATH=/opt/python/cp27-cp27mu/include/python2.7 - -DPYTHON_LIBRARIES:FILEPATH=/opt/_internal/cpython-2.7.11-ucs4/lib/libpython2.7.so" - elif [ "$1" == "cp35-cp35m" ]; then - export LD_LIBRARY_PATH=/opt/_internal/cpython-3.5.1/lib/:${LD_LIBRARY_PATH} - export PATH=/opt/_internal/cpython-3.5.1/bin/:${PATH} - export PYTHON_FLAGS="-DPYTHON_EXECUTABLE:FILEPATH=/opt/_internal/cpython-3.5.1/bin/python3 + SYSTEM=`uname -s` + if [ "$SYSTEM" == "Darwin" ]; then + if [[ "$1" == "cp27-cp27m" ]] || [[ "$1" == "" ]]; then + echo "using python abi: $1" + if [ -d "/Library/Frameworks/Python.framework/Versions/2.7" ]; then + export LD_LIBRARY_PATH=/Library/Frameworks/Python.framework/Versions/2.7 + export DYLD_LIBRARY_PATH=/Library/Frameworks/Python.framework/Versions/2.7 + export PATH=/Library/Frameworks/Python.framework/Versions/2.7/bin/:${PATH} + PYTHON_FLAGS="-DPYTHON_EXECUTABLE:FILEPATH=/Library/Frameworks/Python.framework/Versions/2.7/bin/python2.7 + -DPYTHON_INCLUDE_DIR:PATH=/Library/Frameworks/Python.framework/Versions/2.7/include/python2.7 + -DPYTHON_LIBRARY:FILEPATH=/Library/Frameworks/Python.framework/Versions/2.7/lib/libpython2.7.dylib" + else + exit 1 + fi + # TODO: qiyang add python3 part here + fi + else + if [ "$1" != "" ]; then + echo "using python abi: $1" + if [ "$1" == "cp27-cp27m" ]; then + export LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs2/lib:${LD_LIBRARY_PATH#/opt/_internal/cpython-2.7.11-ucs4/lib:} + export PATH=/opt/python/cp27-cp27m/bin/:${PATH} + PYTHON_FLAGS="-DPYTHON_EXECUTABLE:FILEPATH=/opt/python/cp27-cp27m/bin/python + -DPYTHON_INCLUDE_DIR:PATH=/opt/python/cp27-cp27m/include/python2.7 + -DPYTHON_LIBRARIES:FILEPATH=/opt/_internal/cpython-2.7.11-ucs2/lib/libpython2.7.so" + elif [ "$1" == "cp27-cp27mu" ]; then + export LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs4/lib:${LD_LIBRARY_PATH#/opt/_internal/cpython-2.7.11-ucs2/lib:} + export PATH=/opt/python/cp27-cp27mu/bin/:${PATH} + PYTHON_FLAGS="-DPYTHON_EXECUTABLE:FILEPATH=/opt/python/cp27-cp27mu/bin/python + -DPYTHON_INCLUDE_DIR:PATH=/opt/python/cp27-cp27mu/include/python2.7 + -DPYTHON_LIBRARIES:FILEPATH=/opt/_internal/cpython-2.7.11-ucs4/lib/libpython2.7.so" + elif [ "$1" == "cp35-cp35m" ]; then + export LD_LIBRARY_PATH=/opt/_internal/cpython-3.5.1/lib/:${LD_LIBRARY_PATH} + export PATH=/opt/_internal/cpython-3.5.1/bin/:${PATH} + export PYTHON_FLAGS="-DPYTHON_EXECUTABLE:FILEPATH=/opt/_internal/cpython-3.5.1/bin/python3 -DPYTHON_INCLUDE_DIR:PATH=/opt/_internal/cpython-3.5.1/include/python3.5m -DPYTHON_LIBRARIES:FILEPATH=/opt/_internal/cpython-3.5.1/lib/libpython3.so" + fi fi fi @@ -117,6 +135,8 @@ function cmake_gen() { -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DWITH_CONTRIB=${WITH_CONTRIB:-ON} -DWITH_INFERENCE=${WITH_INFERENCE:-ON} + -DWITH_INFERENCE_API_TEST=${WITH_INFERENCE_API_TEST:-ON} + -DINFERENCE_DEMO_INSTALL_DIR=${INFERENCE_DEMO_INSTALL_DIR:-/root/.cache/inference_demo} -DWITH_ANAKIN=${WITH_ANAKIN:-OFF} -DPY_VERSION=${PY_VERSION:-2.7} ======================================== @@ -147,6 +167,8 @@ EOF -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ -DWITH_CONTRIB=${WITH_CONTRIB:-ON} \ -DWITH_INFERENCE=${WITH_INFERENCE:-ON} \ + -DWITH_INFERENCE_API_TEST=${WITH_INFERENCE_API_TEST:-ON} \ + -DINFERENCE_DEMO_INSTALL_DIR=${INFERENCE_DEMO_INSTALL_DIR:-/root/.cache/inference_demo} \ -DWITH_ANAKIN=${WITH_ANAKIN:-OFF} \ -DPY_VERSION=${PY_VERSION:-2.7} } @@ -201,6 +223,19 @@ EOF make install -j `nproc` } +function build_mac() { + mkdir -p ${PADDLE_ROOT}/build + cd ${PADDLE_ROOT}/build + cat <