提交 7b464d68 编写于 作者: C chengduoZH

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into speed_up_lod_tensor_to_array

......@@ -69,6 +69,7 @@ option(WITH_ANAKIN "Compile with Anakin library" OFF)
option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE})
option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF)
option(WITH_INFERENCE "Compile fluid inference library" ON)
option(WITH_INFERENCE_API_TEST "Test fluid inference high-level api interface" OFF)
option(WITH_SYSTEM_BLAS "Use system blas library" OFF)
option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION})
......
# PaddlePaddle Releasing Process
PaddlePaddle manages its branches using "git-flow branching model", and [Semantic Versioning](http://semver.org/) as it's version number semantics.
PaddlePaddle manages its branches using Trunk Based Development, and [Semantic Versioning](http://semver.org/) as it's version number semantics.
Each time we release a new PaddlePaddle version, we should follow the below steps:
......
......@@ -28,10 +28,20 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
if(WITH_GPU)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass)
endif()
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto)
if(WITH_GPU)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass)
else()
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto)
endif()
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context)
......
......@@ -32,6 +32,10 @@ struct ComputationOpHandle : public OpHandleBase {
std::string Name() const override;
const Scope *GetScope() const { return scope_; }
const platform::Place &GetPlace() const { return place_; }
protected:
void RunImpl() override;
......
......@@ -127,6 +127,9 @@ static const char kLocalScopes[] = "local_scopes";
static const char kStrategy[] = "strategy";
void MultiDevSSAGraphBuilder::Init() const {
all_vars_.clear();
balance_vars_.clear();
loss_var_name_ = Get<const std::string>(kLossVarName);
places_ = Get<const std::vector<platform::Place>>(kPlaces);
local_scopes_ = Get<const std::vector<Scope *>>(kLocalScopes);
......
......@@ -40,12 +40,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
size_t device_id) const;
void Init() const;
private:
mutable std::string loss_var_name_;
mutable std::vector<platform::Place> places_;
mutable std::vector<Scope *> local_scopes_;
mutable std::unordered_set<std::string> grad_names_;
#ifdef PADDLE_WITH_CUDA
mutable platform::NCCLContextMap *nccl_ctxs_;
#endif
......@@ -95,13 +89,17 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
size_t GetAppropriateDeviceID(
const std::vector<std::string> &var_names) const;
private:
void SetCommunicationContext(OpHandleBase *op_handle,
const platform::Place &p) const;
mutable std::string loss_var_name_;
mutable std::vector<platform::Place> places_;
mutable std::vector<Scope *> local_scopes_;
mutable std::unordered_set<std::string> grad_names_;
mutable BuildStrategy strategy_;
mutable std::unordered_map<std::string, VarDesc *> all_vars_;
mutable std::vector<int64_t> balance_vars_;
void SetCommunicationContext(OpHandleBase *op_handle,
const platform::Place &p) const;
};
} // namespace details
} // namespace framework
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
namespace paddle {
namespace framework {
namespace details {
using ReferenceCountMap = std::unordered_map<std::string, int>;
using AtomicReferenceCountMap =
std::unordered_map<std::string, std::atomic<int>>;
using DeviceReferenceCountMap =
std::unordered_map<int, std::unique_ptr<ReferenceCountMap>>;
using AtomicDeviceReferenceCountMap =
std::unordered_map<int, std::unique_ptr<AtomicReferenceCountMap>>;
using DeviceGarbageCollectorMap =
std::unordered_map<int,
std::unique_ptr<GarbageCollector<framework::Tensor>>>;
class ReferenceCountOpHandle : public OpHandleBase {
public:
ReferenceCountOpHandle(ir::Node *node, const Scope *scope,
const platform::CUDAPlace &place,
const std::vector<std::string> &var_names,
GarbageCollector<Tensor> *gc,
AtomicReferenceCountMap *ref_cnts)
: OpHandleBase(node),
scope_(scope),
var_names_(var_names),
gc_(gc),
ref_cnts_(ref_cnts) {
dev_ctx_ = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place));
if (IsStreamGarabageCollector()) {
PADDLE_ENFORCE(cudaSetDevice(place.device));
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
}
}
~ReferenceCountOpHandle() {
if (IsStreamGarabageCollector()) {
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace());
PADDLE_ENFORCE(cudaSetDevice(gpu_place.device));
PADDLE_ENFORCE(cudaEventDestroy(event_));
}
}
std::string Name() const override { return "reference_count"; }
protected:
void RunImpl() override {
auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
std::vector<LoDTensor *> tensors;
for (auto &name : var_names_) {
auto it = ref_cnts_->find(name);
if (it == ref_cnts_->end()) continue;
auto *var = exec_scope->FindVar(name);
if (var == nullptr || !var->IsType<LoDTensor>()) continue;
if (it->second.fetch_sub(1) <= 1) {
tensors.emplace_back(var->GetMutable<LoDTensor>());
}
}
if (!tensors.empty()) {
ClearTensors(tensors);
}
}
private:
void ClearTensors(const std::vector<LoDTensor *> &tensors) {
auto *gc = dynamic_cast<StreamGarbageCollector<Tensor> *>(gc_);
if (gc != nullptr) {
auto compute_stream = dev_ctx_->stream();
auto callback_stream = gc->stream();
auto callback_func = [=]() {
PADDLE_ENFORCE(cudaEventRecord(event_, compute_stream));
PADDLE_ENFORCE(cudaStreamWaitEvent(callback_stream, event_, 0));
};
gc_->Add(tensors, callback_func);
} else {
gc_->Add(tensors);
}
}
bool IsStreamGarabageCollector() const {
return dynamic_cast<const StreamGarbageCollector<Tensor> *>(gc_) != nullptr;
}
const Scope *scope_;
platform::CUDADeviceContext *dev_ctx_;
std::vector<std::string> var_names_;
GarbageCollector<Tensor> *gc_; // not own
AtomicReferenceCountMap *ref_cnts_; // not own
cudaEvent_t event_;
};
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/reference_count_pass.h"
namespace paddle {
namespace framework {
namespace details {
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
auto &ref_cnts = Get<DeviceReferenceCountMap>(kGlobalReferenceCount);
auto &cur_ref_cnts = Get<AtomicDeviceReferenceCountMap>(kCurReferenceCount);
auto &gcs = Get<DeviceGarbageCollectorMap>(kGarbageCollector);
// It is not easy to find the right reference counts of varaibles in graph
// Step 1: Find all variables in computation ops
// Step 2: Find all variables in non-computation ops which refers to variables
// in computation ops
std::unordered_set<std::string> names;
auto get_ref_cnts_from_compute_op = [&](
const std::unique_ptr<OpHandleBase> &op,
const std::vector<VarHandleBase *> &vars) {
std::vector<std::string> var_names_in_op;
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get());
if (compute_op == nullptr ||
!platform::is_gpu_place(compute_op->GetPlace()))
return var_names_in_op;
auto place = boost::get<platform::CUDAPlace>(compute_op->GetPlace());
for (VarHandleBase *var_handle_base : vars) {
auto *var_handle = dynamic_cast<VarHandle *>(var_handle_base);
if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue;
if (!platform::is_gpu_place(var_handle->place_) ||
boost::get<platform::CUDAPlace>(var_handle->place_) != place)
continue;
VarDesc *var_desc = var_handle->Node()->Var();
auto var_name = var_handle->Node()->Name();
// This is wierd but there is really some variables without var_desc
// in computation_op
if (var_desc == nullptr) {
if (compute_op->Node()->Op()->Block()->FindVar(var_name) == nullptr)
continue;
} else {
if (var_desc->Persistable() ||
var_desc->Proto()->type().type() != proto::VarType::LOD_TENSOR)
continue;
}
// compute op only runs in one device
if (ref_cnts[place.device]->count(var_name))
++(*ref_cnts[place.device])[var_name];
else
(*ref_cnts[place.device])[var_name] = 1;
names.insert(var_name);
var_names_in_op.push_back(var_name);
}
return var_names_in_op;
};
auto update_ref_cnts_from_non_compute_op = [&](
const std::unique_ptr<OpHandleBase> &op,
const std::vector<VarHandleBase *> &vars) {
if (dynamic_cast<ComputationOpHandle *>(op.get()) != nullptr) return;
for (VarHandleBase *var_handle_base : vars) {
auto *var_handle = dynamic_cast<VarHandle *>(var_handle_base);
if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue;
auto var_name = var_handle->Node()->Name();
auto var_place = var_handle->place_;
if (!platform::is_gpu_place(var_place)) continue;
auto place = boost::get<platform::CUDAPlace>(var_place);
if (names.count(var_name) == 0) continue;
if (ref_cnts.count(place.device) &&
ref_cnts[place.device]->count(var_name)) {
++(*ref_cnts[place.device])[var_name];
}
}
};
std::unordered_map<OpHandleBase *, ReferenceCountOpHandle *>
compute_ref_cnt_map;
auto &all_ops = graph->Get<GraphOps>(kGraphOps);
for (auto &op : all_ops) {
auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs());
auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs());
if (in_var_names.empty() && out_var_names.empty()) continue;
in_var_names.insert(in_var_names.end(), out_var_names.begin(),
out_var_names.end());
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get());
auto place = boost::get<platform::CUDAPlace>(compute_op->GetPlace());
ir::Node *ref_cnt_node =
graph->CreateEmptyNode("reference_count", ir::Node::Type::kOperation);
auto *ref_cnt_handle = new ReferenceCountOpHandle(
ref_cnt_node, compute_op->GetScope(), place, in_var_names,
gcs[place.device].get(), cur_ref_cnts[place.device].get());
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
compute_op->AddOutput(dep_var);
ref_cnt_handle->AddInput(dep_var);
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
compute_ref_cnt_map[compute_op] = ref_cnt_handle;
}
for (auto &op : all_ops) {
update_ref_cnts_from_non_compute_op(op, op->Inputs());
update_ref_cnts_from_non_compute_op(op, op->Outputs());
}
std::vector<std::unique_ptr<OpHandleBase>> new_all_ops;
new_all_ops.reserve(compute_ref_cnt_map.size() + all_ops.size());
for (auto &op : all_ops) {
new_all_ops.emplace_back(std::move(op));
auto it = compute_ref_cnt_map.find(new_all_ops.back().get());
if (it != compute_ref_cnt_map.end()) {
new_all_ops.emplace_back(it->second);
}
}
all_ops.swap(new_all_ops);
return graph;
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(reference_count_pass,
paddle::framework::details::ReferenceCountPass)
.RequirePassAttr(paddle::framework::details::kGlobalReferenceCount)
.RequirePassAttr(paddle::framework::details::kCurReferenceCount)
.RequirePassAttr(paddle::framework::details::kGarbageCollector);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/details/reference_count_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
constexpr char kGlobalReferenceCount[] = "reference_count";
constexpr char kCurReferenceCount[] = "current_reference_count";
constexpr char kGarbageCollector[] = "garbage_collector";
class ReferenceCountPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -18,6 +18,9 @@
#include <vector>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/details/reference_count_op_handle.h"
#endif
namespace paddle {
namespace framework {
......@@ -65,12 +68,28 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
platform::RecordEvent e("ScopeBufferedSSAGraphExecutorAfterRun", nullptr);
drop_scope_counter_ += 1;
#ifdef PADDLE_WITH_CUDA
const std::string gc_name = "garbage_collector";
DeviceGarbageCollectorMap *gc =
Graph().Has(gc_name) ? &(Graph().Get<DeviceGarbageCollectorMap>(gc_name))
: nullptr;
#endif
if (!fetch_tensors.empty() ||
drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
drop_scope_counter_ = 0;
// Wait All computational streams
for (auto p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
#ifdef PADDLE_WITH_CUDA
if (gc != nullptr && platform::is_gpu_place(p)) {
auto gpu_place = boost::get<platform::CUDAPlace>(p);
auto &gc_at_place = gc->at(gpu_place.device);
gc_at_place->Wait();
gc_at_place->Reset();
}
#endif
}
for (auto &scope : local_scopes_) {
auto &local_scope =
......
......@@ -37,7 +37,11 @@ int kProgramId = -1;
ExecutorPrepareContext::ExecutorPrepareContext(
const framework::ProgramDesc& prog, size_t block_id)
: prog_(prog), block_id_(block_id) {}
: prog_(prog), block_id_(block_id) {
if (GetEagerDeletionThreshold() >= 0) {
ref_cnts_ = GetNonPersistableReferenceCount<int>(prog_, block_id_);
}
}
ExecutorPrepareContext::~ExecutorPrepareContext() {
VLOG(5) << "destroy ExecutorPrepareContext";
......@@ -329,15 +333,81 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
CreateVariables(ctx->prog_, local_scope, ctx->block_id_);
}
int64_t max_memory_size = GetEagerDeletionThreshold();
std::unique_ptr<GarbageCollector<Tensor>> gc;
if (max_memory_size >= 0) {
ctx->ResetReferenceCount();
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) {
gc.reset(new DefaultStreamGarbageCollector<Tensor>(
boost::get<platform::CUDAPlace>(place_), max_memory_size));
} else {
#endif
gc.reset(new CPUGarbageCollector<Tensor>(
boost::get<platform::CPUPlace>(place_), max_memory_size));
#ifdef PADDLE_WITH_CUDA
}
#endif
}
for (auto& op : ctx->ops_) {
op->Run(*local_scope, place_);
if (gc != nullptr) {
std::vector<std::string> erase_vars;
for (auto& input : op->Inputs()) {
for (auto& input_name : input.second) {
auto it = ctx->cur_ref_cnts_.find(input_name);
if (it == ctx->cur_ref_cnts_.end()) continue;
if (it->second == 1) { // should delete it
erase_vars.emplace_back(input_name);
ctx->cur_ref_cnts_.erase(input_name);
} else {
--(it->second);
}
}
}
for (auto& output : op->Outputs()) {
for (auto& output_name : output.second) {
auto it = ctx->cur_ref_cnts_.find(output_name);
if (it == ctx->cur_ref_cnts_.end()) continue;
if (it->second == 1) {
erase_vars.emplace_back(output_name);
ctx->cur_ref_cnts_.erase(output_name);
} else {
--(it->second);
}
}
}
if (!erase_vars.empty()) {
std::vector<framework::LoDTensor*> erase_tensors;
for (auto& name : erase_vars) {
auto* var = local_scope->FindVar(name);
if (var == nullptr) continue;
if (var->IsType<framework::LoDTensor>()) {
auto* tensor = var->GetMutable<framework::LoDTensor>();
erase_tensors.push_back(tensor);
}
}
if (!erase_tensors.empty()) gc->Add(erase_tensors);
}
}
if (FLAGS_benchmark) {
VLOG(2) << "Memory used after operator " + op->Type() + " running: "
<< memory::memory_usage(place_);
}
}
platform::DeviceContextPool::Instance().Get(place_)->Wait();
if (gc != nullptr) {
gc->Wait();
} else {
platform::DeviceContextPool::Instance().Get(place_)->Wait();
}
if (local_scope != scope) {
scope->DeleteScope(local_scope);
} else {
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
......@@ -27,13 +28,58 @@ namespace paddle {
namespace framework {
extern void InitializeVariable(Variable* var, proto::VarType::Type var_type);
template <typename T>
std::unordered_map<std::string, T> GetNonPersistableReferenceCount(
const ProgramDesc& prog, size_t block_id) {
auto& block = prog.Block(block_id);
std::unordered_set<std::string> ignored_vars;
std::unordered_map<std::string, T> ref_cnts;
for (auto var_desc : block.AllVars()) {
auto type = var_desc->Proto()->type().type();
if (type != proto::VarType::LOD_TENSOR || var_desc->Persistable()) {
ignored_vars.insert(var_desc->Name()); // ignore persistable vars
}
}
for (auto op_desc : block.AllOps()) {
for (auto& input : op_desc->Inputs()) {
for (auto& input_name : input.second) {
if (!ignored_vars.count(input_name)) {
if (ref_cnts.count(input_name))
++ref_cnts[input_name];
else
ref_cnts[input_name] = 1;
}
}
}
for (auto& output : op_desc->Outputs()) {
for (auto output_name : output.second) {
if (!ignored_vars.count(output_name)) {
if (ref_cnts.count(output_name))
++ref_cnts[output_name];
else
ref_cnts[output_name] = 1;
}
}
}
}
return ref_cnts;
}
struct ExecutorPrepareContext {
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id);
~ExecutorPrepareContext();
void ResetReferenceCount() { cur_ref_cnts_ = ref_cnts_; }
const framework::ProgramDesc& prog_;
size_t block_id_;
std::vector<std::unique_ptr<OperatorBase>> ops_;
std::unordered_map<std::string, int> ref_cnts_;
std::unordered_map<std::string, int> cur_ref_cnts_;
};
class Executor {
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <deque>
#include <functional>
#include <memory>
#include <mutex> // NOLINT
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace framework {
// T should have memory_size() and clear() method
template <typename T>
class GarbageCollector {
public:
GarbageCollector(const platform::Place &place, size_t max_memory_size)
: max_memory_size_(std::max(max_memory_size, static_cast<size_t>(1))) {
garbages_.reset(new std::deque<T *>());
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place);
}
virtual ~GarbageCollector() {}
void Reset() {
std::lock_guard<std::mutex> guard(mutex_);
garbages_.reset(new std::deque<T *>());
cur_memory_size_ = 0;
}
template <typename Container>
void Add(const Container &objs) {
Add(objs, []() {});
}
template <typename Container, typename Callback>
void Add(const Container &objs, Callback &&callback) {
std::shared_ptr<std::deque<T *>> clear_deque;
{
std::lock_guard<std::mutex> guard(mutex_);
for (auto *obj : objs) {
garbages_->push_back(obj);
cur_memory_size_ += obj->memory_size();
}
if (cur_memory_size_ >= max_memory_size_) {
cur_memory_size_ = 0;
clear_deque = garbages_;
garbages_.reset(new std::deque<T *>());
}
}
if (clear_deque != nullptr) {
callback();
ClearCallback([=]() {
for (auto *obj : *clear_deque) obj->clear();
});
}
}
virtual void Wait() const {}
protected:
virtual void ClearCallback(const std::function<void()> &callback) = 0;
platform::DeviceContext *dev_ctx_;
std::shared_ptr<std::deque<T *>> garbages_;
mutable std::mutex mutex_;
const size_t max_memory_size_;
size_t cur_memory_size_ = 0;
};
template <typename T>
class CPUGarbageCollector : public GarbageCollector<T> {
public:
CPUGarbageCollector(const platform::CPUPlace &place, size_t max_memory_size)
: GarbageCollector<T>(place, max_memory_size) {}
protected:
void ClearCallback(const std::function<void()> &callback) override {
callback();
}
};
#ifdef PADDLE_WITH_CUDA
template <typename T>
class DefaultStreamGarbageCollector : public GarbageCollector<T> {
public:
DefaultStreamGarbageCollector(const platform::CUDAPlace &place,
size_t max_memory_size)
: GarbageCollector<T>(place, max_memory_size) {}
cudaStream_t stream() const {
return static_cast<const platform::CUDADeviceContext *>(this->dev_ctx_)
->stream();
}
void Wait() const override {
this->dev_ctx_->Wait();
static_cast<const platform::CUDADeviceContext *>(this->dev_ctx_)
->WaitStreamCallback();
}
protected:
void ClearCallback(const std::function<void()> &callback) override {
static_cast<platform::CUDADeviceContext *>(this->dev_ctx_)
->AddStreamCallback(callback);
}
};
template <typename T>
class StreamGarbageCollector : public GarbageCollector<T> {
public:
StreamGarbageCollector(const platform::CUDAPlace &place,
size_t max_memory_size)
: GarbageCollector<T>(place, max_memory_size) {
PADDLE_ENFORCE(cudaSetDevice(place.device));
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
callback_manager_.reset(new platform::StreamCallbackManager(stream_));
}
~StreamGarbageCollector() {
auto place = boost::get<platform::CUDAPlace>(this->dev_ctx_->GetPlace());
PADDLE_ENFORCE(cudaSetDevice(place.device));
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
}
void Wait() const override {
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
std::lock_guard<std::mutex> guard(this->mutex_);
callback_manager_->Wait();
}
cudaStream_t stream() const { return stream_; }
protected:
void ClearCallback(const std::function<void()> &callback) override {
std::lock_guard<std::mutex> guard(this->mutex_);
callback_manager_->AddCallback(callback);
}
private:
cudaStream_t stream_;
std::unique_ptr<platform::StreamCallbackManager> callback_manager_;
};
#endif
} // namespace framework
} // namespace paddle
......@@ -94,6 +94,14 @@ class Graph {
};
}
template <typename AttrType>
void SetNotOwned(const std::string &attr_name, AttrType *attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the graph",
attr_name);
attrs_[attr_name] = attr;
attr_dels_[attr_name] = []() {};
}
const std::unordered_set<ir::Node *> &Nodes() const { return node_set_; }
// Create a normal variable with non-null VarDesc.
......
......@@ -188,6 +188,30 @@ ParallelExecutor::ParallelExecutor(
main_program, member_->places_, loss_var_name, params,
member_->local_scopes_, member_->use_cuda_, build_strategy,
member_->nccl_ctxs_.get());
auto max_memory_size = GetEagerDeletionThreshold();
if (max_memory_size >= 0) {
for (auto &place : member_->places_) {
if (!platform::is_gpu_place(place)) continue;
auto gpu_place = boost::get<platform::CUDAPlace>(place);
if (gcs_[gpu_place.device] == nullptr) {
ref_cnts_[gpu_place.device].reset(new details::ReferenceCountMap());
cur_ref_cnts_[gpu_place.device].reset(
new details::AtomicReferenceCountMap());
gcs_[gpu_place.device].reset(
new StreamGarbageCollector<Tensor>(gpu_place, max_memory_size));
}
}
if (!gcs_.empty()) {
auto ref_cnt_pass =
ir::PassRegistry::Instance().Get("reference_count_pass");
ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount, &ref_cnts_);
ref_cnt_pass->SetNotOwned(details::kCurReferenceCount, &cur_ref_cnts_);
ref_cnt_pass->SetNotOwned(details::kGarbageCollector, &gcs_);
graph = ref_cnt_pass->Apply(std::move(graph));
graph->SetNotOwned("garbage_collector", &gcs_);
}
}
#else
std::unique_ptr<ir::Graph> graph = ApplyParallelExecutorPass(
main_program, member_->places_, loss_var_name, params,
......@@ -209,30 +233,9 @@ ParallelExecutor::ParallelExecutor(
void ParallelExecutor::BCastParamsToDevices(
const std::unordered_set<std::string> &vars) const {
// the initializing bcast, all vars would be bcast from device(0),
// otherwise
// bcast from the specified device.
bool initializing = member_->executor_ ? false : true;
// the initializing bcast, all vars would be bcast from device(0).
for (auto &var : vars) {
int var_dev_id = -1;
if (member_->executor_) {
auto &sharded_var_device =
member_->executor_->Graph().Get<details::ShardedVarDevice>(
details::kShardedVarDevice);
if (sharded_var_device.find(var) != sharded_var_device.end()) {
var_dev_id = sharded_var_device.at(var);
}
}
if (!initializing && var_dev_id == -1) continue;
framework::Variable *main_var = nullptr;
if (initializing) {
main_var = member_->local_scopes_[0]->FindVar(var);
} else {
main_var = member_->local_scopes_[var_dev_id]->FindVar(var);
}
framework::Variable *main_var = member_->local_scopes_[0]->FindVar(var);
if (main_var == nullptr || !main_var->IsType<LoDTensor>()) {
continue;
}
......@@ -248,8 +251,7 @@ void ParallelExecutor::BCastParamsToDevices(
auto place = member_->places_[i];
void *buffer;
if ((initializing && i == 0) ||
(!initializing && static_cast<int>(i) == var_dev_id)) {
if (i == 0) {
buffer = const_cast<void *>(main_tensor.data<void>());
} else {
auto local_scope = member_->local_scopes_[i];
......@@ -266,29 +268,18 @@ void ParallelExecutor::BCastParamsToDevices(
platform::NCCLGroupGuard guard;
for (size_t i = 0; i < member_->places_.size(); ++i) {
auto &nccl_ctx = member_->nccl_ctxs_->at(member_->places_[i]);
if (initializing) {
platform::dynload::ncclBcast(buffers[i], numel, data_type, 0,
nccl_ctx.comm_, nccl_ctx.stream());
} else {
if (var_dev_id >= 0) {
platform::dynload::ncclBcast(buffers[i], numel, data_type,
var_dev_id, nccl_ctx.comm_,
nccl_ctx.stream());
}
}
platform::dynload::ncclBcast(buffers[i], numel, data_type, 0,
nccl_ctx.comm_, nccl_ctx.stream());
}
member_->nccl_ctxs_->WaitAll();
}
#else
PADDLE_THROW("Not compiled with CUDA");
#endif
} else {
platform::CPUPlace cpu;
for (size_t i = 0; i < member_->places_.size(); ++i) {
if ((initializing && i == 0) ||
(!initializing && static_cast<int>(i) == var_dev_id))
continue;
if (i == 0) continue;
auto local_scope = member_->local_scopes_[i];
auto *t = local_scope->Var(var)->GetMutable<LoDTensor>();
......@@ -310,6 +301,11 @@ void ParallelExecutor::BCastParamsToDevices(
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name) {
platform::RecordBlock b(0);
#ifdef PADDLE_WITH_CUDA
if (!gcs_.empty()) {
ResetReferenceCount();
}
#endif
auto fetch_data = member_->executor_->Run(fetch_tensors);
*member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
fetch_data;
......@@ -367,3 +363,6 @@ USE_PASS(graph_viz_pass);
USE_PASS(multi_devices_pass);
USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass);
#ifdef PADDLE_WITH_CUDA
USE_PASS(reference_count_pass);
#endif
......@@ -15,7 +15,9 @@ limitations under the License. */
#pragma once
#include <paddle/fluid/framework/details/build_strategy.h>
#include <atomic>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/execution_strategy.h"
......@@ -27,6 +29,10 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/details/reference_count_pass.h"
#endif
namespace paddle {
namespace framework {
......@@ -66,10 +72,27 @@ class ParallelExecutor {
void Run(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name);
private:
void BCastParamsToDevices(const std::unordered_set<std::string> &vars) const;
private:
ParallelExecutorPrivate *member_;
#ifdef PADDLE_WITH_CUDA
// ref_cnts_ is only initialized when ParallelExecutor constructs, and then
// keeps unchanged
// Before each iteration, cur_ref_cnts_ is reset to ref_cnts_
details::DeviceReferenceCountMap ref_cnts_;
details::AtomicDeviceReferenceCountMap cur_ref_cnts_;
details::DeviceGarbageCollectorMap gcs_;
void ResetReferenceCount() {
for (auto &pair1 : ref_cnts_) {
for (auto &pair2 : *(pair1.second)) {
(*(cur_ref_cnts_[pair1.first]))[pair2.first] = pair2.second;
}
}
}
#endif
};
} // namespace framework
......
......@@ -31,9 +31,21 @@ DEFINE_bool(
"Delete local scope eagerly. It will reduce GPU memory usage but "
"slow down the destruction of variables.(around 1% performance harm)");
DEFINE_double(
eager_delete_tensor_gb, -1.0,
"Memory size threshold (GB) when the garbage collector clear tensors."
"Disabled when this value is less than 0");
namespace paddle {
namespace framework {
int64_t GetEagerDeletionThreshold() {
return FLAGS_eager_delete_tensor_gb < 0
? -1
: static_cast<int64_t>(FLAGS_eager_delete_tensor_gb *
(static_cast<int64_t>(1) << 30));
}
Scope::~Scope() { DropKids(); }
Scope& Scope::NewScope() const {
......
......@@ -26,6 +26,8 @@ limitations under the License. */
namespace paddle {
namespace framework {
int64_t GetEagerDeletionThreshold();
class Scope;
/**
......
......@@ -151,6 +151,8 @@ class Tensor {
void set_layout(const DataLayout layout) { layout_ = layout; }
void clear() { holder_ = nullptr; }
private:
/**
* @note Placeholder hides type T, so it doesn't appear as a template
......
......@@ -17,9 +17,7 @@ get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES)
# paddle_fluid_origin exclude inference api interface
cc_library(paddle_fluid_origin DEPS ${fluid_modules} paddle_fluid_api)
#if(APPLE)
add_subdirectory(api)
#endif()
add_subdirectory(api)
# Create static library
cc_library(paddle_fluid DEPS ${fluid_modules} paddle_fluid_api paddle_inference_api analysis_predictor)
......@@ -57,5 +55,7 @@ endif()
if(WITH_TESTING)
# tests/book depends the models that generated by python/paddle/fluid/tests/book
add_subdirectory(tests/book)
add_subdirectory(tests/api)
if(WITH_INFERENCE_API_TEST)
add_subdirectory(tests/api)
endif()
endif()
......@@ -69,8 +69,9 @@ class DfgPassManagerImpl final : public DfgPassManager {
if (FLAGS_IA_enable_tensorrt_subgraph_engine) {
auto trt_teller = [&](const Node* node) {
std::unordered_set<std::string> teller_set(
{"elementwise_add", "mul", "conv2d", "pool2d", "relu", "softmax",
"depthwise_conv2d", "batch_norm", "concat"});
{"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid",
"depthwise_conv2d", "batch_norm", "concat", "tanh",
"elementwise_add", "dropout"});
if (!node->IsFunction()) return false;
const auto* func = static_cast<const Function*>(node);
......
......@@ -69,25 +69,4 @@ if (WITH_ANAKIN AND WITH_MKL) # only needed in CI
endfunction()
anakin_target(inference_anakin_api)
anakin_target(inference_anakin_api_shared)
if (WITH_TESTING)
# TODO(luotao): ANAKIN_MODLE_URL etc will move to demo ci later.
set(INFERENCE_URL "http://paddle-inference-dist.bj.bcebos.com")
set(ANAKIN_RNN_MODLE_URL "${INFERENCE_URL}/anakin_test%2Fditu_rnn.anakin2.model.bin")
set(ANAKIN_RNN_DATA_URL "${INFERENCE_URL}/anakin_test%2Fditu_rnn_data.txt")
execute_process(COMMAND bash -c "mkdir -p ${ANAKIN_SOURCE_DIR}")
execute_process(COMMAND bash -c "cd ${ANAKIN_SOURCE_DIR}; wget -q --no-check-certificate ${ANAKIN_RNN_MODLE_URL} -N")
execute_process(COMMAND bash -c "cd ${ANAKIN_SOURCE_DIR}; wget -q --no-check-certificate ${ANAKIN_RNN_DATA_URL} -N")
if(WITH_GPU)
set(anakin_test_extra_deps dynload_cuda)
set(ANAKIN_MODLE_URL "${INFERENCE_URL}/mobilenet_v2.anakin.bin")
execute_process(COMMAND bash -c "cd ${ANAKIN_SOURCE_DIR}; wget -q --no-check-certificate ${ANAKIN_MODLE_URL} -N")
cc_test(api_anakin_engine_tester SRCS api_anakin_engine_tester.cc
ARGS --model=${ANAKIN_SOURCE_DIR}/mobilenet_v2.anakin.bin
DEPS inference_anakin_api_shared ${anakin_test_extra_deps} SERIAL)
endif()
cc_test(api_anakin_engine_rnn_tester SRCS api_anakin_engine_rnn_tester.cc
ARGS --model=${ANAKIN_SOURCE_DIR}/anakin_test%2Fditu_rnn.anakin2.model.bin
--datapath=${ANAKIN_SOURCE_DIR}/anakin_test%2Fditu_rnn_data.txt
DEPS inference_anakin_api_shared ${anakin_test_extra_deps} SERIAL)
endif(WITH_TESTING)
endif()
......@@ -153,11 +153,21 @@ CreatePaddlePredictor<TensorRTConfig, PaddleEngineKind::kAutoMixedTensorRT>(
} // namespace paddle
USE_TRT_CONVERTER(elementwise_add_weight);
USE_TRT_CONVERTER(elementwise_add_tensor);
USE_TRT_CONVERTER(elementwise_sub_tensor);
USE_TRT_CONVERTER(elementwise_div_tensor);
USE_TRT_CONVERTER(elementwise_mul_tensor);
USE_TRT_CONVERTER(elementwise_max_tensor);
USE_TRT_CONVERTER(elementwise_min_tensor);
USE_TRT_CONVERTER(elementwise_pow_tensor);
USE_TRT_CONVERTER(mul);
USE_TRT_CONVERTER(conv2d);
USE_TRT_CONVERTER(relu);
USE_TRT_CONVERTER(sigmoid);
USE_TRT_CONVERTER(tanh);
USE_TRT_CONVERTER(fc);
USE_TRT_CONVERTER(pool2d);
USE_TRT_CONVERTER(softmax);
USE_TRT_CONVERTER(batch_norm);
USE_TRT_CONVERTER(concat);
USE_TRT_CONVERTER(dropout);
# Add TRT tests
nv_library(tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
DEPS tensorrt_engine operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
......@@ -24,6 +24,8 @@ nv_test(test_trt_softmax_op SRCS test_softmax_op.cc softmax_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine softmax_op SERIAL)
nv_test(test_trt_batch_norm_op SRCS test_batch_norm_op.cc batch_norm_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine batch_norm_op SERIAL)
nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine concat_op SERIAL)
nv_test(test_trt_dropout_op SRCS test_dropout_op.cc dropout_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine dropout_op SERIAL)
......@@ -19,23 +19,31 @@ namespace paddle {
namespace inference {
namespace tensorrt {
class ReluOpConverter : public OpConverter {
class ActivationOpConverter : public OpConverter {
public:
ReluOpConverter() {}
ActivationOpConverter() {}
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
// Here the two nullptr looks strange, that's because the
// framework::OpDesc's constructor is strange.
framework::OpDesc op_desc(op, nullptr);
LOG(INFO) << "convert a fluid relu op to tensorrt activation layer whose "
"type is Relu";
LOG(INFO)
<< "convert a fluid Activation op to tensorrt activation layer whose "
"type is "
<< op_type_;
const nvinfer1::ITensor* input_tensor =
engine_->GetITensor(op_desc.Input("X")[0]);
auto op_pair = ops.find(op_type_);
if (op_pair == ops.end()) {
PADDLE_THROW("Wrong activation op type!");
}
nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *const_cast<nvinfer1::ITensor*>(input_tensor),
nvinfer1::ActivationType::kRELU);
op_pair->second);
auto output_name = op_desc.Output("Out")[0];
layer->setName(("relu (Output: " + output_name + ")").c_str());
layer->setName((op_type_ + " (Output: " + output_name + ")").c_str());
layer->getOutput(0)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode) { // the test framework can not determine which is the
......@@ -43,6 +51,32 @@ class ReluOpConverter : public OpConverter {
engine_->DeclareOutput(output_name);
}
}
protected:
std::string op_type_;
static const std::unordered_map<std::string, nvinfer1::ActivationType> ops;
};
const std::unordered_map<std::string, nvinfer1::ActivationType>
ActivationOpConverter::ops = {
{"relu", nvinfer1::ActivationType::kRELU},
{"sigmoid", nvinfer1::ActivationType::kSIGMOID},
{"tanh", nvinfer1::ActivationType::kTANH},
};
class ReluOpConverter : public ActivationOpConverter {
public:
ReluOpConverter() { op_type_ = "relu"; }
};
class SigmoidOpConverter : public ActivationOpConverter {
public:
SigmoidOpConverter() { op_type_ = "sigmoid"; }
};
class TanhOpConverter : public ActivationOpConverter {
public:
TanhOpConverter() { op_type_ = "tanh"; }
};
} // namespace tensorrt
......@@ -50,3 +84,5 @@ class ReluOpConverter : public OpConverter {
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(relu, ReluOpConverter);
REGISTER_TRT_OP_CONVERTER(sigmoid, SigmoidOpConverter);
REGISTER_TRT_OP_CONVERTER(tanh, TanhOpConverter);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* DropoutOp. This Layer doesn't has weights.
*/
class DropoutOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(4) << "convert a fluid dropout op to tensorrt dropout layer";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
float dropout_prob = boost::get<float>(op_desc.GetAttr("dropout_prob"));
platform::CPUPlace cpu_place;
std::unique_ptr<framework::LoDTensor> weight_tensor(
new framework::LoDTensor());
weight_tensor->Resize(framework::make_ddim({1}));
auto* weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
weight_data[0] = 1 - dropout_prob;
TensorRTEngine::Weight scale_weights{
nvinfer1::DataType::kFLOAT, static_cast<void*>(weight_data),
weight_tensor->memory_size() / sizeof(float)};
TensorRTEngine::Weight shift_weights{nvinfer1::DataType::kFLOAT, nullptr,
0};
TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr,
0};
auto* layer = TRT_ENGINE_ADD_LAYER(
engine_, Scale, *const_cast<nvinfer1::ITensor*>(input1),
nvinfer1::ScaleMode::kUNIFORM, shift_weights.get(), scale_weights.get(),
power_weights.get());
engine_->weight_map[op_desc.Output("Out").front() + "_dropout"] =
std::move(weight_tensor);
auto output_name = op_desc.Output("Out")[0];
layer->setName(("dropout (Output: " + output_name + ")").c_str());
engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode) {
engine_->DeclareOutput(output_name);
}
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
USE_OP(dropout);
REGISTER_TRT_OP_CONVERTER(dropout, DropoutOpConverter);
......@@ -20,18 +20,18 @@ namespace paddle {
namespace inference {
namespace tensorrt {
TEST(ReluOpConverter, main) {
void test_activation(std::string act_type) {
framework::Scope scope;
std::unordered_set<std::string> parameters;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("relu-X", nvinfer1::Dims2(10, 6));
validator.DeclOutputVar("relu-Out", nvinfer1::Dims2(10, 6));
validator.DeclInputVar("act-X", nvinfer1::Dims2(10, 6));
validator.DeclOutputVar("act-Out", nvinfer1::Dims2(10, 6));
// Prepare Op description
framework::OpDesc desc;
desc.SetType("relu");
desc.SetInput("X", {"relu-X"});
desc.SetOutput("Out", {"relu-Out"});
desc.SetType(act_type);
desc.SetInput("X", {"act-X"});
desc.SetOutput("Out", {"act-Out"});
LOG(INFO) << "set OP";
validator.SetOp(*desc.Proto());
......@@ -40,8 +40,16 @@ TEST(ReluOpConverter, main) {
validator.Execute(5);
}
TEST(ReluOpConverter, main) { test_activation("relu"); }
TEST(SigmoidOpConverter, main) { test_activation("sigmoid"); }
TEST(TanhOpConverter, main) { test_activation("tanh"); }
} // namespace tensorrt
} // namespace inference
} // namespace paddle
USE_OP(relu);
USE_OP(sigmoid);
USE_OP(tanh);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
namespace paddle {
namespace inference {
namespace tensorrt {
TEST(DropoutOpConverter, main) {
framework::Scope scope;
std::unordered_set<std::string> parameters;
TRTConvertValidation validator(8, parameters, scope, 1000);
std::vector<int> tensor_shape{8, 10};
validator.DeclInputVar("dropout-X", tensor_shape,
nvinfer1::DimsCHW(10, 1, 1));
validator.DeclOutputVar("dropout-Out", nvinfer1::DimsCHW(10, 1, 1));
validator.DeclOutputVar("mask-Out", nvinfer1::DimsCHW(10, 1, 1));
// Prepare Op description
framework::OpDesc desc;
int is_test = 1;
float dropout_prob = 0.4;
desc.SetType("dropout");
desc.SetInput("X", {"dropout-X"});
desc.SetOutput("Mask", {"mask-Out"});
desc.SetOutput("Out", {"dropout-Out"});
desc.SetAttr("is_test", is_test);
desc.SetAttr("dropout_prob", dropout_prob);
LOG(INFO) << "set OP";
validator.SetOp(*desc.Proto());
LOG(INFO) << "execute";
std::unordered_set<std::string> neglected_output = {"mask-Out"};
validator.Execute(8, neglected_output);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
USE_OP(dropout);
set(INFERENCE_URL "http://paddle-inference-dist.bj.bcebos.com")
set(INFERENCE_DEMO_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo")
set(INFERENCE_URL "http://paddle-inference-dist.cdn.bcebos.com")
set(INFERENCE_DEMO_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo" CACHE STRING
"A path setting inference demo download directories.")
set(INFERENCE_EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis_predictor)
function (inference_download_and_uncompress install_dir filename)
message(STATUS "Download inference test stuff from ${INFERENCE_URL}/${filename}")
function (inference_download install_dir url filename)
message(STATUS "Download inference test stuff from ${url}/${filename}")
execute_process(COMMAND bash -c "mkdir -p ${install_dir}")
execute_process(COMMAND bash -c "cd ${install_dir} && wget -q ${INFERENCE_URL}/${filename}")
execute_process(COMMAND bash -c "cd ${install_dir} && tar xzf ${filename}")
execute_process(COMMAND bash -c "cd ${install_dir} && wget -q ${url}/${filename}")
message(STATUS "finish downloading ${filename}")
endfunction(inference_download_and_uncompress)
endfunction()
function (inference_download_and_uncompress install_dir url filename)
inference_download(${install_dir} ${url} ${filename})
execute_process(COMMAND bash -c "cd ${install_dir} && tar xzf ${filename}")
endfunction()
function(download_model_and_data install_dir model_name data_name)
if (NOT EXISTS ${install_dir} AND WITH_INFERENCE)
inference_download_and_uncompress(${install_dir} ${model_name})
inference_download_and_uncompress(${install_dir} ${data_name})
if (NOT EXISTS ${install_dir})
inference_download_and_uncompress(${install_dir} ${INFERENCE_URL} ${model_name})
inference_download_and_uncompress(${install_dir} ${INFERENCE_URL} ${data_name})
endif()
endfunction()
function(inference_analysis_api_test target install_dir filename)
inference_analysis_test(${target} SRCS ${filename}
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${install_dir}/model --infer_data=${install_dir}/data.txt)
endfunction()
# RNN1
set(RNN1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/rnn1")
download_model_and_data(${RNN1_INSTALL_DIR} "rnn1%2Fmodel.tar.gz" "rnn1%2Fdata.txt.tar.gz")
inference_analysis_test(test_analyzer_rnn1 SRCS analyzer_rnn1_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${RNN1_INSTALL_DIR}/model
--infer_data=${RNN1_INSTALL_DIR}/data.txt)
if(NOT APPLE)
set(RNN1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/rnn1")
download_model_and_data(${RNN1_INSTALL_DIR} "rnn1%2Fmodel.tar.gz" "rnn1%2Fdata.txt.tar.gz")
inference_analysis_api_test(test_analyzer_rnn1 ${RNN1_INSTALL_DIR} analyzer_rnn1_tester.cc)
else()
# TODO: fix this test on MACOS, the reason is that
# fusion_seqexpand_concat_fc_op is not supported on MACOS
message(WARNING "These tests has been disabled in OSX before being fixed: \n test_analyzer_rnn1")
endif()
# RNN2
set(RNN2_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/rnn2")
download_model_and_data(${RNN2_INSTALL_DIR} "rnn2_model.tar.gz" "rnn2_data.txt.tar.gz")
inference_analysis_test(test_analyzer_rnn2 SRCS analyzer_rnn2_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${RNN2_INSTALL_DIR}/model
--infer_data=${RNN2_INSTALL_DIR}/data.txt)
inference_analysis_api_test(test_analyzer_rnn2 ${RNN2_INSTALL_DIR} analyzer_rnn2_tester.cc)
# chinese_ner
set(CHINESE_NER_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/chinese_ner")
download_model_and_data(${CHINESE_NER_INSTALL_DIR} "chinese_ner_model.tar.gz" "chinese_ner-data.txt.tar.gz")
inference_analysis_test(test_analyzer_ner SRCS analyzer_ner_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${CHINESE_NER_INSTALL_DIR}/model
--infer_data=${CHINESE_NER_INSTALL_DIR}/data.txt)
inference_analysis_api_test(test_analyzer_ner ${CHINESE_NER_INSTALL_DIR} analyzer_ner_tester.cc)
# lac
set(LAC_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/lac")
download_model_and_data(${LAC_INSTALL_DIR} "lac_model.tar.gz" "lac_data.txt.tar.gz")
inference_analysis_test(test_analyzer_lac SRCS analyzer_lac_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${LAC_INSTALL_DIR}/model
--infer_data=${LAC_INSTALL_DIR}/data.txt)
inference_analysis_api_test(test_analyzer_lac ${LAC_INSTALL_DIR} analyzer_lac_tester.cc)
# text_classification
set(TEXT_CLASSIFICATION_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/text_classification")
download_model_and_data(${TEXT_CLASSIFICATION_INSTALL_DIR} "text-classification-Senta.tar.gz" "text_classification_data.txt.tar.gz")
inference_analysis_test(test_analyzer_text_classification SRCS analyzer_text_classification_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TEXT_CLASSIFICATION_INSTALL_DIR}/model
--infer_data=${TEXT_CLASSIFICATION_INSTALL_DIR}/data.txt)
inference_analysis_api_test(test_analyzer_text_classification ${TEXT_CLASSIFICATION_INSTALL_DIR} analyzer_text_classification_tester.cc)
# ocr
set(OCR_MODEL_URL "http://paddlemodels.cdn.bcebos.com/inference-vis-demos%2Focr.tar.gz")
set(OCR_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo/ocr")
if (NOT EXISTS ${OCR_INSTALL_DIR} AND WITH_INFERENCE)
get_filename_component(filename ${OCR_MODEL_URL} NAME)
message(STATUS "Download inference test stuff ${filename} from ${OCR_MODEL_URL}")
execute_process(COMMAND bash -c "mkdir -p ${OCR_INSTALL_DIR}")
execute_process(COMMAND bash -c "cd ${OCR_INSTALL_DIR} && wget -q ${OCR_MODEL_URL}")
execute_process(COMMAND bash -c "cd ${OCR_INSTALL_DIR} && tar xzf ${filename}")
message(STATUS "finish downloading ${filename}")
set(OCR_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/ocr")
if (NOT EXISTS ${OCR_INSTALL_DIR})
inference_download_and_uncompress(${OCR_INSTALL_DIR} "http://paddlemodels.cdn.bcebos.com/" "inference-vis-demos%2Focr.tar.gz")
endif()
inference_analysis_api_test(test_analyzer_ocr ${OCR_INSTALL_DIR} analyzer_vis_tester.cc)
# anakin
if (WITH_ANAKIN AND WITH_MKL) # only needed in CI
# anakin rnn1
set(ANAKIN_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/anakin")
set(ANAKIN_RNN1_INSTALL_DIR "${ANAKIN_INSTALL_DIR}/rnn1")
inference_download(${ANAKIN_RNN1_INSTALL_DIR} ${INFERENCE_URL} "anakin_test%2Fditu_rnn.anakin2.model.bin")
inference_download(${ANAKIN_RNN1_INSTALL_DIR} ${INFERENCE_URL} "anakin_test%2Fditu_rnn_data.txt")
cc_test(test_anakin_rnn1 SRCS anakin_rnn1_tester.cc
ARGS --model=${ANAKIN_RNN1_INSTALL_DIR}/anakin_test%2Fditu_rnn.anakin2.model.bin
--datapath=${ANAKIN_RNN1_INSTALL_DIR}/anakin_test%2Fditu_rnn_data.txt
DEPS inference_anakin_api_shared SERIAL)
# anakin mobilenet
if(WITH_GPU)
set(ANAKIN_MOBILENET_INSTALL_DIR "${ANAKIN_INSTALL_DIR}/mobilenet")
inference_download(${ANAKIN_MOBILENET_INSTALL_DIR} ${INFERENCE_URL} "mobilenet_v2.anakin.bin")
cc_test(test_anakin_mobilenet SRCS anakin_mobilenet_tester.cc
ARGS --model=${ANAKIN_MOBILENET_INSTALL_DIR}/mobilenet_v2.anakin.bin
DEPS inference_anakin_api_shared dynload_cuda SERIAL)
endif()
endif()
inference_analysis_test(test_analyzer_ocr SRCS analyzer_vis_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${OCR_INSTALL_DIR}/model
--infer_data=${OCR_INSTALL_DIR}/data.txt)
......@@ -167,6 +167,8 @@ void BuddyAllocator::Free(void* p) {
}
size_t BuddyAllocator::Used() { return total_used_; }
size_t BuddyAllocator::GetMinChunkSize() { return min_chunk_size_; }
size_t BuddyAllocator::GetMaxChunkSize() { return max_chunk_size_; }
void* BuddyAllocator::SystemAlloc(size_t size) {
size_t index = 0;
......
......@@ -42,6 +42,8 @@ class BuddyAllocator {
void* Alloc(size_t unaligned_size);
void Free(void* ptr);
size_t Used();
size_t GetMinChunkSize();
size_t GetMaxChunkSize();
public:
// Disable copy and assignment
......
......@@ -119,8 +119,8 @@ void* Alloc<platform::CUDAPlace>(platform::CUDAPlace place, size_t size) {
LOG(WARNING) << "Cannot allocate " << size << " bytes in GPU "
<< place.device << ", available " << avail << " bytes";
LOG(WARNING) << "total " << total;
LOG(WARNING) << "GpuMinChunkSize " << platform::GpuMinChunkSize();
LOG(WARNING) << "GpuMaxChunkSize " << platform::GpuMaxChunkSize();
LOG(WARNING) << "GpuMinChunkSize " << buddy_allocator->GetMinChunkSize();
LOG(WARNING) << "GpuMaxChunkSize " << buddy_allocator->GetMaxChunkSize();
LOG(WARNING) << "GPU memory used: " << Used<platform::CUDAPlace>(place);
platform::SetDeviceId(cur_dev);
}
......
......@@ -296,6 +296,7 @@ op_library(flatten_op DEPS reshape_op)
op_library(sequence_pad_op DEPS sequence_padding)
op_library(unstack_op DEPS stack_op)
op_library(fake_quantize_op DEPS memory)
op_library(fusion_lstm_op DEPS cpu_lstm_compute)
if (WITH_GPU)
op_library(conv_op DEPS vol2col depthwise_conv im2col)
......
......@@ -56,7 +56,7 @@ class VarHandle {
const std::string& name,
const platform::DeviceContext* p_ctx = nullptr,
const framework::Scope* p_scope = nullptr)
: ok_(kVarHandleDefaultState) {
: status_(kDefaultState) {
ep_ = ep;
ctx_ = p_ctx;
scope_ = p_scope;
......@@ -68,18 +68,20 @@ class VarHandle {
public:
bool Wait() {
int ret = kDefaultState;
{
std::unique_lock<std::mutex> lk(sync_mutex_);
wait_cond_.wait(lk, [this] { return ok_ != kVarHandleDefaultState; });
wait_cond_.wait(lk, [this] { return status_ != kDefaultState; });
ret = status_;
}
VLOG(7) << "VarHandle wait:" << ok_;
return ok_ != 0;
VLOG(7) << "VarHandle wait:" << ret;
return ret != kErrorState;
}
void Finish(bool ok) {
{
std::unique_lock<std::mutex> lk(sync_mutex_);
ok_ = ok;
status_ = ok ? kFinishState : kErrorState;
}
VLOG(7) << "VarHandle finish:" << ok;
wait_cond_.notify_all();
......@@ -87,8 +89,8 @@ class VarHandle {
std::string String() const {
std::ostringstream s;
s << method_ << " name:[" << name_ << "], ep:[" << ep_ << "], ok:[" << ok_
<< "]";
s << method_ << " name:[" << name_ << "], ep:[" << ep_ << "], status:["
<< status_ << "]";
return s.str();
}
......@@ -111,9 +113,13 @@ class VarHandle {
protected:
std::mutex sync_mutex_;
std::condition_variable wait_cond_;
int ok_;
static const int kVarHandleDefaultState = -1;
enum VarHandleStatus {
kDefaultState = -1,
kErrorState = 0,
kFinishState = 1,
};
VarHandleStatus status_;
private:
DISABLE_COPY_AND_ASSIGN(VarHandle);
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/fusion_lstm_op.h"
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_lstm_compute.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
......@@ -269,7 +270,6 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
blas.GEMM(CblasNoTrans, CblasNoTrans, bs, D4, D, static_cast<T>(1), prev, D, \
wh_data, D4, static_cast<T>(1), out, D4)
// gates: W_ch, W_ih, W_fh, W_oh
#define GET_Ct(ct_1, gates, ct) \
/* C_t = C_t-1 * fgated + cand_gated * igated*/ \
act_cand(D, gates, gates); \
......@@ -395,11 +395,22 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
}
}
} else {
// TODO(TJ): unly workaround, clean me
std::function<void(T*, const T*, T*, T*)> compute_ctht;
if (platform::jit::MayIUse(platform::jit::avx) &&
act_gate_str == "sigmoid" && act_cand_str == "tanh" &&
act_cell_str == "tanh" && D == 8) {
compute_ctht = math::lstm_compute_ctht<T>;
} else {
compute_ctht = [&](T* gates, const T* ct_1, T* ct, T* ht) {
COMPUTE_CtHt(gates, ct_1, ct, ht);
};
}
for (int i = 0; i < N; ++i) {
PROCESS_H0C0
for (int step = tstart; step < seq_len; ++step) {
GEMM_WH_ADDON(1, prev_h_data, xx_data);
COMPUTE_CtHt(xx_data, prev_c_data, c_out_data, h_out_data);
compute_ctht(xx_data, prev_c_data, c_out_data, h_out_data);
MOVE_ONE_STEP;
}
}
......@@ -532,12 +543,23 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
MOVE_ONE_STEP;
}
} else {
// TODO(TJ): unly workaround, clean me
std::function<void(T*, const T*, T*, T*)> compute_ctht;
if (platform::jit::MayIUse(platform::jit::avx) &&
act_gate_str == "sigmoid" && act_cand_str == "tanh" &&
act_cell_str == "tanh" && D == 8) {
compute_ctht = math::lstm_compute_ctht<T>;
} else {
compute_ctht = [&](T* gates, const T* ct_1, T* ct, T* ht) {
COMPUTE_CtHt(gates, ct_1, ct, ht);
};
}
for (int step = tstart; step < max_seq_len; ++step) {
const int cur_bs = batch_starts[step + 1] - batch_starts[step];
GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data);
DEFINE_CUR;
for (int i = 0; i < cur_bs; ++i) {
COMPUTE_CtHt(cur_in_data, cur_prev_c_data, cur_c_out_data,
compute_ctht(cur_in_data, cur_prev_c_data, cur_c_out_data,
cur_h_out_data);
MOVE_ONE_BATCH;
}
......
......@@ -45,6 +45,8 @@ math_library(im2col)
if (NOT WIN32) # windows do not support avx functions yet.
math_library(gru_compute DEPS activation_functions math_function)
math_library(lstm_compute DEPS activation_functions)
# TODO(TJ): ugly workaround, clean me
cc_library(cpu_lstm_compute SRCS cpu_lstm_compute.cc DEPS activation_functions cblas cpu_info)
endif (NOT WIN32)
cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context)
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/cpu_lstm_compute.h"
namespace paddle {
namespace operators {
namespace math {} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/platform/cpu_info.h"
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace paddle {
namespace operators {
namespace math {
// TODO(TJ): ugly workaround, clean me
template <typename T>
void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) {
// gates: W_ch, W_ih, W_fh, W_oh
vec_sigmoid<T, platform::jit::avx>(24, gates + 8, gates + 8);
vec_tanh<T, platform::jit::avx>(8, gates, gates);
const T *i = gates + 8, *f = gates + 16, *o = gates + 24;
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int d = 0; d < 8; ++d) {
// C_t = C_t-1 * fgated + cand_gated * igated
ct[d] = ct_1[d] * f[d] + gates[d] * i[d];
// H_t = act_cell(C_t) * ogated
T tmp = ct[d] * 2;
tmp = static_cast<T>(0) - ((tmp < min) ? min : ((tmp > max) ? max : tmp));
vec_exp<T>(1, &tmp, &tmp);
tmp = static_cast<T>(2) / (static_cast<T>(1) + tmp) - static_cast<T>(1);
ht[d] = tmp * o[d];
}
}
#ifdef __AVX__
namespace detail {
namespace forward {
namespace avx {
__m256 Sigmoid(const __m256 a);
__m256 Tanh(const __m256 a);
} // namespace avx
} // namespace forward
} // namespace detail
template <>
void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct,
float* ht) {
namespace act = detail::forward::avx;
// gates: W_ch, W_ih, W_fh, W_oh
__m256 c, i, f, o;
c = _mm256_loadu_ps(gates);
i = _mm256_loadu_ps(gates + 8);
f = _mm256_loadu_ps(gates + 16);
o = _mm256_loadu_ps(gates + 24);
/* C_t = C_t-1 * fgated + cand_gated * igated*/
c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i));
i = _mm256_loadu_ps(ct_1);
f = _mm256_mul_ps(i, act::Sigmoid(f));
f = _mm256_add_ps(c, f);
_mm256_storeu_ps(ct, f);
/* H_t = act_cell(C_t) * ogated */
o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o));
_mm256_storeu_ps(ht, o);
}
#endif
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <functional>
#include <string>
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/enforce.h"
#ifdef __AVX__
#include <immintrin.h>
#endif
......@@ -476,7 +477,7 @@ class VecActivations {
} else if (type == "identity" || type == "") {
return vec_identity<T, isa>;
}
LOG(FATAL) << "Not support type: " << type;
PADDLE_THROW("Not support type: %s", type);
}
};
......
......@@ -71,8 +71,7 @@ class MaxOutOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of MaxoutOp"
"should not be null.");
"Input(X) of MaxoutOpshould not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of MaxoutOp should not be null.");
auto in_x_dims = ctx->GetInputDim("X");
......@@ -90,9 +89,10 @@ class MaxOutOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of MaxOutOpGrad must not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Input(X@GRAD) should not be null.");
"Output(Grad@X) of MaxOutOpGrad should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
......
......@@ -42,7 +42,7 @@ class RNNMemoryHelperOp : public framework::OperatorBase {
auto *out_tensor = out_var->GetMutable<framework::LoDTensor>();
auto &mem_tensor = mem_var->Get<framework::LoDTensor>();
out_tensor->ShareDataWith(mem_tensor);
framework::TensorCopySync(mem_tensor, dev_place, out_tensor);
out_tensor->set_lod(mem_tensor.lod());
}
};
......@@ -50,8 +50,10 @@ class RNNMemoryHelperOp : public framework::OperatorBase {
class RNNMemoryHelperOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "");
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of rnn_memory_helper op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output of rnn_memory_helper op should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
......@@ -107,7 +109,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
} else {
auto &out_grad_tensor = out_grad_var->Get<framework::LoDTensor>();
auto *in_grad_tensor = in_grad_var->GetMutable<framework::LoDTensor>();
in_grad_tensor->ShareDataWith(out_grad_tensor);
framework::TensorCopySync(out_grad_tensor, dev_place, in_grad_tensor);
in_grad_tensor->set_lod(out_grad_tensor.lod());
}
}
......@@ -133,8 +135,11 @@ class RNNMemoryHelperGradOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
auto x_grad_name = framework::GradVarName("X");
PADDLE_ENFORCE(ctx->HasOutput(x_grad_name), "");
PADDLE_ENFORCE(ctx->HasInput("X"), "");
PADDLE_ENFORCE(ctx->HasOutput(x_grad_name),
"Gradient of Input(X) in rnn_memory_helper_grad of should "
"not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of rnn_memory_helper_grad of should not be null.");
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ x_grad_name);
}
......
......@@ -25,7 +25,7 @@ class SliceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input (Input) of slice op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
......@@ -58,7 +58,7 @@ class SliceOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
ctx.GetPlace());
......@@ -87,13 +87,13 @@ Slice Operator.
Produces a slice of the input tensor along multiple axes. Similar to numpy:
https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html
Slice uses `axes`, `starts` and `ends` attributes to specify the start and
Slice uses `axes`, `starts` and `ends` attributes to specify the start and
end dimension for each axis in the list of axes, it uses this information
to slice the input data tensor. If a negative value is passed for any of
the start or end indices, it represents number of elements before the end
to slice the input data tensor. If a negative value is passed for any of
the start or end indices, it represents number of elements before the end
of that dimension. If the value passed to start or end is larger than
the n (the number of elements in this dimension), it represents n.
For slicing to the end of a dimension with unknown size, it is recommended
the n (the number of elements in this dimension), it represents n.
For slicing to the end of a dimension with unknown size, it is recommended
to pass in INT_MAX. If axes are omitted, they are set to [0, ..., ndim-1].
Following examples will explain how slice works:
......@@ -119,15 +119,54 @@ Following examples will explain how slice works:
}
};
class SliceOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("Input");
auto x_grad_name = framework::GradVarName("Input");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
};
class SliceOpGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* bind = new framework::OpDesc();
bind->SetInput("Input", Input("Input"));
bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
bind->SetOutput(framework::GradVarName("Input"), InputGrad("Input"));
bind->SetAttrMap(Attrs());
bind->SetType("slice_grad");
return std::unique_ptr<framework::OpDesc>(bind);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(slice, ops::SliceOp, ops::SliceOpMaker,
paddle::framework::EmptyGradOpMaker);
ops::SliceOpGradMaker);
REGISTER_OPERATOR(slice_grad, ops::SliceOpGrad);
REGISTER_OP_CPU_KERNEL(
slice, ops::SliceKernel<paddle::platform::CPUDeviceContext, int>,
ops::SliceKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SliceKernel<paddle::platform::CPUDeviceContext, float>,
ops::SliceKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
slice_grad, ops::SliceGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -20,3 +20,10 @@ REGISTER_OP_CUDA_KERNEL(
ops::SliceKernel<paddle::platform::CUDADeviceContext, double>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, int>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
slice_grad,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
......@@ -84,5 +85,79 @@ class SliceKernel : public framework::OpKernel<T> {
out_t.device(place) = in_t.slice(offsets, extents);
}
};
template <typename DeviceContext, typename T>
class SliceGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
size_t rank = ctx.Input<framework::Tensor>(framework::GradVarName("Out"))
->dims()
.size();
switch (rank) {
case 1:
SliceCompute<1>(ctx);
break;
case 2:
SliceCompute<2>(ctx);
break;
case 3:
SliceCompute<3>(ctx);
break;
case 4:
SliceCompute<4>(ctx);
break;
case 5:
SliceCompute<5>(ctx);
break;
case 6:
SliceCompute<6>(ctx);
break;
}
}
private:
template <size_t D>
void SliceCompute(const framework::ExecutionContext& context) const {
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto* d_out =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_input =
context.Output<framework::Tensor>(framework::GradVarName("Input"));
d_input->mutable_data<T>(context.GetPlace());
auto out_dims = d_out->dims();
auto in_dims = d_input->dims();
auto axes = context.Attr<std::vector<int>>("axes");
auto starts = context.Attr<std::vector<int>>("starts");
auto offsets = Eigen::array<int, D>();
auto extents = Eigen::array<int, D>();
for (size_t i = 0; i < D; ++i) {
offsets[i] = 0;
extents[i] = out_dims[i];
}
int start;
for (size_t i = 0; i < axes.size(); ++i) {
start = starts[i];
if (start < 0) {
start = (start + in_dims[axes[i]]);
}
start = std::max(start, 0);
offsets[axes[i]] = start;
}
Eigen::array<std::pair<int, int>, D> paddings;
for (size_t i = 0; i < paddings.size(); ++i) {
paddings[i].first = offsets[i];
paddings[i].second = (in_dims[i] - out_dims[i]) - offsets[i];
}
auto d_in_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*d_input);
auto d_out_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*d_out);
d_in_t.device(place) = d_out_t.pad(paddings, 0);
}
};
} // namespace operators
} // namespace paddle
......@@ -63,7 +63,7 @@ class WhileOp : public framework::OperatorBase {
while (cond.data<bool>()[0]) {
auto &current_scope = scope.NewScope();
step_scopes->push_back(&current_scope);
executor.RunPreparedContext(ctx.get(), &current_scope, false);
executor.RunPreparedContext(ctx.get(), &current_scope, false, true, true);
if (is_test) {
scope.DeleteScope(&current_scope);
}
......@@ -169,7 +169,8 @@ class WhileGradOp : public framework::OperatorBase {
}
}
}
executor.RunPreparedContext(ctx.get(), *cur_scope_iter, false);
executor.RunPreparedContext(ctx.get(), *cur_scope_iter, false, true,
true);
auto &pg_names = Outputs(kXGRAD);
auto &p_names = Inputs(kX);
......
......@@ -51,7 +51,7 @@ ENDIF()
# memcpy depends on device_context, here add deps individually for
# avoiding cycle dependencies
cc_library(device_context SRCS device_context.cc init.cc DEPS malloc
cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc
place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS})
nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info)
......
......@@ -210,11 +210,14 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
if (dynload::HasCUDNN()) {
cudnn_holder_.reset(new CudnnHolder(&stream_, place));
}
callback_manager_.reset(new StreamCallbackManager(stream_));
}
CUDADeviceContext::~CUDADeviceContext() {
SetDeviceId(place_.device);
Wait();
WaitStreamCallback();
PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
eigen_stream_.reset();
eigen_device_.reset();
......
......@@ -31,6 +31,9 @@ limitations under the License. */
#include "glog/logging.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/stream_callback_manager.h"
#endif
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
......@@ -112,6 +115,17 @@ class CUDADeviceContext : public DeviceContext {
PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
}
template <typename Callback>
void AddStreamCallback(Callback&& callback) const {
std::lock_guard<std::mutex> guard(callback_mtx_);
callback_manager_->AddCallback(callback);
}
void WaitStreamCallback() const {
std::lock_guard<std::mutex> guard(callback_mtx_);
callback_manager_->Wait();
}
private:
CUDAPlace place_;
......@@ -125,7 +139,12 @@ class CUDADeviceContext : public DeviceContext {
int multi_process;
int max_threads_per_mp;
std::mutex mtx_;
mutable std::mutex mtx_;
// This lock is only used by callback
// If we use mtx_ for StreamCallbackManager, deadlock may occur sometimes
mutable std::mutex callback_mtx_;
std::unique_ptr<StreamCallbackManager> callback_manager_;
};
template <>
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <functional>
#include <memory>
#include "ThreadPool.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace platform {
using StreamCallback = std::function<void(cudaStream_t, cudaError_t)>;
class StreamCallbackManager;
struct StreamCallbackContext {
template <typename Callback>
inline StreamCallbackContext(const StreamCallbackManager *manager,
Callback &&callback)
: manager_(manager), callback_(callback) {}
const StreamCallbackManager *manager_; // do not own
StreamCallback callback_;
};
class StreamCallbackManager {
public:
explicit inline StreamCallbackManager(cudaStream_t stream = nullptr)
: stream_(stream), thread_pool_(new ThreadPool(1)) {}
template <typename Callback>
inline void AddCallback(Callback &&callback) const {
AddCallbackWithStreamAndErrorInfo(
[=](cudaStream_t, cudaError_t) { callback(); });
}
template <typename Callback>
inline void AddCallbackWithStreamAndErrorInfo(Callback &&callback) const {
auto *stream_callback_context = new StreamCallbackContext(this, callback);
PADDLE_ENFORCE(cudaStreamAddCallback(
stream_, StreamCallbackManager::StreamCallbackFunc,
stream_callback_context, 0));
}
void Wait() const { thread_pool_.reset(new ThreadPool(1)); }
private:
const cudaStream_t stream_;
mutable std::unique_ptr<ThreadPool> thread_pool_;
// cudaStreamCallback cannot call CUDA API inside, so we have to use
// thread_pool here
static void CUDART_CB StreamCallbackFunc(cudaStream_t stream,
cudaError_t status,
void *user_data) {
auto *callback_context_ptr =
reinterpret_cast<StreamCallbackContext *>(user_data);
callback_context_ptr->manager_->thread_pool_->enqueue([=]() {
std::unique_ptr<StreamCallbackContext> callback_context(
callback_context_ptr);
callback_context->callback_(stream, status);
});
}
};
} // namespace platform
} // namespace paddle
......@@ -68,26 +68,44 @@ function cmake_gen() {
# Support build for all python versions, currently
# including cp27-cp27m and cp27-cp27mu.
PYTHON_FLAGS=""
if [ "$1" != "" ]; then
echo "using python abi: $1"
if [ "$1" == "cp27-cp27m" ]; then
export LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs2/lib:${LD_LIBRARY_PATH#/opt/_internal/cpython-2.7.11-ucs4/lib:}
export PATH=/opt/python/cp27-cp27m/bin/:${PATH}
PYTHON_FLAGS="-DPYTHON_EXECUTABLE:FILEPATH=/opt/python/cp27-cp27m/bin/python
-DPYTHON_INCLUDE_DIR:PATH=/opt/python/cp27-cp27m/include/python2.7
-DPYTHON_LIBRARIES:FILEPATH=/opt/_internal/cpython-2.7.11-ucs2/lib/libpython2.7.so"
elif [ "$1" == "cp27-cp27mu" ]; then
export LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs4/lib:${LD_LIBRARY_PATH#/opt/_internal/cpython-2.7.11-ucs2/lib:}
export PATH=/opt/python/cp27-cp27mu/bin/:${PATH}
PYTHON_FLAGS="-DPYTHON_EXECUTABLE:FILEPATH=/opt/python/cp27-cp27mu/bin/python
-DPYTHON_INCLUDE_DIR:PATH=/opt/python/cp27-cp27mu/include/python2.7
-DPYTHON_LIBRARIES:FILEPATH=/opt/_internal/cpython-2.7.11-ucs4/lib/libpython2.7.so"
elif [ "$1" == "cp35-cp35m" ]; then
export LD_LIBRARY_PATH=/opt/_internal/cpython-3.5.1/lib/:${LD_LIBRARY_PATH}
export PATH=/opt/_internal/cpython-3.5.1/bin/:${PATH}
export PYTHON_FLAGS="-DPYTHON_EXECUTABLE:FILEPATH=/opt/_internal/cpython-3.5.1/bin/python3
SYSTEM=`uname -s`
if [ "$SYSTEM" == "Darwin" ]; then
if [[ "$1" == "cp27-cp27m" ]] || [[ "$1" == "" ]]; then
echo "using python abi: $1"
if [ -d "/Library/Frameworks/Python.framework/Versions/2.7" ]; then
export LD_LIBRARY_PATH=/Library/Frameworks/Python.framework/Versions/2.7
export DYLD_LIBRARY_PATH=/Library/Frameworks/Python.framework/Versions/2.7
export PATH=/Library/Frameworks/Python.framework/Versions/2.7/bin/:${PATH}
PYTHON_FLAGS="-DPYTHON_EXECUTABLE:FILEPATH=/Library/Frameworks/Python.framework/Versions/2.7/bin/python2.7
-DPYTHON_INCLUDE_DIR:PATH=/Library/Frameworks/Python.framework/Versions/2.7/include/python2.7
-DPYTHON_LIBRARY:FILEPATH=/Library/Frameworks/Python.framework/Versions/2.7/lib/libpython2.7.dylib"
else
exit 1
fi
# TODO: qiyang add python3 part here
fi
else
if [ "$1" != "" ]; then
echo "using python abi: $1"
if [ "$1" == "cp27-cp27m" ]; then
export LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs2/lib:${LD_LIBRARY_PATH#/opt/_internal/cpython-2.7.11-ucs4/lib:}
export PATH=/opt/python/cp27-cp27m/bin/:${PATH}
PYTHON_FLAGS="-DPYTHON_EXECUTABLE:FILEPATH=/opt/python/cp27-cp27m/bin/python
-DPYTHON_INCLUDE_DIR:PATH=/opt/python/cp27-cp27m/include/python2.7
-DPYTHON_LIBRARIES:FILEPATH=/opt/_internal/cpython-2.7.11-ucs2/lib/libpython2.7.so"
elif [ "$1" == "cp27-cp27mu" ]; then
export LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs4/lib:${LD_LIBRARY_PATH#/opt/_internal/cpython-2.7.11-ucs2/lib:}
export PATH=/opt/python/cp27-cp27mu/bin/:${PATH}
PYTHON_FLAGS="-DPYTHON_EXECUTABLE:FILEPATH=/opt/python/cp27-cp27mu/bin/python
-DPYTHON_INCLUDE_DIR:PATH=/opt/python/cp27-cp27mu/include/python2.7
-DPYTHON_LIBRARIES:FILEPATH=/opt/_internal/cpython-2.7.11-ucs4/lib/libpython2.7.so"
elif [ "$1" == "cp35-cp35m" ]; then
export LD_LIBRARY_PATH=/opt/_internal/cpython-3.5.1/lib/:${LD_LIBRARY_PATH}
export PATH=/opt/_internal/cpython-3.5.1/bin/:${PATH}
export PYTHON_FLAGS="-DPYTHON_EXECUTABLE:FILEPATH=/opt/_internal/cpython-3.5.1/bin/python3
-DPYTHON_INCLUDE_DIR:PATH=/opt/_internal/cpython-3.5.1/include/python3.5m
-DPYTHON_LIBRARIES:FILEPATH=/opt/_internal/cpython-3.5.1/lib/libpython3.so"
fi
fi
fi
......@@ -117,6 +135,8 @@ function cmake_gen() {
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON
-DWITH_CONTRIB=${WITH_CONTRIB:-ON}
-DWITH_INFERENCE=${WITH_INFERENCE:-ON}
-DWITH_INFERENCE_API_TEST=${WITH_INFERENCE_API_TEST:-ON}
-DINFERENCE_DEMO_INSTALL_DIR=${INFERENCE_DEMO_INSTALL_DIR:-/root/.cache/inference_demo}
-DWITH_ANAKIN=${WITH_ANAKIN:-OFF}
-DPY_VERSION=${PY_VERSION:-2.7}
========================================
......@@ -147,6 +167,8 @@ EOF
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
-DWITH_CONTRIB=${WITH_CONTRIB:-ON} \
-DWITH_INFERENCE=${WITH_INFERENCE:-ON} \
-DWITH_INFERENCE_API_TEST=${WITH_INFERENCE_API_TEST:-ON} \
-DINFERENCE_DEMO_INSTALL_DIR=${INFERENCE_DEMO_INSTALL_DIR:-/root/.cache/inference_demo} \
-DWITH_ANAKIN=${WITH_ANAKIN:-OFF} \
-DPY_VERSION=${PY_VERSION:-2.7}
}
......@@ -201,6 +223,19 @@ EOF
make install -j `nproc`
}
function build_mac() {
mkdir -p ${PADDLE_ROOT}/build
cd ${PADDLE_ROOT}/build
cat <<EOF
============================================
Building in /paddle/build ...
============================================
EOF
make clean
sudo make -j 8
sudo make install -j 8
}
function build_android() {
if [ $ANDROID_ABI == "arm64-v8a" ]; then
ANDROID_ARCH=arm64
......@@ -324,6 +359,27 @@ EOF
fi
}
function run_mac_test() {
mkdir -p ${PADDLE_ROOT}/build
cd ${PADDLE_ROOT}/build
if [ ${WITH_TESTING:-ON} == "ON" ] ; then
cat <<EOF
========================================
Running unit tests ...
========================================
EOF
# TODO: jiabin need to refine this part when these tests fixed on mac
ctest --output-on-failure -j8
# make install should also be test when unittest
make install -j 8
pip install /usr/local/opt/paddle/share/wheels/*.whl
if [[ ${WITH_FLUID_ONLY:-OFF} == "OFF" ]] ; then
paddle version
fi
fi
}
function assert_api_not_changed() {
mkdir -p ${PADDLE_ROOT}/build/.check_api_workspace
cd ${PADDLE_ROOT}/build/.check_api_workspace
......@@ -432,24 +488,6 @@ EOF
linkchecker doc/v2/cn/html/index.html
linkchecker doc/v2/api/en/html/index.html
# if [[ "$TRAVIS_PULL_REQUEST" != "false" ]]; then exit 0; fi;
#
# # Deploy to the the content server if its a "develop" or "release/version" branch
# # The "develop_doc" branch is reserved to test full deploy process without impacting the real content.
# if [ "$TRAVIS_BRANCH" == "develop_doc" ]; then
# PPO_SCRIPT_BRANCH=develop
# elif [[ "$TRAVIS_BRANCH" == "develop" || "$TRAVIS_BRANCH" =~ ^v|release/[[:digit:]]+\.[[:digit:]]+(\.[[:digit:]]+)?(-\S*)?$ ]]; then
# PPO_SCRIPT_BRANCH=master
# else
# # Early exit, this branch doesn't require documentation build
# return 0;
# fi
# # Fetch the paddlepaddle.org deploy_docs.sh from the appopriate branch
# export DEPLOY_DOCS_SH=https://raw.githubusercontent.com/PaddlePaddle/PaddlePaddle.org/$PPO_SCRIPT_BRANCH/scripts/deploy/deploy_docs.sh
# export PYTHONPATH=$PYTHONPATH:${PADDLE_ROOT}/build/python:/paddle/build/python
# cd ..
# curl $DEPLOY_DOCS_SH | bash -s $CONTENT_DEC_PASSWD $TRAVIS_BRANCH ${PADDLE_ROOT} ${PADDLE_ROOT}/build/doc/ ${PPO_SCRIPT_BRANCH}
# cd -
}
function gen_doc_lib() {
......@@ -677,6 +715,17 @@ function main() {
test_fluid_inference_lib
assert_api_spec_approvals
;;
maccheck)
cmake_gen ${PYTHON_ABI:-""}
build_mac
run_mac_test
;;
cicheck_py35)
cmake_gen ${PYTHON_ABI:-""}
build
run_test
assert_api_not_changed
;;
*)
print_usage
exit 0
......
......@@ -67,7 +67,7 @@ def get_word_dict():
for field in movie_reviews.fileids(category):
for words in movie_reviews.words(field):
word_freq_dict[words] += 1
words_sort_list = six.iteritems(word_freq_dict)
words_sort_list = list(six.iteritems(word_freq_dict))
words_sort_list.sort(cmp=lambda a, b: b[1] - a[1])
for index, word in enumerate(words_sort_list):
words_freq_sorted.append((word[0], index))
......
......@@ -122,7 +122,7 @@ def __bootstrap__():
'use_pinned_memory', 'check_nan_inf', 'benchmark', 'warpctc_dir',
'eager_delete_scope', 'use_mkldnn', 'initial_cpu_memory_in_mb',
'init_allocated_mem', 'free_idle_memory', 'paddle_num_threads',
"dist_threadpool_size", 'cpu_deterministic'
"dist_threadpool_size", 'cpu_deterministic', 'eager_delete_tensor_gb'
]
if core.is_compiled_with_dist():
read_env_flags.append('rpc_deadline')
......
......@@ -28,9 +28,18 @@ list(REMOVE_ITEM TEST_OPS test_cond_op) # FIXME(qijun): https://github.com/Paddl
list(REMOVE_ITEM TEST_OPS op_test) # op_test is a helper python file, not a test
list(REMOVE_ITEM TEST_OPS decorators) # decorators is a helper python file, not a test
if(APPLE)
message(WARNING "These tests has been disabled in OSX before being fixed: \n test_detection_map_op \n test_desc_clone \n test_debugger \n test_program_code \n test_dist_transformer \n test_dist_se_resnext")
# this op is not support on mac
list(REMOVE_ITEM TEST_OPS test_fusion_seqexpand_concat_fc_op)
# TODO: add the unitest back when it fixed
list(REMOVE_ITEM TEST_OPS test_detection_map_op)
list(REMOVE_ITEM TEST_OPS test_desc_clone)
list(REMOVE_ITEM TEST_OPS test_debugger)
list(REMOVE_ITEM TEST_OPS test_program_code)
list(REMOVE_ITEM TEST_OPS test_dist_transformer)
list(REMOVE_ITEM TEST_OPS test_dist_se_resnext)
endif()
function(py_test_modules TARGET_NAME)
......
......@@ -109,15 +109,20 @@ def get_transpiler(trainer_id, main_program, pserver_endpoints, trainers):
return t
from paddle.fluid.transpiler.details import op_to_code
def operator_equal(a, b):
if op_to_code(a) != op_to_code(b):
raise ValueError("In operator_equal not equal\n")
for k, v in six.iteritems(a.__dict__):
if isinstance(v, fluid.framework.Program) or \
isinstance(v, fluid.framework.Block):
continue
elif isinstance(v, core.OpDesc):
if v.serialize_to_string() != b.__dict__[k].serialize_to_string():
raise ValueError("In operator_equal not equal:{0}\n".format(k))
continue
elif isinstance(v, collections.OrderedDict):
v0 = sorted(list(six.iteritems(v)), key=lambda x: x[0])
......
......@@ -61,9 +61,7 @@ class TestDistTransformer2x2Sync(TestDistBase):
def test_transformer(self):
download_files()
#Note: loss on test dataset of the first 5 batch are:
# 10.518872, 10.518871, 10.518868, 10.518862, 10.518855
self.check_with_place("dist_transformer.py", delta=1e-7)
self.check_with_place("dist_transformer.py", delta=1e-5)
class TestDistTransformer2x2Async(TestDistBase):
......
......@@ -41,6 +41,9 @@ class TestSliceOp(OpTest):
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
class TestCase1(TestSliceOp):
def config(self):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -30,8 +30,10 @@ class TestWhileOp(unittest.TestCase):
"d1", shape=[10], append_batch_size=False, dtype='float32')
d2 = layers.data(
"d2", shape=[10], append_batch_size=False, dtype='float32')
i = layers.zeros(shape=[1], dtype='int64')
i.stop_gradient = True
init = layers.zeros(shape=[10], dtype='float32')
mem_array = layers.array_write(x=init, i=i)
data_array = layers.array_write(x=d0, i=i)
......@@ -45,11 +47,19 @@ class TestWhileOp(unittest.TestCase):
i = layers.zeros(shape=[1], dtype='int64')
i.stop_gradient = True
array_len = layers.fill_constant(shape=[1], dtype='int64', value=3)
array_len = layers.fill_constant(shape=[1], dtype='int64', value=1)
array_len.stop_gradient = True
cond = layers.less_than(x=i, y=array_len)
j = layers.fill_constant(shape=[1], dtype='int64', value=1)
j.stop_gradient = True
array_len2 = layers.fill_constant(shape=[1], dtype='int64', value=3)
array_len2.stop_gradient = True
cond2 = layers.less_than(x=j, y=array_len2)
while_op = layers.While(cond=cond)
while_op2 = layers.While(cond=cond2)
with while_op.block():
d = layers.array_read(array=data_array, i=i)
prev = layers.array_read(array=mem_array, i=i)
......@@ -59,7 +69,16 @@ class TestWhileOp(unittest.TestCase):
layers.array_write(result, i=i, array=mem_array)
layers.less_than(x=i, y=array_len, cond=cond)
sum_result = layers.array_read(array=mem_array, i=i)
with while_op2.block():
d2 = layers.array_read(array=data_array, i=j)
prev2 = layers.array_read(array=mem_array, i=j)
result2 = layers.sums(input=[d2, prev2])
j = layers.increment(x=j, in_place=True)
layers.array_write(result2, i=j, array=mem_array)
layers.less_than(x=j, y=array_len2, cond=cond2)
sum_result = layers.array_read(array=mem_array, i=j)
loss = layers.mean(sum_result)
append_backward(loss)
......
......@@ -113,27 +113,32 @@ def op_to_code(op):
inputs_str += ", "
inputs_str += "}"
attr_names = sorted(op.attr_names)
attrs_str = ""
for i in range(0, len(op.attr_names)):
name = op.attr_names[i]
for i in range(0, len(attr_names)):
name = attr_names[i]
attr_type = op.desc.attr_type(name)
if attr_type == core.AttrType.BLOCK:
a = "{name} = block[{value}]".format(
name=name, type=attr_type, value=op.block_attr_id(name))
attrs_str += a
if i != len(attr_names) - 1:
attrs_str += ", "
continue
if attr_type == core.AttrType.BLOCKS:
a = "{name} = blocks{value}".format(
name=name, type=attr_type, value=op.blocks_attr_ids(name))
attrs_str += a
if i != len(attr_names) - 1:
attrs_str += ", "
continue
a = "{name} = {value}".format(
name=name, type=attr_type, value=op.desc.attr(name))
attrs_str += a
if i != len(op.attr_names) - 1:
if i != len(attr_names) - 1:
attrs_str += ", "
if outputs_str != "{}":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册