提交 d9942cd1 编写于 作者: S sneaxiy

Merge develop

# PaddlePaddle Releasing Process # PaddlePaddle Releasing Process
PaddlePaddle manages its branches using "git-flow branching model", and [Semantic Versioning](http://semver.org/) as it's version number semantics. PaddlePaddle manages its branches using Trunk Based Development, and [Semantic Versioning](http://semver.org/) as it's version number semantics.
Each time we release a new PaddlePaddle version, we should follow the below steps: Each time we release a new PaddlePaddle version, we should follow the below steps:
......
...@@ -28,10 +28,20 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_ ...@@ -28,10 +28,20 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope) cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
if(WITH_GPU)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass)
endif()
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle) scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto) if(WITH_GPU)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass)
else()
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto)
endif()
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context) simple_threadpool device_context)
......
...@@ -32,6 +32,10 @@ struct ComputationOpHandle : public OpHandleBase { ...@@ -32,6 +32,10 @@ struct ComputationOpHandle : public OpHandleBase {
std::string Name() const override; std::string Name() const override;
const Scope *GetScope() const { return scope_; }
const platform::Place &GetPlace() const { return place_; }
protected: protected:
void RunImpl() override; void RunImpl() override;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
namespace paddle {
namespace framework {
namespace details {
using ReferenceCountMap = std::unordered_map<std::string, int>;
using AtomicReferenceCountMap =
std::unordered_map<std::string, std::atomic<int>>;
using DeviceReferenceCountMap =
std::unordered_map<int, std::unique_ptr<ReferenceCountMap>>;
using AtomicDeviceReferenceCountMap =
std::unordered_map<int, std::unique_ptr<AtomicReferenceCountMap>>;
using DeviceGarbageCollectorMap =
std::unordered_map<int,
std::unique_ptr<GarbageCollector<framework::Tensor>>>;
class ReferenceCountOpHandle : public OpHandleBase {
public:
ReferenceCountOpHandle(ir::Node *node, const Scope *scope,
const platform::CUDAPlace &place,
const std::vector<std::string> &var_names,
GarbageCollector<Tensor> *gc,
AtomicReferenceCountMap *ref_cnts)
: OpHandleBase(node),
scope_(scope),
var_names_(var_names),
gc_(gc),
ref_cnts_(ref_cnts) {
dev_ctx_ = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place));
if (IsStreamGarabageCollector()) {
PADDLE_ENFORCE(cudaSetDevice(place.device));
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
}
}
~ReferenceCountOpHandle() {
if (IsStreamGarabageCollector()) {
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace());
PADDLE_ENFORCE(cudaSetDevice(gpu_place.device));
PADDLE_ENFORCE(cudaEventDestroy(event_));
}
}
std::string Name() const override { return "reference_count"; }
protected:
void RunImpl() override {
auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
std::vector<LoDTensor *> tensors;
for (auto &name : var_names_) {
auto it = ref_cnts_->find(name);
if (it == ref_cnts_->end()) continue;
auto *var = exec_scope->FindVar(name);
if (var == nullptr || !var->IsType<LoDTensor>()) continue;
if (it->second.fetch_sub(1) <= 1) {
tensors.emplace_back(var->GetMutable<LoDTensor>());
}
}
if (!tensors.empty()) {
ClearTensors(tensors);
}
}
private:
void ClearTensors(const std::vector<LoDTensor *> &tensors) {
auto *gc = dynamic_cast<StreamGarbageCollector<Tensor> *>(gc_);
if (gc != nullptr) {
auto compute_stream = dev_ctx_->stream();
auto callback_stream = gc->stream();
auto callback_func = [=]() {
PADDLE_ENFORCE(cudaEventRecord(event_, compute_stream));
PADDLE_ENFORCE(cudaStreamWaitEvent(callback_stream, event_, 0));
};
gc_->Add(tensors, callback_func);
} else {
gc_->Add(tensors);
}
}
bool IsStreamGarabageCollector() const {
return dynamic_cast<const StreamGarbageCollector<Tensor> *>(gc_) != nullptr;
}
const Scope *scope_;
platform::CUDADeviceContext *dev_ctx_;
std::vector<std::string> var_names_;
GarbageCollector<Tensor> *gc_; // not own
AtomicReferenceCountMap *ref_cnts_; // not own
cudaEvent_t event_;
};
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/reference_count_pass.h"
namespace paddle {
namespace framework {
namespace details {
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
auto &ref_cnts = Get<DeviceReferenceCountMap>(kGlobalReferenceCount);
auto &cur_ref_cnts = Get<AtomicDeviceReferenceCountMap>(kCurReferenceCount);
auto &gcs = Get<DeviceGarbageCollectorMap>(kGarbageCollector);
// It is not easy to find the right reference counts of varaibles in graph
// Step 1: Find all variables in computation ops
// Step 2: Find all variables in non-computation ops which refers to variables
// in computation ops
std::unordered_set<std::string> names;
auto get_ref_cnts_from_compute_op = [&](
const std::unique_ptr<OpHandleBase> &op,
const std::vector<VarHandleBase *> &vars) {
std::vector<std::string> var_names_in_op;
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get());
if (compute_op == nullptr ||
!platform::is_gpu_place(compute_op->GetPlace()))
return var_names_in_op;
auto place = boost::get<platform::CUDAPlace>(compute_op->GetPlace());
for (VarHandleBase *var_handle_base : vars) {
auto *var_handle = dynamic_cast<VarHandle *>(var_handle_base);
if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue;
if (!platform::is_gpu_place(var_handle->place_) ||
boost::get<platform::CUDAPlace>(var_handle->place_) != place)
continue;
VarDesc *var_desc = var_handle->Node()->Var();
auto var_name = var_handle->Node()->Name();
// This is wierd but there is really some variables without var_desc
// in computation_op
if (var_desc == nullptr) {
if (compute_op->Node()->Op()->Block()->FindVar(var_name) == nullptr)
continue;
} else {
if (var_desc->Persistable() ||
var_desc->Proto()->type().type() != proto::VarType::LOD_TENSOR)
continue;
}
// compute op only runs in one device
if (ref_cnts[place.device]->count(var_name))
++(*ref_cnts[place.device])[var_name];
else
(*ref_cnts[place.device])[var_name] = 1;
names.insert(var_name);
var_names_in_op.push_back(var_name);
}
return var_names_in_op;
};
auto update_ref_cnts_from_non_compute_op = [&](
const std::unique_ptr<OpHandleBase> &op,
const std::vector<VarHandleBase *> &vars) {
if (dynamic_cast<ComputationOpHandle *>(op.get()) != nullptr) return;
for (VarHandleBase *var_handle_base : vars) {
auto *var_handle = dynamic_cast<VarHandle *>(var_handle_base);
if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue;
auto var_name = var_handle->Node()->Name();
auto var_place = var_handle->place_;
if (!platform::is_gpu_place(var_place)) continue;
auto place = boost::get<platform::CUDAPlace>(var_place);
if (names.count(var_name) == 0) continue;
if (ref_cnts.count(place.device) &&
ref_cnts[place.device]->count(var_name)) {
++(*ref_cnts[place.device])[var_name];
}
}
};
std::unordered_map<OpHandleBase *, ReferenceCountOpHandle *>
compute_ref_cnt_map;
auto &all_ops = graph->Get<GraphOps>(kGraphOps);
for (auto &op : all_ops) {
auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs());
auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs());
if (in_var_names.empty() && out_var_names.empty()) continue;
in_var_names.insert(in_var_names.end(), out_var_names.begin(),
out_var_names.end());
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get());
auto place = boost::get<platform::CUDAPlace>(compute_op->GetPlace());
ir::Node *ref_cnt_node =
graph->CreateEmptyNode("reference_count", ir::Node::Type::kOperation);
auto *ref_cnt_handle = new ReferenceCountOpHandle(
ref_cnt_node, compute_op->GetScope(), place, in_var_names,
gcs[place.device].get(), cur_ref_cnts[place.device].get());
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
compute_op->AddOutput(dep_var);
ref_cnt_handle->AddInput(dep_var);
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
compute_ref_cnt_map[compute_op] = ref_cnt_handle;
}
for (auto &op : all_ops) {
update_ref_cnts_from_non_compute_op(op, op->Inputs());
update_ref_cnts_from_non_compute_op(op, op->Outputs());
}
std::vector<std::unique_ptr<OpHandleBase>> new_all_ops;
new_all_ops.reserve(compute_ref_cnt_map.size() + all_ops.size());
for (auto &op : all_ops) {
new_all_ops.emplace_back(std::move(op));
auto it = compute_ref_cnt_map.find(new_all_ops.back().get());
if (it != compute_ref_cnt_map.end()) {
new_all_ops.emplace_back(it->second);
}
}
all_ops.swap(new_all_ops);
return graph;
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(reference_count_pass,
paddle::framework::details::ReferenceCountPass)
.RequirePassAttr(paddle::framework::details::kGlobalReferenceCount)
.RequirePassAttr(paddle::framework::details::kCurReferenceCount)
.RequirePassAttr(paddle::framework::details::kGarbageCollector);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/details/reference_count_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
constexpr char kGlobalReferenceCount[] = "reference_count";
constexpr char kCurReferenceCount[] = "current_reference_count";
constexpr char kGarbageCollector[] = "garbage_collector";
class ReferenceCountPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -18,6 +18,9 @@ ...@@ -18,6 +18,9 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/details/reference_count_op_handle.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -65,12 +68,28 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( ...@@ -65,12 +68,28 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
platform::RecordEvent e("ScopeBufferedSSAGraphExecutorAfterRun", nullptr); platform::RecordEvent e("ScopeBufferedSSAGraphExecutorAfterRun", nullptr);
drop_scope_counter_ += 1; drop_scope_counter_ += 1;
#ifdef PADDLE_WITH_CUDA
const std::string gc_name = "garbage_collector";
DeviceGarbageCollectorMap *gc =
Graph().Has(gc_name) ? &(Graph().Get<DeviceGarbageCollectorMap>(gc_name))
: nullptr;
#endif
if (!fetch_tensors.empty() || if (!fetch_tensors.empty() ||
drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) { drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
drop_scope_counter_ = 0; drop_scope_counter_ = 0;
// Wait All computational streams // Wait All computational streams
for (auto p : places_) { for (auto p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait(); platform::DeviceContextPool::Instance().Get(p)->Wait();
#ifdef PADDLE_WITH_CUDA
if (gc != nullptr && platform::is_gpu_place(p)) {
auto gpu_place = boost::get<platform::CUDAPlace>(p);
auto &gc_at_place = gc->at(gpu_place.device);
gc_at_place->Wait();
gc_at_place->Reset();
}
#endif
} }
for (auto &scope : local_scopes_) { for (auto &scope : local_scopes_) {
auto &local_scope = auto &local_scope =
......
...@@ -37,7 +37,11 @@ int kProgramId = -1; ...@@ -37,7 +37,11 @@ int kProgramId = -1;
ExecutorPrepareContext::ExecutorPrepareContext( ExecutorPrepareContext::ExecutorPrepareContext(
const framework::ProgramDesc& prog, size_t block_id) const framework::ProgramDesc& prog, size_t block_id)
: prog_(prog), block_id_(block_id) {} : prog_(prog), block_id_(block_id) {
if (GetEagerDeletionThreshold() >= 0) {
ref_cnts_ = GetNonPersistableReferenceCount<int>(prog_, block_id_);
}
}
ExecutorPrepareContext::~ExecutorPrepareContext() { ExecutorPrepareContext::~ExecutorPrepareContext() {
VLOG(5) << "destroy ExecutorPrepareContext"; VLOG(5) << "destroy ExecutorPrepareContext";
...@@ -329,15 +333,81 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -329,15 +333,81 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
CreateVariables(ctx->prog_, local_scope, ctx->block_id_); CreateVariables(ctx->prog_, local_scope, ctx->block_id_);
} }
int64_t max_memory_size = GetEagerDeletionThreshold();
std::unique_ptr<GarbageCollector<Tensor>> gc;
if (max_memory_size >= 0) {
ctx->ResetReferenceCount();
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) {
gc.reset(new DefaultStreamGarbageCollector<Tensor>(
boost::get<platform::CUDAPlace>(place_), max_memory_size));
} else {
#endif
gc.reset(new CPUGarbageCollector<Tensor>(
boost::get<platform::CPUPlace>(place_), max_memory_size));
#ifdef PADDLE_WITH_CUDA
}
#endif
}
for (auto& op : ctx->ops_) { for (auto& op : ctx->ops_) {
op->Run(*local_scope, place_); op->Run(*local_scope, place_);
if (gc != nullptr) {
std::vector<std::string> erase_vars;
for (auto& input : op->Inputs()) {
for (auto& input_name : input.second) {
auto it = ctx->cur_ref_cnts_.find(input_name);
if (it == ctx->cur_ref_cnts_.end()) continue;
if (it->second == 1) { // should delete it
erase_vars.emplace_back(input_name);
ctx->cur_ref_cnts_.erase(input_name);
} else {
--(it->second);
}
}
}
for (auto& output : op->Outputs()) {
for (auto& output_name : output.second) {
auto it = ctx->cur_ref_cnts_.find(output_name);
if (it == ctx->cur_ref_cnts_.end()) continue;
if (it->second == 1) {
erase_vars.emplace_back(output_name);
ctx->cur_ref_cnts_.erase(output_name);
} else {
--(it->second);
}
}
}
if (!erase_vars.empty()) {
std::vector<framework::LoDTensor*> erase_tensors;
for (auto& name : erase_vars) {
auto* var = local_scope->FindVar(name);
if (var == nullptr) continue;
if (var->IsType<framework::LoDTensor>()) {
auto* tensor = var->GetMutable<framework::LoDTensor>();
erase_tensors.push_back(tensor);
}
}
if (!erase_tensors.empty()) gc->Add(erase_tensors);
}
}
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
VLOG(2) << "Memory used after operator " + op->Type() + " running: " VLOG(2) << "Memory used after operator " + op->Type() + " running: "
<< memory::memory_usage(place_); << memory::memory_usage(place_);
} }
} }
platform::DeviceContextPool::Instance().Get(place_)->Wait();
if (gc != nullptr) {
gc->Wait();
} else {
platform::DeviceContextPool::Instance().Get(place_)->Wait();
}
if (local_scope != scope) { if (local_scope != scope) {
scope->DeleteScope(local_scope); scope->DeleteScope(local_scope);
} else { } else {
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <map> #include <map>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -27,13 +28,58 @@ namespace paddle { ...@@ -27,13 +28,58 @@ namespace paddle {
namespace framework { namespace framework {
extern void InitializeVariable(Variable* var, proto::VarType::Type var_type); extern void InitializeVariable(Variable* var, proto::VarType::Type var_type);
template <typename T>
std::unordered_map<std::string, T> GetNonPersistableReferenceCount(
const ProgramDesc& prog, size_t block_id) {
auto& block = prog.Block(block_id);
std::unordered_set<std::string> ignored_vars;
std::unordered_map<std::string, T> ref_cnts;
for (auto var_desc : block.AllVars()) {
auto type = var_desc->Proto()->type().type();
if (type != proto::VarType::LOD_TENSOR || var_desc->Persistable()) {
ignored_vars.insert(var_desc->Name()); // ignore persistable vars
}
}
for (auto op_desc : block.AllOps()) {
for (auto& input : op_desc->Inputs()) {
for (auto& input_name : input.second) {
if (!ignored_vars.count(input_name)) {
if (ref_cnts.count(input_name))
++ref_cnts[input_name];
else
ref_cnts[input_name] = 1;
}
}
}
for (auto& output : op_desc->Outputs()) {
for (auto output_name : output.second) {
if (!ignored_vars.count(output_name)) {
if (ref_cnts.count(output_name))
++ref_cnts[output_name];
else
ref_cnts[output_name] = 1;
}
}
}
}
return ref_cnts;
}
struct ExecutorPrepareContext { struct ExecutorPrepareContext {
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id); ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id);
~ExecutorPrepareContext(); ~ExecutorPrepareContext();
void ResetReferenceCount() { cur_ref_cnts_ = ref_cnts_; }
const framework::ProgramDesc& prog_; const framework::ProgramDesc& prog_;
size_t block_id_; size_t block_id_;
std::vector<std::unique_ptr<OperatorBase>> ops_; std::vector<std::unique_ptr<OperatorBase>> ops_;
std::unordered_map<std::string, int> ref_cnts_;
std::unordered_map<std::string, int> cur_ref_cnts_;
}; };
class Executor { class Executor {
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <deque>
#include <functional>
#include <memory>
#include <mutex> // NOLINT
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace framework {
// T should have memory_size() and clear() method
template <typename T>
class GarbageCollector {
public:
GarbageCollector(const platform::Place &place, size_t max_memory_size)
: max_memory_size_(std::max(max_memory_size, static_cast<size_t>(1))) {
garbages_.reset(new std::deque<T *>());
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place);
}
virtual ~GarbageCollector() {}
void Reset() {
std::lock_guard<std::mutex> guard(mutex_);
garbages_.reset(new std::deque<T *>());
cur_memory_size_ = 0;
}
template <typename Container>
void Add(const Container &objs) {
Add(objs, []() {});
}
template <typename Container, typename Callback>
void Add(const Container &objs, Callback &&callback) {
std::shared_ptr<std::deque<T *>> clear_deque;
{
std::lock_guard<std::mutex> guard(mutex_);
for (auto *obj : objs) {
garbages_->push_back(obj);
cur_memory_size_ += obj->memory_size();
}
if (cur_memory_size_ >= max_memory_size_) {
cur_memory_size_ = 0;
clear_deque = garbages_;
garbages_.reset(new std::deque<T *>());
}
}
if (clear_deque != nullptr) {
callback();
ClearCallback([=]() {
for (auto *obj : *clear_deque) obj->clear();
});
}
}
virtual void Wait() const {}
protected:
virtual void ClearCallback(const std::function<void()> &callback) = 0;
platform::DeviceContext *dev_ctx_;
std::shared_ptr<std::deque<T *>> garbages_;
mutable std::mutex mutex_;
const size_t max_memory_size_;
size_t cur_memory_size_ = 0;
};
template <typename T>
class CPUGarbageCollector : public GarbageCollector<T> {
public:
CPUGarbageCollector(const platform::CPUPlace &place, size_t max_memory_size)
: GarbageCollector<T>(place, max_memory_size) {}
protected:
void ClearCallback(const std::function<void()> &callback) override {
callback();
}
};
#ifdef PADDLE_WITH_CUDA
template <typename T>
class DefaultStreamGarbageCollector : public GarbageCollector<T> {
public:
DefaultStreamGarbageCollector(const platform::CUDAPlace &place,
size_t max_memory_size)
: GarbageCollector<T>(place, max_memory_size) {}
cudaStream_t stream() const {
return static_cast<const platform::CUDADeviceContext *>(this->dev_ctx_)
->stream();
}
void Wait() const override {
this->dev_ctx_->Wait();
static_cast<const platform::CUDADeviceContext *>(this->dev_ctx_)
->WaitStreamCallback();
}
protected:
void ClearCallback(const std::function<void()> &callback) override {
static_cast<platform::CUDADeviceContext *>(this->dev_ctx_)
->AddStreamCallback(callback);
}
};
template <typename T>
class StreamGarbageCollector : public GarbageCollector<T> {
public:
StreamGarbageCollector(const platform::CUDAPlace &place,
size_t max_memory_size)
: GarbageCollector<T>(place, max_memory_size) {
PADDLE_ENFORCE(cudaSetDevice(place.device));
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
callback_manager_.reset(new platform::StreamCallbackManager(stream_));
}
~StreamGarbageCollector() {
auto place = boost::get<platform::CUDAPlace>(this->dev_ctx_->GetPlace());
PADDLE_ENFORCE(cudaSetDevice(place.device));
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
}
void Wait() const override {
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
std::lock_guard<std::mutex> guard(this->mutex_);
callback_manager_->Wait();
}
cudaStream_t stream() const { return stream_; }
protected:
void ClearCallback(const std::function<void()> &callback) override {
std::lock_guard<std::mutex> guard(this->mutex_);
callback_manager_->AddCallback(callback);
}
private:
cudaStream_t stream_;
std::unique_ptr<platform::StreamCallbackManager> callback_manager_;
};
#endif
} // namespace framework
} // namespace paddle
...@@ -94,6 +94,14 @@ class Graph { ...@@ -94,6 +94,14 @@ class Graph {
}; };
} }
template <typename AttrType>
void SetNotOwned(const std::string &attr_name, AttrType *attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the graph",
attr_name);
attrs_[attr_name] = attr;
attr_dels_[attr_name] = []() {};
}
const std::unordered_set<ir::Node *> &Nodes() const { return node_set_; } const std::unordered_set<ir::Node *> &Nodes() const { return node_set_; }
// Create a normal variable with non-null VarDesc. // Create a normal variable with non-null VarDesc.
......
...@@ -188,6 +188,30 @@ ParallelExecutor::ParallelExecutor( ...@@ -188,6 +188,30 @@ ParallelExecutor::ParallelExecutor(
main_program, member_->places_, loss_var_name, params, main_program, member_->places_, loss_var_name, params,
member_->local_scopes_, member_->use_cuda_, build_strategy, member_->local_scopes_, member_->use_cuda_, build_strategy,
member_->nccl_ctxs_.get()); member_->nccl_ctxs_.get());
auto max_memory_size = GetEagerDeletionThreshold();
if (max_memory_size >= 0) {
for (auto &place : member_->places_) {
if (!platform::is_gpu_place(place)) continue;
auto gpu_place = boost::get<platform::CUDAPlace>(place);
if (gcs_[gpu_place.device] == nullptr) {
ref_cnts_[gpu_place.device].reset(new details::ReferenceCountMap());
cur_ref_cnts_[gpu_place.device].reset(
new details::AtomicReferenceCountMap());
gcs_[gpu_place.device].reset(
new StreamGarbageCollector<Tensor>(gpu_place, max_memory_size));
}
}
if (!gcs_.empty()) {
auto ref_cnt_pass =
ir::PassRegistry::Instance().Get("reference_count_pass");
ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount, &ref_cnts_);
ref_cnt_pass->SetNotOwned(details::kCurReferenceCount, &cur_ref_cnts_);
ref_cnt_pass->SetNotOwned(details::kGarbageCollector, &gcs_);
graph = ref_cnt_pass->Apply(std::move(graph));
graph->SetNotOwned("garbage_collector", &gcs_);
}
}
#else #else
std::unique_ptr<ir::Graph> graph = ApplyParallelExecutorPass( std::unique_ptr<ir::Graph> graph = ApplyParallelExecutorPass(
main_program, member_->places_, loss_var_name, params, main_program, member_->places_, loss_var_name, params,
...@@ -310,6 +334,11 @@ void ParallelExecutor::BCastParamsToDevices( ...@@ -310,6 +334,11 @@ void ParallelExecutor::BCastParamsToDevices(
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name) { const std::string &fetched_var_name) {
platform::RecordBlock b(0); platform::RecordBlock b(0);
#ifdef PADDLE_WITH_CUDA
if (!gcs_.empty()) {
ResetReferenceCount();
}
#endif
auto fetch_data = member_->executor_->Run(fetch_tensors); auto fetch_data = member_->executor_->Run(fetch_tensors);
*member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() = *member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
fetch_data; fetch_data;
...@@ -367,3 +396,6 @@ USE_PASS(graph_viz_pass); ...@@ -367,3 +396,6 @@ USE_PASS(graph_viz_pass);
USE_PASS(multi_devices_pass); USE_PASS(multi_devices_pass);
USE_PASS(multi_devices_check_pass); USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass); USE_PASS(multi_devices_print_pass);
#ifdef PADDLE_WITH_CUDA
USE_PASS(reference_count_pass);
#endif
...@@ -15,7 +15,9 @@ limitations under the License. */ ...@@ -15,7 +15,9 @@ limitations under the License. */
#pragma once #pragma once
#include <paddle/fluid/framework/details/build_strategy.h> #include <paddle/fluid/framework/details/build_strategy.h>
#include <atomic>
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/execution_strategy.h" #include "paddle/fluid/framework/details/execution_strategy.h"
...@@ -27,6 +29,10 @@ limitations under the License. */ ...@@ -27,6 +29,10 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/details/reference_count_pass.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -70,6 +76,23 @@ class ParallelExecutor { ...@@ -70,6 +76,23 @@ class ParallelExecutor {
private: private:
ParallelExecutorPrivate *member_; ParallelExecutorPrivate *member_;
#ifdef PADDLE_WITH_CUDA
// ref_cnts_ is only initialized when ParallelExecutor constructs, and then
// keeps unchanged
// Before each iteration, cur_ref_cnts_ is reset to ref_cnts_
details::DeviceReferenceCountMap ref_cnts_;
details::AtomicDeviceReferenceCountMap cur_ref_cnts_;
details::DeviceGarbageCollectorMap gcs_;
void ResetReferenceCount() {
for (auto &pair1 : ref_cnts_) {
for (auto &pair2 : *(pair1.second)) {
(*(cur_ref_cnts_[pair1.first]))[pair2.first] = pair2.second;
}
}
}
#endif
}; };
} // namespace framework } // namespace framework
......
...@@ -31,9 +31,21 @@ DEFINE_bool( ...@@ -31,9 +31,21 @@ DEFINE_bool(
"Delete local scope eagerly. It will reduce GPU memory usage but " "Delete local scope eagerly. It will reduce GPU memory usage but "
"slow down the destruction of variables.(around 1% performance harm)"); "slow down the destruction of variables.(around 1% performance harm)");
DEFINE_double(
eager_delete_tensor_gb, -1.0,
"Memory size threshold (GB) when the garbage collector clear tensors."
"Disabled when this value is less than 0");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
int64_t GetEagerDeletionThreshold() {
return FLAGS_eager_delete_tensor_gb < 0
? -1
: static_cast<int64_t>(FLAGS_eager_delete_tensor_gb *
(static_cast<int64_t>(1) << 30));
}
Scope::~Scope() { DropKids(); } Scope::~Scope() { DropKids(); }
Scope& Scope::NewScope() const { Scope& Scope::NewScope() const {
......
...@@ -26,6 +26,8 @@ limitations under the License. */ ...@@ -26,6 +26,8 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
int64_t GetEagerDeletionThreshold();
class Scope; class Scope;
/** /**
......
...@@ -151,6 +151,8 @@ class Tensor { ...@@ -151,6 +151,8 @@ class Tensor {
void set_layout(const DataLayout layout) { layout_ = layout; } void set_layout(const DataLayout layout) { layout_ = layout; }
void clear() { holder_ = nullptr; }
private: private:
/** /**
* @note Placeholder hides type T, so it doesn't appear as a template * @note Placeholder hides type T, so it doesn't appear as a template
......
...@@ -69,8 +69,9 @@ class DfgPassManagerImpl final : public DfgPassManager { ...@@ -69,8 +69,9 @@ class DfgPassManagerImpl final : public DfgPassManager {
if (FLAGS_IA_enable_tensorrt_subgraph_engine) { if (FLAGS_IA_enable_tensorrt_subgraph_engine) {
auto trt_teller = [&](const Node* node) { auto trt_teller = [&](const Node* node) {
std::unordered_set<std::string> teller_set( std::unordered_set<std::string> teller_set(
{"elementwise_add", "mul", "conv2d", "pool2d", "relu", "softmax", {"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid",
"depthwise_conv2d", "batch_norm", "concat"}); "depthwise_conv2d", "batch_norm", "concat", "tanh",
"elementwise_add", "dropout"});
if (!node->IsFunction()) return false; if (!node->IsFunction()) return false;
const auto* func = static_cast<const Function*>(node); const auto* func = static_cast<const Function*>(node);
......
...@@ -153,11 +153,21 @@ CreatePaddlePredictor<TensorRTConfig, PaddleEngineKind::kAutoMixedTensorRT>( ...@@ -153,11 +153,21 @@ CreatePaddlePredictor<TensorRTConfig, PaddleEngineKind::kAutoMixedTensorRT>(
} // namespace paddle } // namespace paddle
USE_TRT_CONVERTER(elementwise_add_weight); USE_TRT_CONVERTER(elementwise_add_weight);
USE_TRT_CONVERTER(elementwise_add_tensor);
USE_TRT_CONVERTER(elementwise_sub_tensor);
USE_TRT_CONVERTER(elementwise_div_tensor);
USE_TRT_CONVERTER(elementwise_mul_tensor);
USE_TRT_CONVERTER(elementwise_max_tensor);
USE_TRT_CONVERTER(elementwise_min_tensor);
USE_TRT_CONVERTER(elementwise_pow_tensor);
USE_TRT_CONVERTER(mul); USE_TRT_CONVERTER(mul);
USE_TRT_CONVERTER(conv2d); USE_TRT_CONVERTER(conv2d);
USE_TRT_CONVERTER(relu); USE_TRT_CONVERTER(relu);
USE_TRT_CONVERTER(sigmoid);
USE_TRT_CONVERTER(tanh);
USE_TRT_CONVERTER(fc); USE_TRT_CONVERTER(fc);
USE_TRT_CONVERTER(pool2d); USE_TRT_CONVERTER(pool2d);
USE_TRT_CONVERTER(softmax); USE_TRT_CONVERTER(softmax);
USE_TRT_CONVERTER(batch_norm); USE_TRT_CONVERTER(batch_norm);
USE_TRT_CONVERTER(concat); USE_TRT_CONVERTER(concat);
USE_TRT_CONVERTER(dropout);
# Add TRT tests # Add TRT tests
nv_library(tensorrt_converter nv_library(tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
DEPS tensorrt_engine operator scope framework_proto op_registry) DEPS tensorrt_engine operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS nv_test(test_op_converter SRCS test_op_converter.cc DEPS
...@@ -24,6 +24,8 @@ nv_test(test_trt_softmax_op SRCS test_softmax_op.cc softmax_op.cc ...@@ -24,6 +24,8 @@ nv_test(test_trt_softmax_op SRCS test_softmax_op.cc softmax_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine softmax_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine softmax_op SERIAL)
nv_test(test_trt_batch_norm_op SRCS test_batch_norm_op.cc batch_norm_op.cc nv_test(test_trt_batch_norm_op SRCS test_batch_norm_op.cc batch_norm_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine batch_norm_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine batch_norm_op SERIAL)
nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine concat_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine concat_op SERIAL)
nv_test(test_trt_dropout_op SRCS test_dropout_op.cc dropout_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine dropout_op SERIAL)
...@@ -19,23 +19,31 @@ namespace paddle { ...@@ -19,23 +19,31 @@ namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
class ReluOpConverter : public OpConverter { class ActivationOpConverter : public OpConverter {
public: public:
ReluOpConverter() {} ActivationOpConverter() {}
void operator()(const framework::proto::OpDesc& op, void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override { const framework::Scope& scope, bool test_mode) override {
// Here the two nullptr looks strange, that's because the // Here the two nullptr looks strange, that's because the
// framework::OpDesc's constructor is strange. // framework::OpDesc's constructor is strange.
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
LOG(INFO) << "convert a fluid relu op to tensorrt activation layer whose " LOG(INFO)
"type is Relu"; << "convert a fluid Activation op to tensorrt activation layer whose "
"type is "
<< op_type_;
const nvinfer1::ITensor* input_tensor = const nvinfer1::ITensor* input_tensor =
engine_->GetITensor(op_desc.Input("X")[0]); engine_->GetITensor(op_desc.Input("X")[0]);
auto op_pair = ops.find(op_type_);
if (op_pair == ops.end()) {
PADDLE_THROW("Wrong activation op type!");
}
nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER( nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *const_cast<nvinfer1::ITensor*>(input_tensor), engine_, Activation, *const_cast<nvinfer1::ITensor*>(input_tensor),
nvinfer1::ActivationType::kRELU); op_pair->second);
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
layer->setName(("relu (Output: " + output_name + ")").c_str()); layer->setName((op_type_ + " (Output: " + output_name + ")").c_str());
layer->getOutput(0)->setName(output_name.c_str()); layer->getOutput(0)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(0)); engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode) { // the test framework can not determine which is the if (test_mode) { // the test framework can not determine which is the
...@@ -43,6 +51,32 @@ class ReluOpConverter : public OpConverter { ...@@ -43,6 +51,32 @@ class ReluOpConverter : public OpConverter {
engine_->DeclareOutput(output_name); engine_->DeclareOutput(output_name);
} }
} }
protected:
std::string op_type_;
static const std::unordered_map<std::string, nvinfer1::ActivationType> ops;
};
const std::unordered_map<std::string, nvinfer1::ActivationType>
ActivationOpConverter::ops = {
{"relu", nvinfer1::ActivationType::kRELU},
{"sigmoid", nvinfer1::ActivationType::kSIGMOID},
{"tanh", nvinfer1::ActivationType::kTANH},
};
class ReluOpConverter : public ActivationOpConverter {
public:
ReluOpConverter() { op_type_ = "relu"; }
};
class SigmoidOpConverter : public ActivationOpConverter {
public:
SigmoidOpConverter() { op_type_ = "sigmoid"; }
};
class TanhOpConverter : public ActivationOpConverter {
public:
TanhOpConverter() { op_type_ = "tanh"; }
}; };
} // namespace tensorrt } // namespace tensorrt
...@@ -50,3 +84,5 @@ class ReluOpConverter : public OpConverter { ...@@ -50,3 +84,5 @@ class ReluOpConverter : public OpConverter {
} // namespace paddle } // namespace paddle
REGISTER_TRT_OP_CONVERTER(relu, ReluOpConverter); REGISTER_TRT_OP_CONVERTER(relu, ReluOpConverter);
REGISTER_TRT_OP_CONVERTER(sigmoid, SigmoidOpConverter);
REGISTER_TRT_OP_CONVERTER(tanh, TanhOpConverter);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* DropoutOp. This Layer doesn't has weights.
*/
class DropoutOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(4) << "convert a fluid dropout op to tensorrt dropout layer";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
float dropout_prob = boost::get<float>(op_desc.GetAttr("dropout_prob"));
platform::CPUPlace cpu_place;
std::unique_ptr<framework::LoDTensor> weight_tensor(
new framework::LoDTensor());
weight_tensor->Resize(framework::make_ddim({1}));
auto* weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
weight_data[0] = 1 - dropout_prob;
TensorRTEngine::Weight scale_weights{
nvinfer1::DataType::kFLOAT, static_cast<void*>(weight_data),
weight_tensor->memory_size() / sizeof(float)};
TensorRTEngine::Weight shift_weights{nvinfer1::DataType::kFLOAT, nullptr,
0};
TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr,
0};
auto* layer = TRT_ENGINE_ADD_LAYER(
engine_, Scale, *const_cast<nvinfer1::ITensor*>(input1),
nvinfer1::ScaleMode::kUNIFORM, shift_weights.get(), scale_weights.get(),
power_weights.get());
engine_->weight_map[op_desc.Output("Out").front() + "_dropout"] =
std::move(weight_tensor);
auto output_name = op_desc.Output("Out")[0];
layer->setName(("dropout (Output: " + output_name + ")").c_str());
engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode) {
engine_->DeclareOutput(output_name);
}
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
USE_OP(dropout);
REGISTER_TRT_OP_CONVERTER(dropout, DropoutOpConverter);
...@@ -20,18 +20,18 @@ namespace paddle { ...@@ -20,18 +20,18 @@ namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
TEST(ReluOpConverter, main) { void test_activation(std::string act_type) {
framework::Scope scope; framework::Scope scope;
std::unordered_set<std::string> parameters; std::unordered_set<std::string> parameters;
TRTConvertValidation validator(10, parameters, scope, 1000); TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("relu-X", nvinfer1::Dims2(10, 6)); validator.DeclInputVar("act-X", nvinfer1::Dims2(10, 6));
validator.DeclOutputVar("relu-Out", nvinfer1::Dims2(10, 6)); validator.DeclOutputVar("act-Out", nvinfer1::Dims2(10, 6));
// Prepare Op description // Prepare Op description
framework::OpDesc desc; framework::OpDesc desc;
desc.SetType("relu"); desc.SetType(act_type);
desc.SetInput("X", {"relu-X"}); desc.SetInput("X", {"act-X"});
desc.SetOutput("Out", {"relu-Out"}); desc.SetOutput("Out", {"act-Out"});
LOG(INFO) << "set OP"; LOG(INFO) << "set OP";
validator.SetOp(*desc.Proto()); validator.SetOp(*desc.Proto());
...@@ -40,8 +40,16 @@ TEST(ReluOpConverter, main) { ...@@ -40,8 +40,16 @@ TEST(ReluOpConverter, main) {
validator.Execute(5); validator.Execute(5);
} }
TEST(ReluOpConverter, main) { test_activation("relu"); }
TEST(SigmoidOpConverter, main) { test_activation("sigmoid"); }
TEST(TanhOpConverter, main) { test_activation("tanh"); }
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
USE_OP(relu); USE_OP(relu);
USE_OP(sigmoid);
USE_OP(tanh);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
namespace paddle {
namespace inference {
namespace tensorrt {
TEST(DropoutOpConverter, main) {
framework::Scope scope;
std::unordered_set<std::string> parameters;
TRTConvertValidation validator(8, parameters, scope, 1000);
std::vector<int> tensor_shape{8, 10};
validator.DeclInputVar("dropout-X", tensor_shape,
nvinfer1::DimsCHW(10, 1, 1));
validator.DeclOutputVar("dropout-Out", nvinfer1::DimsCHW(10, 1, 1));
validator.DeclOutputVar("mask-Out", nvinfer1::DimsCHW(10, 1, 1));
// Prepare Op description
framework::OpDesc desc;
int is_test = 1;
float dropout_prob = 0.4;
desc.SetType("dropout");
desc.SetInput("X", {"dropout-X"});
desc.SetOutput("Mask", {"mask-Out"});
desc.SetOutput("Out", {"dropout-Out"});
desc.SetAttr("is_test", is_test);
desc.SetAttr("dropout_prob", dropout_prob);
LOG(INFO) << "set OP";
validator.SetOp(*desc.Proto());
LOG(INFO) << "execute";
std::unordered_set<std::string> neglected_output = {"mask-Out"};
validator.Execute(8, neglected_output);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
USE_OP(dropout);
...@@ -167,6 +167,8 @@ void BuddyAllocator::Free(void* p) { ...@@ -167,6 +167,8 @@ void BuddyAllocator::Free(void* p) {
} }
size_t BuddyAllocator::Used() { return total_used_; } size_t BuddyAllocator::Used() { return total_used_; }
size_t BuddyAllocator::GetMinChunkSize() { return min_chunk_size_; }
size_t BuddyAllocator::GetMaxChunkSize() { return max_chunk_size_; }
void* BuddyAllocator::SystemAlloc(size_t size) { void* BuddyAllocator::SystemAlloc(size_t size) {
size_t index = 0; size_t index = 0;
......
...@@ -42,6 +42,8 @@ class BuddyAllocator { ...@@ -42,6 +42,8 @@ class BuddyAllocator {
void* Alloc(size_t unaligned_size); void* Alloc(size_t unaligned_size);
void Free(void* ptr); void Free(void* ptr);
size_t Used(); size_t Used();
size_t GetMinChunkSize();
size_t GetMaxChunkSize();
public: public:
// Disable copy and assignment // Disable copy and assignment
......
...@@ -119,8 +119,8 @@ void* Alloc<platform::CUDAPlace>(platform::CUDAPlace place, size_t size) { ...@@ -119,8 +119,8 @@ void* Alloc<platform::CUDAPlace>(platform::CUDAPlace place, size_t size) {
LOG(WARNING) << "Cannot allocate " << size << " bytes in GPU " LOG(WARNING) << "Cannot allocate " << size << " bytes in GPU "
<< place.device << ", available " << avail << " bytes"; << place.device << ", available " << avail << " bytes";
LOG(WARNING) << "total " << total; LOG(WARNING) << "total " << total;
LOG(WARNING) << "GpuMinChunkSize " << platform::GpuMinChunkSize(); LOG(WARNING) << "GpuMinChunkSize " << buddy_allocator->GetMinChunkSize();
LOG(WARNING) << "GpuMaxChunkSize " << platform::GpuMaxChunkSize(); LOG(WARNING) << "GpuMaxChunkSize " << buddy_allocator->GetMaxChunkSize();
LOG(WARNING) << "GPU memory used: " << Used<platform::CUDAPlace>(place); LOG(WARNING) << "GPU memory used: " << Used<platform::CUDAPlace>(place);
platform::SetDeviceId(cur_dev); platform::SetDeviceId(cur_dev);
} }
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -13,76 +10,9 @@ See the License for the specific language governing permissions and ...@@ -13,76 +10,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/cpu_lstm_compute.h" #include "paddle/fluid/operators/math/cpu_lstm_compute.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/platform/cpu_info.h"
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {} // namespace math
// TODO(TJ): ugly workaround, clean me
template <typename T>
void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) {
// gates: W_ch, W_ih, W_fh, W_oh
vec_sigmoid<T, platform::jit::avx>(24, gates + 8, gates + 8);
vec_tanh<T, platform::jit::avx>(8, gates, gates);
const T *i = gates + 8, *f = gates + 16, *o = gates + 24;
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int d = 0; d < 8; ++d) {
// C_t = C_t-1 * fgated + cand_gated * igated
ct[d] = ct_1[d] * f[d] + gates[d] * i[d];
// H_t = act_cell(C_t) * ogated
T tmp = ct[d] * 2;
tmp = static_cast<T>(0) - ((tmp < min) ? min : ((tmp > max) ? max : tmp));
vec_exp<T>(1, &tmp, &tmp);
tmp = static_cast<T>(2) / (static_cast<T>(1) + tmp) - static_cast<T>(1);
ht[d] = tmp * o[d];
}
}
#ifdef __AVX__
namespace detail {
namespace forward {
namespace avx {
__m256 Sigmoid(const __m256 a);
__m256 Tanh(const __m256 a);
} // namespace avx
} // namespace forward
} // namespace detail
template <>
void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct,
float* ht) {
namespace act = detail::forward::avx;
// gates: W_ch, W_ih, W_fh, W_oh
__m256 c, i, f, o;
c = _mm256_loadu_ps(gates);
i = _mm256_loadu_ps(gates + 8);
f = _mm256_loadu_ps(gates + 16);
o = _mm256_loadu_ps(gates + 24);
/* C_t = C_t-1 * fgated + cand_gated * igated*/
c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i));
i = _mm256_loadu_ps(ct_1);
f = _mm256_mul_ps(i, act::Sigmoid(f));
f = _mm256_add_ps(c, f);
_mm256_storeu_ps(ct, f);
/* H_t = act_cell(C_t) * ogated */
o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o));
_mm256_storeu_ps(ht, o);
}
#endif
template void lstm_compute_ctht<float>(float* gates, const float* ct_1,
float* ct, float* ht);
template void lstm_compute_ctht<double>(double* gates, const double* ct_1,
double* ct, double* ht);
} // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -14,6 +11,11 @@ limitations under the License. */ ...@@ -14,6 +11,11 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/platform/cpu_info.h"
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -21,7 +23,58 @@ namespace math { ...@@ -21,7 +23,58 @@ namespace math {
// TODO(TJ): ugly workaround, clean me // TODO(TJ): ugly workaround, clean me
template <typename T> template <typename T>
void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht); void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) {
// gates: W_ch, W_ih, W_fh, W_oh
vec_sigmoid<T, platform::jit::avx>(24, gates + 8, gates + 8);
vec_tanh<T, platform::jit::avx>(8, gates, gates);
const T *i = gates + 8, *f = gates + 16, *o = gates + 24;
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int d = 0; d < 8; ++d) {
// C_t = C_t-1 * fgated + cand_gated * igated
ct[d] = ct_1[d] * f[d] + gates[d] * i[d];
// H_t = act_cell(C_t) * ogated
T tmp = ct[d] * 2;
tmp = static_cast<T>(0) - ((tmp < min) ? min : ((tmp > max) ? max : tmp));
vec_exp<T>(1, &tmp, &tmp);
tmp = static_cast<T>(2) / (static_cast<T>(1) + tmp) - static_cast<T>(1);
ht[d] = tmp * o[d];
}
}
#ifdef __AVX__
namespace detail {
namespace forward {
namespace avx {
__m256 Sigmoid(const __m256 a);
__m256 Tanh(const __m256 a);
} // namespace avx
} // namespace forward
} // namespace detail
template <>
void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct,
float* ht) {
namespace act = detail::forward::avx;
// gates: W_ch, W_ih, W_fh, W_oh
__m256 c, i, f, o;
c = _mm256_loadu_ps(gates);
i = _mm256_loadu_ps(gates + 8);
f = _mm256_loadu_ps(gates + 16);
o = _mm256_loadu_ps(gates + 24);
/* C_t = C_t-1 * fgated + cand_gated * igated*/
c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i));
i = _mm256_loadu_ps(ct_1);
f = _mm256_mul_ps(i, act::Sigmoid(f));
f = _mm256_add_ps(c, f);
_mm256_storeu_ps(ct, f);
/* H_t = act_cell(C_t) * ogated */
o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o));
_mm256_storeu_ps(ht, o);
}
#endif
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -63,7 +63,7 @@ class WhileOp : public framework::OperatorBase { ...@@ -63,7 +63,7 @@ class WhileOp : public framework::OperatorBase {
while (cond.data<bool>()[0]) { while (cond.data<bool>()[0]) {
auto &current_scope = scope.NewScope(); auto &current_scope = scope.NewScope();
step_scopes->push_back(&current_scope); step_scopes->push_back(&current_scope);
executor.RunPreparedContext(ctx.get(), &current_scope, false); executor.RunPreparedContext(ctx.get(), &current_scope, false, true, true);
if (is_test) { if (is_test) {
scope.DeleteScope(&current_scope); scope.DeleteScope(&current_scope);
} }
...@@ -169,7 +169,8 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -169,7 +169,8 @@ class WhileGradOp : public framework::OperatorBase {
} }
} }
} }
executor.RunPreparedContext(ctx.get(), *cur_scope_iter, false); executor.RunPreparedContext(ctx.get(), *cur_scope_iter, false, true,
true);
auto &pg_names = Outputs(kXGRAD); auto &pg_names = Outputs(kXGRAD);
auto &p_names = Inputs(kX); auto &p_names = Inputs(kX);
......
...@@ -51,7 +51,7 @@ ENDIF() ...@@ -51,7 +51,7 @@ ENDIF()
# memcpy depends on device_context, here add deps individually for # memcpy depends on device_context, here add deps individually for
# avoiding cycle dependencies # avoiding cycle dependencies
cc_library(device_context SRCS device_context.cc init.cc DEPS malloc cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc
place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}) place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS})
nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info) nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info)
......
...@@ -210,11 +210,14 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) ...@@ -210,11 +210,14 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
if (dynload::HasCUDNN()) { if (dynload::HasCUDNN()) {
cudnn_holder_.reset(new CudnnHolder(&stream_, place)); cudnn_holder_.reset(new CudnnHolder(&stream_, place));
} }
callback_manager_.reset(new StreamCallbackManager(stream_));
} }
CUDADeviceContext::~CUDADeviceContext() { CUDADeviceContext::~CUDADeviceContext() {
SetDeviceId(place_.device); SetDeviceId(place_.device);
Wait(); Wait();
WaitStreamCallback();
PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_)); PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
eigen_stream_.reset(); eigen_stream_.reset();
eigen_device_.reset(); eigen_device_.reset();
......
...@@ -31,6 +31,9 @@ limitations under the License. */ ...@@ -31,6 +31,9 @@ limitations under the License. */
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/stream_callback_manager.h"
#endif
#include "unsupported/Eigen/CXX11/Tensor" #include "unsupported/Eigen/CXX11/Tensor"
namespace paddle { namespace paddle {
...@@ -112,6 +115,17 @@ class CUDADeviceContext : public DeviceContext { ...@@ -112,6 +115,17 @@ class CUDADeviceContext : public DeviceContext {
PADDLE_ENFORCE(cudaEventRecord(ev, stream_)); PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
} }
template <typename Callback>
void AddStreamCallback(Callback&& callback) const {
std::lock_guard<std::mutex> guard(callback_mtx_);
callback_manager_->AddCallback(callback);
}
void WaitStreamCallback() const {
std::lock_guard<std::mutex> guard(callback_mtx_);
callback_manager_->Wait();
}
private: private:
CUDAPlace place_; CUDAPlace place_;
...@@ -125,7 +139,12 @@ class CUDADeviceContext : public DeviceContext { ...@@ -125,7 +139,12 @@ class CUDADeviceContext : public DeviceContext {
int multi_process; int multi_process;
int max_threads_per_mp; int max_threads_per_mp;
std::mutex mtx_; mutable std::mutex mtx_;
// This lock is only used by callback
// If we use mtx_ for StreamCallbackManager, deadlock may occur sometimes
mutable std::mutex callback_mtx_;
std::unique_ptr<StreamCallbackManager> callback_manager_;
}; };
template <> template <>
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <functional>
#include <memory>
#include "ThreadPool.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace platform {
using StreamCallback = std::function<void(cudaStream_t, cudaError_t)>;
class StreamCallbackManager;
struct StreamCallbackContext {
template <typename Callback>
inline StreamCallbackContext(const StreamCallbackManager *manager,
Callback &&callback)
: manager_(manager), callback_(callback) {}
const StreamCallbackManager *manager_; // do not own
StreamCallback callback_;
};
class StreamCallbackManager {
public:
explicit inline StreamCallbackManager(cudaStream_t stream = nullptr)
: stream_(stream), thread_pool_(new ThreadPool(1)) {}
template <typename Callback>
inline void AddCallback(Callback &&callback) const {
AddCallbackWithStreamAndErrorInfo(
[=](cudaStream_t, cudaError_t) { callback(); });
}
template <typename Callback>
inline void AddCallbackWithStreamAndErrorInfo(Callback &&callback) const {
auto *stream_callback_context = new StreamCallbackContext(this, callback);
PADDLE_ENFORCE(cudaStreamAddCallback(
stream_, StreamCallbackManager::StreamCallbackFunc,
stream_callback_context, 0));
}
void Wait() const { thread_pool_.reset(new ThreadPool(1)); }
private:
const cudaStream_t stream_;
mutable std::unique_ptr<ThreadPool> thread_pool_;
// cudaStreamCallback cannot call CUDA API inside, so we have to use
// thread_pool here
static void CUDART_CB StreamCallbackFunc(cudaStream_t stream,
cudaError_t status,
void *user_data) {
auto *callback_context_ptr =
reinterpret_cast<StreamCallbackContext *>(user_data);
callback_context_ptr->manager_->thread_pool_->enqueue([=]() {
std::unique_ptr<StreamCallbackContext> callback_context(
callback_context_ptr);
callback_context->callback_(stream, status);
});
}
};
} // namespace platform
} // namespace paddle
...@@ -716,6 +716,12 @@ function main() { ...@@ -716,6 +716,12 @@ function main() {
build_mac build_mac
run_mac_test run_mac_test
;; ;;
cicheck_py35)
cmake_gen ${PYTHON_ABI:-""}
build
run_test
assert_api_not_changed
;;
*) *)
print_usage print_usage
exit 0 exit 0
......
...@@ -67,7 +67,7 @@ def get_word_dict(): ...@@ -67,7 +67,7 @@ def get_word_dict():
for field in movie_reviews.fileids(category): for field in movie_reviews.fileids(category):
for words in movie_reviews.words(field): for words in movie_reviews.words(field):
word_freq_dict[words] += 1 word_freq_dict[words] += 1
words_sort_list = six.iteritems(word_freq_dict) words_sort_list = list(six.iteritems(word_freq_dict))
words_sort_list.sort(cmp=lambda a, b: b[1] - a[1]) words_sort_list.sort(cmp=lambda a, b: b[1] - a[1])
for index, word in enumerate(words_sort_list): for index, word in enumerate(words_sort_list):
words_freq_sorted.append((word[0], index)) words_freq_sorted.append((word[0], index))
......
...@@ -122,7 +122,7 @@ def __bootstrap__(): ...@@ -122,7 +122,7 @@ def __bootstrap__():
'use_pinned_memory', 'check_nan_inf', 'benchmark', 'warpctc_dir', 'use_pinned_memory', 'check_nan_inf', 'benchmark', 'warpctc_dir',
'eager_delete_scope', 'use_mkldnn', 'initial_cpu_memory_in_mb', 'eager_delete_scope', 'use_mkldnn', 'initial_cpu_memory_in_mb',
'init_allocated_mem', 'free_idle_memory', 'paddle_num_threads', 'init_allocated_mem', 'free_idle_memory', 'paddle_num_threads',
"dist_threadpool_size", 'cpu_deterministic' "dist_threadpool_size", 'cpu_deterministic', 'eager_delete_tensor_gb'
] ]
if core.is_compiled_with_dist(): if core.is_compiled_with_dist():
read_env_flags.append('rpc_deadline') read_env_flags.append('rpc_deadline')
......
...@@ -109,15 +109,20 @@ def get_transpiler(trainer_id, main_program, pserver_endpoints, trainers): ...@@ -109,15 +109,20 @@ def get_transpiler(trainer_id, main_program, pserver_endpoints, trainers):
return t return t
from paddle.fluid.transpiler.details import op_to_code
def operator_equal(a, b): def operator_equal(a, b):
if op_to_code(a) != op_to_code(b):
raise ValueError("In operator_equal not equal\n")
for k, v in six.iteritems(a.__dict__): for k, v in six.iteritems(a.__dict__):
if isinstance(v, fluid.framework.Program) or \ if isinstance(v, fluid.framework.Program) or \
isinstance(v, fluid.framework.Block): isinstance(v, fluid.framework.Block):
continue continue
elif isinstance(v, core.OpDesc): elif isinstance(v, core.OpDesc):
if v.serialize_to_string() != b.__dict__[k].serialize_to_string(): continue
raise ValueError("In operator_equal not equal:{0}\n".format(k))
elif isinstance(v, collections.OrderedDict): elif isinstance(v, collections.OrderedDict):
v0 = sorted(list(six.iteritems(v)), key=lambda x: x[0]) v0 = sorted(list(six.iteritems(v)), key=lambda x: x[0])
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -30,8 +30,10 @@ class TestWhileOp(unittest.TestCase): ...@@ -30,8 +30,10 @@ class TestWhileOp(unittest.TestCase):
"d1", shape=[10], append_batch_size=False, dtype='float32') "d1", shape=[10], append_batch_size=False, dtype='float32')
d2 = layers.data( d2 = layers.data(
"d2", shape=[10], append_batch_size=False, dtype='float32') "d2", shape=[10], append_batch_size=False, dtype='float32')
i = layers.zeros(shape=[1], dtype='int64') i = layers.zeros(shape=[1], dtype='int64')
i.stop_gradient = True i.stop_gradient = True
init = layers.zeros(shape=[10], dtype='float32') init = layers.zeros(shape=[10], dtype='float32')
mem_array = layers.array_write(x=init, i=i) mem_array = layers.array_write(x=init, i=i)
data_array = layers.array_write(x=d0, i=i) data_array = layers.array_write(x=d0, i=i)
...@@ -45,11 +47,19 @@ class TestWhileOp(unittest.TestCase): ...@@ -45,11 +47,19 @@ class TestWhileOp(unittest.TestCase):
i = layers.zeros(shape=[1], dtype='int64') i = layers.zeros(shape=[1], dtype='int64')
i.stop_gradient = True i.stop_gradient = True
array_len = layers.fill_constant(shape=[1], dtype='int64', value=3) array_len = layers.fill_constant(shape=[1], dtype='int64', value=1)
array_len.stop_gradient = True array_len.stop_gradient = True
cond = layers.less_than(x=i, y=array_len) cond = layers.less_than(x=i, y=array_len)
j = layers.fill_constant(shape=[1], dtype='int64', value=1)
j.stop_gradient = True
array_len2 = layers.fill_constant(shape=[1], dtype='int64', value=3)
array_len2.stop_gradient = True
cond2 = layers.less_than(x=j, y=array_len2)
while_op = layers.While(cond=cond) while_op = layers.While(cond=cond)
while_op2 = layers.While(cond=cond2)
with while_op.block(): with while_op.block():
d = layers.array_read(array=data_array, i=i) d = layers.array_read(array=data_array, i=i)
prev = layers.array_read(array=mem_array, i=i) prev = layers.array_read(array=mem_array, i=i)
...@@ -59,7 +69,16 @@ class TestWhileOp(unittest.TestCase): ...@@ -59,7 +69,16 @@ class TestWhileOp(unittest.TestCase):
layers.array_write(result, i=i, array=mem_array) layers.array_write(result, i=i, array=mem_array)
layers.less_than(x=i, y=array_len, cond=cond) layers.less_than(x=i, y=array_len, cond=cond)
sum_result = layers.array_read(array=mem_array, i=i) with while_op2.block():
d2 = layers.array_read(array=data_array, i=j)
prev2 = layers.array_read(array=mem_array, i=j)
result2 = layers.sums(input=[d2, prev2])
j = layers.increment(x=j, in_place=True)
layers.array_write(result2, i=j, array=mem_array)
layers.less_than(x=j, y=array_len2, cond=cond2)
sum_result = layers.array_read(array=mem_array, i=j)
loss = layers.mean(sum_result) loss = layers.mean(sum_result)
append_backward(loss) append_backward(loss)
......
...@@ -113,27 +113,32 @@ def op_to_code(op): ...@@ -113,27 +113,32 @@ def op_to_code(op):
inputs_str += ", " inputs_str += ", "
inputs_str += "}" inputs_str += "}"
attr_names = sorted(op.attr_names)
attrs_str = "" attrs_str = ""
for i in range(0, len(op.attr_names)): for i in range(0, len(attr_names)):
name = op.attr_names[i] name = attr_names[i]
attr_type = op.desc.attr_type(name) attr_type = op.desc.attr_type(name)
if attr_type == core.AttrType.BLOCK: if attr_type == core.AttrType.BLOCK:
a = "{name} = block[{value}]".format( a = "{name} = block[{value}]".format(
name=name, type=attr_type, value=op.block_attr_id(name)) name=name, type=attr_type, value=op.block_attr_id(name))
attrs_str += a attrs_str += a
if i != len(attr_names) - 1:
attrs_str += ", "
continue continue
if attr_type == core.AttrType.BLOCKS: if attr_type == core.AttrType.BLOCKS:
a = "{name} = blocks{value}".format( a = "{name} = blocks{value}".format(
name=name, type=attr_type, value=op.blocks_attr_ids(name)) name=name, type=attr_type, value=op.blocks_attr_ids(name))
attrs_str += a attrs_str += a
if i != len(attr_names) - 1:
attrs_str += ", "
continue continue
a = "{name} = {value}".format( a = "{name} = {value}".format(
name=name, type=attr_type, value=op.desc.attr(name)) name=name, type=attr_type, value=op.desc.attr(name))
attrs_str += a attrs_str += a
if i != len(op.attr_names) - 1: if i != len(attr_names) - 1:
attrs_str += ", " attrs_str += ", "
if outputs_str != "{}": if outputs_str != "{}":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册