提交 13ecb5e5 编写于 作者: Q qiaolongfei

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into...

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into lookup_table_support_SelectedRows_as_parameter
...@@ -17,8 +17,6 @@ limitations under the License. */ ...@@ -17,8 +17,6 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/threadpool.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
#endif #endif
......
...@@ -2,7 +2,7 @@ if(WITH_DISTRIBUTE) ...@@ -2,7 +2,7 @@ if(WITH_DISTRIBUTE)
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows) grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(serde_test.cc grpc_server_test PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr cc_test(serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
cares zlib protobuf sendrecvop_grpc) cares zlib protobuf sendrecvop_grpc)
cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf) cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)
......
...@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "grpc_client.h" #include "paddle/fluid/operators/detail/grpc_client.h"
#include <sys/time.h>
#include <limits>
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
namespace paddle { namespace paddle {
...@@ -52,7 +54,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, ...@@ -52,7 +54,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_); s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
call->StartCall(); call->StartCall();
call->Finish(&s->reply_, &s->status_, (void*)s); call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
}); });
req_count_++; req_count_++;
...@@ -70,8 +72,7 @@ void ProcGetResponse(const VarHandle& var_h, ...@@ -70,8 +72,7 @@ void ProcGetResponse(const VarHandle& var_h,
template <typename T> template <typename T>
void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) { void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
::grpc::Slice slice(proto.ByteSizeLong()); ::grpc::Slice slice(proto.ByteSizeLong());
proto.SerializeWithCachedSizesToArray( proto.SerializeWithCachedSizesToArray(const_cast<uint8_t*>(slice.begin()));
const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(slice.begin())));
::grpc::ByteBuffer tmp(&slice, 1); ::grpc::ByteBuffer tmp(&slice, 1);
result->Swap(&tmp); result->Swap(&tmp);
} }
...@@ -109,7 +110,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, ...@@ -109,7 +110,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_); s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_);
call->StartCall(); call->StartCall();
call->Finish(&s->reply_, &s->status_, (void*)s); call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
}); });
req_count_++; req_count_++;
...@@ -153,7 +154,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, ...@@ -153,7 +154,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req, s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
&cq_); &cq_);
call->StartCall(); call->StartCall();
call->Finish(&s->reply_, &s->status_, (void*)s); call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
}); });
req_count_++; req_count_++;
...@@ -169,7 +170,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { ...@@ -169,7 +170,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(BATCH_BARRIER_MESSAGE); req.set_varname(BATCH_BARRIER_MESSAGE);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s); rpc->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
req_count_++; req_count_++;
} }
...@@ -181,7 +182,7 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { ...@@ -181,7 +182,7 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(FETCH_BARRIER_MESSAGE); req.set_varname(FETCH_BARRIER_MESSAGE);
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s); rpc->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
req_count_++; req_count_++;
} }
......
...@@ -14,6 +14,9 @@ limitations under the License. */ ...@@ -14,6 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/grpc_server.h" #include "paddle/fluid/operators/detail/grpc_server.h"
#include <limits>
#include <string>
using ::grpc::ServerAsyncResponseWriter; using ::grpc::ServerAsyncResponseWriter;
namespace paddle { namespace paddle {
...@@ -156,6 +159,8 @@ class RequestPrefetch final : public RequestBase { ...@@ -156,6 +159,8 @@ class RequestPrefetch final : public RequestBase {
::grpc::ByteBuffer reply; ::grpc::ByteBuffer reply;
// TODO(Yancey1989): execute the Block which containers prefetch ops // TODO(Yancey1989): execute the Block which containers prefetch ops
VLOG(3) << "RequestPrefetch Process in";
responder_.Finish(reply, ::grpc::Status::OK, this); responder_.Finish(reply, ::grpc::Status::OK, this);
status_ = FINISH; status_ = FINISH;
} }
...@@ -221,6 +226,7 @@ void AsyncGRPCServer::ShutdownQueue() { ...@@ -221,6 +226,7 @@ void AsyncGRPCServer::ShutdownQueue() {
std::unique_lock<std::mutex> lock(cq_mutex_); std::unique_lock<std::mutex> lock(cq_mutex_);
cq_send_->Shutdown(); cq_send_->Shutdown();
cq_get_->Shutdown(); cq_get_->Shutdown();
cq_prefetch_->Shutdown();
} }
// This URL explains why shutdown is complicate: // This URL explains why shutdown is complicate:
...@@ -233,6 +239,7 @@ void AsyncGRPCServer::ShutDown() { ...@@ -233,6 +239,7 @@ void AsyncGRPCServer::ShutDown() {
void AsyncGRPCServer::TryToRegisterNewSendOne() { void AsyncGRPCServer::TryToRegisterNewSendOne() {
std::unique_lock<std::mutex> lock(cq_mutex_); std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) { if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewSendOne";
return; return;
} }
RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_, RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_,
...@@ -243,6 +250,7 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() { ...@@ -243,6 +250,7 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
void AsyncGRPCServer::TryToRegisterNewGetOne() { void AsyncGRPCServer::TryToRegisterNewGetOne() {
std::unique_lock<std::mutex> lock(cq_mutex_); std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) { if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewGetOne";
return; return;
} }
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_, RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_,
...@@ -253,6 +261,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { ...@@ -253,6 +261,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
void AsyncGRPCServer::TryToRegisterNewPrefetchOne() { void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
std::unique_lock<std::mutex> lock(cq_mutex_); std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) { if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewPrefetchOne";
return; return;
} }
RequestPrefetch* prefetch = RequestPrefetch* prefetch =
...@@ -270,25 +279,28 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, ...@@ -270,25 +279,28 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
void* tag = NULL; void* tag = NULL;
bool ok = false; bool ok = false;
while (true) { while (true) {
VLOG(3) << "HandleRequest for " << cq_name << " while in";
if (!cq->Next(&tag, &ok)) { if (!cq->Next(&tag, &ok)) {
LOG(INFO) << cq_name << " CompletionQueue shutdown!"; LOG(INFO) << cq_name << " CompletionQueue shutdown!";
break; break;
} }
VLOG(3) << "HandleRequest for " << cq_name << " while after Next";
PADDLE_ENFORCE(tag); PADDLE_ENFORCE(tag);
// FIXME(typhoonzero): de-couple the barriers with recv_op // FIXME(typhoonzero): de-couple the barriers with recv_op
if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1); if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1);
if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0); if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
RequestBase* base = (RequestBase*)tag; RequestBase* base = reinterpret_cast<RequestBase*>(tag);
// reference: // reference:
// https://github.com/tensorflow/tensorflow/issues/5596 // https://github.com/tensorflow/tensorflow/issues/5596
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
// https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I // https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
if (!ok) { if (!ok) {
LOG(WARNING) << cq_name << " recv no regular event:argument name" LOG(WARNING) << cq_name << " recv no regular event:argument name["
<< base->GetReqName(); << base->GetReqName() << "]";
TryToRegisterNewOne(); TryToRegisterNewOne();
delete base; delete base;
continue; continue;
......
...@@ -15,7 +15,8 @@ limitations under the License. */ ...@@ -15,7 +15,8 @@ limitations under the License. */
#pragma once #pragma once
#include <grpc++/grpc++.h> #include <grpc++/grpc++.h>
#include <thread> #include <string>
#include <utility>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
...@@ -93,6 +94,7 @@ class AsyncGRPCServer final { ...@@ -93,6 +94,7 @@ class AsyncGRPCServer final {
// received variable from RPC, operators fetch variable from this queue. // received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue<MessageWithName> var_get_queue_; SimpleBlockQueue<MessageWithName> var_get_queue_;
// client send variable to this queue.
ReceivedQueue var_recv_queue_; ReceivedQueue var_recv_queue_;
// condition of the sub program // condition of the sub program
......
...@@ -28,6 +28,7 @@ std::unique_ptr<detail::AsyncGRPCServer> rpc_service_; ...@@ -28,6 +28,7 @@ std::unique_ptr<detail::AsyncGRPCServer> rpc_service_;
void StartServer(const std::string& endpoint) { void StartServer(const std::string& endpoint) {
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
rpc_service_->RunSyncUpdate();
} }
TEST(PREFETCH, CPU) { TEST(PREFETCH, CPU) {
...@@ -39,13 +40,23 @@ TEST(PREFETCH, CPU) { ...@@ -39,13 +40,23 @@ TEST(PREFETCH, CPU) {
platform::CPUPlace place; platform::CPUPlace place;
platform::CPUDeviceContext ctx(place); platform::CPUDeviceContext ctx(place);
// create var on local scope // create var on local scope
std::string var_name("tmp_0"); std::string in_var_name("in");
auto var = scope.Var(var_name); std::string out_var_name("out");
auto tensor = var->GetMutable<framework::LoDTensor>(); auto* in_var = scope.Var(in_var_name);
tensor->Resize({10, 10}); auto* in_tensor = in_var->GetMutable<framework::LoDTensor>();
in_tensor->Resize({10, 10});
VLOG(3) << "before mutable_data";
in_tensor->mutable_data<int>(place);
scope.Var(out_var_name);
VLOG(3) << "before fetch";
detail::RPCClient client; detail::RPCClient client;
client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, var_name, ""); client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name,
out_var_name);
client.Wait();
rpc_service_->ShutDown();
server_thread.join(); server_thread.join();
rpc_service_.reset(nullptr); rpc_service_.reset(nullptr);
} }
...@@ -80,7 +80,7 @@ enum class GrpcMethod { ...@@ -80,7 +80,7 @@ enum class GrpcMethod {
}; };
static const int kGrpcNumMethods = static const int kGrpcNumMethods =
static_cast<int>(GrpcMethod::kGetVariable) + 1; static_cast<int>(GrpcMethod::kPrefetchVariable) + 1;
inline const char* GrpcMethodName(GrpcMethod id) { inline const char* GrpcMethodName(GrpcMethod id) {
switch (id) { switch (id) {
...@@ -89,7 +89,7 @@ inline const char* GrpcMethodName(GrpcMethod id) { ...@@ -89,7 +89,7 @@ inline const char* GrpcMethodName(GrpcMethod id) {
case GrpcMethod::kGetVariable: case GrpcMethod::kGetVariable:
return "/sendrecv.SendRecvService/GetVariable"; return "/sendrecv.SendRecvService/GetVariable";
case GrpcMethod::kPrefetchVariable: case GrpcMethod::kPrefetchVariable:
return "/sendrecv.SendREcvService/PrefetchVariable"; return "/sendrecv.SendRecvService/PrefetchVariable";
} }
// Shouldn't be reached. // Shouldn't be reached.
...@@ -117,5 +117,5 @@ class GrpcService final { ...@@ -117,5 +117,5 @@ class GrpcService final {
}; };
} // namespace detail } // namespace detail
} // namespace operator } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -13,22 +13,13 @@ See the License for the specific language governing permissions and ...@@ -13,22 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <stdint.h> #include <stdint.h>
#include <sys/stat.h>
#include <ostream> #include <ostream>
#include <thread>
#include <unistd.h>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_server.h" #include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/simple_block_queue.h"
#include "paddle/fluid/string/printf.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -111,6 +102,11 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -111,6 +102,11 @@ class ListenAndServOp : public framework::OperatorBase {
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
// TODO(qiao) set proper fields for table lookup and update
rpc_service_->SetExecutor(&executor);
rpc_service_->SetPrefetchBlkdId(0);
rpc_service_->SetProgram(program);
// TODO(typhoonzero): change this to a while_op for every cluster-batch. // TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false; bool exit_flag = false;
// Record received sparse variables, so that // Record received sparse variables, so that
...@@ -173,7 +169,8 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -173,7 +169,8 @@ class ListenAndServOp : public framework::OperatorBase {
} }
ParallelExecuteBlocks(parallel_blkids, &executor, program, &recv_scope); ParallelExecuteBlocks(parallel_blkids, &executor, program, &recv_scope);
VLOG(2) << "run all blocks spent (ms) " << detail::GetTimestamp() - ts; VLOG(3) << "run all blocks spent " << detail::GetTimestamp() - ts
<< "(ms)";
// Reset the received sparse variables, the sum operator would not // Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next // sum the input sparse variables which rows is empty at the next
......
...@@ -20,12 +20,29 @@ namespace paddle { ...@@ -20,12 +20,29 @@ namespace paddle {
namespace operators { namespace operators {
namespace reader { namespace reader {
static constexpr size_t kDoubleBufferSize = 2; // 'Double buffer' means we shall maintain two batches of input data at the same
// time. So the kCacheSize shoul be at least 2.
static constexpr size_t kCacheSize = 2;
// There will be two bacthes out of the channel during training:
// 1. the one waiting to be sent to the channel
// 2. the one just be received from the channel, which is also being used by
// subsequent operators.
// So the channel size should be kChacheSize - 2
static constexpr size_t kChannelSize = 0; // kCacheSize - 2
class DoubleBufferReader : public framework::DecoratedReader { class DoubleBufferReader : public framework::DecoratedReader {
public: public:
struct Item { struct Item {
Item() : ctx_(nullptr) {} Item() : ctx_(nullptr) {}
Item(Item&& b) {
payloads_ = std::move(b.payloads_);
ctx_ = std::move(b.ctx_);
}
Item& operator=(Item&& b) {
payloads_ = std::move(b.payloads_);
ctx_ = std::move(b.ctx_);
return *this;
}
std::vector<framework::LoDTensor> payloads_; std::vector<framework::LoDTensor> payloads_;
platform::DeviceContext* ctx_; platform::DeviceContext* ctx_;
...@@ -34,42 +51,44 @@ class DoubleBufferReader : public framework::DecoratedReader { ...@@ -34,42 +51,44 @@ class DoubleBufferReader : public framework::DecoratedReader {
explicit DoubleBufferReader( explicit DoubleBufferReader(
ReaderBase* reader, platform::Place target_place = platform::CPUPlace()) ReaderBase* reader, platform::Place target_place = platform::CPUPlace())
: DecoratedReader(reader), place_(target_place) { : DecoratedReader(reader), place_(target_place) {
for (size_t i = 0; i < kDoubleBufferSize; ++i) {
if (platform::is_gpu_place(place_)) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
for (size_t i = 0; i < kCacheSize; ++i) {
if (platform::is_gpu_place(place_)) {
ctxs_.emplace_back(new platform::CUDADeviceContext( ctxs_.emplace_back(new platform::CUDADeviceContext(
boost::get<platform::CUDAPlace>(place_))); boost::get<platform::CUDAPlace>(place_)));
#endif
} }
} }
#endif
start_thread(); StartPrefetcher();
}
void start_thread() {
buffer_ = framework::MakeChannel<Item>(kDoubleBufferSize);
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
} }
bool HasNext() const override;
void ReadNext(std::vector<framework::LoDTensor>* out) override; void ReadNext(std::vector<framework::LoDTensor>* out) override;
void ReInit() override; void ReInit() override;
~DoubleBufferReader() { ~DoubleBufferReader() { EndPrefetcher(); }
buffer_->Close();
prefetcher_.join(); private:
delete buffer_; void StartPrefetcher() {
channel_ = framework::MakeChannel<Item>(kChannelSize);
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
} }
bool HasNext() const override; void EndPrefetcher() {
channel_->Close();
if (prefetcher_.joinable()) {
prefetcher_.join();
}
delete channel_;
channel_ = nullptr;
}
private:
void PrefetchThreadFunc(); void PrefetchThreadFunc();
std::thread prefetcher_; std::thread prefetcher_;
framework::Channel<Item>* buffer_; framework::Channel<Item>* channel_;
platform::Place place_; platform::Place place_;
std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_; std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_;
mutable Item local_buffer_;
}; };
class CreateDoubleBufferReaderOp : public framework::OperatorBase { class CreateDoubleBufferReaderOp : public framework::OperatorBase {
...@@ -123,70 +142,70 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -123,70 +142,70 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
} }
}; };
bool DoubleBufferReader::HasNext() const {
while (!channel_->IsClosed() && !channel_->CanReceive()) {
}
return channel_->CanReceive();
}
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) { void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
if (!HasNext()) { if (!HasNext()) {
PADDLE_THROW("There is no next data!"); PADDLE_THROW("There is no next data!");
} }
if (local_buffer_.payloads_.empty()) { Item batch;
buffer_->Receive(&local_buffer_); channel_->Receive(&batch);
} *out = batch.payloads_;
*out = local_buffer_.payloads_; if (batch.ctx_) {
local_buffer_.payloads_.clear(); batch.ctx_->Wait();
if (local_buffer_.ctx_) {
local_buffer_.ctx_->Wait();
} }
} }
void DoubleBufferReader::ReInit() { void DoubleBufferReader::ReInit() {
reader_->ReInit(); reader_->ReInit();
buffer_->Close(); EndPrefetcher();
prefetcher_.join(); StartPrefetcher();
delete buffer_;
start_thread();
} }
void DoubleBufferReader::PrefetchThreadFunc() { void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "A new prefetch thread starts."; VLOG(5) << "A new prefetch thread starts.";
size_t gpu_ctx_offset = 0; std::vector<std::vector<framework::LoDTensor>> cpu_tensor_cache(kCacheSize);
std::vector<std::vector<framework::LoDTensor>> gpu_tensor_cache(kCacheSize);
size_t cached_tensor_id = 0;
while (reader_->HasNext()) { while (reader_->HasNext()) {
Item batch; Item batch;
reader_->ReadNext(&batch.payloads_); auto& cpu_batch = cpu_tensor_cache[cached_tensor_id];
reader_->ReadNext(&cpu_batch);
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
std::vector<framework::LoDTensor> gpu_batch; auto& gpu_batch = gpu_tensor_cache[cached_tensor_id];
auto& gpu_ctx = this->ctxs_[gpu_ctx_offset++]; auto* gpu_ctx = ctxs_[cached_tensor_id].get();
gpu_ctx_offset %= this->ctxs_.size(); gpu_batch.resize(cpu_batch.size());
gpu_batch.resize(batch.payloads_.size()); for (size_t i = 0; i < cpu_batch.size(); ++i) {
for (size_t i = 0; i < batch.payloads_.size(); ++i) { framework::TensorCopy(cpu_batch[i], place_, *gpu_ctx, &gpu_batch[i]);
framework::TensorCopy(batch.payloads_[i], place_, *gpu_ctx, gpu_batch[i].set_lod(cpu_batch[i].lod());
&gpu_batch[i]);
gpu_batch[i].set_lod(batch.payloads_[i].lod());
} }
batch.ctx_ = gpu_ctx.get(); batch.payloads_ = gpu_batch;
std::swap(gpu_batch, batch.payloads_); batch.ctx_ = gpu_ctx;
} else {
// CPUPlace
batch.payloads_ = cpu_batch;
} }
++cached_tensor_id;
cached_tensor_id %= kCacheSize;
try { try {
buffer_->Send(&batch); channel_->Send(&batch);
} catch (paddle::platform::EnforceNotMet e) { } catch (paddle::platform::EnforceNotMet e) {
VLOG(5) << "WARNING: The double buffer channel has been closed. The " VLOG(5) << "WARNING: The double buffer channel has been closed. The "
"prefetch thread will terminate."; "prefetch thread will terminate.";
break; break;
} }
} }
buffer_->Close(); channel_->Close();
VLOG(5) << "Prefetch thread terminates."; VLOG(5) << "Prefetch thread terminates.";
} }
bool DoubleBufferReader::HasNext() const {
if (local_buffer_.payloads_.empty()) {
bool ok = buffer_->Receive(&local_buffer_);
return ok;
} else {
return true;
}
}
} // namespace reader } // namespace reader
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -276,20 +276,25 @@ class DistributeTranspiler: ...@@ -276,20 +276,25 @@ class DistributeTranspiler:
suff_idx = v.name.find(".trainer_") suff_idx = v.name.find(".trainer_")
if suff_idx >= 0: if suff_idx >= 0:
orig_var_name = v.name[:suff_idx] orig_var_name = v.name[:suff_idx]
pserver_program.global_block().create_var( else:
orig_var_name = v.name
single_trainer_var = pserver_program.global_block().create_var(
name=orig_var_name, name=orig_var_name,
persistable=True, persistable=True,
type=v.type, type=v.type,
dtype=v.dtype, dtype=v.dtype,
shape=v.shape) shape=v.shape)
for trainer_id in xrange(self.trainers): if self.trainers > 1:
var = pserver_program.global_block().create_var( for trainer_id in xrange(self.trainers):
name="%s.trainer_%d" % (orig_var_name, trainer_id), var = pserver_program.global_block().create_var(
persistable=False, name="%s.trainer_%d" % (orig_var_name, trainer_id),
type=v.type, persistable=False,
dtype=v.dtype, type=v.type,
shape=v.shape) dtype=v.dtype,
recv_inputs.append(var) shape=v.shape)
recv_inputs.append(var)
else:
recv_inputs.append(single_trainer_var)
# step3 # step3
optimize_block = pserver_program.create_block(0) optimize_block = pserver_program.create_block(0)
...@@ -511,8 +516,11 @@ class DistributeTranspiler: ...@@ -511,8 +516,11 @@ class DistributeTranspiler:
def _append_split_op(self, program, gradblocks): def _append_split_op(self, program, gradblocks):
# Split variables that need to be split and append respective ops # Split variables that need to be split and append respective ops
add_suffix = False
if self.trainers > 1:
add_suffix = True
var_mapping = self._create_vars_from_blocklist( var_mapping = self._create_vars_from_blocklist(
program, gradblocks, add_trainer_suffix=True) program, gradblocks, add_trainer_suffix=add_suffix)
for varname, splited_vars in var_mapping.iteritems(): for varname, splited_vars in var_mapping.iteritems():
# variable that don't need to split have empty splited_vars # variable that don't need to split have empty splited_vars
if len(splited_vars) <= 1: if len(splited_vars) <= 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册