提交 ed2d7d7d 编写于 作者: Y Yu Yang

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into feature/mix_cpu_gpu_op

...@@ -78,7 +78,7 @@ if(NOT CMAKE_CROSSCOMPILING) ...@@ -78,7 +78,7 @@ if(NOT CMAKE_CROSSCOMPILING)
/usr/lib/reference/ /usr/lib/reference/
) )
else() else()
# Diable the finding of reference cblas under host's system path # Disable the finding of reference cblas under host's system path
set(REFERENCE_CBLAS_INCLUDE_SEARCH_PATHS ${REFERENCE_CBLAS_ROOT}/include) set(REFERENCE_CBLAS_INCLUDE_SEARCH_PATHS ${REFERENCE_CBLAS_ROOT}/include)
set(REFERENCE_CBLAS_LIB_SEARCH_PATHS ${REFERENCE_CBLAS_ROOT}/lib) set(REFERENCE_CBLAS_LIB_SEARCH_PATHS ${REFERENCE_CBLAS_ROOT}/lib)
endif() endif()
......
...@@ -83,8 +83,8 @@ static void CheckTensorNANOrInf(const std::string& name, ...@@ -83,8 +83,8 @@ static void CheckTensorNANOrInf(const std::string& name,
if (tensor.memory_size() == 0) { if (tensor.memory_size() == 0) {
return; return;
} }
if (tensor.type().hash_code() != typeid(float).hash_code() && if (tensor.type().hash_code() != typeid(float).hash_code() && // NOLINT
tensor.type().hash_code() != typeid(double).hash_code()) { tensor.type().hash_code() != typeid(double).hash_code()) { // NOLINT
return; return;
} }
PADDLE_ENFORCE(!framework::TensorContainsInf(tensor), PADDLE_ENFORCE(!framework::TensorContainsInf(tensor),
...@@ -145,12 +145,13 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, ...@@ -145,12 +145,13 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
// Return true if the block has feed operators and holder of matching info. // Return true if the block has feed operators and holder of matching info.
static bool has_feed_operators( static bool has_feed_operators(
const BlockDesc& block, const BlockDesc& block,
std::map<std::string, const LoDTensor*>& feed_targets, const std::map<std::string, const LoDTensor*>& feed_targets,
const std::string& feed_holder_name) { const std::string& feed_holder_name) {
size_t feed_count = 0; size_t feed_count = 0;
for (auto* op : block.AllOps()) { for (auto* op : block.AllOps()) {
if (op->Type() == kFeedOpType) { if (op->Type() == kFeedOpType) {
feed_count++; feed_count++;
// The input variable's name of feed_op should be feed_holder_name.
PADDLE_ENFORCE_EQ(op->Input("X")[0], feed_holder_name, PADDLE_ENFORCE_EQ(op->Input("X")[0], feed_holder_name,
"Input to feed op should be '%s'", feed_holder_name); "Input to feed op should be '%s'", feed_holder_name);
std::string feed_target_name = op->Output("Out")[0]; std::string feed_target_name = op->Output("Out")[0];
...@@ -166,13 +167,15 @@ static bool has_feed_operators( ...@@ -166,13 +167,15 @@ static bool has_feed_operators(
feed_count, feed_targets.size(), feed_count, feed_targets.size(),
"The number of feed operators should match 'feed_targets'"); "The number of feed operators should match 'feed_targets'");
// When feed operator are present, so should be feed_holder if (!feed_holder_name.empty()) {
auto var = block.FindVar(feed_holder_name); // When feed operator are present, so should be feed_holder.
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", auto var = block.FindVar(feed_holder_name);
feed_holder_name); PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH, feed_holder_name);
"'%s' variable should be 'FEED_MINIBATCH' type", PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH,
feed_holder_name); "'%s' variable should be 'FEED_MINIBATCH' type",
feed_holder_name);
}
} }
return feed_count > 0; return feed_count > 0;
...@@ -185,12 +188,14 @@ static bool has_feed_operators( ...@@ -185,12 +188,14 @@ static bool has_feed_operators(
// and fetch_holder_name. Raise exception when any mismatch is found. // and fetch_holder_name. Raise exception when any mismatch is found.
// Return true if the block has fetch operators and holder of matching info. // Return true if the block has fetch operators and holder of matching info.
static bool has_fetch_operators( static bool has_fetch_operators(
const BlockDesc& block, std::map<std::string, LoDTensor*>& fetch_targets, const BlockDesc& block,
const std::map<std::string, LoDTensor*>& fetch_targets,
const std::string& fetch_holder_name) { const std::string& fetch_holder_name) {
size_t fetch_count = 0; size_t fetch_count = 0;
for (auto* op : block.AllOps()) { for (auto* op : block.AllOps()) {
if (op->Type() == kFetchOpType) { if (op->Type() == kFetchOpType) {
fetch_count++; fetch_count++;
// The output variable's name of fetch_op should be fetch_holder_name.
PADDLE_ENFORCE_EQ(op->Output("Out")[0], fetch_holder_name, PADDLE_ENFORCE_EQ(op->Output("Out")[0], fetch_holder_name,
"Output of fetch op should be '%s'", fetch_holder_name); "Output of fetch op should be '%s'", fetch_holder_name);
std::string fetch_target_name = op->Input("X")[0]; std::string fetch_target_name = op->Input("X")[0];
...@@ -206,13 +211,15 @@ static bool has_fetch_operators( ...@@ -206,13 +211,15 @@ static bool has_fetch_operators(
fetch_count, fetch_targets.size(), fetch_count, fetch_targets.size(),
"The number of fetch operators should match 'fetch_targets'"); "The number of fetch operators should match 'fetch_targets'");
// When fetch operator are present, so should be fetch_holder if (!fetch_holder_name.empty()) {
auto var = block.FindVar(fetch_holder_name); // When fetch operator are present, so should be fetch_holder.
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", auto var = block.FindVar(fetch_holder_name);
fetch_holder_name); PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST, fetch_holder_name);
"'%s' variable should be 'FETCH_LIST' type", PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST,
fetch_holder_name); "'%s' variable should be 'FETCH_LIST' type",
fetch_holder_name);
}
} }
return fetch_count > 0; return fetch_count > 0;
...@@ -259,16 +266,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -259,16 +266,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
} }
} }
// map the data of feed_targets to feed_holder
for (auto* op : global_block->AllOps()) {
if (op->Type() == kFeedOpType) {
std::string feed_target_name = op->Output("Out")[0];
int idx = boost::get<int>(op->GetAttr("col"));
SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name,
idx);
}
}
if (!has_fetch_ops) { if (!has_fetch_ops) {
// create fetch_holder variable // create fetch_holder variable
auto* fetch_holder = global_block->Var(fetch_holder_name); auto* fetch_holder = global_block->Var(fetch_holder_name);
...@@ -292,17 +289,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -292,17 +289,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
} }
} }
Run(*copy_program, scope, 0, create_vars, create_vars); auto ctx = Prepare(*copy_program, 0);
RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets, create_vars,
// obtain the data of fetch_targets from fetch_holder feed_holder_name, fetch_holder_name);
for (auto* op : global_block->AllOps()) {
if (op->Type() == kFetchOpType) {
std::string fetch_target_name = op->Input("X")[0];
int idx = boost::get<int>(op->GetAttr("col"));
*fetch_targets[fetch_target_name] =
GetFetchVariable(*scope, fetch_holder_name, idx);
}
}
} }
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare( std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
...@@ -370,5 +359,42 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -370,5 +359,42 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
} }
} }
void Executor::RunPreparedContext(
ExecutorPrepareContext* ctx, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets, bool create_vars,
const std::string& feed_holder_name, const std::string& fetch_holder_name) {
auto& global_block = ctx->prog_.Block(ctx->block_id_);
PADDLE_ENFORCE(
has_feed_operators(global_block, feed_targets, feed_holder_name),
"Program in ExecutorPrepareContext should has feed_ops.");
PADDLE_ENFORCE(
has_fetch_operators(global_block, fetch_targets, fetch_holder_name),
"Program in the prepared context should has fetch_ops.");
// map the data of feed_targets to feed_holder
for (auto* op : global_block.AllOps()) {
if (op->Type() == kFeedOpType) {
std::string feed_target_name = op->Output("Out")[0];
int idx = boost::get<int>(op->GetAttr("col"));
SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name,
idx);
}
}
RunPreparedContext(ctx, scope, create_vars, create_vars);
// obtain the data of fetch_targets from fetch_holder
for (auto* op : global_block.AllOps()) {
if (op->Type() == kFetchOpType) {
std::string fetch_target_name = op->Input("X")[0];
int idx = boost::get<int>(op->GetAttr("col"));
*fetch_targets[fetch_target_name] =
GetFetchVariable(*scope, fetch_holder_name, idx);
}
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,6 +14,9 @@ limitations under the License. */ ...@@ -14,6 +14,9 @@ limitations under the License. */
#pragma once #pragma once
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -70,6 +73,13 @@ class Executor { ...@@ -70,6 +73,13 @@ class Executor {
bool create_local_scope = true, bool create_local_scope = true,
bool create_vars = true); bool create_vars = true);
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets,
bool create_vars = true,
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch");
private: private:
const platform::Place place_; const platform::Place place_;
}; };
......
...@@ -14,8 +14,12 @@ ...@@ -14,8 +14,12 @@
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "gflags/gflags.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
DEFINE_int32(io_threadpool_size, 100,
"number of threads used for doing IO, default 100");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -91,5 +95,20 @@ void ThreadPool::TaskLoop() { ...@@ -91,5 +95,20 @@ void ThreadPool::TaskLoop() {
} }
} }
std::unique_ptr<ThreadPool> ThreadPoolIO::io_threadpool_(nullptr);
std::once_flag ThreadPoolIO::io_init_flag_;
ThreadPool* ThreadPoolIO::GetInstanceIO() {
std::call_once(io_init_flag_, &ThreadPoolIO::InitIO);
return io_threadpool_.get();
}
void ThreadPoolIO::InitIO() {
if (io_threadpool_.get() == nullptr) {
// TODO(typhoonzero1986): make this configurable
io_threadpool_.reset(new ThreadPool(FLAGS_io_threadpool_size));
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,12 +14,12 @@ limitations under the License. */ ...@@ -14,12 +14,12 @@ limitations under the License. */
#pragma once #pragma once
#include <condition_variable> #include <condition_variable> // NOLINT
#include <functional> #include <functional>
#include <future> #include <future> // NOLINT
#include <mutex> #include <mutex> // NOLINT
#include <queue> #include <queue>
#include <thread> #include <thread> // NOLINT
#include <vector> #include <vector>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -28,6 +28,22 @@ limitations under the License. */ ...@@ -28,6 +28,22 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
struct ExceptionHandler {
mutable std::future<std::unique_ptr<platform::EnforceNotMet>> future_;
explicit ExceptionHandler(
std::future<std::unique_ptr<platform::EnforceNotMet>>&& f)
: future_(std::move(f)) {}
void operator()() const {
auto ex = this->future_.get();
if (ex != nullptr) {
LOG(FATAL) << "The exception is thrown inside the thread pool. You "
"should use RunAndGetException to handle the exception.\n"
"The default exception handler is LOG(FATAL)."
<< ex->what();
}
}
};
// ThreadPool maintains a queue of tasks, and runs them using a fixed // ThreadPool maintains a queue of tasks, and runs them using a fixed
// number of threads. // number of threads.
class ThreadPool { class ThreadPool {
...@@ -87,22 +103,6 @@ class ThreadPool { ...@@ -87,22 +103,6 @@ class ThreadPool {
void Wait(); void Wait();
private: private:
struct ExceptionHandler {
mutable std::future<std::unique_ptr<platform::EnforceNotMet>> future_;
explicit ExceptionHandler(
std::future<std::unique_ptr<platform::EnforceNotMet>>&& f)
: future_(std::move(f)) {}
void operator()() const {
auto ex = this->future_.get();
if (ex != nullptr) {
LOG(FATAL) << "The exception is thrown inside the thread pool. You "
"should use RunAndGetException to handle the exception.\n"
"The default exception handler is LOG(FATAL)."
<< ex->what();
}
}
};
DISABLE_COPY_AND_ASSIGN(ThreadPool); DISABLE_COPY_AND_ASSIGN(ThreadPool);
// If the task queue is empty and avaialbe is equal to the number of // If the task queue is empty and avaialbe is equal to the number of
...@@ -135,6 +135,17 @@ class ThreadPool { ...@@ -135,6 +135,17 @@ class ThreadPool {
std::condition_variable completed_; std::condition_variable completed_;
}; };
class ThreadPoolIO : ThreadPool {
public:
static ThreadPool* GetInstanceIO();
static void InitIO();
private:
// NOTE: threadpool in base will be inhereted here.
static std::unique_ptr<ThreadPool> io_threadpool_;
static std::once_flag io_init_flag_;
};
// Run a function asynchronously. // Run a function asynchronously.
// NOTE: The function must return void. If the function need to return a value, // NOTE: The function must return void. If the function need to return a value,
// you can use lambda to capture a value pointer. // you can use lambda to capture a value pointer.
...@@ -143,5 +154,10 @@ std::future<void> Async(Callback callback) { ...@@ -143,5 +154,10 @@ std::future<void> Async(Callback callback) {
return ThreadPool::GetInstance()->Run(callback); return ThreadPool::GetInstance()->Run(callback);
} }
template <typename Callback>
std::future<void> AsyncIO(Callback callback) {
return ThreadPoolIO::GetInstanceIO()->Run(callback);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -23,7 +23,7 @@ limitations under the License. */ ...@@ -23,7 +23,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace inference { namespace inference {
// Temporarilly add this function for exposing framework::InitDevices() when // Temporarily add this function for exposing framework::InitDevices() when
// linking the inference shared library. // linking the inference shared library.
void Init(bool init_p2p) { framework::InitDevices(init_p2p); } void Init(bool init_p2p) { framework::InitDevices(init_p2p); }
......
...@@ -46,8 +46,8 @@ TEST(inference, image_classification) { ...@@ -46,8 +46,8 @@ TEST(inference, image_classification) {
// Run inference on CPU // Run inference on CPU
LOG(INFO) << "--- CPU Runs: ---"; LOG(INFO) << "--- CPU Runs: ---";
TestInference<paddle::platform::CPUPlace, false>(dirname, cpu_feeds, TestInference<paddle::platform::CPUPlace, false, true>(
cpu_fetchs1, FLAGS_repeat); dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat);
LOG(INFO) << output1.dims(); LOG(INFO) << output1.dims();
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -57,8 +57,8 @@ TEST(inference, image_classification) { ...@@ -57,8 +57,8 @@ TEST(inference, image_classification) {
// Run inference on CUDA GPU // Run inference on CUDA GPU
LOG(INFO) << "--- GPU Runs: ---"; LOG(INFO) << "--- GPU Runs: ---";
TestInference<paddle::platform::CUDAPlace, false>(dirname, cpu_feeds, TestInference<paddle::platform::CUDAPlace, false, true>(
cpu_fetchs2, FLAGS_repeat); dirname, cpu_feeds, cpu_fetchs2, FLAGS_repeat);
LOG(INFO) << output2.dims(); LOG(INFO) << output2.dims();
CheckError<float>(output1, output2); CheckError<float>(output1, output2);
......
...@@ -89,7 +89,7 @@ void CheckError(const paddle::framework::LoDTensor& output1, ...@@ -89,7 +89,7 @@ void CheckError(const paddle::framework::LoDTensor& output1,
EXPECT_EQ(count, 0U) << "There are " << count << " different elements."; EXPECT_EQ(count, 0U) << "There are " << count << " different elements.";
} }
template <typename Place, bool CreateVars = true> template <typename Place, bool CreateVars = true, bool PrepareContext = false>
void TestInference(const std::string& dirname, void TestInference(const std::string& dirname,
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds, const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
const std::vector<paddle::framework::LoDTensor*>& cpu_fetchs, const std::vector<paddle::framework::LoDTensor*>& cpu_fetchs,
...@@ -175,8 +175,15 @@ void TestInference(const std::string& dirname, ...@@ -175,8 +175,15 @@ void TestInference(const std::string& dirname,
} }
// Ignore the profiling results of the first run // Ignore the profiling results of the first run
executor.Run(*inference_program, scope, feed_targets, fetch_targets, std::unique_ptr<paddle::framework::ExecutorPrepareContext> ctx;
CreateVars); if (PrepareContext) {
ctx = executor.Prepare(*inference_program, 0);
executor.RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets,
CreateVars);
} else {
executor.Run(*inference_program, scope, feed_targets, fetch_targets,
CreateVars);
}
// Enable the profiler // Enable the profiler
paddle::platform::EnableProfiler(state); paddle::platform::EnableProfiler(state);
...@@ -187,8 +194,15 @@ void TestInference(const std::string& dirname, ...@@ -187,8 +194,15 @@ void TestInference(const std::string& dirname,
"run_inference", "run_inference",
paddle::platform::DeviceContextPool::Instance().Get(place)); paddle::platform::DeviceContextPool::Instance().Get(place));
executor.Run(*inference_program, scope, feed_targets, fetch_targets, if (PrepareContext) {
CreateVars); // Note: if you change the inference_program, you need to call
// executor.Prepare() again to get a new ExecutorPrepareContext.
executor.RunPreparedContext(ctx.get(), scope, feed_targets,
fetch_targets, CreateVars);
} else {
executor.Run(*inference_program, scope, feed_targets, fetch_targets,
CreateVars);
}
} }
// Disable the profiler and print the timing information // Disable the profiler and print the timing information
......
...@@ -35,7 +35,8 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, ...@@ -35,7 +35,8 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val); const auto ch = GetChannel(ep_val);
framework::Async([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, this] { framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch,
this] {
auto* var = p_scope->FindVar(var_name_val); auto* var = p_scope->FindVar(var_name_val);
::grpc::ByteBuffer req; ::grpc::ByteBuffer req;
...@@ -89,7 +90,8 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, ...@@ -89,7 +90,8 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val); const auto ch = GetChannel(ep_val);
framework::Async([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] { framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch,
this] {
// prepare input // prepare input
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(var_name_val); req.set_varname(var_name_val);
...@@ -132,8 +134,8 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, ...@@ -132,8 +134,8 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val); const auto ch = GetChannel(ep_val);
framework::Async([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
time_out, ch, this] { time_out, ch, this] {
auto* var = p_scope->FindVar(in_var_name_val); auto* var = p_scope->FindVar(in_var_name_val);
::grpc::ByteBuffer req; ::grpc::ByteBuffer req;
...@@ -196,7 +198,7 @@ bool RPCClient::Wait() { ...@@ -196,7 +198,7 @@ bool RPCClient::Wait() {
std::vector<std::future<void>> waits(req_count_); std::vector<std::future<void>> waits(req_count_);
for (int i = 0; i < req_count_; i++) { for (int i = 0; i < req_count_; i++) {
waits[i] = framework::Async([i, &a, this] { a[i] = Proceed(); }); waits[i] = framework::AsyncIO([i, &a, this] { a[i] = Proceed(); });
} }
for (int i = 0; i < req_count_; i++) { for (int i = 0; i < req_count_; i++) {
......
...@@ -217,10 +217,10 @@ void AsyncGRPCServer::RunSyncUpdate() { ...@@ -217,10 +217,10 @@ void AsyncGRPCServer::RunSyncUpdate() {
std::function<void()> prefetch_register = std::function<void()> prefetch_register =
std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this); std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this);
// TODO(wuyi): Run these "HandleRequest" in thread pool
t_send_.reset( t_send_.reset(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_send_.get(), "cq_send", send_register))); cq_send_.get(), "cq_send", send_register)));
t_get_.reset( t_get_.reset(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_get_.get(), "cq_get", get_register))); cq_get_.get(), "cq_get", get_register)));
......
...@@ -33,28 +33,14 @@ static constexpr size_t kChannelSize = 0; // kCacheSize - 2 ...@@ -33,28 +33,14 @@ static constexpr size_t kChannelSize = 0; // kCacheSize - 2
class DoubleBufferReader : public framework::DecoratedReader { class DoubleBufferReader : public framework::DecoratedReader {
public: public:
struct Item {
Item() : ctx_(nullptr) {}
Item(Item&& b) {
payloads_ = std::move(b.payloads_);
ctx_ = std::move(b.ctx_);
}
Item& operator=(Item&& b) {
payloads_ = std::move(b.payloads_);
ctx_ = std::move(b.ctx_);
return *this;
}
std::vector<framework::LoDTensor> payloads_;
platform::DeviceContext* ctx_;
};
explicit DoubleBufferReader( explicit DoubleBufferReader(
ReaderBase* reader, platform::Place target_place = platform::CPUPlace()) ReaderBase* reader, platform::Place target_place = platform::CPUPlace())
: DecoratedReader(reader), place_(target_place) { : DecoratedReader(reader), place_(target_place) {
cpu_tensor_cache_.resize(kCacheSize);
gpu_tensor_cache_.resize(kCacheSize);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
for (size_t i = 0; i < kCacheSize; ++i) { if (platform::is_gpu_place(place_)) {
if (platform::is_gpu_place(place_)) { for (size_t i = 0; i < kCacheSize; ++i) {
ctxs_.emplace_back(new platform::CUDADeviceContext( ctxs_.emplace_back(new platform::CUDADeviceContext(
boost::get<platform::CUDAPlace>(place_))); boost::get<platform::CUDAPlace>(place_)));
} }
...@@ -72,7 +58,7 @@ class DoubleBufferReader : public framework::DecoratedReader { ...@@ -72,7 +58,7 @@ class DoubleBufferReader : public framework::DecoratedReader {
bool HasNext() const; bool HasNext() const;
void StartPrefetcher() { void StartPrefetcher() {
channel_ = framework::MakeChannel<Item>(kChannelSize); channel_ = framework::MakeChannel<size_t>(kChannelSize);
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); }); prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
} }
...@@ -88,8 +74,10 @@ class DoubleBufferReader : public framework::DecoratedReader { ...@@ -88,8 +74,10 @@ class DoubleBufferReader : public framework::DecoratedReader {
void PrefetchThreadFunc(); void PrefetchThreadFunc();
std::thread prefetcher_; std::thread prefetcher_;
framework::Channel<Item>* channel_; framework::Channel<size_t>* channel_;
platform::Place place_; platform::Place place_;
std::vector<std::vector<framework::LoDTensor>> cpu_tensor_cache_;
std::vector<std::vector<framework::LoDTensor>> gpu_tensor_cache_;
std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_; std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_;
}; };
...@@ -153,11 +141,14 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -153,11 +141,14 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) { void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
out->clear(); out->clear();
if (HasNext()) { if (HasNext()) {
Item batch; size_t cached_tensor_id;
channel_->Receive(&batch); channel_->Receive(&cached_tensor_id);
*out = batch.payloads_; if (platform::is_gpu_place(place_)) {
if (batch.ctx_) { *out = gpu_tensor_cache_[cached_tensor_id];
batch.ctx_->Wait(); ctxs_[cached_tensor_id]->Wait();
} else {
// CPU place
*out = cpu_tensor_cache_[cached_tensor_id];
} }
} }
} }
...@@ -176,42 +167,33 @@ bool DoubleBufferReader::HasNext() const { ...@@ -176,42 +167,33 @@ bool DoubleBufferReader::HasNext() const {
void DoubleBufferReader::PrefetchThreadFunc() { void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "A new prefetch thread starts."; VLOG(5) << "A new prefetch thread starts.";
std::vector<std::vector<framework::LoDTensor>> cpu_tensor_cache(kCacheSize);
std::vector<std::vector<framework::LoDTensor>> gpu_tensor_cache(kCacheSize);
size_t cached_tensor_id = 0; size_t cached_tensor_id = 0;
while (true) { while (true) {
Item batch; auto& cpu_batch = cpu_tensor_cache_[cached_tensor_id];
auto& cpu_batch = cpu_tensor_cache[cached_tensor_id];
reader_->ReadNext(&cpu_batch); reader_->ReadNext(&cpu_batch);
if (cpu_batch.empty()) { if (cpu_batch.empty()) {
// The underlying reader have no next data. // The underlying reader have no next data.
break; break;
} }
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
auto& gpu_batch = gpu_tensor_cache[cached_tensor_id]; auto& gpu_batch = gpu_tensor_cache_[cached_tensor_id];
auto* gpu_ctx = ctxs_[cached_tensor_id].get(); auto* gpu_ctx = ctxs_[cached_tensor_id].get();
gpu_batch.resize(cpu_batch.size()); gpu_batch.resize(cpu_batch.size());
for (size_t i = 0; i < cpu_batch.size(); ++i) { for (size_t i = 0; i < cpu_batch.size(); ++i) {
framework::TensorCopy(cpu_batch[i], place_, *gpu_ctx, &gpu_batch[i]); framework::TensorCopy(cpu_batch[i], place_, *gpu_ctx, &gpu_batch[i]);
gpu_batch[i].set_lod(cpu_batch[i].lod()); gpu_batch[i].set_lod(cpu_batch[i].lod());
} }
batch.payloads_ = gpu_batch;
batch.ctx_ = gpu_ctx;
} else {
// CPUPlace
batch.payloads_ = cpu_batch;
} }
++cached_tensor_id;
cached_tensor_id %= kCacheSize;
try { try {
channel_->Send(&batch); size_t tmp = cached_tensor_id;
channel_->Send(&tmp);
} catch (paddle::platform::EnforceNotMet e) { } catch (paddle::platform::EnforceNotMet e) {
VLOG(5) << "WARNING: The double buffer channel has been closed. The " VLOG(5) << "WARNING: The double buffer channel has been closed. The "
"prefetch thread will terminate."; "prefetch thread will terminate.";
break; break;
} }
++cached_tensor_id;
cached_tensor_id %= kCacheSize;
} }
channel_->Close(); channel_->Close();
VLOG(5) << "Prefetch thread terminates."; VLOG(5) << "Prefetch thread terminates.";
......
...@@ -147,6 +147,7 @@ class ReshapeKernel : public framework::OpKernel<T> { ...@@ -147,6 +147,7 @@ class ReshapeKernel : public framework::OpKernel<T> {
if (!inplace) { if (!inplace) {
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
framework::TensorCopy(*in, ctx.GetPlace(), ctx.device_context(), out); framework::TensorCopy(*in, ctx.GetPlace(), ctx.device_context(), out);
ctx.device_context().Wait();
// TensorCopy will resize to in_dims. // TensorCopy will resize to in_dims.
out->Resize(out_dims); out->Resize(out_dims);
} else { } else {
...@@ -169,6 +170,7 @@ class ReshapeGradKernel : public framework::OpKernel<T> { ...@@ -169,6 +170,7 @@ class ReshapeGradKernel : public framework::OpKernel<T> {
auto in_dims = d_x->dims(); auto in_dims = d_x->dims();
if (!inplace) { if (!inplace) {
framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x); framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x);
ctx.device_context().Wait();
d_x->Resize(in_dims); d_x->Resize(in_dims);
} else { } else {
d_x->ShareDataWith(*d_out); d_x->ShareDataWith(*d_out);
......
...@@ -24,7 +24,19 @@ template <typename T> ...@@ -24,7 +24,19 @@ template <typename T>
class CPUUniformRandomKernel : public framework::OpKernel<T> { class CPUUniformRandomKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* tensor = ctx.Output<framework::Tensor>("Out"); framework::Tensor* tensor = nullptr;
auto out_var = ctx.OutputVar("Out");
if (out_var->IsType<framework::LoDTensor>()) {
tensor = out_var->GetMutable<framework::LoDTensor>();
} else if (out_var->IsType<framework::SelectedRows>()) {
auto shape = ctx.Attr<std::vector<int>>("shape");
tensor = out_var->GetMutable<framework::SelectedRows>()->mutable_value();
tensor->Resize(framework::make_ddim(shape));
} else {
PADDLE_THROW(
"uniform_random_op's output only"
"supports SelectedRows and Tensor");
}
T* data = tensor->mutable_data<T>(ctx.GetPlace()); T* data = tensor->mutable_data<T>(ctx.GetPlace());
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed")); unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
std::minstd_rand engine; std::minstd_rand engine;
......
...@@ -43,7 +43,19 @@ template <typename T> ...@@ -43,7 +43,19 @@ template <typename T>
class GPUUniformRandomKernel : public framework::OpKernel<T> { class GPUUniformRandomKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out"); framework::Tensor* tensor = nullptr;
auto out_var = context.OutputVar("Out");
if (out_var->IsType<framework::LoDTensor>()) {
tensor = out_var->GetMutable<framework::LoDTensor>();
} else if (out_var->IsType<framework::SelectedRows>()) {
auto shape = context.Attr<std::vector<int>>("shape");
tensor = out_var->GetMutable<framework::SelectedRows>()->mutable_value();
tensor->Resize(framework::make_ddim(shape));
} else {
PADDLE_THROW(
"uniform_random_op's output only"
"supports SelectedRows and Tensor");
}
T* data = tensor->mutable_data<T>(context.GetPlace()); T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed")); unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
if (seed == 0) { if (seed == 0) {
......
...@@ -157,7 +157,6 @@ def train(nn_type, ...@@ -157,7 +157,6 @@ def train(nn_type,
for ip in pserver_ips.split(","): for ip in pserver_ips.split(","):
eplist.append(':'.join([ip, port])) eplist.append(':'.join([ip, port]))
pserver_endpoints = ",".join(eplist) # ip:port,ip:port... pserver_endpoints = ",".join(eplist) # ip:port,ip:port...
pserver_endpoints = os.getenv("PSERVERS")
trainers = int(os.getenv("TRAINERS")) trainers = int(os.getenv("TRAINERS"))
current_endpoint = os.getenv("POD_IP") + ":" + port current_endpoint = os.getenv("POD_IP") + ":" + port
trainer_id = int(os.getenv("PADDLE_INIT_TRAINER_ID")) trainer_id = int(os.getenv("PADDLE_INIT_TRAINER_ID"))
......
...@@ -15,6 +15,16 @@ ...@@ -15,6 +15,16 @@
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
def output_hist(out):
hist, _ = np.histogram(out, range=(-5, 10))
hist = hist.astype("float32")
hist /= float(out.size)
prob = 0.1 * np.ones((10))
return hist, prob
class TestUniformRandomOp(OpTest): class TestUniformRandomOp(OpTest):
...@@ -33,11 +43,37 @@ class TestUniformRandomOp(OpTest): ...@@ -33,11 +43,37 @@ class TestUniformRandomOp(OpTest):
self.check_output_customized(self.verify_output) self.check_output_customized(self.verify_output)
def verify_output(self, outs): def verify_output(self, outs):
tensor = outs[0] hist, prob = output_hist(np.array(outs[0]))
hist, _ = np.histogram(outs[0], range=(-5, 10)) self.assertTrue(
hist = hist.astype("float32") np.allclose(
hist /= float(outs[0].size) hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
prob = 0.1 * np.ones((10))
class TestUniformRandomOpSelectedRows(unittest.TestCase):
def get_places(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
return places
def test_check_output(self):
for place in self.get_places():
self.check_with_place(place)
def check_with_place(self, place):
scope = core.Scope()
out = scope.var("X").get_selected_rows()
op = Operator(
"uniform_random",
Out="X",
shape=[4, 784],
min=-5.0,
max=10.0,
seed=10)
op.run(scope, place)
self.assertEqual(out.get_tensor().shape(), [4, 784])
hist, prob = output_hist(np.array(out.get_tensor()))
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist)) hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册