提交 b645dfac 编写于 作者: Q qiaolongfei

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into update-api-reference-1

...@@ -180,7 +180,7 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc, ...@@ -180,7 +180,7 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc,
print_train_time(start_time, time.time(), num_samples) print_train_time(start_time, time.time(), num_samples)
print("Pass: %d, Loss: %f" % (pass_id, np.mean(train_losses))), print("Pass: %d, Loss: %f" % (pass_id, np.mean(train_losses))),
# evaluation # evaluation
if not args.no_test and batch_acc: if not args.no_test and batch_acc and not args.use_reader_op:
pass_test_acc = test(exe, infer_prog, test_reader, feeder, pass_test_acc = test(exe, infer_prog, test_reader, feeder,
batch_acc) batch_acc)
print(", Test Accuracy: %f" % pass_test_acc) print(", Test Accuracy: %f" % pass_test_acc)
...@@ -277,11 +277,12 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader, ...@@ -277,11 +277,12 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
batch_id += 1 batch_id += 1
print_train_time(start_time, time.time(), num_samples) print_train_time(start_time, time.time(), num_samples)
if not args.no_test and batch_acc: if not args.no_test and batch_acc and not args.use_reader_op:
# we have not implement record io for test
# skip test when use args.use_reader_op
test_acc = test(startup_exe, infer_prog, test_reader, feeder, test_acc = test(startup_exe, infer_prog, test_reader, feeder,
batch_acc) batch_acc)
print("Pass: %d, Test Accuracy: %f\n" % (pass_id, test_acc)) print("Pass: %d, Test Accuracy: %f\n" % (pass_id, test_acc))
exit(0)
def print_arguments(args): def print_arguments(args):
......
...@@ -199,7 +199,10 @@ def get_model(args): ...@@ -199,7 +199,10 @@ def get_model(args):
batched_train_reader = paddle.batch( batched_train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
train_reader, buf_size=5120), train_reader, buf_size=5120),
batch_size=args.batch_size * args.gpus) batch_size=args.batch_size * args.gpus,
batched_test_reader = paddle.batch(train_reader, batch_size=args.batch_size) drop_last=True)
batched_test_reader = paddle.batch(
train_reader, batch_size=args.batch_size, drop_last=True)
return avg_cost, inference_program, optimizer, batched_train_reader, batched_test_reader, batch_acc return avg_cost, inference_program, optimizer, batched_train_reader,\
batched_test_reader, batch_acc
...@@ -118,6 +118,10 @@ endif() ...@@ -118,6 +118,10 @@ endif()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${SIMD_FLAG}") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${SIMD_FLAG}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${SIMD_FLAG}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${SIMD_FLAG}")
if(WITH_DISTRIBUTE)
add_definitions(-DPADDLE_WITH_DISTRIBUTE)
endif()
if(WITH_GOLANG) if(WITH_GOLANG)
# we need to symlink Paddle directory into GOPATH. If we # we need to symlink Paddle directory into GOPATH. If we
# don't do it and we have code that depends on Paddle, go # don't do it and we have code that depends on Paddle, go
......
...@@ -83,8 +83,13 @@ cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor) ...@@ -83,8 +83,13 @@ cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor)
cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog) cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope if(WITH_DISTRIBUTE)
framework_proto glog lod_rank_table feed_fetch_method) cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else()
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method)
endif()
cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor) cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor)
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
class SSAGraph; struct SSAGraph;
class SSAGraghBuilderWithChecker : public SSAGraphBuilder { class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
public: public:
......
...@@ -20,6 +20,9 @@ limitations under the License. */ ...@@ -20,6 +20,9 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/detail/grpc_client.h"
#endif
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -44,6 +47,14 @@ ExecutorPrepareContext::~ExecutorPrepareContext() { ...@@ -44,6 +47,14 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
Executor::Executor(const platform::Place& place) : place_(place) {} Executor::Executor(const platform::Place& place) : place_(place) {}
#ifdef PADDLE_WITH_DISTRIBUTE
void Executor::Complete() {
::paddle::operators::detail::RPCClient::GetInstance<
::paddle::operators::detail::GRPCClient>()
->SendComplete();
}
#endif
void InitializeVariable(Variable* var, proto::VarType::Type var_type) { void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
if (var_type == proto::VarType::LOD_TENSOR) { if (var_type == proto::VarType::LOD_TENSOR) {
var->GetMutable<LoDTensor>(); var->GetMutable<LoDTensor>();
......
...@@ -44,6 +44,13 @@ class Executor { ...@@ -44,6 +44,13 @@ class Executor {
explicit Executor(const platform::Place& place); explicit Executor(const platform::Place& place);
#ifdef PADDLE_WITH_DISTRIBUTE
/*
* Sending signal to pserver to mark current trainer stop.
*/
void Complete();
#endif
/* @Brief /* @Brief
* Runtime evaluation of the given ProgramDesc under certain Scope * Runtime evaluation of the given ProgramDesc under certain Scope
* *
......
...@@ -35,14 +35,15 @@ class ReaderBase { ...@@ -35,14 +35,15 @@ class ReaderBase {
class DecoratedReader : public ReaderBase { class DecoratedReader : public ReaderBase {
public: public:
explicit DecoratedReader(ReaderBase* reader) : ReaderBase(), reader_(reader) { explicit DecoratedReader(const std::shared_ptr<ReaderBase>& reader)
: ReaderBase(), reader_(reader) {
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
} }
void ReInit() override { reader_->ReInit(); } void ReInit() override { reader_->ReInit(); }
protected: protected:
ReaderBase* reader_; std::shared_ptr<ReaderBase> reader_;
}; };
class FileReader : public ReaderBase { class FileReader : public ReaderBase {
...@@ -64,7 +65,7 @@ class ReaderHolder { ...@@ -64,7 +65,7 @@ class ReaderHolder {
public: public:
void Reset(ReaderBase* reader) { reader_.reset(reader); } void Reset(ReaderBase* reader) { reader_.reset(reader); }
ReaderBase* Get() const { return reader_.get(); } std::shared_ptr<ReaderBase> Get() const { return reader_; }
void ReadNext(std::vector<LoDTensor>* out) { void ReadNext(std::vector<LoDTensor>* out) {
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
...@@ -76,7 +77,7 @@ class ReaderHolder { ...@@ -76,7 +77,7 @@ class ReaderHolder {
} }
private: private:
std::unique_ptr<ReaderBase> reader_; std::shared_ptr<ReaderBase> reader_;
}; };
} // namespace framework } // namespace framework
......
...@@ -34,6 +34,12 @@ void GRPCClient::InitEventLoop() { ...@@ -34,6 +34,12 @@ void GRPCClient::InitEventLoop() {
client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this))); client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
} }
void GRPCClient::SendComplete() {
for (auto& it : channels_) {
this->AsyncSendComplete(it.first);
}
}
GRPCClient::~GRPCClient() { GRPCClient::~GRPCClient() {
Wait(); Wait();
cq_.Shutdown(); cq_.Shutdown();
...@@ -210,6 +216,19 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep, ...@@ -210,6 +216,19 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
req_count_++; req_count_++;
} }
void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
s->Prepare(time_out);
sendrecv::VariableMessage req;
req.set_varname(COMPLETE_MESSAGE);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
}
void GRPCClient::Wait() { void GRPCClient::Wait() {
std::unique_lock<std::mutex> lk(sync_mutex_); std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return req_count_ == 0; }); sync_cond_.wait(lk, [this] { return req_count_ == 0; });
......
...@@ -195,6 +195,8 @@ class GRPCClient : public RPCClient { ...@@ -195,6 +195,8 @@ class GRPCClient : public RPCClient {
void Wait() override; void Wait() override;
void SendComplete() override;
protected: protected:
void InitImpl() override; void InitImpl() override;
...@@ -204,6 +206,9 @@ class GRPCClient : public RPCClient { ...@@ -204,6 +206,9 @@ class GRPCClient : public RPCClient {
void Proceed(); void Proceed();
void AsyncSendComplete(const std::string& ep,
int64_t time_out = RPCClient::rpc_time_out);
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep); std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
private: private:
......
...@@ -162,16 +162,18 @@ class RequestPrefetch final : public RequestBase { ...@@ -162,16 +162,18 @@ class RequestPrefetch final : public RequestBase {
void Process() override { void Process() override {
// prefetch process... // prefetch process...
std::string varname = request_->OutVarname(); std::string in_var_name = request_->Varname();
VLOG(3) << "RequestPrefetch " << varname; std::string out_var_name = request_->OutVarname();
VLOG(3) << "RequestPrefetch, in_var_name: " << in_var_name
<< " out_var_name: " << out_var_name;
auto scope = request_->GetMutableLocalScope(); auto scope = request_->GetMutableLocalScope();
auto invar = scope->FindVar(varname); auto invar = scope->FindVar(in_var_name);
framework::Variable* outvar = nullptr; framework::Variable* outvar = scope->FindVar(out_var_name);
request_handler_->Handle(varname, scope, invar, &outvar); request_handler_->Handle(in_var_name, scope, invar, &outvar, out_var_name);
SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(), SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(),
&reply_); &reply_);
Finish(reply_, &responder_); Finish(reply_, &responder_);
} }
...@@ -287,7 +289,7 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, ...@@ -287,7 +289,7 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
} else if (rpc_name == kRequestPrefetch) { } else if (rpc_name == kRequestPrefetch) {
b = new RequestPrefetch(&service_, cq.get(), handler, req_id); b = new RequestPrefetch(&service_, cq.get(), handler, req_id);
} else { } else {
PADDLE_ENFORCE(false, "not surpported rpc"); PADDLE_ENFORCE(false, "not supported rpc");
} }
reqs[req_id] = b; reqs[req_id] = b;
......
...@@ -40,6 +40,7 @@ constexpr char kRequestPrefetch[] = "RequestPrefetch"; ...@@ -40,6 +40,7 @@ constexpr char kRequestPrefetch[] = "RequestPrefetch";
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" #define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV" #define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV"
class RPCServer; class RPCServer;
...@@ -60,9 +61,12 @@ class RequestHandler { ...@@ -60,9 +61,12 @@ class RequestHandler {
void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; } void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
void SetProgram(framework::ProgramDesc* program) { program_ = program; } void SetProgram(framework::ProgramDesc* program) { program_ = program; }
void SetExecutor(framework::Executor* executor) { executor_ = executor; } void SetExecutor(framework::Executor* executor) { executor_ = executor; }
// Used for dist lookup table prefetch
void SetPrefetchPreparedCtx( void SetPrefetchPreparedCtx(
std::unique_ptr<framework::ExecutorPrepareContext> prepared) { std::unordered_map<
prefetch_ctx_.reset(prepared.release()); std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
prefetch_var_name_to_prepared_ctx_ = g;
} }
// Used for async. // Used for async.
...@@ -78,9 +82,6 @@ class RequestHandler { ...@@ -78,9 +82,6 @@ class RequestHandler {
bool sync_mode() { return sync_mode_; } bool sync_mode() { return sync_mode_; }
framework::Scope* scope() { return scope_; } framework::Scope* scope() { return scope_; }
const platform::DeviceContext* dev_ctx() { return dev_ctx_; } const platform::DeviceContext* dev_ctx() { return dev_ctx_; }
framework::ExecutorPrepareContext* prefetch_ctx() {
return prefetch_ctx_.get();
}
framework::ProgramDesc* program() { return program_; } framework::ProgramDesc* program() { return program_; }
framework::Executor* executor() { return executor_; } framework::Executor* executor() { return executor_; }
...@@ -99,8 +100,8 @@ class RequestHandler { ...@@ -99,8 +100,8 @@ class RequestHandler {
// *request_handler_->dev_ctx(), &reply_); // *request_handler_->dev_ctx(), &reply_);
// } // }
virtual bool Handle(const std::string& varname, framework::Scope* scope, virtual bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable* var, framework::Variable** outvar,
framework::Variable** outvar) = 0; const std::string& out_var_name = "") = 0;
protected: protected:
const bool sync_mode_; const bool sync_mode_;
...@@ -109,12 +110,17 @@ class RequestHandler { ...@@ -109,12 +110,17 @@ class RequestHandler {
framework::Executor* executor_; framework::Executor* executor_;
framework::Scope* scope_; framework::Scope* scope_;
framework::ProgramDesc* program_; framework::ProgramDesc* program_;
std::unique_ptr<framework::ExecutorPrepareContext> prefetch_ctx_;
// used for distribute lookup table prefetch
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>*
prefetch_var_name_to_prepared_ctx_;
// Used for async. // Used for async.
std::unordered_map<std::string, std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>* std::shared_ptr<framework::ExecutorPrepareContext>>*
grad_to_prepared_ctx_; grad_to_prepared_ctx_;
RPCServer* rpc_server_; RPCServer* rpc_server_;
}; };
......
...@@ -30,7 +30,8 @@ namespace detail { ...@@ -30,7 +30,8 @@ namespace detail {
bool RequestSendHandler::Handle(const std::string& varname, bool RequestSendHandler::Handle(const std::string& varname,
framework::Scope* scope, framework::Scope* scope,
framework::Variable* invar, framework::Variable* invar,
framework::Variable** outvar) { framework::Variable** outvar,
const std::string& out_var_name) {
VLOG(4) << "RequestSendHandler:" << varname; VLOG(4) << "RequestSendHandler:" << varname;
// Async // Async
...@@ -49,6 +50,9 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -49,6 +50,9 @@ bool RequestSendHandler::Handle(const std::string& varname,
if (varname == BATCH_BARRIER_MESSAGE) { if (varname == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv batch barrier message"; VLOG(3) << "sync: recv batch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestSend); rpc_server_->IncreaseBatchBarrier(kRequestSend);
} else if (varname == COMPLETE_MESSAGE) {
VLOG(3) << "sync: recv complete message";
rpc_server_->DecreaseClientNum();
} else { } else {
VLOG(3) << "sync: received var_name: " << varname; VLOG(3) << "sync: received var_name: " << varname;
if (sync_mode_) { if (sync_mode_) {
...@@ -79,7 +83,8 @@ void RequestSendHandler::ResetSparseVarRecorder() { ...@@ -79,7 +83,8 @@ void RequestSendHandler::ResetSparseVarRecorder() {
bool RequestGetHandler::Handle(const std::string& varname, bool RequestGetHandler::Handle(const std::string& varname,
framework::Scope* scope, framework::Scope* scope,
framework::Variable* invar, framework::Variable* invar,
framework::Variable** outvar) { framework::Variable** outvar,
const std::string& out_var_name) {
VLOG(4) << "RequestGetHandler:" << varname; VLOG(4) << "RequestGetHandler:" << varname;
if (varname != FETCH_BARRIER_MESSAGE) { if (varname != FETCH_BARRIER_MESSAGE) {
...@@ -102,13 +107,14 @@ bool RequestGetHandler::Handle(const std::string& varname, ...@@ -102,13 +107,14 @@ bool RequestGetHandler::Handle(const std::string& varname,
bool RequestPrefetchHandler::Handle(const std::string& varname, bool RequestPrefetchHandler::Handle(const std::string& varname,
framework::Scope* scope, framework::Scope* scope,
framework::Variable* invar, framework::Variable* invar,
framework::Variable** outvar) { framework::Variable** outvar,
const std::string& out_var_name) {
VLOG(4) << "RequestPrefetchHandler " << varname; VLOG(4) << "RequestPrefetchHandler " << varname;
auto var_desc = program_->Block(0).FindVar(varname); auto var_desc = program_->Block(0).FindVar(out_var_name);
*outvar = scope->FindVar(varname);
InitializeVariable(*outvar, var_desc->GetType()); InitializeVariable(*outvar, var_desc->GetType());
executor_->RunPreparedContext(prefetch_ctx_.get(), scope); executor_->RunPreparedContext(
(*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope);
return true; return true;
} }
......
...@@ -39,7 +39,8 @@ class RequestSendHandler final : public RequestHandler { ...@@ -39,7 +39,8 @@ class RequestSendHandler final : public RequestHandler {
explicit RequestSendHandler(bool sync_mode) : RequestHandler(sync_mode) {} explicit RequestSendHandler(bool sync_mode) : RequestHandler(sync_mode) {}
virtual ~RequestSendHandler() {} virtual ~RequestSendHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope, bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar) override; framework::Variable* var, framework::Variable** outvar,
const std::string& out_var_name = "") override;
void ResetSparseVarRecorder(); void ResetSparseVarRecorder();
private: private:
...@@ -52,7 +53,8 @@ class RequestGetHandler final : public RequestHandler { ...@@ -52,7 +53,8 @@ class RequestGetHandler final : public RequestHandler {
explicit RequestGetHandler(bool sync_mode) : RequestHandler(sync_mode) {} explicit RequestGetHandler(bool sync_mode) : RequestHandler(sync_mode) {}
virtual ~RequestGetHandler() {} virtual ~RequestGetHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope, bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar) override; framework::Variable* var, framework::Variable** outvar,
const std::string& out_var_name = "") override;
}; };
class RequestPrefetchHandler final : public RequestHandler { class RequestPrefetchHandler final : public RequestHandler {
...@@ -60,7 +62,8 @@ class RequestPrefetchHandler final : public RequestHandler { ...@@ -60,7 +62,8 @@ class RequestPrefetchHandler final : public RequestHandler {
explicit RequestPrefetchHandler(bool sync_mode) : RequestHandler(sync_mode) {} explicit RequestPrefetchHandler(bool sync_mode) : RequestHandler(sync_mode) {}
virtual ~RequestPrefetchHandler() {} virtual ~RequestPrefetchHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope, bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar) override; framework::Variable* var, framework::Variable** outvar,
const std::string& out_var_name = "") override;
}; };
} // namespace detail } // namespace detail
......
...@@ -53,6 +53,11 @@ class RPCClient { ...@@ -53,6 +53,11 @@ class RPCClient {
virtual void AsyncSendFetchBarrier(const std::string& ep, virtual void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = rpc_time_out) = 0; int64_t time_out = rpc_time_out) = 0;
// SendComplete tells all the server that current trainer have no more data
// to train, so that the pserver can reduce it's barrier count, and continue
// to train with other trainers.
virtual void SendComplete() = 0;
virtual void Wait() = 0; virtual void Wait() = 0;
static constexpr int64_t rpc_time_out = 120 * 1000; static constexpr int64_t rpc_time_out = 120 * 1000;
......
...@@ -43,7 +43,7 @@ void RPCServer::SavePort() const { ...@@ -43,7 +43,7 @@ void RPCServer::SavePort() const {
void RPCServer::WaitBarrier(const std::string& rpc_name) { void RPCServer::WaitBarrier(const std::string& rpc_name) {
std::unique_lock<std::mutex> lock(this->mutex_); std::unique_lock<std::mutex> lock(this->mutex_);
barrier_cond_.wait(lock, [=] { barrier_cond_.wait(lock, [this, &rpc_name] {
return (barrier_counter_[rpc_name] >= client_num_ || exit_flag_.load()); return (barrier_counter_[rpc_name] >= client_num_ || exit_flag_.load());
}); });
...@@ -53,19 +53,23 @@ void RPCServer::WaitBarrier(const std::string& rpc_name) { ...@@ -53,19 +53,23 @@ void RPCServer::WaitBarrier(const std::string& rpc_name) {
void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) { void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
VLOG(3) << "RPCServer begin IncreaseBatchBarrier " << rpc_name; VLOG(3) << "RPCServer begin IncreaseBatchBarrier " << rpc_name;
int b = 0; int b = 0;
{ std::unique_lock<std::mutex> lock(mutex_);
std::unique_lock<std::mutex> lock(mutex_); b = ++barrier_counter_[rpc_name];
b = ++barrier_counter_[rpc_name];
}
VLOG(3) << "RPCServer IncreaseBatchBarrier " << rpc_name
<< ", barrier_count:" << b << ", fan_in" << client_num_;
if (b >= client_num_) { if (b >= client_num_) {
lock.unlock();
barrier_cond_.notify_all(); barrier_cond_.notify_all();
lock.lock();
} }
} }
void RPCServer::DecreaseClientNum() {
{
std::unique_lock<std::mutex> lock(mutex_);
client_num_--;
}
barrier_cond_.notify_all();
}
void RPCServer::ResetBarrierCounter() { void RPCServer::ResetBarrierCounter() {
VLOG(3) << "RPCServer ResetBarrierCounter "; VLOG(3) << "RPCServer ResetBarrierCounter ";
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
......
...@@ -60,7 +60,7 @@ class RPCServer { ...@@ -60,7 +60,7 @@ class RPCServer {
void SetCond(const std::string& rpc_name); void SetCond(const std::string& rpc_name);
void WaitCond(const std::string& rpc_name); void WaitCond(const std::string& rpc_name);
void IncreaseBatchBarrier(const std::string rpc_name); void IncreaseBatchBarrier(const std::string rpc_name);
void DecreaseClientNum();
void ResetBarrierCounter(); void ResetBarrierCounter();
protected: protected:
...@@ -79,8 +79,7 @@ class RPCServer { ...@@ -79,8 +79,7 @@ class RPCServer {
std::string bind_address_; std::string bind_address_;
std::atomic<int> exit_flag_; std::atomic<int> exit_flag_;
int selected_port_; int selected_port_;
int client_num_;
const int client_num_;
std::unordered_map<std::string, RequestHandler*> rpc_call_map_; std::unordered_map<std::string, RequestHandler*> rpc_call_map_;
std::unordered_map<std::string, int> rpc_thread_num_; std::unordered_map<std::string, int> rpc_thread_num_;
......
...@@ -98,11 +98,17 @@ void StartServer() { ...@@ -98,11 +98,17 @@ void StartServer() {
framework::Executor exe(place); framework::Executor exe(place);
platform::CPUDeviceContext ctx(place); platform::CPUDeviceContext ctx(place);
auto* block = AppendPrefetchBlcok(&program); auto* block = AppendPrefetchBlcok(&program);
auto prepared = exe.Prepare(program, block->ID()); std::string in_var_name("ids");
std::vector<int> prefetch_block_ids{block->ID()};
auto prepared = exe.Prepare(program, prefetch_block_ids);
InitTensorsOnServer(&scope, &place, 10); InitTensorsOnServer(&scope, &place, 10);
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
prefetch_var_name_to_prepared;
prefetch_var_name_to_prepared[in_var_name] = prepared[0];
g_req_handler->SetProgram(&program); g_req_handler->SetProgram(&program);
g_req_handler->SetPrefetchPreparedCtx(std::move(prepared)); g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared);
g_req_handler->SetDevCtx(&ctx); g_req_handler->SetDevCtx(&ctx);
g_req_handler->SetScope(&scope); g_req_handler->SetScope(&scope);
g_req_handler->SetExecutor(&exe); g_req_handler->SetExecutor(&exe);
......
...@@ -66,40 +66,41 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -66,40 +66,41 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(-1) .SetDefault(-1)
.EqualGreaterThan(-1); .EqualGreaterThan(-1);
AddComment(string::Sprintf(R"DOC( AddComment(string::Sprintf(R"DOC(
Limited Elementwise %s Operator. Limited Elementwise %s Operator
The equation is: The equation is:
$$%s$$ $$%s$$
$X$ is a tensor of any dimension and the dimensions of tensor $Y$ must be - $X$: a tensor of any dimension.
smaller than or equal to the dimensions of $X$. - $Y$: a tensor whose dimensions must be less than or equal to the dimensions of $X$.
There are two cases for this operator: There are two cases for this operator:
1. The shape of $Y$ is same with $X$;
2. The shape of $Y$ is a congiguous subsequencet of $X$. The trailing dimensions
of size 1 for $Y$ will be ignored for the consideration of subsequence.
1. The shape of $Y$ is the same with $X$.
2. The shape of $Y$ is a continuous subsequence of $X$.
For case 2: For case 2:
$Y$ will be broadcasted to match the shape of $X$ and axis should be 1. Broadcast $Y$ to match the shape of $X$, where $axis$ is the start dimension index
set to index of the start dimension to broadcast $Y$ onto $X$. for broadcasting $Y$ onto $X$.
2. If $axis$ is -1 (default), $axis = rank(X) - rank(Y)$.
3. The trailing dimensions of size 1 for $Y$ will be ignored for the consideration of
subsequence, such as shape(Y) = (2, 1) => (2).
If axis is -1, it is treated as axis=rank(X)-rank(Y). For example:
For example
.. code-block:: python .. code-block:: python
shape(X) = (2, 3, 4, 5), shape(Y) = (,) shape(X) = (2, 3, 4, 5), shape(Y) = (,)
shape(X) = (2, 3, 4, 5), shape(Y) = (5,) shape(X) = (2, 3, 4, 5), shape(Y) = (5,)
shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5) shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5), with axis=-1(default) or axis=2
shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1 shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0 shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0
shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1), with axis=0 shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1), with axis=0
Either of the inputs $X$ and $Y$ or none can carry the LoD (Level of Details) The inputs $X$ and $Y$ can carry the different LoD information.
information. However, the output only shares the LoD information with input $X$. But the output only shares the LoD information with the input $X$.
)DOC", )DOC",
GetName(), GetEquation())); GetName(), GetEquation()));
......
...@@ -96,19 +96,22 @@ static int64_t GetTimestamp() { ...@@ -96,19 +96,22 @@ static int64_t GetTimestamp() {
return tp.tv_sec * 1000 + tp.tv_usec / 1000; return tp.tv_sec * 1000 + tp.tv_usec / 1000;
} }
void ListenAndServOp::RunSyncLoop(framework::Executor *executor, void ListenAndServOp::RunSyncLoop(
framework::ProgramDesc *program, framework::Executor *executor, framework::ProgramDesc *program,
framework::Scope *recv_scope, framework::Scope *recv_scope,
framework::BlockDesc *prefetch_block) const { const std::vector<int> &prefetch_block_id_list) const {
size_t num_blocks = program->Size(); size_t num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2, PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks"); "server program should have at least 2 blocks");
std::vector<int> block_list; std::vector<int> optimize_block_id_list;
for (size_t blkid = 1; blkid < num_blocks; ++blkid) { for (int blkid = 1; blkid < num_blocks; ++blkid) {
block_list.push_back(blkid); if (std::find(prefetch_block_id_list.begin(), prefetch_block_id_list.end(),
blkid) == prefetch_block_id_list.end()) {
optimize_block_id_list.push_back(blkid);
}
} }
auto optimize_prepared = executor->Prepare(*program, block_list); auto optimize_prepared = executor->Prepare(*program, optimize_block_id_list);
// Insert placeholder for block0 which holds current op itself. // Insert placeholder for block0 which holds current op itself.
optimize_prepared.insert( optimize_prepared.insert(
optimize_prepared.begin(), optimize_prepared.begin(),
...@@ -135,16 +138,17 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, ...@@ -135,16 +138,17 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
std::vector<size_t> parallel_blkids; std::vector<size_t> parallel_blkids;
parallel_blkids.push_back(1); parallel_blkids.push_back(1);
double ts = GetTimestamp(); double ts = GetTimestamp();
for (size_t blkid = 2; blkid < num_blocks; ++blkid) { for (size_t i = 1; i < optimize_block_id_list.size(); ++i) {
if (blkid != static_cast<size_t>(prefetch_block->ID())) { // skip the first optimize block because it is already in the
if (program->Block(blkid).Parent() != last_parent_blkid) { // parallel_blkids.
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, int blkid = optimize_block_id_list[i];
program, recv_scope); if (program->Block(blkid).Parent() != last_parent_blkid) {
parallel_blkids.clear(); ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared,
last_parent_blkid = program->Block(blkid).Parent(); program, recv_scope);
} parallel_blkids.clear();
parallel_blkids.push_back(blkid); last_parent_blkid = program->Block(blkid).Parent();
} }
parallel_blkids.push_back(blkid);
} }
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, program, ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, program,
recv_scope); recv_scope);
...@@ -210,18 +214,19 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -210,18 +214,19 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
} // while(true) } // while(true)
} }
static void FillRequestCtx(detail::RequestHandler *h, framework::Scope *scope, static void FillRequestCtx(
platform::DeviceContext *dev_ctx, detail::RequestHandler *h, framework::Scope *scope,
framework::Executor *executor, platform::DeviceContext *dev_ctx, framework::Executor *executor,
framework::ProgramDesc *program, framework::ProgramDesc *program,
framework::ExecutorPrepareContext *prefetch_ctx, std::unordered_map<std::string,
detail::RPCServer *rpc_server) { std::shared_ptr<framework::ExecutorPrepareContext>>
*prefetch_ctx,
detail::RPCServer *rpc_server) {
h->SetScope(scope); h->SetScope(scope);
h->SetDevCtx(dev_ctx); h->SetDevCtx(dev_ctx);
h->SetExecutor(executor); h->SetExecutor(executor);
h->SetProgram(program); h->SetProgram(program);
h->SetPrefetchPreparedCtx( h->SetPrefetchPreparedCtx(prefetch_ctx);
std::unique_ptr<framework::ExecutorPrepareContext>(prefetch_ctx));
h->SetRPCServer(rpc_server); h->SetRPCServer(rpc_server);
} }
...@@ -255,17 +260,42 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -255,17 +260,42 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
request_prefetch_handler_.get()); request_prefetch_handler_.get());
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock); auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
auto *program = optimize_block->Program(); auto *program = optimize_block->Program();
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
// prepare for prefetch // prepare for prefetch
VLOG(3) << "prefetch block id is " << prefetch_block->ID(); std::vector<int> prefetch_block_id_list;
auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID()); std::unordered_map<int, std::string> block_id_to_prefetch_var_name;
auto prefetch_var_name_to_block_id_str =
Attr<std::vector<std::string>>(kPrefetchVarNameToBlockId);
for (const auto &prefetch_var_name_and_id :
prefetch_var_name_to_block_id_str) {
std::vector<std::string> pieces;
split(prefetch_var_name_and_id, ':', &pieces);
VLOG(3) << "after split, prefetch_var = " << pieces[0]
<< ", id=" << pieces[1];
PADDLE_ENFORCE_EQ(pieces.size(), 2);
int block_id = std::stoi(pieces[1]);
prefetch_block_id_list.push_back(block_id);
block_id_to_prefetch_var_name[block_id] = pieces[0];
}
auto prefetch_prepared = executor.Prepare(*program, prefetch_block_id_list);
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
prefetch_var_name_to_prepared_ctx;
for (size_t i = 0; i < prefetch_block_id_list.size(); ++i) {
auto block_id = prefetch_block_id_list[i];
auto prefetch_var_name = block_id_to_prefetch_var_name[block_id];
prefetch_var_name_to_prepared_ctx[prefetch_var_name] = prefetch_prepared[i];
}
auto f = std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, auto f = std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope,
&dev_ctx, &executor, program, prefetch_prepared.release(), &dev_ctx, &executor, program,
rpc_service_.get()); &prefetch_var_name_to_prepared_ctx, rpc_service_.get());
f(request_send_handler_.get()); f(request_send_handler_.get());
f(request_get_handler_.get()); f(request_get_handler_.get());
...@@ -283,7 +313,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -283,7 +313,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
// Write to a file of server selected port for python use. // Write to a file of server selected port for python use.
SavePort(); SavePort();
if (sync_mode) { if (sync_mode) {
RunSyncLoop(&executor, program, &recv_scope, prefetch_block); RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list);
} else { } else {
RunAsyncLoop(&executor, program); RunAsyncLoop(&executor, program);
} }
...@@ -309,8 +339,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -309,8 +339,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true); AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true);
AddAttr<framework::BlockDesc *>(kOptimizeBlock, AddAttr<framework::BlockDesc *>(kOptimizeBlock,
"BlockID to run on server side."); "BlockID to run on server side.");
AddAttr<framework::BlockDesc *>(kPrefetchBlock, AddAttr<std::vector<std::string>>(kPrefetchVarNameToBlockId,
"prefetch block to run on server side."); "prefetch blocks to run on server side.")
.SetDefault({});
AddAttr<int>("Fanin", "How many clients send to this server.") AddAttr<int>("Fanin", "How many clients send to this server.")
.SetDefault(1); .SetDefault(1);
} }
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include <atomic> #include <atomic>
#include <set> #include <set>
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
...@@ -30,7 +31,7 @@ namespace paddle { ...@@ -30,7 +31,7 @@ namespace paddle {
namespace operators { namespace operators {
constexpr char kOptimizeBlock[] = "OptimizeBlock"; constexpr char kOptimizeBlock[] = "OptimizeBlock";
constexpr char kPrefetchBlock[] = "PrefetchBlock"; constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id";
void RunServer(std::shared_ptr<detail::RPCServer> service); void RunServer(std::shared_ptr<detail::RPCServer> service);
...@@ -46,7 +47,7 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -46,7 +47,7 @@ class ListenAndServOp : public framework::OperatorBase {
void RunSyncLoop(framework::Executor* executor, void RunSyncLoop(framework::Executor* executor,
framework::ProgramDesc* program, framework::ProgramDesc* program,
framework::Scope* recv_scope, framework::Scope* recv_scope,
framework::BlockDesc* prefetch_block) const; const std::vector<int>& prefetch_block_id_list) const;
void RunAsyncLoop(framework::Executor* executor, void RunAsyncLoop(framework::Executor* executor,
framework::ProgramDesc* program) const; framework::ProgramDesc* program) const;
......
...@@ -20,7 +20,7 @@ namespace reader { ...@@ -20,7 +20,7 @@ namespace reader {
class BatchReader : public framework::DecoratedReader { class BatchReader : public framework::DecoratedReader {
public: public:
BatchReader(ReaderBase* reader, int batch_size) BatchReader(const std::shared_ptr<ReaderBase>& reader, int batch_size)
: DecoratedReader(reader), batch_size_(batch_size) { : DecoratedReader(reader), batch_size_(batch_size) {
buffer_.reserve(batch_size_); buffer_.reserve(batch_size_);
} }
......
...@@ -22,7 +22,8 @@ namespace reader { ...@@ -22,7 +22,8 @@ namespace reader {
class CustomReader : public framework::DecoratedReader { class CustomReader : public framework::DecoratedReader {
public: public:
CustomReader(ReaderBase* reader, const framework::BlockDesc& sub_block, CustomReader(const std::shared_ptr<ReaderBase>& reader,
const framework::BlockDesc& sub_block,
const std::vector<std::string>& source_var_names, const std::vector<std::string>& source_var_names,
const std::vector<std::string>& sink_var_names) const std::vector<std::string>& sink_var_names)
: DecoratedReader(reader), : DecoratedReader(reader),
......
...@@ -34,7 +34,8 @@ static constexpr size_t kChannelSize = 1; // kCacheSize - 2 ...@@ -34,7 +34,8 @@ static constexpr size_t kChannelSize = 1; // kCacheSize - 2
class DoubleBufferReader : public framework::DecoratedReader { class DoubleBufferReader : public framework::DecoratedReader {
public: public:
explicit DoubleBufferReader( explicit DoubleBufferReader(
ReaderBase* reader, platform::Place target_place = platform::CPUPlace()) const std::shared_ptr<ReaderBase>& reader,
platform::Place target_place = platform::CPUPlace())
: DecoratedReader(reader), place_(target_place) { : DecoratedReader(reader), place_(target_place) {
cpu_tensor_cache_.resize(kCacheSize); cpu_tensor_cache_.resize(kCacheSize);
gpu_tensor_cache_.resize(kCacheSize); gpu_tensor_cache_.resize(kCacheSize);
......
...@@ -21,7 +21,7 @@ namespace reader { ...@@ -21,7 +21,7 @@ namespace reader {
class MultiPassReader : public framework::DecoratedReader { class MultiPassReader : public framework::DecoratedReader {
public: public:
MultiPassReader(ReaderBase* reader, int pass_num) MultiPassReader(const std::shared_ptr<ReaderBase>& reader, int pass_num)
: DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {} : DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNext(std::vector<framework::LoDTensor>* out) override {
......
...@@ -23,7 +23,8 @@ namespace reader { ...@@ -23,7 +23,8 @@ namespace reader {
class ShuffleReader : public framework::DecoratedReader { class ShuffleReader : public framework::DecoratedReader {
public: public:
ShuffleReader(ReaderBase* reader, size_t buffer_size, size_t seed = 0) ShuffleReader(const std::shared_ptr<ReaderBase>& reader, size_t buffer_size,
size_t seed = 0)
: DecoratedReader(reader), buffer_size_(buffer_size), seed_(seed) { : DecoratedReader(reader), buffer_size_(buffer_size), seed_(seed) {
VLOG(10) << "Create shuffle reader of " << reader_; VLOG(10) << "Create shuffle reader of " << reader_;
if (seed_ == 0) { if (seed_ == 0) {
......
...@@ -21,7 +21,8 @@ namespace reader { ...@@ -21,7 +21,8 @@ namespace reader {
class ThreadedReader : public framework::DecoratedReader { class ThreadedReader : public framework::DecoratedReader {
public: public:
explicit ThreadedReader(ReaderBase* reader) : DecoratedReader(reader) {} explicit ThreadedReader(const std::shared_ptr<ReaderBase>& reader)
: DecoratedReader(reader) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNext(std::vector<framework::LoDTensor>* out) override {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
......
...@@ -21,12 +21,17 @@ limitations under the License. */ ...@@ -21,12 +21,17 @@ limitations under the License. */
#include <unistd.h> #include <unistd.h>
#endif #endif
#include <algorithm>
#include "gflags/gflags.h" #include "gflags/gflags.h"
DEFINE_double(fraction_of_cpu_memory_to_use, 1, DEFINE_double(fraction_of_cpu_memory_to_use, 1,
"Default use 100% of CPU memory for PaddlePaddle," "Default use 100% of CPU memory for PaddlePaddle,"
"reserve the rest for page tables, etc"); "reserve the rest for page tables, etc");
DEFINE_uint64(
initial_cpu_memory_in_mb, 500,
"Default initial 500MB of CPU memory for PaddlePaddle, in MD unit.");
DEFINE_double( DEFINE_double(
fraction_of_cuda_pinned_memory_to_use, 0.5, fraction_of_cuda_pinned_memory_to_use, 0.5,
"Default use 50% of CPU memory as the pinned_memory for PaddlePaddle," "Default use 50% of CPU memory as the pinned_memory for PaddlePaddle,"
...@@ -54,7 +59,9 @@ inline size_t CpuTotalPhysicalMemory() { ...@@ -54,7 +59,9 @@ inline size_t CpuTotalPhysicalMemory() {
size_t CpuMaxAllocSize() { size_t CpuMaxAllocSize() {
// For distributed systems, it requires configuring and limiting // For distributed systems, it requires configuring and limiting
// the fraction of memory to use. // the fraction of memory to use.
return FLAGS_fraction_of_cpu_memory_to_use * CpuTotalPhysicalMemory(); return std::min(static_cast<size_t>(FLAGS_fraction_of_cpu_memory_to_use *
CpuTotalPhysicalMemory()),
FLAGS_initial_cpu_memory_in_mb * 1 << 20);
} }
size_t CpuMinChunkSize() { size_t CpuMinChunkSize() {
......
...@@ -413,6 +413,9 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -413,6 +413,9 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<framework::Executor>(m, "Executor") py::class_<framework::Executor>(m, "Executor")
.def(py::init<const platform::Place &>()) .def(py::init<const platform::Place &>())
#ifdef PADDLE_WITH_DISTRIBUTE
.def("complete", &Executor::Complete)
#endif
.def("run", .def("run",
(void (Executor::*)(const ProgramDesc &, Scope *, int, bool, bool)) & (void (Executor::*)(const ProgramDesc &, Scope *, int, bool, bool)) &
Executor::Run); Executor::Run);
......
此差异已折叠。
...@@ -515,35 +515,38 @@ class DistributeTranspiler: ...@@ -515,35 +515,38 @@ class DistributeTranspiler:
grad_to_block_id, None) grad_to_block_id, None)
# process distributed lookup_table # process distributed lookup_table
prefetch_block = None prefetch_var_name_to_block_id = []
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
pserver_index = self.pserver_endpoints.index(endpoint) pserver_index = self.pserver_endpoints.index(endpoint)
table_opt_block = self._create_table_optimize_block( table_opt_block = self._create_table_optimize_block(
pserver_index, pserver_program, pre_block_idx, grad_to_block_id) pserver_index, pserver_program, pre_block_idx, grad_to_block_id)
prefetch_block = self._create_prefetch_block( prefetch_var_name_to_block_id = self._create_prefetch_block(
pserver_index, pserver_program, table_opt_block) pserver_index, pserver_program, table_opt_block)
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will # NOTE: if has_distributed_lookup_table is False, then prefetch_block will
# not be executed, so it's safe to use optimize_block to hold the place # not be executed, so it's safe to use optimize_block to hold the place
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
assert prefetch_block is not None assert len(prefetch_var_name_to_block_id) > 0
else: else:
assert prefetch_block is None assert len(prefetch_var_name_to_block_id) == 0
prefetch_block = pserver_program.global_block()
attrs = {
"OptimizeBlock": pserver_program.block(1),
"endpoint": endpoint,
"Fanin": self.trainer_num,
"sync_mode": self.sync_mode,
"grad_to_block_id": grad_to_block_id
}
if len(prefetch_var_name_to_block_id) > 0:
attrs['prefetch_var_name_to_block_id'] \
= prefetch_var_name_to_block_id
# step5 append the listen_and_serv op # step5 append the listen_and_serv op
pserver_program.global_block().append_op( pserver_program.global_block().append_op(
type="listen_and_serv", type="listen_and_serv",
inputs={'X': recv_inputs}, inputs={'X': recv_inputs},
outputs={}, outputs={},
attrs={ attrs=attrs)
"OptimizeBlock": pserver_program.block(1),
"endpoint": endpoint,
"Fanin": self.trainer_num,
"PrefetchBlock": prefetch_block,
"sync_mode": self.sync_mode,
"grad_to_block_id": grad_to_block_id
})
pserver_program.sync_with_cpp() pserver_program.sync_with_cpp()
return pserver_program return pserver_program
...@@ -608,8 +611,15 @@ class DistributeTranspiler: ...@@ -608,8 +611,15 @@ class DistributeTranspiler:
def _replace_lookup_table_op_with_prefetch(self, program, def _replace_lookup_table_op_with_prefetch(self, program,
pserver_endpoints): pserver_endpoints):
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op # 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
self.prefetch_input_vars = None # self.all_prefetch_input_vars =
self.prefetch_output_vars = None # [[var0_prefetch_in_pserver0, var0_prefetch_in_pserver1]
# [var1_prefetch_in_pserver0, var1_prefetch_in_pserver1]]
self.all_prefetch_input_vars = []
# self.all_prefetch_input_vars =
# [[var0_prefetch_in_pserver0, var0_prefetch_in_pserver1]
# [var1_prefetch_in_pserver0, var1_prefetch_in_pserver1]]
self.all_prefetch_output_vars = []
continue_search_lookup_table_op = True continue_search_lookup_table_op = True
while continue_search_lookup_table_op: while continue_search_lookup_table_op:
...@@ -623,18 +633,19 @@ class DistributeTranspiler: ...@@ -623,18 +633,19 @@ class DistributeTranspiler:
ids_name = op.input("Ids") ids_name = op.input("Ids")
out_name = op.output("Out") out_name = op.output("Out")
if self.prefetch_input_vars is None: ids_var = program.global_block().vars[ids_name[0]]
ids_var = program.global_block().vars[ids_name[0]] prefetch_input_vars = self.create_splited_vars(
self.prefetch_input_vars = self.create_splited_vars( source_var=ids_var,
source_var=ids_var, block=program.global_block(),
block=program.global_block(), tag="_prefetch_in_")
tag="_prefetch_in_") self.all_prefetch_input_vars.append(prefetch_input_vars)
if self.prefetch_output_vars is None:
out_var = program.global_block().vars[out_name[0]] out_var = program.global_block().vars[out_name[0]]
self.prefetch_output_vars = self.create_splited_vars( prefetch_output_vars = self.create_splited_vars(
source_var=out_var, source_var=out_var,
block=program.global_block(), block=program.global_block(),
tag="_prefetch_out_") tag="_prefetch_out_")
self.all_prefetch_output_vars.append(prefetch_output_vars)
# insert split_ids_op # insert split_ids_op
program.global_block().insert_op( program.global_block().insert_op(
...@@ -646,14 +657,14 @@ class DistributeTranspiler: ...@@ -646,14 +657,14 @@ class DistributeTranspiler:
for varname in ids_name for varname in ids_name
] ]
}, },
outputs={"Out": self.prefetch_input_vars}) outputs={"Out": prefetch_input_vars})
# insert prefetch_op # insert prefetch_op
program.global_block().insert_op( program.global_block().insert_op(
index=op_index + 1, index=op_index + 1,
type="prefetch", type="prefetch",
inputs={'X': self.prefetch_input_vars}, inputs={'X': prefetch_input_vars},
outputs={"Out": self.prefetch_output_vars}, outputs={"Out": prefetch_output_vars},
attrs={ attrs={
"epmap": pserver_endpoints, "epmap": pserver_endpoints,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
...@@ -663,7 +674,7 @@ class DistributeTranspiler: ...@@ -663,7 +674,7 @@ class DistributeTranspiler:
program.global_block().insert_op( program.global_block().insert_op(
index=op_index + 2, index=op_index + 2,
type="concat", type="concat",
inputs={'X': self.prefetch_output_vars}, inputs={'X': prefetch_output_vars},
outputs={ outputs={
"Out": [ "Out": [
program.global_block().vars[varname] program.global_block().vars[varname]
...@@ -709,30 +720,34 @@ class DistributeTranspiler: ...@@ -709,30 +720,34 @@ class DistributeTranspiler:
optimize_block): optimize_block):
# STEP: create prefetch block # STEP: create prefetch block
table_var = pserver_program.global_block().vars[self.table_name] table_var = pserver_program.global_block().vars[self.table_name]
prefetch_block = pserver_program.create_block(optimize_block.idx) prefetch_var_name_to_block_id = []
trainer_ids = self.prefetch_input_vars[pserver_index] for index in range(len(self.all_prefetch_input_vars)):
pserver_ids = pserver_program.global_block().create_var( prefetch_block = pserver_program.create_block(optimize_block.idx)
name=trainer_ids.name, trainer_ids = self.all_prefetch_input_vars[index][pserver_index]
type=trainer_ids.type, pserver_ids = pserver_program.global_block().create_var(
shape=trainer_ids.shape, name=trainer_ids.name,
dtype=trainer_ids.dtype) type=trainer_ids.type,
trainer_out = self.prefetch_output_vars[pserver_index] shape=trainer_ids.shape,
pserver_out = pserver_program.global_block().create_var( dtype=trainer_ids.dtype)
name=trainer_out.name, trainer_out = self.all_prefetch_output_vars[index][pserver_index]
type=trainer_out.type, pserver_out = pserver_program.global_block().create_var(
shape=trainer_out.shape, name=trainer_out.name,
dtype=trainer_out.dtype) type=trainer_out.type,
prefetch_block.append_op( shape=trainer_out.shape,
type="lookup_sparse_table", dtype=trainer_out.dtype)
inputs={'Ids': pserver_ids, prefetch_block.append_op(
"W": table_var}, type="lookup_sparse_table",
outputs={"Out": pserver_out}, inputs={'Ids': pserver_ids,
attrs={ "W": table_var},
"is_sparse": True, # has no effect on lookup_table op outputs={"Out": pserver_out},
"is_distributed": True, attrs={
"padding_idx": -1 "is_sparse": True, # has no effect on lookup_table op
}) "is_distributed": True,
return prefetch_block "padding_idx": -1
})
prefetch_var_name_to_block_id.append(trainer_ids.name + ":" + str(
prefetch_block.idx))
return prefetch_var_name_to_block_id
def _create_table_optimize_block(self, pserver_index, pserver_program, def _create_table_optimize_block(self, pserver_index, pserver_program,
pre_block_idx, grad_to_block_id): pre_block_idx, grad_to_block_id):
......
...@@ -126,9 +126,10 @@ class DocstringChecker(BaseChecker): ...@@ -126,9 +126,10 @@ class DocstringChecker(BaseChecker):
'W9002': 'W9002':
('Doc string does not end with "." period', symbol + "-end-with", ('Doc string does not end with "." period', symbol + "-end-with",
'Used when a doc string does not end with a period'), 'Used when a doc string does not end with a period'),
'W9003': ('All args with their types must be mentioned in doc string', 'W9003':
symbol + "-with-all-args", ('All args with their types must be mentioned in doc string %s',
'Used when not all arguments are in the doc string '), symbol + "-with-all-args",
'Used when not all arguments are in the doc string '),
'W9005': ('Missing docstring or docstring is too short', 'W9005': ('Missing docstring or docstring is too short',
symbol + "-missing", 'Add docstring longer >=10'), symbol + "-missing", 'Add docstring longer >=10'),
'W9006': ('Docstring indent error, use 4 space for indent', 'W9006': ('Docstring indent error, use 4 space for indent',
...@@ -178,6 +179,8 @@ class DocstringChecker(BaseChecker): ...@@ -178,6 +179,8 @@ class DocstringChecker(BaseChecker):
self.indent_style(node) self.indent_style(node)
def missing_doc_string(self, node): def missing_doc_string(self, node):
if node.name.startswith("__") or node.name.startswith("_"):
return True
if node.tolineno - node.fromlineno <= 10: if node.tolineno - node.fromlineno <= 10:
return True return True
...@@ -199,12 +202,16 @@ class DocstringChecker(BaseChecker): ...@@ -199,12 +202,16 @@ class DocstringChecker(BaseChecker):
doc = node.doc doc = node.doc
lines = doc.splitlines() lines = doc.splitlines()
line_num = 0
for l in lines: for l in lines:
if line_num == 0:
continue
cur_indent = len(l) - len(l.lstrip()) cur_indent = len(l) - len(l.lstrip())
if cur_indent % indent != 0: if cur_indent % indent != 0:
self.add_message('W9006', node=node, line=node.fromlineno) self.add_message('W9006', node=node, line=node.fromlineno)
return False return False
line_num += 1
return True return True
...@@ -320,15 +327,19 @@ class DocstringChecker(BaseChecker): ...@@ -320,15 +327,19 @@ class DocstringChecker(BaseChecker):
return True return True
parsed_args = doc.args parsed_args = doc.args
args_not_documented = set(args) - set(parsed_args)
if len(args) > 0 and len(parsed_args) <= 0: if len(args) > 0 and len(parsed_args) <= 0:
print "debug:parsed args: ", parsed_args self.add_message(
self.add_message('W9003', node=node, line=node.fromlineno) 'W9003',
node=node,
line=node.fromlineno,
args=list(args_not_documented))
return False return False
for t in args: for t in args:
if t not in parsed_args: if t not in parsed_args:
print t, " with (type) not in ", parsed_args self.add_message(
self.add_message('W9003', node=node, line=node.fromlineno) 'W9003', node=node, line=node.fromlineno, args=[t, ])
return False return False
return True return True
...@@ -7,13 +7,13 @@ DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" ...@@ -7,13 +7,13 @@ DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
export PYTHONPATH=$DIR:$PYTHONPATH export PYTHONPATH=$DIR:$PYTHONPATH
# The trick to remove deleted files: https://stackoverflow.com/a/2413151 # The trick to remove deleted files: https://stackoverflow.com/a/2413151
for file in $(git diff --cached --name-status | awk '$1 != "D" {print $2}'); do for file in $(git diff --name-status | awk '$1 != "D" {print $2}'); do
pylint --disable=all --load-plugins=docstring_checker \ pylint --disable=all --load-plugins=docstring_checker \
--enable=doc-string-one-line,doc-string-end-with,doc-string-with-all-args,doc-string-triple-quotes,doc-string-missing,doc-string-indent-error,doc-string-with-returns,doc-string-with-raises $file; --enable=doc-string-one-line,doc-string-end-with,doc-string-with-all-args,doc-string-triple-quotes,doc-string-missing,doc-string-indent-error,doc-string-with-returns,doc-string-with-raises $file;
TOTAL_ERRORS=$(expr $TOTAL_ERRORS + $?); TOTAL_ERRORS=$(expr $TOTAL_ERRORS + $?);
done done
#exit $TOTAL_ERRORS exit $TOTAL_ERRORS
#For now, just warning: #For now, just warning:
exit 0 #exit 0
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册