未验证 提交 e29cbfe4 编写于 作者: Q Qiyang Min 提交者: GitHub

Merge pull request #14829 from velconia/accelerate_ddpg

Accelerate little models
......@@ -72,13 +72,13 @@ cc_test(reader_test SRCS reader_test.cc DEPS reader)
cc_library(threadpool SRCS threadpool.cc DEPS enforce)
cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool)
cc_library(var_type_traits SRCS var_type_traits DEPS lod_tensor selected_rows framework_proto)
cc_library(var_type_traits SRCS var_type_traits DEPS lod_tensor selected_rows framework_proto)
if (WITH_GPU)
target_link_libraries(var_type_traits dynload_cuda)
endif()
cc_test(var_type_traits_test SRCS var_type_traits_test.cc DEPS var_type_traits)
cc_library(scope SRCS scope.cc DEPS glog threadpool var_type_traits)
cc_library(scope SRCS scope.cc DEPS glog threadpool xxhash var_type_traits)
cc_library(scope_pool SRCS scope_pool.cc DEPS scope)
cc_test(scope_test SRCS scope_test.cc DEPS scope)
cc_test(variable_test SRCS variable_test.cc DEPS tensor var_type_traits)
......
......@@ -25,7 +25,7 @@ struct ExecutionStrategy {
size_t num_threads_{0};
bool use_cuda_{true};
bool allow_op_delay_{false};
size_t num_iteration_per_drop_scope_{100};
size_t num_iteration_per_drop_scope_{1};
ExecutorType type_{kDefault};
bool dry_run_{false};
};
......
......@@ -64,20 +64,26 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
}
platform::RecordEvent e("ScopeBufferedSSAGraphExecutorAfterRun", nullptr);
drop_scope_counter_ += 1;
++drop_scope_counter_;
if (!fetch_tensors.empty() ||
drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
drop_scope_counter_ = 0;
// Wait All computational streams
for (auto p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
bool stream_end = false;
if (!fetch_tensors.empty()) {
WaitComputationalStreams();
stream_end = true;
}
if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
if (!stream_end) {
WaitComputationalStreams();
}
for (auto &scope : local_scopes_) {
auto &local_scope =
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>();
scope->DeleteScope(local_scope);
}
drop_scope_counter_ = 0;
}
if (eptr) {
std::rethrow_exception(eptr);
......
......@@ -47,6 +47,14 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
FeedFetchList Run(const std::vector<std::string>& fetch_tensors) override;
private:
inline void WaitComputationalStreams() {
// Wait All computational streams
for (auto p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
}
private:
size_t drop_scope_counter_{0};
......
......@@ -16,7 +16,9 @@ limitations under the License. */
#if !defined(_WIN32)
#include <pthread.h>
#endif // !_WIN32
#else
#include <mutex> // NOLINT
#endif // !_WIN32
#include "paddle/fluid/platform/enforce.h"
......@@ -29,17 +31,17 @@ struct RWLock {
~RWLock() { pthread_rwlock_destroy(&lock_); }
void RDLock() {
inline void RDLock() {
PADDLE_ENFORCE_EQ(pthread_rwlock_rdlock(&lock_), 0,
"acquire read lock failed");
}
void WRLock() {
inline void WRLock() {
PADDLE_ENFORCE_EQ(pthread_rwlock_wrlock(&lock_), 0,
"acquire write lock failed");
}
void UNLock() {
inline void UNLock() {
PADDLE_ENFORCE_EQ(pthread_rwlock_unlock(&lock_), 0, "unlock failed");
}
......@@ -51,81 +53,46 @@ struct RWLock {
// https://stackoverflow.com/questions/7125250/making-pthread-rwlock-wrlock-recursive
// In windows, rw_lock seems like a hack. Use empty object and do nothing.
struct RWLock {
void RDLock() {}
void WRLock() {}
void UNLock() {}
// FIXME(minqiyang): use mutex here to do fake lock
inline void RDLock() { mutex_.lock(); }
inline void WRLock() { mutex_.lock(); }
inline void UNLock() { mutex_.unlock(); }
private:
std::mutex mutex_;
};
#endif
class RWLockGuard {
class AutoWRLock {
public:
enum Status { kUnLock, kWRLock, kRDLock };
RWLockGuard(RWLock* rw_lock, Status init_status)
: lock_(rw_lock), status_(Status::kUnLock) {
switch (init_status) {
case Status::kRDLock: {
RDLock();
break;
}
case Status::kWRLock: {
WRLock();
break;
}
case Status::kUnLock: {
break;
}
}
}
explicit AutoWRLock(RWLock* rw_lock) : lock_(rw_lock) { Lock(); }
void WRLock() {
switch (status_) {
case Status::kUnLock: {
lock_->WRLock();
status_ = Status::kWRLock;
break;
}
case Status::kWRLock: {
break;
}
case Status::kRDLock: {
PADDLE_THROW(
"Please unlock read lock first before invoking write lock.");
break;
}
}
}
~AutoWRLock() { UnLock(); }
void RDLock() {
switch (status_) {
case Status::kUnLock: {
lock_->RDLock();
status_ = Status::kRDLock;
break;
}
case Status::kRDLock: {
break;
}
case Status::kWRLock: {
PADDLE_THROW(
"Please unlock write lock first before invoking read lock.");
break;
}
}
}
private:
inline void Lock() { lock_->WRLock(); }
void UnLock() {
if (status_ != Status::kUnLock) {
lock_->UNLock();
status_ = Status::kUnLock;
}
}
inline void UnLock() { lock_->UNLock(); }
private:
RWLock* lock_;
};
class AutoRDLock {
public:
explicit AutoRDLock(RWLock* rw_lock) : lock_(rw_lock) { Lock(); }
~AutoRDLock() { UnLock(); }
private:
inline void Lock() { lock_->RDLock(); }
~RWLockGuard() { UnLock(); }
inline void UnLock() { lock_->UNLock(); }
private:
RWLock* lock_;
Status status_;
};
} // namespace framework
......
......@@ -47,9 +47,15 @@ DEFINE_bool(fast_eager_deletion_mode, false,
// the mutex will cause serious performance issue.
// So the mutex is disabled when `ON_INFER`.
#ifdef PADDLE_ON_INFERENCE
#define SCOPE_LOCK_GUARD
#define SCOPE_KIDS_READER_LOCK
#define SCOPE_KIDS_WRITER_LOCK
#define SCOPE_VARS_READER_LOCK
#define SCOPE_VARS_WRITER_LOCK
#else
#define SCOPE_LOCK_GUARD std::lock_guard<std::mutex> lock(mutex_);
#define SCOPE_KIDS_READER_LOCK AutoRDLock auto_lock(&kids_lock_);
#define SCOPE_KIDS_WRITER_LOCK AutoWRLock auto_lock(&kids_lock_);
#define SCOPE_VARS_READER_LOCK AutoRDLock auto_lock(&vars_lock_);
#define SCOPE_VARS_WRITER_LOCK AutoWRLock auto_lock(&vars_lock_);
#endif
namespace paddle {
......@@ -67,64 +73,69 @@ bool IsFastEagerDeletionModeEnabled() { return FLAGS_fast_eager_deletion_mode; }
Scope::~Scope() { DropKids(); }
Scope& Scope::NewScope() const {
SCOPE_LOCK_GUARD
kids_.push_back(new Scope(this));
return *kids_.back();
Scope* child = new Scope(this);
{
SCOPE_KIDS_WRITER_LOCK
kids_.push_back(child);
}
return *child;
}
Variable* Scope::Var(const std::string& name) {
SCOPE_LOCK_GUARD
SCOPE_VARS_WRITER_LOCK
return VarInternal(name);
}
Variable* Scope::Var(std::string* name) {
SCOPE_LOCK_GUARD
auto new_name = string::Sprintf("%p.%d", this, vars_.size());
if (name != nullptr) {
*name = new_name;
}
SCOPE_VARS_WRITER_LOCK
return VarInternal(new_name);
}
Variable* Scope::FindVar(const std::string& name) const {
SCOPE_LOCK_GUARD
SCOPE_VARS_READER_LOCK
return FindVarInternal(name);
}
Variable* Scope::FindLocalVar(const std::string& name) const {
SCOPE_LOCK_GUARD
SCOPE_VARS_READER_LOCK
return FindVarLocally(name);
}
const Scope* Scope::FindScope(const Variable* var) const {
SCOPE_LOCK_GUARD
SCOPE_VARS_READER_LOCK
return FindScopeInternal(var);
}
void Scope::DropKids() {
SCOPE_LOCK_GUARD
SCOPE_KIDS_WRITER_LOCK
for (Scope* s : kids_) delete s;
kids_.clear();
}
bool Scope::HasKid(const Scope* scope) const {
SCOPE_LOCK_GUARD
SCOPE_KIDS_READER_LOCK
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
return it != this->kids_.end();
}
std::vector<std::string> Scope::LocalVarNames() const {
SCOPE_LOCK_GUARD
std::vector<std::string> known_vars;
known_vars.reserve(this->vars_.size());
for (auto& p : vars_) {
known_vars.emplace_back(p.first);
{
SCOPE_VARS_READER_LOCK
known_vars.reserve(this->vars_.size());
for (auto& p : vars_) {
known_vars.emplace_back(p.first);
}
}
return known_vars;
}
void Scope::DeleteScope(Scope* scope) const {
SCOPE_LOCK_GUARD
SCOPE_KIDS_WRITER_LOCK
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
PADDLE_ENFORCE(it != this->kids_.end(), "%p Cannot find %p as kid scope",
this, scope);
......@@ -138,8 +149,8 @@ void Scope::DeleteScope(Scope* scope) const {
}
void Scope::EraseVars(const std::vector<std::string>& var_names) {
SCOPE_LOCK_GUARD
std::set<std::string> var_set(var_names.begin(), var_names.end());
SCOPE_VARS_WRITER_LOCK
for (auto it = vars_.begin(); it != vars_.end();) {
if (var_set.find(it->first) != var_set.end()) {
it = vars_.erase(it);
......@@ -151,12 +162,12 @@ void Scope::EraseVars(const std::vector<std::string>& var_names) {
void Scope::Rename(const std::string& origin_name,
const std::string& new_name) const {
SCOPE_LOCK_GUARD
SCOPE_VARS_WRITER_LOCK
RenameInternal(origin_name, new_name);
}
std::string Scope::Rename(const std::string& origin_name) const {
SCOPE_LOCK_GUARD
SCOPE_VARS_WRITER_LOCK
auto new_name = string::Sprintf("%p.%d", this, vars_.size());
RenameInternal(origin_name, new_name);
return new_name;
......
......@@ -14,12 +14,18 @@ limitations under the License. */
#pragma once
extern "C" {
#include <xxhash.h>
}
#include <list>
#include <mutex> // NOLINT
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/macros.h"
......@@ -95,7 +101,14 @@ class Scope {
std::string Rename(const std::string& origin_name) const;
protected:
mutable std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
struct KeyHasher {
std::size_t operator()(const std::string& key) const {
return XXH32(key.c_str(), key.size(), 1);
}
};
mutable std::unordered_map<std::string, std::unique_ptr<Variable>, KeyHasher>
vars_;
private:
// Call Scope::NewScope for a sub-scope.
......@@ -124,7 +137,8 @@ class Scope {
DISABLE_COPY_AND_ASSIGN(Scope);
private:
mutable std::mutex mutex_;
mutable RWLock kids_lock_;
mutable RWLock vars_lock_;
};
// Generate some debug string about the inherience structure of scope, quite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册