未验证 提交 22ea4c30 编写于 作者: T tianshuo78520a 提交者: GitHub

Delete grpc.cmake/distribeted/distributed_ops (#32166)

* Delete grpc.cmake/distribeted/distributed_ops

* reset operators/CMakeLists.txt

* rm test_transpiler_ops.py

* del test_transpiler_ops.py
上级 995b5f2c
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
include (ExternalProject)
SET(GRPC_SOURCES_DIR ${THIRD_PARTY_PATH}/grpc)
SET(GRPC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/grpc)
SET(GRPC_INCLUDE_DIR "${GRPC_INSTALL_DIR}/include/" CACHE PATH "grpc include directory." FORCE)
SET(GRPC_CPP_PLUGIN "${GRPC_INSTALL_DIR}/bin/grpc_cpp_plugin" CACHE FILEPATH "GRPC_CPP_PLUGIN" FORCE)
include(ProcessorCount)
ProcessorCount(NUM_OF_PROCESSOR)
IF(APPLE)
SET(BUILD_CMD make -n HAS_SYSTEM_PROTOBUF=false -s -j ${NUM_OF_PROCESSOR} static grpc_cpp_plugin | sed "s/-Werror//g" | sh)
SET(GRPC_INSTALL_CMD make prefix=${GRPC_INSTALL_DIR} install)
ELSE()
SET(GRPC_CFLAGS "-Wno-error -std=c11 ${CLFAGS}")
SET(GRPC_CXXFLAGS "-Wno-error -std=c++11 ${CXXFLAGS}")
SET(BUILD_CMD make CFLAGS=${GRPC_CFLAGS} CXXFLAGS=${GRPC_CXXFLAGS} HAS_SYSTEM_PROTOBUF=false -s -j ${NUM_OF_PROCESSOR} static grpc_cpp_plugin)
SET(GRPC_INSTALL_CMD make prefix=${GRPC_INSTALL_DIR} install CFLAGS=${GRPC_CFLAGS} CXXFLAGS=${GRPC_CXXFLAGS})
ENDIF()
# FIXME(wuyi): do not build zlib cares protobuf twice, find a way to build grpc with them
ExternalProject_Add(
extern_grpc
DEPENDS protobuf zlib
# NOTE(wuyi):
# this package is generated by following steps:
# 1. git clone -b v1.8.x https://github.com/grpc/grpc.git
# 2. git submodule update --init
# 3. keep only zlib, cares, protobuf, boringssl under "third_party",
# checkout and clean other dirs under third_party
# 4. remove .git, and package the directory.
URL http://paddlepaddledeps.bj.bcebos.com/grpc-v1.10.x_paddle.tar.gz
URL_MD5 f5442d137ddccee252e194b1bc90f98c
PREFIX ${GRPC_SOURCES_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_IN_SOURCE 1
# NOTE(yuyang18):
# Disable -Werror, otherwise the compile will fail in MacOS.
# It seems that we cannot configure that by make command.
# Just dry run make command and remove `-Werror`, then use a shell to run make commands
BUILD_COMMAND ${BUILD_CMD}
INSTALL_COMMAND ${GRPC_INSTALL_CMD}
)
ADD_LIBRARY(grpc++_unsecure STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET grpc++_unsecure PROPERTY IMPORTED_LOCATION
"${GRPC_INSTALL_DIR}/lib/libgrpc++_unsecure.a")
ADD_LIBRARY(grpc++ STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET grpc++ PROPERTY IMPORTED_LOCATION
"${GRPC_INSTALL_DIR}/lib/libgrpc++.a")
ADD_LIBRARY(gpr STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET gpr PROPERTY IMPORTED_LOCATION
"${GRPC_INSTALL_DIR}/lib/libgpr.a")
ADD_LIBRARY(grpc_unsecure STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET grpc_unsecure PROPERTY IMPORTED_LOCATION
"${GRPC_INSTALL_DIR}/lib/libgrpc_unsecure.a")
include_directories(${GRPC_INCLUDE_DIR})
ADD_DEPENDENCIES(grpc++_unsecure extern_grpc)
......@@ -15,7 +15,7 @@ limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/operators/distributed_ops/allreduce_op.h"
#include "paddle/fluid/operators/collective/allreduce_op.h"
namespace paddle {
namespace operators {
......
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/distributed_ops/allreduce_op.h"
#include "paddle/fluid/operators/collective/allreduce_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
......
return()
if(WITH_GRPC)
set(cc_generic_services "false")
else()
set(cc_generic_services "true")
endif()
configure_file(send_recv.proto.in ${CMAKE_CURRENT_SOURCE_DIR}/send_recv.proto @ONLY)
cc_library(async_sparse_param_update_recorder SRCS async_sparse_param_update_recorder.cc DEPS enforce simple_threadpool)
cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_recorder_test.cc DEPS async_sparse_param_update_recorder)
cc_library(heart_beat_monitor SRCS heart_beat_monitor.cc DEPS enforce simple_threadpool)
cc_library(large_scale_kv SRCS large_scale_kv.cc DEPS enforce simple_threadpool device_context)
cc_test(heart_beat_monitor_test SRCS heart_beat_monitor_test.cc DEPS heart_beat_monitor)
# FIXME(typhoonzero): use add_subdirectory once we clean the dependency of these files
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
if(WITH_GRPC)
set(GRPC_DEPS grpc++_unsecure grpc_unsecure gpr zlib protobuf)
set(GRPC_SRCS grpc/grpc_client.cc grpc/grpc_server.cc grpc/grpc_serde.cc grpc/grpc_bytebuffer_stream.cc grpc/grpc_variable_response.cc)
grpc_library(sendrecvop_rpc SRCS sendrecvop_utils.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc
variable_response.cc
collective_client.cc collective_server.cc
${GRPC_SRCS}
PROTO send_recv.proto
DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS} async_sparse_param_update_recorder heart_beat_monitor large_scale_kv)
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set(RPC_DEPS sendrecvop_rpc ${GRPC_DEPS})
cc_test(grpc_serde_test SRCS grpc/grpc_serde_test.cc
DEPS ${RPC_DEPS} scope profiler math_function)
else()
set(BRPC_SRCS brpc/brpc_client.cc brpc/brpc_server.cc brpc/brpc_sendrecvop_utils.cc brpc/brpc_variable_response.cc brpc/brpc_rdma_pool.cc)
set_source_files_properties(${BRPC_SRCS} parameter_prefetch.cc parameter_send.cc parameter_recv.cc communicator.cc rpc_server_test.cc brpc/brpc_serde_test.cc collective_server.cc collective_server_test.cc collective_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set(BRPC_DEPS brpc ssl crypto protobuf leveldb zlib)
brpc_library(sendrecvop_rpc SRCS sendrecvop_utils.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc
variable_response.cc
collective_client.cc collective_server.cc
${BRPC_SRCS}
PROTO send_recv.proto
DEPS lod_tensor selected_rows memory scope ${BRPC_DEPS})
set(RPC_DEPS sendrecvop_rpc ${BRPC_DEPS})
cc_test(brpc_serde_test SRCS brpc/brpc_serde_test.cc
DEPS ${RPC_DEPS} gflags glog executor proto_desc lookup_sparse_table_read_op)
endif()
cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS ${RPC_DEPS} executor scope proto_desc lookup_sparse_table_read_op checkpoint_notify_op scale_op )
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory node)
cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_recv SRCS parameter_recv.cc DEPS sendrecvop_rpc memory)
cc_library(communicator SRCS communicator.cc DEPS scope selected_rows tensor variable_helper selected_rows_functor simple_threadpool parameter_send parameter_recv generator)
cc_test(communicator_test SRCS communicator_test.cc DEPS communicator)
if(WITH_GPU OR WITH_ROCM)
cc_test(collective_server_test SRCS collective_server_test.cc
DEPS sendrecvop_rpc executor ${RPC_DEPS}
selected_rows_functor scope math_function)
endif()
if(WITH_TESTING)
if(TEST rpc_server_test)
set_tests_properties(rpc_server_test PROPERTIES TIMEOUT 120)
endif()
if(TEST heart_beat_monitor_test)
set_tests_properties(heart_beat_monitor_test PROPERTIES TIMEOUT 120)
endif()
endif()
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
namespace paddle {
namespace operators {
namespace distributed {
std::once_flag AsyncSparseParamUpdateRecorder::init_flag_;
std::unique_ptr<AsyncSparseParamUpdateRecorder>
AsyncSparseParamUpdateRecorder::recorder_(nullptr);
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace distributed {
class ConcurrentSet {
public:
ConcurrentSet() : pool_(new ::ThreadPool(1)) {}
~ConcurrentSet() {}
std::future<void> Update(const std::vector<int64_t>& rows) {
auto task = [this, rows] {
if (VLOG_IS_ON(3)) {
std::ostringstream sstream;
sstream << "[";
for (auto& id : rows) {
sstream << id << ", ";
}
sstream << "]";
VLOG(3) << "update ids -> " << sstream.str();
}
for (auto row : rows) {
set_.insert(row);
}
};
return pool_->enqueue(std::move(task));
}
std::future<void> GetAndClear(std::vector<int64_t>* result) {
auto task = [this, &result] {
result->clear();
for (auto& id : set_) {
result->push_back(id);
}
if (VLOG_IS_ON(3)) {
std::ostringstream sstream;
sstream << "[";
for (auto& id : *result) {
sstream << id << ", ";
}
sstream << "]";
VLOG(3) << "result ids size: " << result->size() << " "
<< sstream.str();
}
set_.clear();
};
return pool_->enqueue(std::move(task));
}
private:
std::unordered_set<int64_t> set_;
std::unique_ptr<::ThreadPool> pool_{nullptr};
};
class AsyncSparseParamUpdateRecorder {
using TrainerToRows = std::vector<std::unique_ptr<ConcurrentSet>>;
public:
AsyncSparseParamUpdateRecorder(
int trainer_num,
const std::unordered_map<std::string, std::string>& grad_to_param)
: trainer_num_(trainer_num), grad_to_param_(grad_to_param) {
if (VLOG_IS_ON(3)) {
std::ostringstream sstream;
sstream << "[";
for (auto& item : grad_to_param) {
sstream << item.first << ":" << item.second << ", ";
}
sstream << "]";
VLOG(3) << "trainer_num: " << trainer_num
<< " grad_to_param_: " << sstream.str();
}
for (auto& iter : grad_to_param) {
param_to_grad_[iter.second] = iter.first;
auto& param_name = iter.second;
param_to_updated_rows_[param_name] = TrainerToRows();
auto& trainer_to_rows = param_to_updated_rows_[param_name];
for (auto i = 0; i < trainer_num; ++i) {
trainer_to_rows.emplace_back(new ConcurrentSet());
}
}
}
~AsyncSparseParamUpdateRecorder() = default;
void Update(const std::string& grad_name,
const std::vector<int64_t>& update_rows) {
VLOG(3) << "update grad: " << grad_name
<< " row size: " << update_rows.size();
auto& param_name = grad_to_param_.at(grad_name);
auto& trainer_to_rows = param_to_updated_rows_.at(param_name);
std::vector<std::future<void>> fs;
for (auto& set : trainer_to_rows) {
fs.push_back(set->Update(update_rows));
}
for (auto& f : fs) {
f.wait();
}
}
void GetAndClear(const std::string& param_name, int trainer_id,
std::vector<int64_t>* result) {
VLOG(3) << "GetAndClear param: " << param_name
<< " for trainer: " << trainer_id;
PADDLE_ENFORCE_LT(
trainer_id, trainer_num_,
platform::errors::InvalidArgument(
"The value of trainer_id: %s should less than trainer_num: %s.",
trainer_id, trainer_num_));
param_to_updated_rows_.at(param_name)[trainer_id]
->GetAndClear(result)
.wait();
}
bool HasParam(const std::string& param_name) {
return param_to_grad_.find(param_name) != param_to_grad_.end();
}
bool HasGrad(const std::string& grad_name) {
return grad_to_param_.find(grad_name) != grad_to_param_.end();
}
private:
const int trainer_num_;
std::unordered_map<std::string, std::string> grad_to_param_;
std::unordered_map<std::string, std::string> param_to_grad_;
std::unordered_map<std::string, TrainerToRows> param_to_updated_rows_;
// init recorder
public:
static void Init(
int trainer_num,
const std::unordered_map<std::string, std::string>& grad_to_param) {
InitImpl(trainer_num, grad_to_param);
}
static AsyncSparseParamUpdateRecorder* GetInstance() {
return recorder_.get();
}
private:
// Init is called by GetInstance.
static void InitImpl(
int trainer_num,
const std::unordered_map<std::string, std::string>& grad_to_param) {
if (recorder_ == nullptr) {
recorder_.reset(
new AsyncSparseParamUpdateRecorder(trainer_num, grad_to_param));
}
}
static std::once_flag init_flag_;
static std::unique_ptr<AsyncSparseParamUpdateRecorder> recorder_;
};
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include <algorithm>
#include "gtest/gtest.h"
namespace paddle {
namespace operators {
namespace distributed {
TEST(ConcurrentSet, All) {
ConcurrentSet concurrent_set;
std::vector<int64_t> in1 = {1, 2, 3, 4};
std::vector<int64_t> in2 = {2, 3, 5, 6};
std::vector<std::future<void>> futures;
futures.push_back(concurrent_set.Update(in1));
futures.push_back(concurrent_set.Update(in2));
for (auto &f : futures) {
f.wait();
}
std::unordered_set<int64_t> in;
std::copy(in1.begin(), in1.end(), std::inserter(in, in.begin()));
std::copy(in2.begin(), in2.end(), std::inserter(in, in.begin()));
std::vector<int64_t> ret;
concurrent_set.GetAndClear(&ret).wait();
std::unordered_set<int64_t> out;
std::copy(ret.begin(), ret.end(), std::inserter(out, out.begin()));
EXPECT_EQ(in, out);
concurrent_set.GetAndClear(&ret).wait();
EXPECT_EQ(ret.size(), 0UL);
}
TEST(AsyncSparseParamUpdateRecorder, All) {
std::unordered_map<std::string, std::string> grad_to_param;
grad_to_param["grad1"] = "param1";
grad_to_param["grad2"] = "param2";
int trainer_num = 10;
AsyncSparseParamUpdateRecorder recorder(trainer_num, grad_to_param);
std::vector<int64_t> in1 = {1, 2, 3, 4};
std::vector<int64_t> in2 = {2, 3, 5, 6};
std::unordered_set<int64_t> in;
std::copy(in1.begin(), in1.end(), std::inserter(in, in.begin()));
std::copy(in2.begin(), in2.end(), std::inserter(in, in.begin()));
recorder.Update("grad1", in1);
recorder.Update("grad1", in2);
EXPECT_TRUE(recorder.HasParam("param1"));
EXPECT_TRUE(recorder.HasParam("param2"));
EXPECT_FALSE(recorder.HasParam("param3"));
EXPECT_TRUE(recorder.HasGrad("grad1"));
EXPECT_TRUE(recorder.HasGrad("grad2"));
EXPECT_FALSE(recorder.HasGrad("grad3"));
std::vector<int64_t> ret;
EXPECT_ANY_THROW(recorder.GetAndClear("param1", trainer_num, &ret));
for (int i = 0; i < trainer_num; ++i) {
std::vector<int64_t> ret;
std::unordered_set<int64_t> out;
recorder.GetAndClear("param1", i, &ret);
std::copy(ret.begin(), ret.end(), std::inserter(out, out.begin()));
EXPECT_EQ(in, out);
recorder.GetAndClear("param1", i, &ret);
EXPECT_EQ(ret.size(), 0UL);
}
}
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/brpc/brpc_client.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
namespace distributed {
DEFINE_int32(timeout_ms, 30000, "RPC timeout in milliseconds");
DEFINE_int32(max_retry, 3, "Max retries(not including the first RPC)");
BRPCClient::~BRPCClient() { Wait(); }
void HandleSendResponse(brpc::Controller* cntl, sendrecv::VoidMessage* response,
VarHandlePtr var_h, ChannelQueuePtr ch_ptr,
ChannelContextPtr ch_ctx, BRPCClient* cls) {
// std::unique_ptr makes sure cntl/response will be deleted before returning.
std::unique_ptr<brpc::Controller> cntl_guard(cntl);
std::unique_ptr<sendrecv::VoidMessage> response_guard(response);
// this channel can be used by other now.
ch_ptr->Push(ch_ctx);
if (cntl->Failed()) {
PADDLE_THROW(platform::errors::Unavailable(
"Failed to send variable %s, error text is %s.", var_h->name(),
cntl->ErrorText()));
var_h->Finish(false);
cls->DecreaseReqCount();
return;
}
var_h->Finish(true);
cls->DecreaseReqCount();
VLOG(4) << "HandleSendResponse from: " << cntl->remote_side()
<< ", varname: " << var_h->name()
<< ", latency: " << cntl->latency_us() << "us";
VLOG(4) << "Finish HandleSendResponse";
}
VarHandlePtr BRPCClient::AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope;
const auto ch_ptr = GetChannel(ep_val);
const std::string method = kSendRPC;
VarHandlePtr var_h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
framework::AsyncIO([=] {
auto ch_ctx = ch_ptr->Pop();
brpc::Controller* cntl = new brpc::Controller();
sendrecv::VoidMessage* response = new sendrecv::VoidMessage();
cntl->set_timeout_ms(time_out);
auto* var = p_scope->FindVar(var_name_val);
sendrecv::VariableMessage request;
distributed::SerializeToIOBuf(var_name_val, var, *p_ctx, &request,
&cntl->request_attachment(), "", false,
trainer_id_);
google::protobuf::Closure* done = brpc::NewCallback(
&HandleSendResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
platform::RecordRPCEvent record_event(method);
ch_ctx->stub->SendVariable(cntl, &request, response, done);
if (UNLIKELY(platform::IsProfileEnabled())) {
var_h->Wait();
}
});
req_count_++;
return var_h;
}
void HandleFetchBarrierResponse(brpc::Controller* cntl,
sendrecv::VariableMessage* response,
VarHandlePtr var_h, ChannelQueuePtr ch_ptr,
ChannelContextPtr ch_ctx, BRPCClient* cls) {
// std::unique_ptr makes sure cntl/response will be deleted before returning.
std::unique_ptr<brpc::Controller> cntl_guard(cntl);
std::unique_ptr<sendrecv::VariableMessage> response_guard(response);
// this channel can be used other now.
ch_ptr->Push(ch_ctx);
if (cntl->Failed()) {
PADDLE_THROW(platform::errors::Unavailable(
"Failed to get HandleFetchBarrierResponse %s, error text is %s.",
var_h->name(), cntl->ErrorText()));
var_h->Finish(false);
cls->DecreaseReqCount();
return;
}
var_h->Finish(true);
cls->DecreaseReqCount();
VLOG(4) << "HandleFetchBarrierResponse from: " << cntl->remote_side()
<< ", varname: " << var_h->name()
<< ", latency: " << cntl->latency_us() << "us";
VLOG(4) << "Finish HandleFetchBarrierResponse";
}
void HandleGetResponse(brpc::Controller* cntl,
sendrecv::VariableMessage* response, VarHandlePtr var_h,
ChannelQueuePtr ch_ptr, ChannelContextPtr ch_ctx,
BRPCClient* cls) {
// std::unique_ptr makes sure cntl/response will be deleted before returning.
std::unique_ptr<brpc::Controller> cntl_guard(cntl);
std::unique_ptr<sendrecv::VariableMessage> response_guard(response);
// this channel can be used other now.
ch_ptr->Push(ch_ctx);
if (cntl->Failed()) {
PADDLE_THROW(platform::errors::Unavailable(
"Failed to get variable %s, error text is %s.", var_h->name(),
cntl->ErrorText()));
cls->DecreaseReqCount();
var_h->Finish(false);
return;
}
VLOG(4) << "HandleGetResponse from: " << cntl->remote_side()
<< ", varname: " << var_h->name()
<< ", latency: " << cntl->latency_us() << "us";
framework::Variable* outvar = nullptr;
int trainer_id;
distributed::DeserializeFromIOBuf(*response, cntl->response_attachment(),
*var_h->ctx(), var_h->scope(), &outvar,
&trainer_id);
VLOG(4) << "Finish HandleGetResponse";
cls->DecreaseReqCount();
var_h->Finish(true);
}
VarHandlePtr BRPCClient::_AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& out_var_name,
const std::string& method_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string var_name_val = var_name;
const std::string out_varname_val = out_var_name;
const framework::Scope* p_scope = &scope;
const auto ch_ptr = GetChannel(ep_val);
const std::string method = kGetRPC;
VarHandlePtr var_h(
new VarHandle(ep, method, out_varname_val, p_ctx, p_scope));
framework::AsyncIO([=] {
auto ch_ctx = ch_ptr->Pop();
brpc::Controller* cntl = new brpc::Controller();
sendrecv::VariableMessage* response = new sendrecv::VariableMessage();
cntl->set_timeout_ms(time_out);
sendrecv::VariableMessage req;
req.set_varname(var_name_val);
req.set_out_varname(out_varname_val);
req.set_trainer_id(trainer_id_);
google::protobuf::Closure* done = brpc::NewCallback(
&HandleGetResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
platform::RecordRPCEvent record_event(method);
if (method_name == kGetMonomerRPC) {
ch_ctx->stub->GetMonomerVariable(cntl, &req, response, done);
} else if (method_name == kGetNoBarrierRPC) {
ch_ctx->stub->GetVariableNoBarrier(cntl, &req, response, done);
} else {
ch_ctx->stub->GetVariable(cntl, &req, response, done);
}
if (UNLIKELY(platform::IsProfileEnabled())) {
var_h->Wait();
}
});
req_count_++;
return var_h;
}
VarHandlePtr BRPCClient::AsyncGetVarNoBarrier(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
const std::string& out_var_name, int64_t time_out) {
std::string var_name_no_barrier =
string::Sprintf("%s%s", var_name, WITHOUT_BARRIER_MESSAGE);
return _AsyncGetVar(ep, ctx, scope, var_name_no_barrier, out_var_name,
kGetNoBarrierRPC, time_out);
}
VarHandlePtr BRPCClient::AsyncGetMonomerVariable(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out) {
return _AsyncGetVar(ep, ctx, scope, var_name, var_name, kGetMonomerRPC,
time_out);
}
VarHandlePtr BRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
const std::string& var_name,
int64_t time_out) {
return AsyncSendMessage(ep, kSendMonomerFetchBarrierRPC, var_name, time_out);
}
VarHandlePtr BRPCClient::AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& out_var_name,
const std::string& table_name,
int64_t time_out) {
return _AsyncGetVar(ep, ctx, scope, var_name, out_var_name, kGetRPC,
time_out);
}
VarHandlePtr BRPCClient::AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
const std::string& table_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string in_var_name_val = in_var_name;
const std::string out_var_name_val = out_var_name;
const std::string table_name_val = table_name;
const framework::Scope* p_scope = &scope;
const auto ch_ptr = GetChannel(ep_val);
const std::string method = kPrefetchRPC;
VarHandlePtr var_h(
new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
framework::AsyncIO([=] {
auto ch_ctx = ch_ptr->Pop();
brpc::Controller* cntl = new brpc::Controller();
sendrecv::VariableMessage* response = new sendrecv::VariableMessage();
cntl->set_timeout_ms(time_out);
auto* var = p_scope->FindVar(in_var_name_val);
sendrecv::VariableMessage req;
distributed::SerializeToIOBuf(in_var_name_val, var, *p_ctx, &req,
&cntl->request_attachment(), out_var_name_val,
false, 0, table_name_val);
platform::RecordRPCEvent record_event(method);
google::protobuf::Closure* done = brpc::NewCallback(
&HandleGetResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
ch_ctx->stub->PrefetchVariable(cntl, &req, response, done);
if (UNLIKELY(platform::IsProfileEnabled())) {
var_h->Wait();
}
});
req_count_++;
return var_h;
}
VarHandlePtr BRPCClient::AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out) {
return AsyncSendMessage(ep, kBatchBarrierRPC, BATCH_BARRIER_MESSAGE,
time_out);
}
VarHandlePtr BRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out) {
auto ch_ptr = GetChannel(ep);
auto ch_ctx = ch_ptr->Pop();
brpc::Controller* cntl = new brpc::Controller();
sendrecv::VariableMessage* response = new sendrecv::VariableMessage();
cntl->set_timeout_ms(time_out);
sendrecv::VariableMessage req;
req.set_varname(FETCH_BARRIER_MESSAGE);
const std::string method = kFetchBarrierRPC;
// var handle
VarHandlePtr var_h(
new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
platform::RecordRPCEvent record_event(method);
google::protobuf::Closure* done = brpc::NewCallback(
&HandleFetchBarrierResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
ch_ctx->stub->GetVariable(cntl, &req, response, done);
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
var_h->Wait();
}
return var_h;
}
bool BRPCClient::Wait() {
VLOG(9) << "begin to brpcclient wait";
{
std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return req_count_ == 0; });
}
VLOG(9) << "end to brpcclient wait";
return true;
}
ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
VLOG(4) << "begin to GetChannel:" << ep;
{
std::lock_guard<std::mutex> guard(chan_mutex_);
auto it = channels_.find(ep);
if (it != channels_.end()) {
VLOG(4) << "end to GetChannel:" << ep;
return it->second;
}
}
ChannelQueuePtr q(new framework::BlockingQueue<ChannelContextPtr>());
brpc::ChannelOptions options;
#ifdef PADDLE_WITH_BRPC_RDMA
options.use_rdma = true;
#endif
options.protocol = "baidu_std";
// don't use pooled type. the server can't afford that.
options.connection_type = "single";
options.connect_timeout_ms = 1000;
options.timeout_ms = FLAGS_timeout_ms /*milliseconds*/;
options.max_retry = FLAGS_max_retry;
VLOG(1) << "create " << brpc_channel_num_per_server_
<< " brpc channels to pserver:" << ep;
for (int i = 0; i < brpc_channel_num_per_server_; ++i) {
std::shared_ptr<ChannelContext> c(new ChannelContext());
if (c->channel.Init(ep.c_str(), &options) != 0) {
PADDLE_THROW(
platform::errors::Unavailable("Failed to initialize channel."));
return nullptr;
}
c->stub.reset(new sendrecv::SendRecvService_Stub(
static_cast<google::protobuf::RpcChannel*>(&c->channel)));
q->Push(c);
}
{
std::lock_guard<std::mutex> guard(chan_mutex_);
channels_[ep] = q;
}
VLOG(4) << "end to GetChannel:" << ep;
return q;
}
VarHandlePtr BRPCClient::AsyncSendComplete(const std::string& ep,
int64_t time_out) {
return AsyncSendMessage(ep, kSendCompleteRPC, COMPLETE_MESSAGE, time_out);
}
void BRPCClient::SendComplete() {
for (auto& kv : channels_) {
AsyncSendComplete(kv.first);
}
}
VarHandlePtr BRPCClient::AsyncSendVarMessage(
const std::string& ep, const std::string& method_name,
const sendrecv::VariableMessage& req, int64_t time_out) {
auto ch_ptr = GetChannel(ep);
auto ch_ctx = ch_ptr->Pop();
brpc::Controller* cntl = new brpc::Controller();
sendrecv::VoidMessage* response = new sendrecv::VoidMessage();
cntl->set_timeout_ms(time_out);
platform::RecordRPCEvent record_event(method_name);
VarHandlePtr var_h(
new VarHandle(ep, method_name, req.varname(), nullptr, nullptr));
google::protobuf::Closure* done = brpc::NewCallback(
&HandleSendResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
if (method_name == kCheckPointNotifyRPC) {
ch_ctx->stub->CheckpointNotify(cntl, &req, response, done);
} else if (method_name == kSendMonomerFetchBarrierRPC) {
ch_ctx->stub->GetMonomerBarrier(cntl, &req, response, done);
} else {
ch_ctx->stub->SendVariable(cntl, &req, response, done);
}
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
var_h->Wait();
}
return var_h;
}
VarHandlePtr BRPCClient::AsyncSendMessage(const std::string& ep,
const std::string& method_name,
const std::string& message,
int64_t time_out) {
sendrecv::VariableMessage req;
req.set_varname(message);
return AsyncSendVarMessage(ep, method_name, req, time_out);
}
VarHandlePtr BRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dirname,
const std::string& varname,
const int mode,
int64_t time_out) {
sendrecv::VariableMessage req;
req.set_varname(varname);
req.set_out_varname(dirname);
return AsyncSendVarMessage(ep, "CheckPointNotifyRPC", req, time_out);
}
} // namespace distributed
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <time.h>
#include <chrono> // NOLINT
#include <ctime>
#include <functional>
#include <iostream>
#include <map>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <unordered_map>
#include <vector>
#include "brpc/channel.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/distributed_pb.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle {
namespace operators {
namespace distributed {
struct ChannelContext {
brpc::Channel channel;
std::shared_ptr<sendrecv::SendRecvService_Stub> stub;
};
typedef std::shared_ptr<ChannelContext> ChannelContextPtr;
typedef std::shared_ptr<framework::BlockingQueue<ChannelContextPtr>>
ChannelQueuePtr;
class BRPCClient : public RPCClient {
public:
BRPCClient() {}
virtual ~BRPCClient();
VarHandlePtr AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& out_var_name,
const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncGetMonomerBarrier(
const std::string& ep, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncGetMonomerVariable(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncGetVarNoBarrier(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& out_varname,
int64_t time_out = FLAGS_rpc_deadline);
VarHandlePtr AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncSendBatchBarrier(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncSendFetchBarrier(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dirname,
const std::string& varname, const int mode,
int64_t time_out = FLAGS_rpc_deadline) override;
bool Wait() override;
void SendComplete() override;
private:
VarHandlePtr _AsyncGetVar(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
const std::string& out_var_name, const std::string& method_name,
const std::string& table_name, int64_t time_out = FLAGS_rpc_deadline);
void Proceed();
ChannelQueuePtr GetChannel(const std::string& ep);
VarHandlePtr AsyncSendComplete(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline);
VarHandlePtr AsyncSendMessage(const std::string& ep,
const std::string& method_name,
const std::string& message, int64_t time_out);
VarHandlePtr AsyncSendVarMessage(const std::string& ep,
const std::string& method_name,
const sendrecv::VariableMessage& req,
int64_t time_out);
friend void HandleSendResponse(brpc::Controller* cntl,
sendrecv::VoidMessage* response,
VarHandlePtr var_h, ChannelQueuePtr ch_ptr,
ChannelContextPtr ch_ctx, BRPCClient* cls);
friend void HandleGetResponse(brpc::Controller* cntl,
sendrecv::VariableMessage* response,
VarHandlePtr var_h, ChannelQueuePtr ch_ptr,
ChannelContextPtr ch_ctx, BRPCClient* cls);
friend void HandleFetchBarrierResponse(brpc::Controller* cntl,
sendrecv::VariableMessage* response,
VarHandlePtr var_h,
ChannelQueuePtr ch_ptr,
ChannelContextPtr ch_ctx,
BRPCClient* cls);
void DecreaseReqCount() {
if (--req_count_ <= 0) {
sync_cond_.notify_all();
}
}
private:
std::unordered_map<std::string, ChannelQueuePtr> channels_;
// mutex for Wait client sync
std::mutex sync_mutex_;
std::condition_variable sync_cond_;
std::atomic<int64_t> req_count_{0};
static constexpr int brpc_channel_num_per_server_ = 4;
// mutex for GetChannel thread safety
std::mutex chan_mutex_;
DISABLE_COPY_AND_ASSIGN(BRPCClient);
};
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_BRPC_RDMA
#include "paddle/fluid/operators/distributed/brpc/brpc_rdma_pool.h"
#include "brpc/channel.h"
#include "brpc/rdma/rdma_helper.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace distributed {
RdmaMemPool& RdmaMemPool::Instance() {
static RdmaMemPool* g_rdma_mem_pool = new RdmaMemPool();
return *g_rdma_mem_pool;
}
void* RdmaMemPool::Find(const std::string& varname, int64_t size) {
pthread_rwlock_rdlock(&access_);
auto it = pool_.find(varname);
if (it == pool_.end()) {
pthread_rwlock_unlock(&access_);
return nullptr;
}
auto info = it->second;
if (info.data_size != size) {
pthread_rwlock_unlock(&access_);
PADDLE_THROW(platform::errors::InvalidArgument(
"var:%s size:%ld != %ld", varname, size, info.data_size));
return nullptr;
}
pthread_rwlock_unlock(&access_);
return info.data;
}
void RdmaMemPool::Register(const std::string& varname, void* data,
int64_t data_size) {
void* old = Find(varname, data_size);
if (old != nullptr) {
PADDLE_ENFORCE_EQ(
data, old, platform::errors::InvalidArgument("var:%s data:%ld != %ld",
varname, data, old));
VLOG(7) << "Find on rdma:" << varname << " data:" << data
<< " data_size:" << data_size;
return;
}
VarInfo info;
info.data = data;
info.data_size = data_size;
pthread_rwlock_wrlock(&access_);
pool_[varname] = info;
pthread_rwlock_unlock(&access_);
if (brpc::rdma::RegisterMemoryForRdma(data, data_size)) {
PADDLE_THROW(platform::errors::Unavailable(
"Register memory for RDMA failed. Register %s data: %s data size %d "
"error.",
varname, data, data_size));
}
VLOG(4) << "register on rdma:" << varname << " data:" << data
<< " data_size:" << data_size;
}
} // namespace distributed
} // namespace operators
} // namespace paddle
#endif
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_BRPC_RDMA
#include <pthread.h> // NOLINT
#include <string>
#include <unordered_map>
namespace paddle {
namespace operators {
namespace distributed {
/*
* This class is used to avoid duplicated registion of brpc::rdma.
*/
class RdmaMemPool {
public:
static RdmaMemPool& Instance();
RdmaMemPool() : access_(PTHREAD_RWLOCK_INITIALIZER) {}
virtual ~RdmaMemPool() { pthread_rwlock_destroy(&access_); }
void Register(const std::string& varname, void* data, int64_t size);
void* Find(const std::string& varname, int64_t size);
private:
struct VarInfo {
void* data;
int64_t data_size;
VarInfo() : data(nullptr), data_size(0) {}
};
private:
std::unordered_map<std::string, VarInfo> pool_;
pthread_rwlock_t access_;
};
} // namespace distributed
} // namespace operators
} // namespace paddle
#endif
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_NCCL
#include <nccl.h>
#endif
#ifdef PADDLE_WITH_RCCL
#include <rccl.h>
#endif
#include <sys/time.h>
#include <limits>
#include <memory>
#include <thread> // NOLINT
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_rdma_pool.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/distributed_pb.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
namespace distributed {
class IOBufWriter {
public:
static void Append(const std::string& varname, butil::IOBuf* iobuf, int k,
const char* v, int64_t vlen) {
if (vlen >= std::numeric_limits<int>::max() || vlen < 0) {
PADDDLE_THROW(platform::errors::Unavailable(
"Variable lenght is invalid. Variable name is %s, length is %d.",
varname, vlen));
}
iobuf->append(reinterpret_cast<char*>(&k), 4);
iobuf->append(reinterpret_cast<char*>(&vlen), 8);
iobuf->append(v, vlen);
}
static void AppendTCPZeroCopy(butil::IOBuf* iobuf, int k, const char* v,
int64_t vlen, bool in_cuda_pinned,
void (*destroy)(void*), void* user_data) {
VLOG(7) << "AppendTCPZeroCopy "
<< " k:" << k
<< " data:" << static_cast<void*>(const_cast<char*>(v))
<< " data_size:" << vlen << " in_cuda_pinned:" << in_cuda_pinned;
iobuf->append(reinterpret_cast<char*>(&k), 4);
iobuf->append(reinterpret_cast<char*>(&vlen), 8);
// FIXME(gongwb): use append_zerocopy
/*
if (in_cuda_pinned) {
iobuf->append_zerocopy(v, vlen, IOBufWriter::FreeMemory);
} else {
iobuf->append_zerocopy(v, vlen, nullptr);
}
*/
iobuf->append(v, vlen);
destroy(user_data);
}
#ifdef PADDLE_WITH_BRPC_RDMA
static void AppendRdmaZeroCopy(const std::string varname, butil::IOBuf* iobuf,
int k, const char* v, int64_t vlen,
bool in_cuda_pinned, void (*destroy)(void*),
void* user_data) {
VLOG(7) << "AppendRdmaZeroCopy varname:" << varname << " k:" << k
<< " data:" << static_cast<void*>(const_cast<char*>(v))
<< " data_size:" << vlen << " in_cuda_pinned:" << in_cuda_pinned;
iobuf->append(reinterpret_cast<char*>(&k), 4);
iobuf->append(reinterpret_cast<char*>(&vlen), 8);
RdmaMemPool::Instance().Register(
varname, static_cast<void*>(const_cast<char*>(v)), vlen);
// FIXME(gongwb): use append_zerocopy
// iobuf->append_zerocopy(v, vlen, nullptr);
iobuf->append(v, vlen);
destroy(user_data);
return;
}
#endif
static void AppendZeroCopy(const std::string varname, butil::IOBuf* iobuf,
int k, const char* v, int64_t vlen,
bool in_cuda_pinned, void (*destroy)(void*),
void* user_data) {
if (vlen >= std::numeric_limits<int>::max() || vlen < 0) {
PADDDLE_THROW(platform::errors::Unavailable(
"Variable lenght is invalid. Variable name is %s, length is %d.",
varname, vlen));
}
#ifdef PADDLE_WITH_BRPC_RDMA
IOBufWriter::AppendRdmaZeroCopy(varname, iobuf, k, v, vlen, in_cuda_pinned,
destroy, user_data);
#else
IOBufWriter::AppendTCPZeroCopy(iobuf, k, v, vlen, in_cuda_pinned, destroy,
user_data);
#endif
}
};
void SerializeToIOBuf(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* request,
butil::IOBuf* iobuf, const std::string& out_varname,
bool var_is_not_stable, int trainer_id,
const std::string& table_name) {
std::unique_ptr<TensorPayload> payload;
request->set_varname(name);
request->set_trainer_id(trainer_id);
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
// trainer.
if (platform::ShouldSendProfileState()) {
if (platform::IsProfileEnabled()) {
request->set_profile(platform::kEnableProfiler);
} else {
request->set_profile(platform::kDisableProfiler);
}
}
if (!out_varname.empty()) {
request->set_out_varname(out_varname);
}
if (!table_name.empty()) {
request->set_table_name(table_name);
}
if (var->IsType<framework::LoDTensor>()) {
request->set_type(::sendrecv::LOD_TENSOR);
payload.reset(new TensorPayload(GetTensorPayload(var, ctx, request)));
} else if (var->IsType<framework::SelectedRows>()) {
request->set_type(::sendrecv::SELECTED_ROWS);
payload.reset(new TensorPayload(GetSelectedRowsPayload(var, ctx, request)));
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
} else if (var->IsType<ncclUniqueId>()) {
request->set_type(::sendrecv::NCCL_ID);
const ncclUniqueId& uid = var->Get<ncclUniqueId>();
// TODO(gongwb): use append_zero to avoid data copy.
IOBufWriter::Append(name, iobuf,
sendrecv::VariableMessage::kSerializedFieldNumber,
uid.internal, NCCL_UNIQUE_ID_BYTES);
return;
#endif
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Serialize does not support type: %s", typeid(var->Type()).name()));
}
PADDLE_ENFORCE_NOT_NULL(
payload,
platform::errors::InvalidArgument(
"Not support type: %s, need to be LOD_TENSOR or SELECTED_ROWS.",
var->Type()));
// FIXME(gongwb): it seems that can use zero copy.
if (var_is_not_stable) {
IOBufWriter::Append(
name, iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber,
static_cast<const char*>(payload->ptr()), payload->memory_size());
} else {
if (platform::is_gpu_place(ctx.GetPlace())) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
IOBufWriter::AppendZeroCopy(
name, iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber,
static_cast<const char*>(payload->ptr()), payload->memory_size(),
true, SerializeDestroyCallback, static_cast<void*>(payload.get()));
payload.release();
#endif
} else {
IOBufWriter::AppendZeroCopy(
name, iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber,
static_cast<const char*>(payload->ptr()), payload->memory_size(),
false, SerializeDestroyCallback, static_cast<void*>(payload.get()));
payload.release();
}
}
if (var->IsType<framework::SelectedRows>()) {
auto* slr = var->GetMutable<framework::SelectedRows>();
PADDLE_ENFORCE_EQ(VectorElemName(slr->rows()), typeid(int64_t).name(),
platform::errors::InvalidArgument(
"Got wrong type: %s, expect type: int64_t",
VectorElemName(slr->rows())));
size_t rows_memory_size = slr->rows().size() * sizeof(int64_t);
IOBufWriter::Append(name, iobuf,
::sendrecv::VariableMessage::kRowsFieldNumber,
reinterpret_cast<const char*>(slr->rows().data()),
static_cast<int64_t>(rows_memory_size));
}
}
void DeserializeFromIOBuf(const ::sendrecv::VariableMessage& meta,
const butil::IOBuf& iobuf,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var, int* trainer_id) {
operators::distributed::BRPCVariableResponse resp(scope, &ctx);
PADDLE_ENFORCE_EQ(
resp.Parse(iobuf, meta), 0,
platform::errors::InvalidArgument("parse iobuf to tensor error!"));
*var = resp.GetVar();
*trainer_id = resp.GetTrainerId();
}
} // namespace distributed
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <sys/time.h>
#include <iostream>
#include <string>
#include <vector>
#include "brpc/channel.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/distributed_pb.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
namespace paddle {
namespace operators {
namespace distributed {
void SerializeToIOBuf(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* request,
butil::IOBuf* iobuf, const std::string& out_varname,
bool var_is_not_stable, const int trainer_id = 0,
const std::string& table_name = std::string());
void DeserializeFromIOBuf(const VarMsg& meta, const butil::IOBuf& iobuf,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var, int* trainer_id);
} // namespace distributed
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <string>
#include <thread> // NOLINT
#include "brpc/channel.h"
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace operators = paddle::operators;
namespace math = paddle::operators::math;
namespace memory = paddle::memory;
void RunSerdeTestSelectedRows(platform::Place place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
butil::IOBuf iobuf;
sendrecv::VariableMessage msg;
int tensor_numel = 564 * 128;
// serialize var to IOBuf
{
framework::Variable var;
auto* slr = var.GetMutable<framework::SelectedRows>();
slr->set_height(1000);
auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows();
tensor->Resize(framework::make_ddim({564, 128}));
tensor->mutable_data<float>(place);
math::set_constant(ctx, tensor, 32.7);
for (int i = 0; i < 564; ++i) rows->push_back(i);
operators::distributed::SerializeToIOBuf("myvar", &var, ctx, &msg, &iobuf,
"", false);
}
// desrialize
{
framework::Scope scope;
scope.Var("myvar");
operators::distributed::BRPCVariableResponse resp(&scope, &ctx);
EXPECT_EQ(resp.Parse(iobuf, msg), 0);
framework::Variable* var2 = resp.GetVar();
auto* slr2 = var2->GetMutable<framework::SelectedRows>();
auto* tensor2 = slr2->mutable_value();
auto* rows2 = slr2->mutable_rows();
float* tensor_data2 = nullptr;
framework::Tensor tmp_tensor;
if (platform::is_gpu_place(ctx.GetPlace())) {
platform::CPUPlace cpu;
framework::TensorCopy(*tensor2, cpu, &tmp_tensor);
tensor_data2 = tmp_tensor.data<float>();
} else {
tensor_data2 = const_cast<float*>(tensor2->data<float>());
}
const int64_t* rows_data2 = rows2->data();
for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data2[i], 32.7);
}
for (size_t i = 0; i < rows2->size(); ++i) {
EXPECT_EQ(rows_data2[i], static_cast<int64_t>(i));
}
EXPECT_EQ(slr2->height(), 1000);
}
}
void RunTestLodTensor(platform::Place place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
// serialize var to ByteBuffer
butil::IOBuf iobuf;
sendrecv::VariableMessage msg;
int tensor_numel = 512 * 8 * 4 * 2;
{
framework::Variable var;
auto* tensor = var.GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim({512, 8, 4, 2}));
framework::LoD lod;
lod.push_back(framework::Vector<size_t>({1, 3, 8}));
tensor->set_lod(lod);
tensor->mutable_data<float>(place);
math::set_constant(ctx, tensor, 31.9);
operators::distributed::SerializeToIOBuf("myvar", &var, ctx, &msg, &iobuf,
"", false);
}
// check sendrecv::VariableMessage meta data
{
EXPECT_EQ(msg.varname(), "myvar");
EXPECT_EQ(msg.type(), 0);
EXPECT_EQ(msg.dims()[0], 512);
EXPECT_EQ(msg.dims()[1], 8);
EXPECT_EQ(msg.dims()[2], 4);
EXPECT_EQ(msg.dims()[3], 2);
EXPECT_EQ(msg.lod_level(), 1);
EXPECT_EQ(msg.lod(0).lod_data(0), 1);
EXPECT_EQ(msg.lod(0).lod_data(1), 3);
EXPECT_EQ(msg.lod(0).lod_data(2), 8);
}
// deserialize
{
framework::Scope scope;
scope.Var("myvar");
operators::distributed::BRPCVariableResponse resp(&scope, &ctx);
EXPECT_EQ(resp.Parse(iobuf, msg), 0);
framework::Variable* var2 = resp.GetVar();
auto tensor2 = var2->Get<framework::LoDTensor>();
float* tensor_data2 = nullptr;
framework::Tensor tmp_tensor;
if (platform::is_gpu_place(ctx.GetPlace())) {
platform::CPUPlace cpu;
framework::TensorCopy(tensor2, cpu, &tmp_tensor);
tensor_data2 = tmp_tensor.data<float>();
} else {
tensor_data2 = const_cast<float*>(tensor2.data<float>());
}
for (int i = 0; i < tensor_numel; ++i)
EXPECT_FLOAT_EQ(tensor_data2[i], 31.9);
}
}
TEST(LodTensor, Run) {
platform::CPUPlace place;
RunTestLodTensor(place);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform::CUDAPlace gpu(0);
RunTestLodTensor(gpu);
#endif
}
TEST(SelectedRows, Run) {
platform::CPUPlace place;
RunSerdeTestSelectedRows(place);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform::CUDAPlace gpu;
RunSerdeTestSelectedRows(gpu);
#endif
}
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/brpc/brpc_server.h"
#include <memory>
#include <unordered_map>
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
namespace sendrecv {
namespace distributed = paddle::operators::distributed;
typedef std::unordered_map<std::string, distributed::RequestHandler*>
HandlerMap;
class BRPCServiceImpl : public SendRecvService {
public:
explicit BRPCServiceImpl(const HandlerMap& rpc_call_map,
distributed::RPCServer* rpc_server)
: rpc_server_(rpc_server) {
VLOG(3) << "BRPCServiceImpl size: " << rpc_call_map.size();
auto it = rpc_call_map.find(distributed::kRequestSend);
if (it != rpc_call_map.end()) {
request_send_h_ = it->second;
send_threads_.reset(new paddle::framework::ThreadPool(
rpc_server_->GetThreadNum(distributed::kRequestSend)));
}
it = rpc_call_map.find(distributed::kRequestGet);
if (it != rpc_call_map.end()) {
request_get_h_ = it->second;
get_threads_.reset(new paddle::framework::ThreadPool(
rpc_server_->GetThreadNum(distributed::kRequestGet)));
}
it = rpc_call_map.find(distributed::kRequestGetNoBarrier);
if (it != rpc_call_map.end()) {
request_getnobarrier_h_ = it->second;
getnobarrier_threads_.reset(new paddle::framework::ThreadPool(
rpc_server_->GetThreadNum(distributed::kRequestGetNoBarrier)));
}
it = rpc_call_map.find(distributed::kRequestPrefetch);
if (it != rpc_call_map.end()) {
request_prefetch_h_ = it->second;
prefetch_threads_.reset(new paddle::framework::ThreadPool(
rpc_server_->GetThreadNum(distributed::kRequestPrefetch)));
}
it = rpc_call_map.find(distributed::kRequestCheckpoint);
if (it != rpc_call_map.end()) {
request_checkpoint_h_ = it->second;
checkpoint_notify_threads_.reset(new paddle::framework::ThreadPool(
rpc_server_->GetThreadNum(distributed::kRequestPrefetch)));
}
it = rpc_call_map.find(distributed::kRequestGetMonomerVariable);
if (it != rpc_call_map.end()) {
request_get_monomer_handler_h_ = it->second;
}
it = rpc_call_map.find(distributed::kRequestGetMonomerBarrier);
if (it != rpc_call_map.end()) {
request_get_monomer_barrier_handler_h_ = it->second;
}
}
virtual ~BRPCServiceImpl() {}
void SendVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) override {
send_threads_->Run(
[=] { _SendVariable(cntl_butil, request, response, done); });
}
void _SendVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) {
PADDLE_ENFORCE_NOT_NULL(
request_send_h_, platform::errors::PreconditionNotMet(
"RequestSend handler should be registed first!"));
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
std::string varname = request->varname();
VLOG(3) << "RequestSend var_name:" << varname
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
distributed::BRPCVariableResponse resp(request_send_h_->scope(),
request_send_h_->dev_ctx(),
request_send_h_->distributed_mode());
PADDLE_ENFORCE_EQ(
resp.Parse(cntl->request_attachment(), *request), 0,
platform::errors::InvalidArgument("parse iobuf to tensor error!"));
auto scope = resp.GetMutableLocalScope();
auto invar = resp.GetVar();
int trainer_id = request->trainer_id();
paddle::framework::Variable* outvar = nullptr;
request_send_h_->Handle(varname, scope, invar, &outvar, trainer_id);
}
void GetVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VariableMessage* response,
google::protobuf::Closure* done) override {
get_threads_->Run(
[=] { _GetVariable(cntl_butil, request, response, done); });
}
void GetVariableNoBarrier(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request,
VariableMessage* response,
google::protobuf::Closure* done) override {
getnobarrier_threads_->Run(
[=] { _GetVariableNoBarrier(cntl_butil, request, response, done); });
}
void _GetVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VariableMessage* response,
google::protobuf::Closure* done) {
PADDLE_ENFORCE_NOT_NULL(
request_get_h_, platform::errors::PreconditionNotMet(
"RequestGet handler should be registed first!"));
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
std::string varname = request->varname();
std::string out_varname = request->out_varname();
VLOG(3) << "RequestGet varname:" << varname
<< ", out_varname:" << out_varname
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
auto scope = request_get_h_->scope();
paddle::framework::Variable* invar = nullptr;
int trainer_id = request->trainer_id();
paddle::framework::Variable* outvar = nullptr;
request_get_h_->Handle(varname, scope, invar, &outvar, trainer_id,
out_varname);
if (outvar) {
distributed::SerializeToIOBuf(out_varname, outvar,
*request_get_h_->dev_ctx(), response,
&cntl->response_attachment(), "", false);
}
}
void _GetVariableNoBarrier(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request,
VariableMessage* response,
google::protobuf::Closure* done) {
PADDLE_ENFORCE_NOT_NULL(
request_getnobarrier_h_,
platform::errors::PreconditionNotMet(
"RequestGetNoBarrier handler should be registed first!"));
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
std::string varname = request->varname();
std::string out_varname = request->out_varname();
int trainer_id = request->trainer_id();
VLOG(3) << "RequestGetNoBarrier varname:" << varname
<< ", out_varname:" << out_varname << ", trainer_id:" << trainer_id
<< ", from:" << cntl->remote_side();
auto scope = request_getnobarrier_h_->scope();
paddle::framework::Variable* invar = nullptr;
paddle::framework::Variable* outvar = nullptr;
request_getnobarrier_h_->Handle(varname, scope, invar, &outvar, trainer_id,
out_varname);
if (outvar) {
distributed::SerializeToIOBuf(
out_varname, outvar, *request_getnobarrier_h_->dev_ctx(), response,
&cntl->response_attachment(), "", false);
}
}
void PrefetchVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request,
VariableMessage* response,
google::protobuf::Closure* done) override {
prefetch_threads_->Run(
[=] { _PrefetchVariable(cntl_butil, request, response, done); });
}
void _PrefetchVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request,
VariableMessage* response,
google::protobuf::Closure* done) {
PADDLE_ENFORCE_NOT_NULL(request_prefetch_h_,
platform::errors::PreconditionNotMet(
"kRequestPrefetch handler should be registed first!");
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
// prefetch process...
std::string in_var_name = request->varname();
std::string out_var_name = request->out_varname();
VLOG(3) << "RequestPrefetch, in_var_name: " << in_var_name
<< ", out_var_name: " << out_var_name
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
distributed::BRPCVariableResponse resp(
request_prefetch_h_->scope(), request_prefetch_h_->dev_ctx(), true);
PADDLE_ENFORCE_EQ(resp.Parse(cntl->request_attachment(), *request), 0,
platform::errors::InvalidArgument(
"parse iobuf to tensor error!"));
auto scope = resp.GetMutableLocalScope();
auto invar = scope->FindVar(in_var_name);
std::string table_name = request->table_name();
int trainer_id = request->trainer_id();
paddle::framework::Variable* outvar = scope->Var(out_var_name);
request_prefetch_h_->Handle(in_var_name, scope, invar, &outvar, trainer_id,
out_var_name, table_name);
distributed::SerializeToIOBuf(out_var_name, outvar,
*request_prefetch_h_->dev_ctx(), response,
&cntl->response_attachment(), "", true);
}
void CheckpointNotify(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) override {
checkpoint_notify_threads_->Run(
[=] { _CheckpointNotify(cntl_butil, request, response, done); });
}
void _CheckpointNotify(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) {
PADDLE_ENFORCE_NOT_NULL(
request_checkpoint_h_,
platform::errors::PreconditionNotMet(
"kRequestCheckpointNotify handler should be registed first!"));
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
distributed::BRPCVariableResponse resp(request_checkpoint_h_->scope(),
request_checkpoint_h_->dev_ctx());
auto scope = resp.GetMutableLocalScope();
std::string checkpoint_notify = request->varname();
std::string checkpoint_dir = request->out_varname();
int trainer_id = request->trainer_id();
VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify
<< ", dir: " << checkpoint_dir
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
request_checkpoint_h_->Handle(checkpoint_notify, scope, nullptr, nullptr,
trainer_id, checkpoint_dir);
}
void GetMonomerVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request,
VariableMessage* response,
google::protobuf::Closure* done) override {
PADDLE_ENFORCE_NOT_NULL(
request_get_monomer_handler_h_,
platform::errors::PreconditionNotMet(
"kRequestGetMonomerVariable handler should be registed first!"));
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
// proc request.
std::string varname = request->varname();
VLOG(3) << "GetMonomerVariable " << varname
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
rpc_server_->WaitVarCond(varname);
distributed::MonomerHandle h = rpc_server_->GetMonomer(varname);
auto scope = h.scope_;
auto invar = scope->FindVar(varname);
paddle::framework::Variable* outvar = nullptr;
request_get_monomer_handler_h_->Handle(varname, scope, invar, &outvar,
request->trainer_id());
if (outvar) {
distributed::SerializeToIOBuf(varname, outvar, *h.dev_ctx_, response,
&cntl->response_attachment(), "", false);
}
}
void GetMonomerBarrier(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) override {
PADDLE_ENFORCE_NOT_NULL(
request_get_monomer_barrier_handler_h_,
platform::errors::PreconditionNotMet(
"RequestGetMonomerBarrier handler should be registed first!"));
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
std::string varname = request->varname();
VLOG(3) << "RequestGetMonomerBarrier var_name:" << varname
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
rpc_server_->WaitVarCond(varname);
distributed::MonomerHandle h = rpc_server_->GetMonomer(varname);
paddle::framework::Scope* scope = nullptr;
paddle::framework::Variable* invar = nullptr;
paddle::framework::Variable* outvar = nullptr;
request_get_monomer_barrier_handler_h_->Handle(
varname, scope, invar, &outvar, request->trainer_id());
}
private:
distributed::RequestHandler* request_send_h_{nullptr};
distributed::RequestHandler* request_get_h_{nullptr};
distributed::RequestHandler* request_getnobarrier_h_{nullptr};
distributed::RequestHandler* request_prefetch_h_{nullptr};
distributed::RequestHandler* request_checkpoint_h_{nullptr};
distributed::RequestHandler* request_get_monomer_handler_h_{nullptr};
distributed::RequestHandler* request_get_monomer_barrier_handler_h_{nullptr};
distributed::RPCServer* rpc_server_{nullptr};
// FIXME(gongwb): brpc should support process one rpc use one threadpool.
std::unique_ptr<paddle::framework::ThreadPool> send_threads_;
std::unique_ptr<paddle::framework::ThreadPool> get_threads_;
std::unique_ptr<paddle::framework::ThreadPool> getnobarrier_threads_;
std::unique_ptr<paddle::framework::ThreadPool> prefetch_threads_;
std::unique_ptr<paddle::framework::ThreadPool> checkpoint_notify_threads_;
};
} // namespace sendrecv
namespace paddle {
namespace operators {
namespace distributed {
void AsyncBRPCServer::StartServer() {
// Instance of your service.
sendrecv::BRPCServiceImpl service_impl(rpc_call_map_, this);
// Add the service into server. Notice the second parameter, because the
// service is put on stack, we don't want server to delete it, otherwise
// use brpc::SERVER_OWNS_SERVICE.
if (server_.AddService(&service_impl, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) {
PADDDLE_THROW(platform::errors::Unavailable(
"Failed to add service into BRPC server."));
return;
}
brpc::ServerOptions options;
#ifdef PADDLE_WITH_BRPC_RDMA
options.use_rdma = true;
#endif
options.idle_timeout_sec = idle_timeout_s_;
options.max_concurrency = max_concurrency_;
if (server_.Start(bind_address_.c_str(), &options) != 0) {
PADDDLE_THROW(platform::errors::Unavailable(
"Failed to start EchoServer %s.", bind_address_));
return;
}
butil::EndPoint ep = server_.listen_address();
selected_port_ = ep.port;
{
std::lock_guard<std::mutex> lock(this->mutex_ready_);
ready_ = 1;
}
condition_ready_.notify_all();
server_.Join();
}
void AsyncBRPCServer::ShutDownImpl() { server_.Stop(1000); }
void AsyncBRPCServer::WaitServerReady() {
VLOG(3) << "AsyncGRPCServer is wait server ready";
std::unique_lock<std::mutex> lock(this->mutex_ready_);
condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
VLOG(3) << "AsyncGRPCServer WaitSeverReady";
}
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <condition_variable> // NOLINT
#include <mutex> // NOLINT
#include <string>
#include "brpc/server.h"
#include "paddle/fluid/operators/distributed/distributed_pb.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
namespace paddle {
namespace operators {
namespace distributed {
class AsyncBRPCServer final : public RPCServer {
public:
explicit AsyncBRPCServer(const std::string& address, int client_num)
: RPCServer(address, client_num), ready_(0) {}
virtual ~AsyncBRPCServer() {}
void StartServer() override;
void WaitServerReady() override;
private:
void ShutDownImpl() override;
brpc::Server server_;
static constexpr int idle_timeout_s_ = -1;
static constexpr int max_concurrency_ = 0;
std::mutex mutex_ready_;
std::condition_variable condition_ready_;
int ready_;
};
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "paddle/fluid/operators/distributed/brpc/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
namespace paddle {
namespace operators {
namespace distributed {
namespace pb = ::google::protobuf;
using vr = ::sendrecv::VariableMessage;
int BRPCVariableResponse::Parse(Source* source) {
pb::io::ZeroCopyInputStream* input_stream = source->contents();
pb::io::CodedInputStream input(input_stream);
input.SetTotalBytesLimit(INT_MAX, INT_MAX);
while (1) {
unsigned int tag = 0;
if (!input.ReadLittleEndian32(&tag)) {
break;
}
uint64_t num_bytes = 0;
if (!input.ReadLittleEndian64(&num_bytes)) {
break;
}
int field = static_cast<int>(tag);
int ret = field == 0 ? -1 : field;
switch (field) {
case vr::kSerializedFieldNumber: {
if (!ProcSerializedField(field, &input, num_bytes)) {
return ret;
}
break;
}
case vr::kRowsFieldNumber: {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
meta_.type() == sendrecv::LOD_TENSOR) &&
meta_.varname() != "",
platform::errors::PreconditionNotMet(
"meta info should be got first!"));
if (!CopySelectRowsData(&input, *dev_ctx_, num_bytes)) {
return ret;
}
break;
}
default: {
PADDLE_THROW(platform::errors::Unavailable(
"not surpported %u fieldnumber", field));
return ret;
}
}
}
return 0;
}
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "brpc/channel.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/distributed_pb.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
namespace paddle {
namespace operators {
namespace distributed {
class BRPCSourceWrapper : public Source {
public:
explicit BRPCSourceWrapper(const butil::IOBuf& iobuf) : source_(iobuf) {}
::google::protobuf::io::ZeroCopyInputStream* contents() override {
return &source_;
}
private:
butil::IOBufAsZeroCopyInputStream source_;
};
class BRPCVariableResponse : public VariableResponse {
public:
BRPCVariableResponse(const framework::Scope* scope,
const platform::DeviceContext* dev_ctx,
bool create_scope = false)
: VariableResponse(scope, dev_ctx, create_scope) {}
virtual ~BRPCVariableResponse() {}
// parse attachment from iobuf
int Parse(Source* source) override;
int Parse(const butil::IOBuf& iobuf, const sendrecv::VariableMessage& meta) {
BRPCSourceWrapper wrapper(iobuf);
return VariableResponse::Parse(&wrapper, meta);
}
};
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/collective_client.h"
#include <memory>
#include "gflags/gflags.h"
DECLARE_int32(rpc_deadline);
namespace paddle {
namespace operators {
namespace distributed {
std::once_flag CollectiveClient::init_flag_;
std::unique_ptr<CollectiveClient> CollectiveClient::client_(nullptr);
bool CollectiveClient::Gather(const std::vector<RemoteVar>& remote_vars,
std::vector<const framework::SelectedRows*>* dst,
const platform::DeviceContext& ctx,
framework::Scope* scope, int64_t time_out) {
for (auto r : remote_vars) {
VLOG(50) << "begin gather from ep:" << r.String();
scope->Var(r.var_name_)->GetMutable<framework::SelectedRows>();
VarHandlePtr ptr = rpc_client_->AsyncGetMonomerVariable(
r.ep_, ctx, *scope, r.var_name_, time_out);
}
rpc_client_->Wait();
for (auto r : remote_vars) {
auto select_rows =
scope->FindVar(r.var_name_)->GetMutable<framework::SelectedRows>();
dst->push_back(select_rows);
VLOG(4) << "gather from ep:" << r.String()
<< ", select_rows:" << GetSelectedRowsInfo(*select_rows);
rpc_client_->AsyncGetMonomerBarrier(r.ep_, r.var_name_);
}
rpc_client_->Wait();
return true;
}
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <condition_variable> // NOLINT
#include <memory>
#include <string>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
namespace paddle {
namespace framework {
class Scope;
class SelectedRows;
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
DECLARE_int32(rpc_deadline);
namespace paddle {
namespace operators {
namespace distributed {
inline std::string GetSelectedRowsInfo(const framework::SelectedRows& slr) {
std::stringstream ss;
ss << ", height:" << slr.height() << ", rows:[";
for (unsigned int i = 0; i < slr.rows().size(); i++) {
if (i != slr.rows().size() - 1) {
ss << slr.rows()[i] << ",";
} else {
ss << slr.rows()[i];
}
}
ss << "], dims:" << slr.value().dims();
return ss.str();
}
struct RemoteVar {
std::string ep_;
std::string var_name_;
int trainer_id_{0};
std::string String() {
std::stringstream ss;
ss << "ep:" << ep_ << ", var_name:" << var_name_
<< ", trainer_id:" << trainer_id_;
return ss.str();
}
};
class CollectiveClient {
public:
CollectiveClient() {
rpc_client_.reset(new RPCCLIENT_T());
rpc_client_->InitImpl();
}
virtual ~CollectiveClient() {}
// note this function will retain the rank order.
bool Gather(const std::vector<RemoteVar>& remote_vars,
std::vector<const framework::SelectedRows*>* dst,
const platform::DeviceContext& ctx, framework::Scope* scope,
int64_t time_out = FLAGS_rpc_deadline);
static CollectiveClient* GetInstance() {
std::call_once(init_flag_, [&]() {
if (client_.get() == nullptr) {
client_.reset(new CollectiveClient());
}
});
return client_.get();
}
private:
std::unique_ptr<RPCClient> rpc_client_;
static std::once_flag init_flag_;
static std::unique_ptr<CollectiveClient> client_;
};
} // namespace distributed
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/distributed/collective_server.h"
#include <memory>
DEFINE_int32(collective_get_thread_num, 5, "number of threads for rpc get");
namespace paddle {
namespace operators {
namespace distributed {
std::once_flag CollectiveServer::init_flag_;
std::shared_ptr<CollectiveServer> CollectiveServer::collective_server_(nullptr);
CollectiveServer::CollectiveServer(const std::string& end_point, int fan_in) {
VLOG(1) << "Create colllective server:" << end_point << ", fan_in:" << fan_in;
rpc_server_.reset(new RPCSERVER_T(end_point, fan_in));
}
void CollectiveServer::Stop() {
rpc_server_->ShutDown();
server_thread_->join();
loop_thread_->join();
}
void CollectiveServer::StartServer() {
get_monomer_handler_.reset(new GetMonomerHandler());
get_monomer_handler_->SetRPCServer(rpc_server_.get());
get_barrier_handler_.reset(new GetMonomerBarrierHandler());
get_barrier_handler_->SetRPCServer(rpc_server_.get());
rpc_server_->RegisterRPC(distributed::kRequestGetMonomerVariable,
get_monomer_handler_.get(),
FLAGS_collective_get_thread_num);
rpc_server_->RegisterRPC(distributed::kRequestGetMonomerBarrier,
get_barrier_handler_.get(), 1);
server_thread_.reset(new std::thread([&]() { rpc_server_->StartServer(); }));
rpc_server_->WaitServerReady();
loop_thread_.reset(new std::thread([&]() {
while (true) {
if (rpc_server_->IsExit()) {
LOG(WARNING) << "get exit!rpc_processor break!";
break;
}
sleep(1);
}
VLOG(1) << "CollectiveServer loop_thread end";
}));
}
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include <memory>
#include <set>
#include <string>
#include <thread> // NOLINT
#include <utility>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
namespace paddle {
namespace framework {
class Variable;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
namespace distributed {
class CollectiveServer;
class GetMonomerHandler final : public RequestHandler {
public:
GetMonomerHandler() : RequestHandler(true) {}
virtual ~GetMonomerHandler() {}
bool Handle(const std::string& var_name, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id, const std::string& out_var_name = "",
const std::string& table_name = "") override {
VLOG(50) << "GetMonomerHandler recv " << var_name;
*outvar = scope->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(
outvar, platform::errors::NotFound("var: %s is not found.", var_name));
return true;
}
};
class GetMonomerBarrierHandler final : public RequestHandler {
public:
GetMonomerBarrierHandler() : RequestHandler(true) {}
virtual ~GetMonomerBarrierHandler() {}
bool Handle(const std::string& var_name, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id, const std::string& out_var_name = "",
const std::string& table_name = "") override {
VLOG(50) << "GetMonomerHandler recv " << var_name;
rpc_server_->IncreaseVarBarrier(var_name);
return true;
}
};
class CollectiveServer final {
public:
explicit CollectiveServer(const std::string& end_point, int fan_in);
virtual ~CollectiveServer() {}
void StartServer();
static CollectiveServer* GetInstance(const std::string& end_point,
int fan_in) {
std::call_once(init_flag_, [&]() {
if (collective_server_.get() == nullptr) {
collective_server_.reset(new CollectiveServer(end_point, fan_in));
collective_server_->StartServer();
}
});
return collective_server_.get();
}
std::shared_ptr<RPCServer> GetRPCServer() { return rpc_server_; }
void Stop();
private:
std::unique_ptr<GetMonomerHandler> get_monomer_handler_;
std::unique_ptr<GetMonomerBarrierHandler> get_barrier_handler_;
std::shared_ptr<distributed::RPCServer> rpc_server_;
std::shared_ptr<std::thread> server_thread_;
std::shared_ptr<std::thread> loop_thread_;
bool ready_{false};
static std::once_flag init_flag_;
static std::shared_ptr<CollectiveServer> collective_server_;
};
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdlib.h>
#include <memory>
#include <string>
#include "gtest/gtest.h"
#include "paddle/fluid/operators/distributed/collective_client.h"
#include "paddle/fluid/operators/distributed/collective_server.h"
namespace paddle {
namespace framework {
class Variable;
} // namespace framework
} // namespace paddle
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace distributed = paddle::operators::distributed;
std::unique_ptr<distributed::CollectiveServer> StartServer(
const std::string& ep, int fan_in, framework::Scope* scope,
platform::DeviceContext* dev_ctx) {
distributed::CollectiveServer* server =
distributed::CollectiveServer::GetInstance(ep, fan_in);
auto rpc_server = server->GetRPCServer();
rpc_server->RegisterVar("var1", distributed::kRequestGetMonomerVariable,
scope, dev_ctx);
std::cout << "StartServer return" << std::endl;
return std::unique_ptr<distributed::CollectiveServer>(server);
}
std::unique_ptr<framework::Scope> GenerateVars(platform::Place place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
framework::Scope* scope = new framework::Scope();
framework::Variable* var = scope->Var("var1");
auto* slr = var->GetMutable<framework::SelectedRows>();
slr->set_height(20000);
auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows();
tensor->Resize(framework::make_ddim({3, 1024}));
tensor->mutable_data<float>(place);
paddle::operators::math::set_constant(ctx, tensor, 32.7);
for (int i = 0; i < 3; ++i) rows->push_back(i);
std::cout << "src:" << distributed::GetSelectedRowsInfo(*slr);
return std::unique_ptr<framework::Scope>(scope);
}
void Gather(const std::vector<distributed::RemoteVar>& vars,
platform::DeviceContext* dev_ctx) {
distributed::CollectiveClient* client =
distributed::CollectiveClient::GetInstance();
framework::Scope* scope = new framework::Scope();
framework::Variable* var = scope->Var("var1");
var->GetMutable<framework::SelectedRows>();
std::vector<const framework::SelectedRows*> dst;
client->Gather(vars, &dst, *dev_ctx, scope);
std::cout << "dst:" << distributed::GetSelectedRowsInfo(*dst[0]);
dev_ctx->Wait();
ASSERT_EQ(dst[0]->value().dims(), framework::make_ddim({3, 1024}));
ASSERT_EQ(dst[0]->height(), 20000);
ASSERT_EQ(dst[0]->rows().size(), static_cast<size_t>(3));
for (int i = 0; i < 3; i++) {
ASSERT_EQ(dst[0]->rows()[i], i);
}
std::vector<float> vec;
TensorToVector(dst[0]->value(), *dev_ctx, &vec);
for (size_t i = 0; i < 3 * 1024; i++) {
ASSERT_FLOAT_EQ(vec[i], 32.7);
}
}
TEST(CollectiveServer, GPU) {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
platform::CUDAPlace place;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
std::string ep = "127.0.0.1:7164";
auto scope = GenerateVars(place);
auto* v1 = scope->FindVar("var1");
std::cout << "var1:" << v1 << std::endl;
auto server = StartServer(ep, 2, scope.get(), &ctx);
auto rpc_server = server->GetRPCServer();
distributed::RemoteVar var;
var.ep_ = ep;
var.var_name_ = "var1";
var.trainer_id_ = 0;
std::vector<distributed::RemoteVar> vars{var};
Gather(vars, &ctx);
Gather(vars, &ctx);
std::cout << "begin WaitVarBarrier" << std::endl;
rpc_server->WaitVarBarrier("var1");
rpc_server->ClearRegisteredVars();
server->Stop();
scope.release();
server.release();
}
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/distributed/communicator.h"
#include <paddle/fluid/framework/program_desc.h>
#include <algorithm>
#include <chrono> // NOLINT
#include <map>
#include <thread> // NOLINT
#include <unordered_set>
#include "gflags/gflags.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/operators/distributed/parameter_send.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/split.h"
namespace paddle {
namespace operators {
namespace distributed {
using Tree =
std::map<std::string, std::map<std::string, std::vector<std::string>>>;
using RpcCtxMap = operators::distributed::RpcCtxMap;
inline double GetCurrentUS() {
struct timeval time;
gettimeofday(&time, NULL);
return 1e+6 * time.tv_sec + time.tv_usec;
}
Communicator::Communicator() {}
std::once_flag Communicator::init_flag_;
std::shared_ptr<Communicator> Communicator::communicator_(nullptr);
void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RpcCtxMap &recv_varname_to_ctx,
Scope *recv_scope) {
send_varname_to_ctx_ = std::move(send_varname_to_ctx);
recv_varname_to_ctx_ = std::move(recv_varname_to_ctx);
recv_scope_ = std::move(recv_scope);
if (send_varname_to_ctx.size() == 0) {
VLOG(0) << "nothing need to be send, will not start send_thread";
} else {
send_scope_.reset(new Scope());
for (auto &iter : send_varname_to_ctx_) {
if (iter.first == STEP_COUNTER && !need_global_step_) continue;
send_varname_to_queue_[iter.first] =
std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(
send_queue_size_);
}
send_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
}
if (recv_varname_to_ctx.size() == 0) {
VLOG(0) << "nothing need to be received, will not start recv_thread";
} else {
recv_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
}
InitParams();
}
void AsyncCommunicator::InitParams() { RecvNoBarrier(); }
AsyncCommunicator::~AsyncCommunicator() {
running_ = false;
if (main_thread_) main_thread_->join();
}
void AsyncCommunicator::SendGlobalStep(int batches) {
if (!need_global_step_) {
return;
}
if (batches == 0) {
return;
}
auto &var_name = STEP_COUNTER;
auto *out_var = send_scope_->Var(var_name);
auto *out_t = out_var->GetMutable<framework::LoDTensor>();
auto *data = out_t->mutable_data<int64_t>({1}, platform::CPUPlace());
data[0] = static_cast<int64_t>(batches);
auto &ctx = send_varname_to_ctx_.at(var_name);
auto send_functor = distributed::ParameterSend<float>();
send_functor(ctx, *send_scope_, true, 1);
}
void AsyncCommunicator::SendByCommunicator() {
std::vector<std::future<void>> task_futures;
task_futures.reserve(send_varname_to_ctx_.size());
VLOG(3) << "run send graph";
auto before_run_send_graph = GetCurrentUS();
for (auto &iter : send_varname_to_queue_) {
auto &var_name = iter.first;
auto &var_queue = iter.second;
auto send_task = [this, &var_name, &var_queue] {
VLOG(3) << var_name << " merge and send; ";
std::vector<std::shared_ptr<Variable>> vars;
int merged_var_num = 0;
int wait_times = 0;
while (merged_var_num < max_merge_var_num_) {
if (var_queue->Size() == 0) {
VLOG(4) << "wait_times -> " << wait_times;
if (wait_times >= send_wait_times_) {
break;
}
std::this_thread::sleep_for(std::chrono::milliseconds(10));
wait_times++;
continue;
} else {
wait_times = 0;
vars.push_back(var_queue->Pop());
merged_var_num++;
}
}
auto before_merge = GetCurrentUS();
if (var_name == STEP_COUNTER) {
SendGlobalStep(merged_var_num);
auto after_merge = GetCurrentUS();
VLOG(3) << "merge and send " << merged_var_num << " " << var_name
<< " use time " << after_merge - before_merge;
return;
}
auto &ctx = send_varname_to_ctx_.at(var_name);
MergeVars<float>(var_name, vars, send_scope_.get(), ctx.merge_add);
auto after_merge = GetCurrentUS();
VLOG(3) << "merge " << merged_var_num << " " << var_name << " use time "
<< after_merge - before_merge;
auto send_functor = distributed::ParameterSend<float>();
send_functor(ctx, *send_scope_, true, 1);
auto after_send = GetCurrentUS();
VLOG(3) << "send " << var_name << " use time "
<< after_send - after_merge;
if (var_name.rfind("@GRAD") != var_name.size() - 5) return;
auto recv_param = var_name.substr(0, var_name.size() - 5);
if (recv_varname_to_ctx_.find(recv_param) == recv_varname_to_ctx_.end())
return;
auto recv_functor = distributed::ParameterRecv<float>();
recv_functor(recv_varname_to_ctx_.at(recv_param), *recv_scope_);
auto after_recv = GetCurrentUS();
VLOG(3) << "recv " << recv_param << " use time "
<< after_recv - after_send;
};
task_futures.emplace_back(send_threadpool_->enqueue(std::move(send_task)));
}
for (auto &task_f : task_futures) {
task_f.wait();
}
auto after_run_send_graph = GetCurrentUS();
VLOG(3) << "run send graph use time "
<< (after_run_send_graph - before_run_send_graph);
}
void HalfAsyncCommunicator::SendByCommunicator() {
std::vector<std::future<void>> task_futures;
task_futures.reserve(send_varname_to_ctx_.size());
VLOG(3) << "run send graph";
int batches = BatchesCounter();
if (batches <= 0) return;
auto before_run_send_graph = GetCurrentUS();
for (auto &iter : send_varname_to_queue_) {
auto &var_name = iter.first;
auto &var_queue = iter.second;
auto send_task = [this, batches, &var_name, &var_queue] {
VLOG(3) << var_name << " merge and send; ";
auto before_task = GetCurrentUS();
std::vector<std::shared_ptr<Variable>> vars;
vars.reserve(batches);
for (int i = 0; i < batches; ++i) {
vars.push_back(var_queue->Pop());
}
if (var_name == STEP_COUNTER) {
SendGlobalStep(batches);
auto end_task = GetCurrentUS();
VLOG(3) << "merge " << batches << " " << var_name << " use time "
<< end_task - before_task;
return;
}
auto &ctx = send_varname_to_ctx_.at(var_name);
auto before_merge = GetCurrentUS();
MergeVars<float>(var_name, vars, send_scope_.get(), ctx.merge_add);
auto after_merge = GetCurrentUS();
VLOG(3) << "merge " << batches << " " << var_name << " use time "
<< after_merge - before_merge;
auto send_functor = distributed::ParameterSend<float>();
send_functor(ctx, *send_scope_, true, 1);
auto after_send = GetCurrentUS();
VLOG(3) << "send " << var_name << " use time "
<< after_send - before_task;
if (var_name.rfind("@GRAD") != var_name.size() - 5) return;
auto recv_param = var_name.substr(0, var_name.size() - 5);
if (recv_varname_to_ctx_.find(recv_param) == recv_varname_to_ctx_.end())
return;
auto recv_functor = distributed::ParameterRecv<float>();
recv_functor(recv_varname_to_ctx_.at(recv_param), *recv_scope_);
auto after_recv = GetCurrentUS();
VLOG(3) << "recv " << recv_param << " use time "
<< after_recv - after_send;
return;
};
task_futures.emplace_back(send_threadpool_->enqueue(std::move(send_task)));
}
for (auto &task_f : task_futures) {
task_f.wait();
}
auto after_run_send_graph = GetCurrentUS();
VLOG(3) << "run send graph use time "
<< (after_run_send_graph - before_run_send_graph);
}
void AsyncCommunicator::MainThread() {
VLOG(3) << "MainThread start and wait";
while (waiting_ && running_) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
VLOG(3) << "wait for running";
}
while (running_) {
SendByCommunicator();
BarrierSend();
}
VLOG(3) << "communicator stopped, send thread exit";
}
void HalfAsyncCommunicator::MainThread() {
VLOG(3) << "MainThread start and wait";
while (waiting_ && running_) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
VLOG(3) << "wait for running";
}
while (running_) {
SendByCommunicator();
BarrierSend();
RecvByCommunicator();
BarrierRecv();
BarrierWeakUp();
}
VLOG(3) << "communicator stopped, send thread exit";
}
void AsyncCommunicator::RecvByCommunicator() {
VLOG(3) << "parallel run recv graph";
if (!running_) return;
RecvNoBarrier();
VLOG(3) << "run recv graph use time";
}
void AsyncCommunicator::RecvNoBarrier() {
std::vector<std::future<void>> task_futures;
task_futures.reserve(recv_varname_to_ctx_.size());
for (auto &iter : recv_varname_to_ctx_) {
auto recv_task = [this, &iter] {
auto before_task = GetCurrentUS();
auto &var_name = iter.first;
auto recv_functor = distributed::ParameterRecv<float>();
recv_functor(iter.second, *recv_scope_);
auto end_task = GetCurrentUS();
VLOG(1) << "recv var " << var_name << " use time "
<< (end_task - before_task);
};
task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task)));
}
for (auto &task : task_futures) {
task.wait();
}
}
void AsyncCommunicator::Start() {
VLOG(3) << "Communicator start";
if (!communicator_) {
VLOG(0) << "Communicator is not inited, do nothing";
} else {
VLOG(3) << "start send thread and recv thread";
waiting_ = true;
running_ = true;
BarrierTriggerReset(max_merge_var_num_);
// start send and recv thread
main_thread_.reset(
new std::thread(std::bind(&AsyncCommunicator::MainThread, this)));
}
}
void AsyncCommunicator::Stop() {
VLOG(3) << "Communicator stop";
running_ = false;
if (!communicator_) {
VLOG(0) << "Communicator is not inited, do nothing";
} else {
if (main_thread_) {
VLOG(3) << "stop send thread";
main_thread_->join();
main_thread_.reset(nullptr);
}
}
VLOG(3) << "Communicator stop done";
}
void AsyncCommunicator::Send(const std::vector<std::string> &var_names,
const std::vector<std::string> &var_tables,
const framework::Scope &scope) {
waiting_ = false;
PADDLE_ENFORCE_EQ(
var_tables.size(), 1,
platform::errors::InvalidArgument("var_tables.size() == 1 is permitted"));
auto table_name = var_tables[0];
if (table_name == STEP_COUNTER && !need_global_step_) return;
auto before_send_op = GetCurrentUS();
auto &queue = send_varname_to_queue_.at(table_name);
if (table_name == STEP_COUNTER) {
auto tmp_var = std::make_shared<Variable>();
auto *tensor = tmp_var->GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim({1}));
auto *out_d = tensor->mutable_data<int64_t>(platform::CPUPlace());
out_d[0] = 1;
queue->Push(tmp_var);
} else {
PADDLE_ENFORCE_GE(var_names.size(), 1,
platform::errors::InvalidArgument(
"var_names.size() >= 1 is permitted"));
auto *var = scope.FindVar(var_names[0]);
PADDLE_ENFORCE_EQ(
var->IsInitialized(), true,
platform::errors::InvalidArgument("grad var should be inited"));
auto tmp_var = std::make_shared<Variable>();
if (var->IsType<framework::SelectedRows>()) {
framework::CopyVariable(*var, tmp_var.get());
queue->Push(tmp_var);
} else if (var->IsType<framework::LoDTensor>()) {
// push var into send queue by var_name
auto var_name = var_names[0];
framework::CopyVariable(*var, tmp_var.get());
queue->Push(tmp_var);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"unknown var type to copy, only support LoDTensor/SelectedRows"));
}
}
auto after_send_op = GetCurrentUS();
VLOG(3) << "send to " << table_name << " with queue size " << queue->Size()
<< ", use time " << (after_send_op - before_send_op);
}
void HalfAsyncCommunicator::Clean() {
for (auto &iter : send_varname_to_queue_) {
auto &var_name = iter.first;
auto &var_queue = iter.second;
while (var_queue->Size() > 0) {
var_queue->Pop();
}
VLOG(3) << "clean var: " << var_name << " done";
}
}
int HalfAsyncCommunicator::BatchesCounter() {
while (running_) {
if (barrier_counter_.load() >= barrier_trigger_.load() &&
barrier_trigger_.load() != 0) {
break;
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}
return barrier_counter_.load();
}
void HalfAsyncCommunicator::Barrier() {
barrier_counter_++;
if (!running_) {
VLOG(3) << "Communicator is not running, release barrier";
return;
}
{
std::unique_lock<std::mutex> lk(barrier_mutex_);
barrier_cond_.wait(lk, [this] { return (barrier_counter_ == 0); });
}
}
void HalfAsyncCommunicator::BarrierTriggerDecrement() {
barrier_trigger_--;
VLOG(3) << "BarrierTriggerDecrement decrement barrier trigger to "
<< barrier_trigger_.load();
}
void HalfAsyncCommunicator::BarrierTriggerReset(int initial_val) {
barrier_trigger_.store(initial_val);
VLOG(3) << "BarrierTriggerReset reset barrier trigger to "
<< barrier_trigger_.load();
}
void HalfAsyncCommunicator::BarrierWeakUp() {
barrier_counter_.store(0);
barrier_cond_.notify_all();
}
void SyncCommunicator::BarrierSend() {
if (!running_) return;
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id_);
std::vector<distributed::VarHandlePtr> rets;
for (auto &ep : pserver_endpoints_) {
rets.push_back(rpc_client->AsyncSendBatchBarrier(ep));
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::External(
"internal error in RPCClient"));
}
VLOG(4) << "BarrierSend with SyncCommunicator";
}
void SyncCommunicator::BarrierRecv() {
if (!running_) return;
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id_);
std::vector<distributed::VarHandlePtr> rets;
for (auto &ep : pserver_endpoints_) {
rets.push_back(rpc_client->AsyncSendFetchBarrier(ep));
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::External(
"internal error in RPCClient"));
}
VLOG(4) << "BarrierRecv with SyncCommunicator";
}
void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RpcCtxMap &recv_varname_to_ctx,
Scope *recv_scope) {
send_varname_to_ctx_ = std::move(send_varname_to_ctx);
recv_varname_to_ctx_ = std::move(recv_varname_to_ctx);
recv_scope_ = std::move(recv_scope);
PADDLE_ENFORCE_GT(
send_varname_to_ctx.size(), 0,
platform::errors::InvalidArgument("send var contexts can not be zero"));
send_scope_.reset(new Scope());
for (auto &iter : send_varname_to_ctx_) {
auto &varname = iter.first;
if (varname == STEP_COUNTER) {
send_varname_to_queue_[varname] =
std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(
send_queue_size_);
} else {
auto &send_ctx = iter.second;
send_var_nums_ += send_ctx.splited_varnames.size();
if (!send_ctx.is_sparse) {
continue;
}
int pserver_num = static_cast<int>(send_ctx.epmap.size());
for (int ep_idx = 0; ep_idx < pserver_num; ep_idx++) {
sparse_id_queues_.insert(
std::pair<std::string, std::shared_ptr<BlockingQueue<
std::shared_ptr<std::vector<int64_t>>>>>(
send_ctx.splited_varnames[ep_idx],
std::make_shared<
BlockingQueue<std::shared_ptr<std::vector<int64_t>>>>(
send_queue_size_)));
}
}
}
send_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
if (recv_varname_to_ctx.size() == 0) {
VLOG(0) << "nothing need to be received, will not start recv_thread";
} else {
recv_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
}
delta_scope_.reset(new Scope());
old_scope_.reset(new Scope());
pserver_scope_.reset(new Scope());
InitParams();
}
void GeoCommunicator::Send(const std::vector<std::string> &var_names,
const std::vector<std::string> &var_tables,
const framework::Scope &scope) {
waiting_ = false;
PADDLE_ENFORCE_EQ(
var_tables.size(), 1,
platform::errors::InvalidArgument("var_tables.size() == 1 is permitted"));
auto table_name = var_tables[0];
if (table_name == STEP_COUNTER) return;
auto before_send = GetCurrentUS();
size_t splited_var_nums =
send_varname_to_ctx_[table_name].splited_varnames.size();
std::unordered_map<std::string, std::unordered_set<int64_t>> ids_table;
for (size_t j = 0; j < splited_var_nums; j++) {
ids_table.insert(std::pair<std::string, std::unordered_set<int64_t>>(
send_varname_to_ctx_[table_name].splited_varnames[j],
std::unordered_set<int64_t>()));
}
auto *var = scope.FindVar(var_names[0]);
auto &rows = var->Get<framework::SelectedRows>().rows();
// insert ids which has not been record
for (size_t j = 0; j < rows.size(); j++) {
auto ep_idx = rows[j] % splited_var_nums;
ids_table.at(send_varname_to_ctx_[table_name].splited_varnames[ep_idx])
.insert(rows[j]);
}
auto before_push = GetCurrentUS();
for (auto &iter : ids_table) {
auto &key = iter.first;
auto &sparse_ids_set = iter.second;
auto sparse_ids_vec = std::make_shared<std::vector<int64_t>>();
sparse_ids_vec->assign(sparse_ids_set.begin(), sparse_ids_set.end());
sparse_id_queues_.at(key)->Push(sparse_ids_vec);
VLOG(3) << "push " << sparse_ids_vec->size() << " ids to " << key
<< "'s queue";
}
auto after_send = GetCurrentUS();
VLOG(3) << "run send " << table_name << " op finish. using "
<< (before_push - before_send) << "; " << (after_send - before_push);
}
void GeoCommunicator::MainThread() {
VLOG(3) << "MainThread start and wait";
while (waiting_ && running_) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
VLOG(3) << "wait for running";
}
while (running_) {
std::vector<std::future<void>> tasks;
tasks.reserve(send_var_nums_);
for (auto &iter : send_varname_to_ctx_) {
auto &var_name = iter.first;
auto &send_ctx = iter.second;
int pserver_num = static_cast<int>(send_ctx.epmap.size());
if (send_ctx.is_sparse) {
for (int ep_idx = 0; ep_idx < pserver_num; ep_idx++) {
auto send_recv_task = [this, ep_idx, &var_name] {
auto before_send_sparse = GetCurrentUS();
if (var_name == STEP_COUNTER) {
return;
}
auto send_varname =
send_varname_to_ctx_.at(var_name).splited_varnames[ep_idx];
auto sparse_ids = MergeSparseIds(send_varname);
if (sparse_ids.size() == 0) {
return;
}
SendSparse(var_name, ep_idx, sparse_ids);
auto after_send_sparse = GetCurrentUS();
RecvSparse(var_name, ep_idx);
auto after_recv_sparse = GetCurrentUS();
VLOG(3)
<< "send recv "
<< send_varname_to_ctx_.at(var_name).splited_varnames[ep_idx]
<< " finish, using " << (after_send_sparse - before_send_sparse)
<< " and " << (after_recv_sparse - after_send_sparse)
<< "; total = " << (after_recv_sparse - before_send_sparse);
};
tasks.emplace_back(
send_threadpool_->enqueue(std::move(send_recv_task)));
}
} else {
auto send_recv_task = [this, &var_name, &send_ctx] {
if (var_name == STEP_COUNTER) {
return;
}
SendDense(var_name);
RecvDense(var_name);
};
tasks.emplace_back(
send_threadpool_->enqueue(std::move(send_recv_task)));
}
}
for (auto &task : tasks) {
task.wait();
}
}
}
std::vector<int64_t> GeoCommunicator::MergeSparseIds(
const std::string &send_varname) {
size_t merge_num = 0, wait_times = 0;
std::unordered_set<int64_t> sparse_ids;
while (merge_num < static_cast<size_t>(max_merge_var_num_)) {
VLOG(3) << "Merge Number of " << send_varname << " = " << merge_num;
if (sparse_id_queues_.at(send_varname)->Size() > 0) {
wait_times = 0;
std::shared_ptr<std::vector<int64_t>> pop_ids =
sparse_id_queues_.at(send_varname)->Pop();
for (size_t j = 0; j < pop_ids->size(); j++) {
sparse_ids.insert(pop_ids->at(j));
}
merge_num += 1;
VLOG(3) << "sparse_id_queues_(" << send_varname << ") pushed";
} else if (sparse_id_queues_.at(send_varname)->Size() == 0) {
VLOG(3) << "wait_times -> " << wait_times;
if (wait_times >= static_cast<size_t>(send_wait_times_)) {
break;
}
std::this_thread::sleep_for(std::chrono::milliseconds(10));
wait_times++;
continue;
}
}
std::vector<int64_t> res;
res.assign(sparse_ids.begin(), sparse_ids.end());
return res;
}
void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx,
const std::vector<int64_t> &sparse_ids) {
auto &rpc_ctx = send_varname_to_ctx_.at(varname);
auto send_varname = rpc_ctx.splited_varnames[ep_idx];
auto trainer_id = rpc_ctx.trainer_id;
auto endpoint = rpc_ctx.epmap[ep_idx];
auto pserver_num = rpc_ctx.epmap.size();
auto *var_latest = recv_scope_->FindVar(varname);
PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true,
platform::errors::Unavailable(
"%s is not initialized, please check", varname));
auto &t_latest = var_latest->Get<framework::LoDTensor>();
auto dims1 = t_latest.dims()[1];
auto cpu_ctx = paddle::platform::CPUDeviceContext();
auto *var_delta = delta_scope_->Var(send_varname);
auto *t_delta = var_delta->GetMutable<framework::SelectedRows>();
auto *t_value = t_delta->mutable_value();
t_value->mutable_data<float>(
framework::make_ddim({static_cast<int64_t>(sparse_ids.size()), dims1}),
cpu_ctx.GetPlace());
std::vector<std::vector<std::vector<float> *>> values;
auto *ins = distributed::LargeScaleKV::GetInstance();
ins->Get(varname)->Get(sparse_ids, {"Param"}, &values);
auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);
float coefficient = 1.0 / static_cast<float>(trainers_);
for (auto j = 0; j < static_cast<int>(sparse_ids.size()); ++j) {
blas.VSUB(dims1, t_latest.data<float>() + sparse_ids[j] * dims1,
values[j][0]->data(), t_value->data<float>() + j * dims1);
blas.SCAL(dims1, coefficient, t_value->data<float>() + j * dims1);
blas.VADD(dims1, values[j][0]->data(), t_value->data<float>() + j * dims1,
values[j][0]->data());
}
std::vector<int64_t> send_rows;
send_rows.reserve(sparse_ids.size());
for (auto idx : sparse_ids) {
send_rows.push_back(idx / pserver_num);
}
t_delta->set_height(rpc_ctx.height_sections[ep_idx]);
t_delta->set_rows(send_rows);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &cpu_ctx_send = *pool.Get(platform::CPUPlace());
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
auto ret = rpc_client->AsyncSendVar(endpoint, cpu_ctx_send,
*delta_scope_.get(), send_varname);
ret->Wait();
}
void GeoCommunicator::SendDense(const std::string &varname) {
auto *var_latest = recv_scope_->FindVar(varname);
auto *var_timestamp = old_scope_->FindVar(varname);
PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true,
platform::errors::Unavailable(
"%s is not initialized, please check", varname));
PADDLE_ENFORCE_EQ(var_timestamp->IsInitialized(), true,
platform::errors::Unavailable(
"%s is not initialized, please check", varname));
auto &t_latest = var_latest->Get<framework::LoDTensor>();
auto t_timestamp = var_timestamp->GetMutable<framework::LoDTensor>();
auto cpu_ctx = paddle::platform::CPUDeviceContext();
auto *var_delta = delta_scope_->Var(varname);
auto *t_delta = var_delta->GetMutable<framework::LoDTensor>();
t_delta->mutable_data<float>(t_latest.dims(), cpu_ctx.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);
blas.VSUB(t_latest.numel(), t_latest.data<float>(),
t_timestamp->data<float>(), t_delta->data<float>());
float coefficient = 1.0 / static_cast<float>(trainers_);
blas.SCAL(t_latest.numel(), coefficient, t_delta->data<float>());
blas.VADD(t_latest.numel(), t_timestamp->data<float>(),
t_delta->data<float>(), t_timestamp->data<float>());
auto &ctx = send_varname_to_ctx_.at(varname);
auto send = distributed::ParameterSend<float>();
send(ctx, *delta_scope_, true, 1);
}
void GeoCommunicator::RecvByCommunicator() { return; }
void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) {
auto train_id = recv_varname_to_ctx_.at(varname).trainer_id;
auto endpoint = recv_varname_to_ctx_.at(varname).epmap[ep_idx];
auto splited_var_name =
recv_varname_to_ctx_.at(varname).splited_varnames[ep_idx];
auto pserver_num = recv_varname_to_ctx_.at(varname).epmap.size();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &cpu_ctx_recv = *pool.Get(platform::CPUPlace());
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(train_id);
auto *var_psrever = pserver_scope_->Var(splited_var_name);
auto handle = rpc_client->AsyncGetVar(endpoint, cpu_ctx_recv,
*pserver_scope_.get(), splited_var_name,
splited_var_name, splited_var_name);
handle->Wait();
auto *var_latest = recv_scope_->FindVar(varname);
PADDLE_ENFORCE_EQ(
var_psrever->IsInitialized(), true,
platform::errors::Unavailable(
"%s in pserver scope is not initialized, please check", varname));
std::vector<int64_t> ids;
ids.assign(var_psrever->Get<framework::SelectedRows>().rows().begin(),
var_psrever->Get<framework::SelectedRows>().rows().end());
for (size_t j = 0; j < ids.size(); j++) {
ids[j] = ids[j] * pserver_num + ep_idx;
}
VLOG(3) << "RecvSparse receive var: " << splited_var_name
<< " ids Size: " << ids.size();
auto t_psrever = var_psrever->Get<framework::SelectedRows>().value();
std::vector<std::vector<std::vector<float> *>> old_values;
auto *ins = distributed::LargeScaleKV::GetInstance();
ins->Get(varname)->Get(ids, {"Param"}, &old_values);
auto *t_latest = var_latest->GetMutable<framework::LoDTensor>();
auto dims1 = t_latest->dims()[1];
auto numel = ids.size() * dims1;
std::vector<float> v_delta;
v_delta.resize(numel);
auto cpu_ctx = paddle::platform::CPUDeviceContext();
auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);
for (auto j = 0; j < static_cast<int>(ids.size()); ++j) {
blas.VSUB(dims1, t_psrever.data<float>() + j * dims1,
old_values[j][0]->data(), v_delta.data() + j * dims1);
blas.VADD(dims1, t_latest->data<float>() + ids[j] * dims1,
v_delta.data() + j * dims1,
t_latest->data<float>() + ids[j] * dims1);
blas.VCOPY(dims1, t_psrever.data<float>() + j * dims1,
old_values[j][0]->data());
}
}
void GeoCommunicator::RecvDense(const std::string &varname) {
auto *var_latest = recv_scope_->FindVar(varname);
auto *var_timestamp = old_scope_->FindVar(varname);
auto *var_psrever = pserver_scope_->Var(varname);
auto &ctx = recv_varname_to_ctx_.at(varname);
auto recv = distributed::ParameterRecv<float>();
recv(ctx, *pserver_scope_);
PADDLE_ENFORCE_EQ(
var_psrever->IsInitialized(), true,
platform::errors::Unavailable(
"%s in pserver scope is not initialized, please check", varname));
auto t_psrever = var_psrever->Get<framework::LoDTensor>();
auto t_latest = var_latest->GetMutable<framework::LoDTensor>();
auto t_timestamp = var_timestamp->GetMutable<framework::LoDTensor>();
auto cpu_ctx = paddle::platform::CPUDeviceContext();
auto *var_delta = delta_scope_->Var(varname);
auto *t_delta = var_delta->GetMutable<framework::LoDTensor>();
t_delta->mutable_data<float>(t_latest->dims(), cpu_ctx.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);
blas.VSUB(t_latest->numel(), t_psrever.data<float>(),
t_timestamp->data<float>(), t_delta->data<float>());
blas.VADD(t_latest->numel(), t_latest->data<float>(), t_delta->data<float>(),
t_latest->data<float>());
blas.VCOPY(t_latest->numel(), t_psrever.data<float>(),
t_timestamp->data<float>());
}
void GeoCommunicator::InitParams() {
std::vector<std::future<void>> tasks;
tasks.reserve(recv_varname_to_ctx_.size());
for (auto &iter : recv_varname_to_ctx_) {
auto &var_name = iter.first;
auto &recv_ctx = iter.second;
auto recv_task = [this, &var_name, &recv_ctx] {
if (!recv_ctx.is_sparse) {
InitDense(var_name);
}
};
tasks.emplace_back(send_threadpool_->enqueue(std::move(recv_task)));
}
for (auto &task : tasks) {
task.wait();
}
InitSparse();
}
void GeoCommunicator::InitDense(const std::string varname) {
auto &ctx = recv_varname_to_ctx_.at(varname);
auto recv = distributed::ParameterRecv<float>();
recv(ctx, *recv_scope_);
auto *global_var = recv_scope_->FindVar(varname);
global_var->GetMutable<framework::LoDTensor>();
auto *old_var = old_scope_->Var(varname);
old_var->GetMutable<framework::LoDTensor>();
framework::CopyVariable(*global_var, old_var);
VLOG(1) << "init dense variable " << varname << " done";
}
void GeoCommunicator::InitSparse() {
auto sparse_metas = string::split_string<std::string>(sparse_attrs_, "#");
std::vector<distributed::SparseMeta> metas;
std::vector<int64_t> dicts;
for (auto &sparse_meta : sparse_metas) {
auto attrs = string::split_string<std::string>(sparse_meta, ":");
auto meta = distributed::SparseMeta();
meta.name = attrs[0];
meta.value_names = {"Param"};
auto dic = string::split_string<std::string>(attrs[1], ",");
dicts.push_back(std::stoi(dic[0]));
meta.value_dims = {std::stoi(dic[1])};
meta.mode = distributed::Mode::training;
meta.grad_name = "none";
meta.cached_varnames = {};
meta.initializer_attrs = string::split_string<std::string>(attrs[2]);
meta.entry = "none";
VLOG(3) << "add sparse meta: " << meta.ToString();
metas.push_back(meta);
}
LargeScaleKV::Init(metas);
for (auto &meta : metas) {
auto &ctx = recv_varname_to_ctx_.at(meta.name);
auto recv = distributed::ParameterRecv<float>();
auto *global_var = recv_scope_->FindVar(meta.name);
auto global_value = global_var->Get<framework::LoDTensor>();
auto rows = global_value.dims()[0];
auto dim1 = global_value.dims()[1];
recv(ctx, *recv_scope_);
VLOG(1) << "recv " << meta.name << " with global scope for init";
auto n_rows = global_var->Get<framework::LoDTensor>().dims()[0];
PADDLE_ENFORCE_EQ(
rows, n_rows,
platform::errors::InvalidArgument(
"global var: %s origin dim must equal recved rows", meta.name));
std::vector<int64_t> ids(rows);
std::iota(ids.begin(), ids.end(), 0);
auto *ins = distributed::LargeScaleKV::GetInstance();
std::vector<std::vector<std::vector<float> *>> values;
ins->Get(meta.name)->Init(ids);
ins->Get(meta.name)->Get(ids, {"Param"}, &values);
auto blas = math::GetBlas<platform::CPUDeviceContext, float>(
paddle::platform::CPUDeviceContext());
for (auto &id : ids) {
blas.VCOPY(dim1, global_value.data<float>() + id * dim1,
values[id][0]->data());
}
}
VLOG(3) << "init sparse variable done";
}
} // namespace distributed
} // namespace operators
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <ThreadPool.h>
#include <stdint.h>
#include <atomic>
#include <deque>
#include <map>
#include <memory>
#include <numeric>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/distributed/communicator_common.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/split.h"
DECLARE_bool(communicator_is_sgd_optimizer);
namespace paddle {
namespace operators {
namespace distributed {
using Scope = framework::Scope;
using Variable = framework::Variable;
template <typename T>
class BlockingQueue {
public:
explicit BlockingQueue(size_t capacity) : capacity_(capacity) {
PADDLE_ENFORCE_GT(capacity_, 0,
platform::errors::InvalidArgument(
"The capacity must be greater than 0."));
}
bool Push(const T &elem) {
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&] { return queue_.size() < capacity_; });
PADDLE_ENFORCE_LT(
queue_.size(), capacity_,
platform::errors::OutOfRange("The queue size: %s out of capacity:%s",
queue_.size(), capacity_));
queue_.push_back(elem);
}
cv_.notify_one();
return true;
}
bool Push(T &&elem) {
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&] { return queue_.size() < capacity_; });
PADDLE_ENFORCE_LT(
queue_.size(), capacity_,
platform::errors::OutOfRange("The queue size: %s out of capacity:%s",
queue_.size(), capacity_));
queue_.emplace_back(std::move(elem));
}
cv_.notify_one();
return true;
}
T Pop() {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [=] { return !queue_.empty(); });
T rc(std::move(queue_.front()));
queue_.pop_front();
cv_.notify_one();
return rc;
}
size_t Cap() const {
std::lock_guard<std::mutex> lock(mutex_);
return capacity_;
}
size_t Size() const {
std::lock_guard<std::mutex> lock(mutex_);
return queue_.size();
}
private:
const size_t capacity_;
std::deque<T> queue_;
mutable std::mutex mutex_;
std::condition_variable cv_;
};
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T>
inline void MergeVars(const std::string &var_name,
const std::vector<std::shared_ptr<Variable>> &vars,
Scope *scope, bool merge_add = true) {
PADDLE_ENFORCE_NE(vars.empty(), true, platform::errors::InvalidArgument(
"vector vars are empty."));
auto cpu_place = platform::CPUPlace();
auto &var0 = vars[0];
auto *out_var = scope->Var(var_name);
if (var0->IsType<framework::LoDTensor>()) {
auto dims = var0->Get<framework::LoDTensor>().dims();
VLOG(3) << "merge " << var_name << " LoDTensor dims " << dims
<< "; merge add: " << merge_add;
// init output tensor
auto *out_t = out_var->GetMutable<framework::LoDTensor>();
out_t->mutable_data<T>(dims, cpu_place);
// check the input dims
for (auto &var : vars) {
auto &var_t = var->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(
var_t.dims(), dims,
platform::errors::InvalidArgument("vars should have the same dims"));
}
// set output tensor to 0.
auto cpu_ctx = paddle::platform::CPUDeviceContext();
math::SetConstant<paddle::platform::CPUDeviceContext, T> constant_functor;
constant_functor(cpu_ctx, out_t, static_cast<T>(0));
// sum all vars to out
auto result = EigenVector<T>::Flatten(*out_t);
for (auto &var : vars) {
auto &in_t = var->Get<framework::LoDTensor>();
auto in = EigenVector<T>::Flatten(in_t);
result.device(*cpu_ctx.eigen_device()) = result + in;
}
if (!merge_add) {
result.device(*cpu_ctx.eigen_device()) =
result / static_cast<T>(vars.size());
}
} else if (var0->IsType<framework::SelectedRows>()) {
auto &slr0 = var0->Get<framework::SelectedRows>();
auto *out_slr = out_var->GetMutable<framework::SelectedRows>();
out_slr->mutable_rows()->clear();
out_slr->mutable_value()->mutable_data<T>({{}}, cpu_place);
std::vector<const paddle::framework::SelectedRows *> inputs;
inputs.reserve(vars.size());
for (auto &var : vars) {
inputs.push_back(&var->Get<framework::SelectedRows>());
}
auto dev_ctx = paddle::platform::CPUDeviceContext();
if (merge_add) {
math::scatter::MergeAdd<paddle::platform::CPUDeviceContext, T> merge_add;
merge_add(dev_ctx, inputs, out_slr);
} else {
math::scatter::MergeAverage<paddle::platform::CPUDeviceContext, T>
merge_average;
merge_average(dev_ctx, inputs, out_slr);
}
VLOG(3) << "merge " << var_name << " SelectedRows height: " << slr0.height()
<< " dims: " << slr0.value().dims() << "; merge add: " << merge_add;
} else {
PADDLE_THROW(platform::errors::InvalidArgument("unsupported var type: %s!",
var0->Type()));
}
}
using RpcCtxMap = std::unordered_map<std::string, CommContext>;
using SparseValue = std::unordered_map<int64_t, std::vector<float>>;
class Communicator {
public:
Communicator();
explicit Communicator(const std::map<std::string, std::string> &envs_) {
for (auto &iter : envs_) {
envs[iter.first] = iter.second;
}
}
virtual ~Communicator() {}
virtual void Start() = 0;
virtual void Stop() = 0;
virtual bool IsRunning() { return running_; }
virtual void Clean() {}
virtual void Send(const std::vector<std::string> &var_names,
const std::vector<std::string> &var_tables,
const framework::Scope &scope) = 0;
virtual void RecvNoBarrier() {}
virtual void Barrier() {}
virtual void BarrierTriggerDecrement() {}
virtual void BarrierTriggerReset(int init_counter) {}
virtual void InitEnvs() = 0;
virtual void InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RpcCtxMap &recv_varname_to_ctx,
Scope *recv_scope) {}
static Communicator *GetInstance() { return communicator_.get(); }
static std::shared_ptr<Communicator> GetInstantcePtr() {
return communicator_;
}
template <typename T>
static Communicator *InitInstance(
const RpcCtxMap &send_ctx, const RpcCtxMap &recv_ctx, Scope *recv_scope,
const std::map<std::string, std::string> &envs) {
std::call_once(init_flag_, &Communicator::InitWithRpcCtx<T>, send_ctx,
recv_ctx, recv_scope, std::ref(envs));
return communicator_.get();
}
// Init is called by InitInstance.
template <typename T>
static void InitWithRpcCtx(const RpcCtxMap &send_ctx,
const RpcCtxMap &recv_ctx, Scope *recv_scope,
const std::map<std::string, std::string> &envs) {
if (communicator_.get() == nullptr) {
communicator_.reset(new T(std::ref(envs)));
communicator_->InitEnvs();
communicator_->InitImpl(send_ctx, recv_ctx, recv_scope);
}
}
protected:
bool running_ = false;
bool waiting_ = true;
static std::shared_ptr<Communicator> communicator_;
static std::once_flag init_flag_;
std::unordered_map<std::string, std::string> envs;
};
class AsyncCommunicator : public Communicator {
public:
AsyncCommunicator() : Communicator() {}
explicit AsyncCommunicator(const std::map<std::string, std::string> &envs)
: Communicator(envs) {}
~AsyncCommunicator();
void InitEnvs() {
min_send_grad_num_before_recv_ =
std::stoi(envs.at("communicator_min_send_grad_num_before_recv"));
thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size"));
need_global_step_ =
static_cast<bool>(std::stoi(envs.at("need_global_step")));
VLOG(0) << "AsyncCommunicator Initialized";
}
void Start() override;
void Stop() override;
void InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RpcCtxMap &recv_varname_to_ctx,
Scope *recv_scope) override;
void InitParams();
virtual void MainThread();
void Send(const std::vector<std::string> &var_names,
const std::vector<std::string> &var_tables,
const framework::Scope &scope) override;
virtual void SendByCommunicator();
virtual void SendGlobalStep(int batches);
virtual void RecvByCommunicator();
virtual void RecvNoBarrier();
virtual void BarrierSend() {}
virtual void BarrierRecv() {}
virtual void BarrierWeakUp() {}
protected:
int min_send_grad_num_before_recv_;
int thread_pool_size_;
int max_merge_var_num_;
int send_wait_times_;
int send_queue_size_;
int trainer_id_ = 0;
bool need_global_step_ = false;
std::unordered_map<std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>>
send_varname_to_queue_;
RpcCtxMap send_varname_to_ctx_;
RpcCtxMap recv_varname_to_ctx_;
std::unique_ptr<std::thread> main_thread_{nullptr};
Scope *recv_scope_; // should be global scope
std::unique_ptr<Scope> send_scope_; // an independent scope
std::unique_ptr<::ThreadPool> send_threadpool_{nullptr};
std::unique_ptr<::ThreadPool> recv_threadpool_{nullptr};
std::atomic_uint grad_num_{0}; // the num of gradient sent since last recv
};
class HalfAsyncCommunicator : public AsyncCommunicator {
public:
HalfAsyncCommunicator() {}
explicit HalfAsyncCommunicator(const std::map<std::string, std::string> &envs)
: AsyncCommunicator(envs) {}
void InitEnvs() {
min_send_grad_num_before_recv_ = 0;
max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size"));
need_global_step_ =
static_cast<bool>(std::stoi(envs.at("need_global_step")));
VLOG(0) << "HalfAsyncCommunicator Initialized";
}
void MainThread() override;
void SendByCommunicator() override;
void Clean() override;
void Barrier() override;
void BarrierTriggerDecrement() override;
void BarrierTriggerReset(int initial_val) override;
int BatchesCounter();
void BarrierWeakUp();
protected:
// mutex for Wait for barrier
std::mutex barrier_mutex_;
std::condition_variable barrier_cond_;
std::atomic<int64_t> barrier_trigger_{0};
std::atomic<int64_t> barrier_counter_{0};
};
class SyncCommunicator : public HalfAsyncCommunicator {
public:
SyncCommunicator() : HalfAsyncCommunicator() {}
explicit SyncCommunicator(const std::map<std::string, std::string> &envs)
: HalfAsyncCommunicator(envs) {}
void InitEnvs() {
min_send_grad_num_before_recv_ = 0;
max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size"));
need_global_step_ =
static_cast<bool>(std::stoi(envs.at("need_global_step")));
trainer_id_ = std::stoi(envs.at("trainer_id"));
auto pserver_strings = envs.at("pserver_endpoints");
pserver_endpoints_ = paddle::string::Split(pserver_strings, ',');
VLOG(0) << "SyncCommunicator Initialized";
}
void BarrierSend();
void BarrierRecv();
private:
std::vector<std::string> pserver_endpoints_{};
};
class GeoCommunicator : public AsyncCommunicator {
public:
GeoCommunicator() : AsyncCommunicator() {}
explicit GeoCommunicator(const std::map<std::string, std::string> &envs)
: AsyncCommunicator(envs) {}
void InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RpcCtxMap &recv_varname_to_ctx,
Scope *recv_scope) override;
void MainThread() override;
void InitEnvs() {
min_send_grad_num_before_recv_ = 0;
max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
send_queue_size_ = max_merge_var_num_;
trainers_ = std::stoi(envs.at("trainers"));
sparse_attrs_ = envs.at("sparse_attrs");
VLOG(0) << "GeoCommunicator Initialized";
}
void Send(const std::vector<std::string> &var_names,
const std::vector<std::string> &var_tables,
const framework::Scope &scope) override;
void SendByCommunicator() { return; }
std::vector<int64_t> MergeSparseIds(const std::string &send_varname);
void SendSparse(const std::string &varname, int ep_idx,
const std::vector<int64_t> &sparse_ids);
void SendDense(const std::string &varname);
void SendGlobalStep(int batches) override {}
void RecvByCommunicator() override;
void RecvSparse(const std::string &varname, int ep_idx);
void RecvDense(const std::string &varname);
void InitParams();
void InitSparse();
void InitDense(const std::string varname);
private:
int trainers_;
std::string sparse_attrs_;
// parameter for delta calc and send
std::shared_ptr<Scope> delta_scope_;
// parameter for storage the pserver param after last recv
std::shared_ptr<Scope> old_scope_;
// parameter on pserver
std::shared_ptr<Scope> pserver_scope_;
int send_var_nums_ = 0;
std::unordered_map<std::string, std::shared_ptr<SparseValue>> old_sparses_;
std::unordered_map<
std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<std::vector<int64_t>>>>>
sparse_id_queues_;
};
} // namespace distributed
} // namespace operators
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <string>
#include <vector>
namespace paddle {
namespace operators {
namespace distributed {
struct CommContext {
CommContext() = default;
CommContext(const std::string &name, const std::vector<std::string> &names,
const std::vector<std::string> &emap,
const std::vector<int64_t> &sections,
const std::vector<std::string> &origin_names, int id,
bool merge_add_ = true, bool is_sparse_ = true,
bool is_distributed_ = false)
: var_name(name),
splited_varnames(names),
epmap(emap),
height_sections(sections),
origin_varnames(origin_names),
trainer_id(id),
merge_add(merge_add_),
is_sparse(is_sparse_),
is_distributed(is_distributed_) {}
CommContext(const CommContext &ctx) {
var_name = ctx.var_name;
splited_varnames = ctx.splited_varnames;
epmap = ctx.epmap;
height_sections = ctx.height_sections;
trainer_id = ctx.trainer_id;
merge_add = ctx.merge_add;
is_sparse = ctx.is_sparse;
origin_varnames = ctx.origin_varnames;
is_distributed = ctx.is_distributed;
}
std::string print() const {
std::stringstream ss;
ss << "varname: " << var_name << " trainer_id: " << trainer_id << " ";
for (size_t i = 0; i < splited_varnames.size(); i++) {
ss << "slice varname: " << splited_varnames[i] << " ep: " << epmap[i]
<< " section: " << height_sections[i] << " ";
}
ss << "origin varnames: ";
for (size_t i = 0; i < origin_varnames.size(); i++) {
ss << origin_varnames[i] << " ";
}
ss << " aggregation->add: " << merge_add << " ";
ss << " is_sparse: " << is_sparse << "\n";
ss << " is_distributed: " << is_distributed << "\n";
return ss.str();
}
std::string var_name;
std::vector<std::string> splited_varnames;
std::vector<std::string> epmap;
std::vector<int64_t> height_sections;
std::vector<std::string> origin_varnames;
int trainer_id;
bool merge_add;
bool is_sparse;
bool is_distributed;
};
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/operators/distributed/communicator.h"
namespace paddle {
namespace operators {
namespace distributed {
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
TEST(communicator, merge_lod_tensors) {
auto cpu_place = platform::CPUPlace();
auto dims = framework::make_ddim({2, 3});
std::vector<std::shared_ptr<framework::Variable>> in_vars;
float out_value = 0;
for (auto i = 0; i < 10; ++i) {
auto var = std::make_shared<Variable>();
in_vars.emplace_back(var);
auto *tensor = var->GetMutable<LoDTensor>();
auto *data = tensor->mutable_data<float>(dims, cpu_place);
for (auto j = 0; j < tensor->numel(); ++j) {
data[j] = static_cast<float>(i);
}
out_value += static_cast<float>(i);
}
const std::string out_name = "Out";
std::unique_ptr<framework::Scope> scope;
scope.reset(new framework::Scope());
scope->Var(out_name);
for (auto i = 0; i < 10; ++i) {
MergeVars<float>(out_name, in_vars, scope.get());
}
auto &out_tensor = scope->FindVar(out_name)->Get<LoDTensor>();
auto *out_data = out_tensor.data<float>();
ASSERT_EQ(out_tensor.dims(), dims);
for (auto i = 0; i < out_tensor.numel(); ++i) {
ASSERT_EQ(out_data[i], out_value);
}
}
TEST(communicator, merge_selected_rows) {
auto cpu_place = platform::CPUPlace();
int64_t width = 10;
std::vector<std::shared_ptr<framework::Variable>> in_vars;
const int64_t height = 100;
for (auto i = 0; i < 10; ++i) {
std::vector<int64_t> rows;
for (auto k = 0; k <= i; ++k) {
rows.push_back(k);
}
auto var = std::make_shared<Variable>();
in_vars.emplace_back(var);
auto *slr = var->GetMutable<SelectedRows>();
slr->set_height(height);
slr->set_rows(rows);
auto dims =
framework::make_ddim({static_cast<int64_t>(rows.size()), width});
auto *data = slr->mutable_value()->mutable_data<float>(dims, cpu_place);
for (size_t i = 0; i < rows.size(); ++i) {
for (auto j = 0; j < width; ++j) {
data[i * width + j] = static_cast<float>(rows[i]);
}
}
}
const std::string out_name = "Out";
std::unique_ptr<framework::Scope> scope;
scope.reset(new framework::Scope());
scope->Var(out_name);
for (auto i = 0; i < 10; ++i) {
MergeVars<float>(out_name, in_vars, scope.get());
}
auto &out_slr = scope->FindVar(out_name)->Get<SelectedRows>();
auto &out_t = out_slr.value();
auto *out_data = out_t.data<float>();
ASSERT_EQ(out_t.dims(), framework::make_ddim({10, width}));
std::vector<float> out_values;
out_values.reserve(10);
for (auto i = 0; i < 10; ++i) {
out_values.push_back(static_cast<float>(i * (10 - i)));
}
for (size_t i = 0; i < out_slr.rows().size(); ++i) {
ASSERT_EQ(out_slr.rows()[i], static_cast<int>(i));
for (auto j = 0; j < width; ++j) {
ASSERT_EQ(out_data[i * width + j], out_values[i]);
}
}
}
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_DISTRIBUTE
#ifdef PADDLE_WITH_GRPC
#include "paddle/fluid/operators/distributed/communicator.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_client.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_server.h"
#define RPCSERVER_T paddle::operators::distributed::AsyncGRPCServer
#define RPCCLIENT_T paddle::operators::distributed::GRPCClient
#else // PADDLE_WITH_GRPC
#include "paddle/fluid/operators/distributed/brpc/brpc_client.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_server.h"
#define RPCSERVER_T paddle::operators::distributed::AsyncBRPCServer
#define RPCCLIENT_T paddle::operators::distributed::BRPCClient
#endif // PADDLE_WITH_GRPC
#endif // PADDLE_WITH_DISTRIBUTE
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_DISTRIBUTE
#ifdef PADDLE_WITH_GRPC
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#else // PADDLE_WITH_GRPC
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#endif // PADDLE_WITH_GRPC
#endif // PADDLE_WITH_DISTRIBUTE
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// NOTE: This file was originally created by tensorflow
// (https://github.com/tensorflow/tensorflow/) we borrow this
// file and did some modifications so that we can send gRPC
// requests without too much copying of the tensor data.
#include "paddle/fluid/operators/distributed/grpc/grpc_bytebuffer_stream.h"
namespace grpc {
class ByteBuffer;
} // namespace grpc
namespace paddle {
namespace operators {
namespace distributed {
GrpcByteBufferSource::GrpcByteBufferSource() {}
bool GrpcByteBufferSource::Init(const grpc::ByteBuffer& src) {
cur_ = -1;
left_ = 0;
ptr_ = nullptr;
byte_count_ = 0;
bool ok = src.Dump(&slices_).ok();
if (!ok) {
slices_.clear();
}
return ok;
}
bool GrpcByteBufferSource::Next(const void** data, int* size) {
// Use loop instead of if in case buffer contained empty slices.
while (left_ == 0) {
// Advance to next slice.
cur_++;
if (cur_ >= slices_.size()) {
return false;
}
const ::grpc::Slice& s = slices_[cur_];
left_ = s.size();
ptr_ = reinterpret_cast<const char*>(s.begin());
}
*data = ptr_;
*size = left_;
byte_count_ += left_;
ptr_ += left_;
left_ = 0;
return true;
}
void GrpcByteBufferSource::BackUp(int count) {
ptr_ -= count;
left_ += count;
byte_count_ -= count;
}
bool GrpcByteBufferSource::Skip(int count) {
const void* data;
int size;
while (Next(&data, &size)) {
if (size >= count) {
BackUp(size - count);
return true;
}
// size < count;
count -= size;
}
// error or we have too large count;
return false;
}
google::protobuf::int64 GrpcByteBufferSource::ByteCount() const {
return byte_count_;
}
} // namespace distributed
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// NOTE: This file was originally created by tensorflow
// (https://github.com/tensorflow/tensorflow/) we borrow this
// file and did some modifications so that we can send gRPC
// requests without too much copying of the tensor data.
#pragma once
#include <vector>
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "grpc++/grpc++.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
struct grpc_byte_buffer;
namespace grpc {
// A ZeroCopyInputStream that reads from grpc_byte_buffer
class ByteBuffer;
class GrpcBufferReader final
: public ::google::protobuf::io::ZeroCopyInputStream {
typedef void (CoreCodegenInterface::*OldReaderInitAPI)(
grpc_byte_buffer_reader* reader, grpc_byte_buffer* buffer);
typedef int (CoreCodegenInterface::*NewReaderInitAPI)(
grpc_byte_buffer_reader* reader, grpc_byte_buffer* buffer);
void ReaderInit(OldReaderInitAPI ptr, grpc_byte_buffer_reader* reader,
grpc_byte_buffer* buffer) {
(g_core_codegen_interface->*ptr)(reader, buffer);
}
void ReaderInit(NewReaderInitAPI ptr, grpc_byte_buffer_reader* reader,
grpc_byte_buffer* buffer) {
int result = (g_core_codegen_interface->*ptr)(reader, buffer);
(void)result;
}
public:
explicit GrpcBufferReader(grpc_byte_buffer* buffer)
: byte_count_(0), backup_count_(0) {
ReaderInit(&CoreCodegenInterface::grpc_byte_buffer_reader_init, &reader_,
buffer);
}
~GrpcBufferReader() override {
g_core_codegen_interface->grpc_byte_buffer_reader_destroy(&reader_);
}
bool Next(const void** data, int* size) override {
if (backup_count_ > 0) {
*data = GRPC_SLICE_START_PTR(slice_) + GRPC_SLICE_LENGTH(slice_) -
backup_count_;
GPR_CODEGEN_ASSERT(backup_count_ <= INT_MAX);
*size = static_cast<int>(backup_count_);
backup_count_ = 0;
return true;
}
if (!g_core_codegen_interface->grpc_byte_buffer_reader_next(&reader_,
&slice_)) {
return false;
}
g_core_codegen_interface->grpc_slice_unref(slice_);
*data = GRPC_SLICE_START_PTR(slice_);
// On win x64, int is only 32bit
GPR_CODEGEN_ASSERT(GRPC_SLICE_LENGTH(slice_) <= INT_MAX);
byte_count_ += * size = static_cast<int>(GRPC_SLICE_LENGTH(slice_));
return true;
}
void BackUp(int count) override { backup_count_ = count; }
bool Skip(int count) override {
const void* data;
int size;
while (Next(&data, &size)) {
if (size >= count) {
BackUp(size - count);
return true;
}
// size < count;
count -= size;
}
// error or we have too large count;
return false;
}
::google::protobuf::int64 ByteCount() const override {
return byte_count_ - backup_count_;
}
private:
int64_t byte_count_;
int64_t backup_count_;
grpc_byte_buffer_reader reader_;
grpc_slice slice_;
};
}; // namespace grpc
namespace paddle {
namespace operators {
namespace distributed {
// A ZeroCopyInputStream that reads from a grpc::ByteBuffer.
class GrpcByteBufferSource
: public ::google::protobuf::io::ZeroCopyInputStream {
public:
GrpcByteBufferSource();
bool Init(const ::grpc::ByteBuffer& src); // Can be called multiple times.
bool Next(const void** data, int* size) override;
void BackUp(int count) override;
bool Skip(int count) override;
::google::protobuf::int64 ByteCount() const override;
private:
std::vector<::grpc::Slice> slices_;
size_t cur_; // Current slice index.
int left_; // Number of bytes in slices_[cur_] left to yield.
const char* ptr_; // Address of next byte in slices_[cur_] to yield.
::google::protobuf::int64 byte_count_;
};
class GrpcByteBufferSourceWrapper : public Source {
public:
explicit GrpcByteBufferSourceWrapper(GrpcByteBufferSource* source)
: source_(source) {}
::google::protobuf::io::ZeroCopyInputStream* contents() override {
return source_;
}
private:
GrpcByteBufferSource* source_;
};
class GrpcByteSource : public Source {
public:
explicit GrpcByteSource(grpc_byte_buffer* buffer) : buffer_(buffer) {}
~GrpcByteSource() override { DeleteStream(); }
typedef ::grpc::GrpcBufferReader Reader;
::google::protobuf::io::ZeroCopyInputStream* contents() override {
DeleteStream();
stream_ = new (&space_) Reader(buffer_);
return stream_;
}
private:
void DeleteStream() {
if (stream_) {
stream_->~Reader();
}
}
grpc_byte_buffer* buffer_; // Not owned
Reader* stream_ = nullptr; // Points into space_ if non-nullptr
char space_[sizeof(Reader)];
};
} // namespace distributed
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdlib.h>
#include <limits>
#include "glog/logging.h" // For VLOG
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_client.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_serde.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_int32(rpc_client_threads, 2, "");
DECLARE_bool(rpc_disable_reuse_port);
namespace paddle {
namespace operators {
namespace distributed {
void GRPCClient::InitImpl() {
// start the client process thread
// TODO(wuyi): can make this in a threadpool
client_threads_.resize(FLAGS_rpc_client_threads);
for (int i = 0; i < FLAGS_rpc_client_threads; i++) {
client_threads_[i].reset(
new std::thread(std::bind(&GRPCClient::Proceed, this)));
}
}
void GRPCClient::SendComplete() {
std::unique_lock<std::mutex> lk(completed_mutex_);
if (!completed_) {
for (auto& it : channels_) {
VLOG(3) << "send complete message to " << it.first;
this->AsyncSendComplete(it.first);
}
PADDLE_ENFORCE_EQ(this->Wait(), true, platform::errors::PreconditionNotMet(
"internal grpc service error."));
completed_ = true;
}
}
GRPCClient::~GRPCClient() {
stopped_ = true;
Wait();
cq_.Shutdown();
{
std::lock_guard<std::mutex> guard(chan_mutex_);
for (auto& it : channels_) {
it.second.reset();
}
channels_.clear();
}
for (size_t i = 0; i < client_threads_.size(); i++)
client_threads_[i]->join();
}
VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
const std::string method = kSendRPC;
int retry_times_ = 0;
while (true) {
SendProcessor* s = new SendProcessor(ch);
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::Async([var_name_val, p_scope, p_ctx, s, method, h, this] {
auto* var = p_scope->FindVar(var_name_val);
::grpc::ByteBuffer req;
SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_);
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
// stub context
s->response_call_back_ = nullptr;
platform::RecordRPCEvent record_event(method);
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req,
&cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
});
req_count_++;
if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) {
h->Wait();
if (h->should_retry) {
VLOG(3) << "rpc call failed, retry times " << retry_times_;
retry_times_++;
std::random_device rd;
std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5));
continue;
}
}
return h;
}
}
void ProcGetResponse(const VarHandle& var_h,
const ::grpc::ByteBuffer& ret_msg) {
VLOG(4) << "ProcGetResponse";
framework::Variable* outvar = nullptr;
// get response's trainer_id is not used
int trainer_id;
DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar,
&trainer_id);
}
void ProcGetRecvResponse(const VarHandle& var_h,
const ::grpc::ByteBuffer& ret_msg) {
VLOG(4) << "ProcGetRecvResponse";
framework::Variable* outvar = nullptr;
int trainer_id;
DeserializeRecvFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar,
&trainer_id);
}
template <typename T>
void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
::grpc::Slice slice(proto.ByteSizeLong());
proto.SerializeWithCachedSizesToArray(const_cast<uint8_t*>(slice.begin()));
::grpc::ByteBuffer tmp(&slice, 1);
result->Swap(&tmp);
}
VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& out_varname,
const std::string& table_name,
int64_t time_out) {
return _AsyncGetVar(ep, ctx, scope, kGetRPC, var_name, out_varname,
"/sendrecv.SendRecvService/GetVariable", table_name,
time_out);
}
VarHandlePtr GRPCClient::AsyncGetVarNoBarrier(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
const std::string& out_varname, int64_t time_out) {
std::string var_name_no_barrier =
string::Sprintf("%s%s", var_name, WITHOUT_BARRIER_MESSAGE);
return _AsyncGetVar(
ep, ctx, scope, kGetNoBarrierRPC, var_name_no_barrier, out_varname,
"/sendrecv.SendRecvService/GetVariableNoBarrier", "", time_out);
}
VarHandlePtr GRPCClient::AsyncGetMonomerVariable(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out) {
return _AsyncGetVar(ep, ctx, scope, kGetMonomerRPC, var_name, var_name,
"/sendrecv.SendRecvService/GetMonomerVariable", "",
time_out);
}
VarHandlePtr GRPCClient::_AsyncGetVar(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& method,
const std::string& var_name, const std::string& out_varname,
const std::string& rpc_path, const std::string& table_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string var_name_val = var_name;
const std::string out_varname_val = out_varname;
const std::string table_name_val = table_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
int retry_times_ = 0;
while (true) {
GetProcessor* s = new GetProcessor(ch);
VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::Async([var_name_val, out_varname_val, table_name_val, s, method,
p_ctx, h, rpc_path, this] {
// prepare input
sendrecv::VariableMessage req;
req.set_varname(var_name_val);
req.set_out_varname(out_varname_val);
req.set_trainer_id(trainer_id_);
req.set_table_name(table_name_val);
::grpc::ByteBuffer buf;
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
// stub context
s->response_call_back_ = ProcGetResponse;
platform::RecordRPCEvent record_event(method);
auto call =
s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
});
req_count_++;
if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) {
h->Wait();
if (h->should_retry) {
VLOG(3) << "rpc call failed, retry times " << retry_times_;
retry_times_++;
std::random_device rd;
std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5));
continue;
}
}
return h;
}
}
VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
const std::string& table_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string in_var_name_val = in_var_name;
const std::string out_var_name_val = out_var_name;
const std::string table_name_val = table_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
const std::string method = kPrefetchRPC;
int retry_times_ = 0;
while (true) {
GetProcessor* s = new GetProcessor(ch);
VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
s->Prepare(h, kPrefetchTimeout);
auto* var = p_scope->FindVar(in_var_name_val);
::grpc::ByteBuffer req;
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val,
0, table_name_val);
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
// stub context
s->response_call_back_ = ProcGetResponse;
platform::RecordRPCEvent record_event(method);
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
&cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
req_count_++;
if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) {
h->Wait();
if (h->should_retry) {
VLOG(3) << "rpc call failed, retry times " << retry_times_;
retry_times_++;
std::random_device rd;
std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5));
continue;
}
}
return h;
}
}
VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out) {
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
const std::string method = kBatchBarrierRPC;
VarHandlePtr h(
new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_varname(BATCH_BARRIER_MESSAGE);
platform::RecordRPCEvent record_event(method);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
return h;
}
VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out) {
const auto ch = GetChannel(ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
const std::string method = kFetchBarrierRPC;
VarHandlePtr h(
new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_varname(FETCH_BARRIER_MESSAGE);
platform::RecordRPCEvent record_event(method);
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
return h;
}
VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
const std::string& var_name,
int64_t time_out) {
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
const std::string method = kSendMonomerFetchBarrierRPC;
VarHandlePtr h(new VarHandle(ep, method, var_name, nullptr, nullptr));
s->Prepare(h, time_out);
VLOG(30) << s->GetVarHandlePtr()->String() << " begin";
sendrecv::VariableMessage req;
req.set_varname(var_name);
platform::RecordRPCEvent record_event(method);
auto rpc = s->stub_->AsyncGetMonomerBarrier(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
return h;
}
VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
int64_t time_out) {
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
const std::string method = kSendCompleteRPC;
VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_trainer_id(trainer_id_);
req.set_varname(COMPLETE_MESSAGE);
platform::RecordRPCEvent record_event(method);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
return h;
}
VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dirname,
const std::string& varname,
const int mode,
int64_t time_out) {
const auto ch = GetChannel(ep);
CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
const std::string method = kCheckPointNotifyRPC;
VarHandlePtr h(
new VarHandle(ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_varname(varname);
req.set_table_name(std::to_string(mode));
req.set_out_varname(dirname);
platform::RecordRPCEvent record_event(method);
auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
return h;
}
VarHandlePtr GRPCClient::AsyncDistributeNotify(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
const std::string method = kRequestNotify;
SendProcessor* s = new SendProcessor(ch);
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::Async([var_name_val, p_scope, p_ctx, s, method, h, this] {
auto* var = p_scope->FindVar(var_name_val);
::grpc::ByteBuffer req;
SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_);
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
// stub context
s->response_call_back_ = nullptr;
platform::RecordRPCEvent record_event(method);
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/DistributeNotify", req,
&cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
});
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
return h;
}
VarHandlePtr GRPCClient::AsyncSendAndRecv(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& send_var_name,
const std::string& recv_var_name,
const std::string& table_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string send_var_name_val = send_var_name;
const std::string recv_var_name_val = recv_var_name;
const std::string table_name_val = table_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
const std::string method = kSendAndRecvRPC;
VLOG(4) << "GRPCClient::SendAndRecv Begin ,Send_var_name: "
<< send_var_name_val << " Recv_var_name: " << recv_var_name_val;
int retry_times_ = 0;
while (true) {
SendAndRecvProcessor* s = new SendAndRecvProcessor(ch);
VarHandlePtr h(
new VarHandle(ep, method, send_var_name_val, p_ctx, p_scope));
VarHandlePtr h_recv(
new VarHandle(ep, method, recv_var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
s->RecvPrepare(h_recv);
framework::Async([send_var_name_val, recv_var_name_val, table_name_val,
p_scope, p_ctx, s, method, h, this] {
auto* send_var = p_scope->FindVar(send_var_name_val);
send_var->GetMutable<framework::LoDTensor>()->set_lod({});
::grpc::ByteBuffer buf;
VLOG(4) << "SerializeToByteBuffer: send_var_name_val: "
<< send_var_name_val
<< " recv_var_name_val: " << recv_var_name_val;
SerializeToByteBuffer(send_var_name_val, send_var, *p_ctx, &buf,
recv_var_name_val, trainer_id_, table_name_val);
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
// stub context
s->response_call_back_ = ProcGetRecvResponse;
platform::RecordRPCEvent record_event(method);
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/SendAndRecvVariable",
buf, &cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
});
req_count_++;
if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) {
h->Wait();
if (h->should_retry) {
VLOG(3) << "rpc call failed, retry times " << retry_times_;
retry_times_++;
std::random_device rd;
std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5));
continue;
}
}
return h;
}
}
bool GRPCClient::Wait() {
std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); });
return ok_;
}
inline bool ShouldRetry(const std::string& method, int error_code) {
if (method == kPrefetchRPC) {
return true;
}
if (error_code == grpc::StatusCode::DEADLINE_EXCEEDED) {
return true;
}
return false;
}
void GRPCClient::Proceed() {
void* tag = nullptr;
bool ok = false;
VLOG(3) << "GRPCClient Proceed begin";
while (!stopped_ && cq_.Next(&tag, &ok)) {
BaseProcessor* c = static_cast<BaseProcessor*>(tag);
GPR_ASSERT(ok);
PADDLE_ENFORCE_NOT_NULL(
c, platform::errors::PreconditionNotMet("Make BaseProcessor failed."));
if (c->status_.ok()) {
VLOG(3) << c->GetVarHandlePtr()->String() << " process";
c->Process();
} else if (ShouldRetry(c->GetVarHandlePtr()->method(),
c->status_.error_code())) {
VLOG(0) << c->GetVarHandlePtr()->String()
<< " meets grpc error, error_code:" << c->status_.error_code()
<< " error_message:" << c->status_.error_message()
<< " error_details:" << c->status_.error_details()
<< " should retry!";
c->GetVarHandlePtr()->should_retry = true;
c->Finish(false);
} else {
PADDLE_THROW(platform::errors::External(
"%s meets grpc error, error_code is %d, error message is %s, error "
"details is %s.",
c->GetVarHandlePtr()->String(), c->status_.error_code(),
c->status_.error_message(), c->status_.error_details()));
c->Finish(false);
}
bool notify = false;
{
std::lock_guard<std::mutex> lk(sync_mutex_);
req_count_--;
notify = (req_count_ <= 0 || !c->status_.ok());
}
delete c;
if (notify) {
sync_cond_.notify_all();
}
}
// Last log message
// Avoid using VLOG() and LOG(): in the destructor of google::LogMessage() a
// static Mutex log_mutex is used for synchronization, which might have been
// destructed at this moment.
if (FLAGS_v >= 3) {
std::string msg("GRPCClient Proceed end");
fwrite(msg.c_str(), msg.length(), 1, stderr);
}
}
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
std::lock_guard<std::mutex> guard(chan_mutex_);
auto it = channels_.find(ep);
if (it != channels_.end()) {
return it->second;
}
// Channel configurations:
grpc::ChannelArguments args;
args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 2000);
if (FLAGS_rpc_disable_reuse_port) {
args.SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0);
}
args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
auto ch =
grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
channels_[ep] = ch;
return ch;
}
} // namespace distributed
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <time.h>
#include <atomic>
#include <chrono> // NOLINT
#include <condition_variable> // NOLINT
#include <ctime>
#include <functional>
#include <iostream>
#include <map>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <thread> // NOLINT
#include <unordered_map>
#include <vector>
#include "grpc++/channel.h"
#include "grpc++/generic/generic_stub.h"
#include "grpc++/grpc++.h"
#include "grpc++/support/byte_buffer.h"
#include "grpc++/support/slice.h"
#include "grpc/support/log.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/distributed_pb.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace grpc {
class Channel;
} // namespace grpc
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
namespace distributed {
void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
void ProcGetRecvResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
class BaseProcessor {
public:
BaseProcessor() { context_ = nullptr; }
virtual ~BaseProcessor() {}
virtual void Prepare(VarHandlePtr h, int64_t time_out) {
var_h_ = h;
context_.reset(new grpc::ClientContext());
context_->set_wait_for_ready(true);
if (time_out) {
std::chrono::system_clock::time_point deadline =
std::chrono::system_clock::now() +
std::chrono::milliseconds(time_out);
context_->set_deadline(deadline);
}
}
void Process() {
ProcessImpl();
var_h_->Finish(true);
}
VarHandlePtr GetVarHandlePtr() { return var_h_; }
bool Wait() { return var_h_->Wait(); }
void Finish(bool ok) { return var_h_->Finish(ok); }
virtual void ProcessImpl() = 0;
std::unique_ptr<grpc::ClientContext> context_;
grpc::Status status_;
protected:
VarHandlePtr var_h_;
};
typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
RequestSendCallBack;
class SendProcessor : public BaseProcessor {
public:
explicit SendProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(), stub_g_(ch) {}
virtual ~SendProcessor() {}
void ProcessImpl() override {
if (response_call_back_) {
response_call_back_(*var_h_.get(), reply_);
}
}
::grpc::GenericStub stub_g_;
::grpc::ByteBuffer reply_;
RequestSendCallBack response_call_back_ = nullptr;
};
typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
RequestGetCallBack;
class GetProcessor : public BaseProcessor {
public:
explicit GetProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(), stub_g_(ch) {}
virtual ~GetProcessor() {}
void ProcessImpl() override {
if (response_call_back_) {
response_call_back_(*var_h_.get(), reply_);
}
}
::grpc::ByteBuffer reply_;
::grpc::GenericStub stub_g_;
RequestGetCallBack response_call_back_ = ProcGetResponse;
};
class SendAndRecvProcessor : public BaseProcessor {
public:
explicit SendAndRecvProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(), stub_g_(ch) {}
virtual ~SendAndRecvProcessor() {}
void ProcessImpl() override {
if (response_call_back_) {
response_call_back_(*var_h_recv_.get(), reply_);
var_h_recv_->Finish(true);
}
}
void RecvPrepare(VarHandlePtr h_recv) { var_h_recv_ = h_recv; }
::grpc::ByteBuffer reply_;
::grpc::GenericStub stub_g_;
RequestGetCallBack response_call_back_ = ProcGetResponse;
VarHandlePtr var_h_recv_;
};
class BatchBarrierProcessor : public BaseProcessor {
public:
explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor() {
stub_ = sendrecv::SendRecvService::NewStub(ch);
}
virtual ~BatchBarrierProcessor() {}
void ProcessImpl() override {}
sendrecv::VoidMessage reply_;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
};
class FetchBarrierProcessor : public BaseProcessor {
public:
explicit FetchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor() {
stub_ = sendrecv::SendRecvService::NewStub(ch);
}
virtual ~FetchBarrierProcessor() {}
void ProcessImpl() override {}
sendrecv::VariableMessage reply_;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
};
class CheckpointNotifyProcessor : public BaseProcessor {
public:
explicit CheckpointNotifyProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor() {
stub_ = sendrecv::SendRecvService::NewStub(ch);
}
virtual ~CheckpointNotifyProcessor() {}
void ProcessImpl() override {}
sendrecv::VoidMessage reply_;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
};
class GRPCClient : public RPCClient {
public:
GRPCClient() : ok_(true), completed_(false), stopped_(false) {}
virtual ~GRPCClient();
VarHandlePtr AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& out_varname,
const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncGetVarNoBarrier(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
const std::string& out_varname,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncGetMonomerVariable(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncSendBatchBarrier(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out) override;
VarHandlePtr AsyncGetMonomerBarrier(
const std::string& ep, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dirname,
const std::string& varname, const int mode,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncDistributeNotify(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncSendAndRecv(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& send_var_name,
const std::string& recv_var_name,
const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncSendComplete(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
bool Wait() override;
void SendComplete() override;
void InitImpl() override;
private:
void Proceed();
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
VarHandlePtr _AsyncGetVar(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& method,
const std::string& var_name, const std::string& out_varname,
const std::string& rpc_path, const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline);
private:
grpc::CompletionQueue cq_;
std::unordered_map<std::string, std::shared_ptr<grpc::Channel>> channels_;
std::vector<std::unique_ptr<std::thread>> client_threads_;
// mutex for Wait client sync
std::mutex sync_mutex_;
std::condition_variable sync_cond_;
std::atomic<int64_t> req_count_{0};
bool ok_;
// mutex for GetChannel thread safety
std::mutex chan_mutex_;
DISABLE_COPY_AND_ASSIGN(GRPCClient);
// mutex for sending complete message only once
std::mutex completed_mutex_;
bool completed_;
volatile bool stopped_;
};
} // namespace distributed
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_NCCL
#include <nccl.h>
#endif
#ifdef PADDLE_WITH_RCCL
#include <rccl.h>
#endif
#include <limits>
#include <memory>
#include "grpcpp/impl/codegen/byte_buffer.h"
#include "grpcpp/impl/codegen/slice.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_serde.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_variable_response.h"
#include "paddle/fluid/operators/distributed/proto_encoder_helper.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace framework {
class Scope;
class Variable;
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
namespace distributed {
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg, const std::string& out_name,
const int trainer_id,
const std::string& table_name) {
platform::RecordRPCEvent record_event("serial");
VarMsg request;
TensorPayload* payload = nullptr;
request.set_varname(name);
request.set_trainer_id(trainer_id);
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
// trainer.
if (platform::ShouldSendProfileState()) {
if (platform::IsProfileEnabled()) {
request.set_profile(platform::kEnableProfiler);
} else {
request.set_profile(platform::kDisableProfiler);
}
}
if (!out_name.empty()) {
request.set_out_varname(out_name);
}
if (!table_name.empty()) {
request.set_table_name(table_name);
}
if (var->IsType<framework::LoDTensor>()) {
request.set_type(::sendrecv::LOD_TENSOR);
payload = new TensorPayload(GetTensorPayload(var, ctx, &request));
} else if (var->IsType<framework::SelectedRows>()) {
request.set_type(::sendrecv::SELECTED_ROWS);
payload = new TensorPayload(GetSelectedRowsPayload(var, ctx, &request));
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
} else if (var->IsType<ncclUniqueId>()) {
request.set_type(::sendrecv::NCCL_ID);
#endif
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Serialize does not support type: %s", typeid(var->Type()).name()));
}
std::string header;
request.AppendToString(&header);
auto buffer = std::unique_ptr<char[]>(new char[1024]);
void* buf = buffer.get();
ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
e.WriteRawBytes(std::string(header.data(), header.size()));
// NCCLID is copied directly to the message, return bytebuffer
// with only one slice if serializing NCCLID.
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if (var->IsType<ncclUniqueId>()) {
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
NCCL_UNIQUE_ID_BYTES);
const ncclUniqueId& uid = var->Get<ncclUniqueId>();
e.WriteRawBytes(std::string(uid.internal, NCCL_UNIQUE_ID_BYTES));
// for serialize NCCL_ID
::grpc::Slice slices(e.size());
memcpy(const_cast<uint8_t*>(slices.begin()), e.data(), e.size());
::grpc::ByteBuffer tmp(&slices, 1);
msg->Swap(&tmp);
return;
}
#endif
PADDLE_ENFORCE_NOT_NULL(
payload,
platform::errors::InvalidArgument(
"Not support type: %s, need to be LOD_TENSOR or SELECTED_ROWS",
var->Type()));
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
payload->memory_size());
if (payload->memory_size() >= std::numeric_limits<int>::max()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Variable %s length %d should less than %d.", name,
payload->memory_size(), std::numeric_limits<int>::max()));
}
// steal reference of tensor data
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
int num_slices = 2; // only SelectedRows have rows buffer
slices[0] = ::grpc::Slice(e.size());
memcpy(const_cast<uint8_t*>(slices[0].begin()), e.data(), e.size());
slices[1] = ::grpc::Slice(
grpc_slice_new_with_user_data(payload->ptr(), payload->memory_size(),
SerializeDestroyCallback, payload),
::grpc::Slice::STEAL_REF);
if (var->IsType<framework::SelectedRows>()) {
auto* slr = var->GetMutable<framework::SelectedRows>();
ProtoEncodeHelper e2(static_cast<char*>(buf), 128);
PADDLE_ENFORCE_EQ(VectorElemName(slr->rows()), typeid(int64_t).name(),
platform::errors::InvalidArgument(
"Got wrong type %s, expect type: int64_t",
VectorElemName(slr->rows())));
size_t rows_memory_size = slr->rows().size() * sizeof(int64_t);
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
slices[2] = ::grpc::Slice(e2.size());
memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size());
slices[3] = ::grpc::Slice(
grpc_slice_new_with_user_data(
const_cast<void*>(
reinterpret_cast<const void*>(slr->rows().data())),
rows_memory_size, [](void* backing) {},
const_cast<char*>(
reinterpret_cast<const char*>(slr->rows().data()))),
::grpc::Slice::STEAL_REF);
num_slices = 4;
}
::grpc::ByteBuffer tmp(&slices[0], num_slices);
msg->Swap(&tmp);
}
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var, int* trainer_id) {
platform::RecordRPCEvent record_event("deserial");
operators::distributed::GRPCVariableResponse resp(scope, &ctx);
PADDLE_ENFORCE_EQ(
resp.Parse(msg), 0,
platform::errors::InvalidArgument("parse bytebuffer to tensor error!"));
*var = resp.GetVar();
*trainer_id = resp.GetTrainerId();
}
void DeserializeRecvFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var, int* trainer_id) {
platform::RecordRPCEvent record_event("deserial");
operators::distributed::GRPCVariableResponse resp(scope, &ctx);
PADDLE_ENFORCE_EQ(
resp.Parse(msg), 0,
platform::errors::InvalidArgument("parse bytebuffer to tensor error!"));
*var = resp.GetRecvVar();
*trainer_id = resp.GetTrainerId();
}
} // namespace distributed
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/distributed_pb.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/platform/port.h"
namespace grpc {
class ByteBuffer;
} // namespace grpc
namespace paddle {
namespace framework {
class Scope;
class Variable;
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
namespace distributed {
typedef void (*DestroyCallback)(void*);
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg,
const std::string& out_varname = std::string(),
const int trainer_id = 0,
const std::string& table_name = std::string());
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var, int* trainer_id);
void DeserializeRecvFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var, int* trainer_id);
} // namespace distributed
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <string>
#include <thread> // NOLINT
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_serde.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_variable_response.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace operators = paddle::operators;
namespace math = paddle::operators::math;
namespace memory = paddle::memory;
void RunSerdeTestSelectedRows(platform::Place place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
// serialize var to ByteBuffer
framework::Variable var;
auto* slr = var.GetMutable<framework::SelectedRows>();
slr->set_height(1000);
auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows();
tensor->Resize(framework::make_ddim({564, 128}));
tensor->mutable_data<float>(place);
int tensor_numel = 564 * 128;
math::set_constant(ctx, tensor, 32.7);
for (int i = 0; i < 564; ++i) rows->push_back(i);
::grpc::ByteBuffer msg;
operators::distributed::SerializeToByteBuffer("myvar", &var, ctx, &msg);
EXPECT_GT(msg.Length(), static_cast<size_t>(0));
// deserialize
std::vector<::grpc::Slice> slices;
(void)msg.Dump(&slices);
std::string tmp;
for (const auto& s : slices) {
tmp.append(reinterpret_cast<const char*>(s.begin()), s.size());
}
sendrecv::VariableMessage varmsg;
EXPECT_TRUE(varmsg.ParseFromString(tmp));
// deserialize bytebuffer
EXPECT_EQ(varmsg.varname(), "myvar");
EXPECT_EQ(varmsg.type(), 1);
const float* tensor_data =
reinterpret_cast<const float*>(varmsg.serialized().data());
const int64_t* rows_data =
reinterpret_cast<const int64_t*>(varmsg.rows().data());
for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data[i], 32.7);
}
for (int i = 0; i < 564; ++i) {
EXPECT_EQ(rows_data[i], i);
}
// deserialize zero-copy
// framework::Variable var2;
// operators::distributed::DeserializeFromByteBuffer(msg, ctx, &var2);
framework::Scope scope;
scope.Var("myvar");
operators::distributed::GRPCVariableResponse resp(&scope, &ctx);
EXPECT_EQ(resp.Parse(msg), 0);
framework::Variable* var2 = resp.GetVar();
auto* slr2 = var2->GetMutable<framework::SelectedRows>();
auto* tensor2 = slr2->mutable_value();
auto* rows2 = slr2->mutable_rows();
float* tensor_data2 = nullptr;
framework::Tensor tmp_tensor;
if (platform::is_gpu_place(ctx.GetPlace())) {
platform::CPUPlace cpu;
framework::TensorCopy(*tensor2, cpu, &tmp_tensor);
tensor_data2 = tmp_tensor.data<float>();
} else {
tensor_data2 = const_cast<float*>(tensor2->data<float>());
}
const int64_t* rows_data2 = rows2->data();
for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data2[i], 32.7);
}
for (size_t i = 0; i < rows2->size(); ++i) {
EXPECT_EQ(rows_data2[i], static_cast<int64_t>(i));
}
EXPECT_EQ(slr2->height(), 1000);
}
void RunTestLodTensor(platform::Place place, int from_type = 0) {
// serialize var to ByteBuffer
framework::Variable var;
auto* tensor = var.GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim({512, 8, 4, 2}));
framework::LoD lod;
lod.push_back(framework::Vector<size_t>({1, 3, 8}));
tensor->set_lod(lod);
int tensor_numel = 512 * 8 * 4 * 2;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
tensor->mutable_data<float>(place);
math::set_constant(ctx, tensor, 31.9);
::grpc::ByteBuffer msg;
operators::distributed::SerializeToByteBuffer("myvar", &var, ctx, &msg,
"outvar", 0, "table_name");
EXPECT_GT(msg.Length(), static_cast<size_t>(0));
// deserialize
std::vector<::grpc::Slice> slices;
(void)msg.Dump(&slices);
std::string tmp;
for (const auto& s : slices) {
tmp.append(reinterpret_cast<const char*>(s.begin()), s.size());
}
sendrecv::VariableMessage varmsg;
EXPECT_TRUE(varmsg.ParseFromString(tmp));
EXPECT_EQ(varmsg.varname(), "myvar");
EXPECT_EQ(varmsg.type(), 0);
EXPECT_EQ(varmsg.dims()[0], 512);
EXPECT_EQ(varmsg.dims()[1], 8);
EXPECT_EQ(varmsg.dims()[2], 4);
EXPECT_EQ(varmsg.dims()[3], 2);
EXPECT_EQ(varmsg.lod_level(), 1);
EXPECT_EQ(varmsg.lod(0).lod_data(0), 1);
EXPECT_EQ(varmsg.lod(0).lod_data(1), 3);
EXPECT_EQ(varmsg.lod(0).lod_data(2), 8);
const float* tensor_data =
reinterpret_cast<const float*>(varmsg.serialized().data());
for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data[i], 31.9);
}
// message binary
std::string str;
varmsg.SerializeToString(&str);
// message bytebuffer
::grpc::Slice slices_2[1];
int num_slices = 1;
slices_2[0] = ::grpc::Slice(str.length());
memcpy(const_cast<uint8_t*>(slices_2[0].begin()), str.c_str(), str.length());
::grpc::ByteBuffer bytebuffer2(&slices_2[0], num_slices);
// deserialize zero-copy
framework::Scope scope;
scope.Var("myvar");
operators::distributed::GRPCVariableResponse resp(&scope, &ctx);
if (from_type == 0) {
EXPECT_EQ(resp.Parse(msg), 0);
} else {
EXPECT_EQ(resp.Parse(bytebuffer2), 0);
}
framework::Variable* var2 = resp.GetVar();
auto tensor2 = var2->Get<framework::LoDTensor>();
float* tensor_data2 = nullptr;
framework::Tensor tmp_tensor;
if (platform::is_gpu_place(ctx.GetPlace())) {
platform::CPUPlace cpu;
framework::TensorCopy(tensor2, cpu, &tmp_tensor);
tensor_data2 = tmp_tensor.data<float>();
} else {
tensor_data2 = const_cast<float*>(tensor2.data<float>());
}
EXPECT_EQ(varmsg.lod_level(), 1);
EXPECT_EQ(varmsg.lod(0).lod_data(0), 1);
EXPECT_EQ(varmsg.lod(0).lod_data(1), 3);
EXPECT_EQ(varmsg.lod(0).lod_data(2), 8);
for (int i = 0; i < tensor_numel; ++i) EXPECT_FLOAT_EQ(tensor_data2[i], 31.9);
}
TEST(LodTensor, Run) {
platform::CPUPlace place;
RunTestLodTensor(place);
RunTestLodTensor(place, 1);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform::CUDAPlace gpu(0);
RunTestLodTensor(gpu);
RunTestLodTensor(gpu, 1);
#endif
}
TEST(SelectedRows, Run) {
platform::CPUPlace place;
RunSerdeTestSelectedRows(place);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform::CUDAPlace gpu;
RunSerdeTestSelectedRows(gpu);
#endif
}
/*Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <limits>
#include <memory>
#include <string>
#include "paddle/fluid/operators/distributed/grpc/grpc_serde.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_server.h"
namespace grpc {
class ChannelArguments;
} // namespace grpc
namespace paddle {
namespace framework {
class Variable;
} // namespace framework
namespace operators {
namespace distributed {
class GRPCVariableResponse;
} // namespace distributed
} // namespace operators
} // namespace paddle
using ::grpc::ServerAsyncResponseWriter;
DECLARE_bool(rpc_disable_reuse_port);
DECLARE_int32(rpc_retry_bind_port);
namespace paddle {
namespace operators {
namespace distributed {
enum CallStatus { PROCESS = 0, FINISH };
// reference:
// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server
class RequestBase {
public:
explicit RequestBase(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id)
: service_(service),
cq_(cq),
status_(PROCESS),
request_handler_(request_handler),
req_id_(req_id) {
PADDLE_ENFORCE_NOT_NULL(cq_, platform::errors::InvalidArgument(
"ServerCompletionQueue cq are empty"));
}
virtual ~RequestBase() {}
virtual void Process() = 0;
std::string Status2String(const std::string& method) {
std::string status = "Process";
if (status_ == FINISH) {
status = "Finish";
}
std::ostringstream s;
s << method << " name:[" << GetReqName() << "]"
<< ", ep:[" << ctx_.peer() << "]"
<< " " << status << " using req_id:" << req_id_;
return s.str();
}
CallStatus Status() const {
std::lock_guard<std::mutex> l(status_mu_);
return status_;
}
template <typename T>
void Finish(const T& reply, ServerAsyncResponseWriter<T>* responder) {
std::lock_guard<std::mutex> l(status_mu_);
status_ = FINISH;
responder->Finish(reply, ::grpc::Status::OK,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
}
virtual std::string GetReqName() = 0;
protected:
mutable std::mutex status_mu_;
::grpc::ServerContext ctx_;
GrpcService::AsyncService* service_;
::grpc::ServerCompletionQueue* cq_;
CallStatus status_;
RequestHandler* request_handler_;
int req_id_;
};
class RequestSend final : public RequestBase {
public:
explicit RequestSend(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new GRPCVariableResponse(request_handler->scope(),
request_handler->dev_ctx(), true));
int method_id = static_cast<int>(distributed::GrpcMethod::kSendVariable);
service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
}
virtual ~RequestSend() {}
std::string GetReqName() override { return request_->Varname(); }
void Process() override {
std::string varname = GetReqName();
auto scope = request_->GetMutableLocalScope();
auto invar = request_->GetVar();
int trainer_id = request_->GetTrainerId();
VLOG(4) << "RequestSend var_name:" << varname << " trainer: " << trainer_id;
framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
Finish(reply_, &responder_);
}
protected:
sendrecv::VoidMessage reply_;
std::shared_ptr<GRPCVariableResponse> request_;
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
};
class RequestGet final : public RequestBase {
public:
explicit RequestGet(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
auto method_id = static_cast<int>(distributed::GrpcMethod::kGetVariable);
service_->RequestAsyncUnary(
method_id, &ctx_, &request_, &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
}
virtual ~RequestGet() {}
std::string GetReqName() override { return request_.varname(); }
void Process() override {
// proc request.
std::string varname = request_.varname();
std::string out_varname = request_.out_varname();
std::string table_name = request_.table_name();
int trainer_id = request_.trainer_id();
VLOG(4) << "RequestGet " << out_varname << " from " << varname;
auto scope = request_handler_->scope();
framework::Variable* invar = nullptr;
framework::Variable* outvar = nullptr;
tmp_scope_ = std::move(scope->NewTmpScope());
request_handler_->Handle(varname, tmp_scope_.get(), invar, &outvar,
trainer_id, out_varname, table_name);
VLOG(1) << "before SerializeToByteBuffer";
if (outvar) {
SerializeToByteBuffer(out_varname, outvar, *request_handler_->dev_ctx(),
&reply_);
}
VLOG(1) << "after SerializeToByteBuffer";
Finish(reply_, &responder_);
}
protected:
sendrecv::VariableMessage request_;
::grpc::ByteBuffer reply_;
std::unique_ptr<framework::Scope> tmp_scope_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
};
class RequestGetNoBarrier final : public RequestBase {
public:
explicit RequestGetNoBarrier(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
auto method_id =
static_cast<int>(distributed::GrpcMethod::kGetVariableNoBarrier);
service_->RequestAsyncUnary(
method_id, &ctx_, &request_, &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
}
virtual ~RequestGetNoBarrier() {}
std::string GetReqName() override { return request_.varname(); }
void Process() override {
// proc request.
std::string varname = request_.varname();
std::string out_varname = request_.out_varname();
int trainer_id = request_.trainer_id();
VLOG(4) << "RequestGetNoBarrier " << out_varname << " from " << varname;
auto scope = request_handler_->scope();
framework::Variable* invar = nullptr;
framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar, trainer_id,
out_varname);
if (outvar) {
SerializeToByteBuffer(out_varname, outvar, *request_handler_->dev_ctx(),
&reply_);
}
Finish(reply_, &responder_);
}
protected:
sendrecv::VariableMessage request_;
::grpc::ByteBuffer reply_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
};
class RequestGetMonomerVariable final : public RequestBase {
public:
explicit RequestGetMonomerVariable(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler,
int req_id, RPCServer* rpc_server)
: RequestBase(service, cq, request_handler, req_id),
responder_(&ctx_),
rpc_server_(rpc_server) {
auto method_id =
static_cast<int>(distributed::GrpcMethod::kGetMonomerVariable);
service_->RequestAsyncUnary(
method_id, &ctx_, &request_, &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
}
virtual ~RequestGetMonomerVariable() {}
std::string GetReqName() override { return request_.varname(); }
void Process() override {
// proc request.
std::string varname = request_.varname();
rpc_server_->WaitVarCond(varname);
MonomerHandle h = rpc_server_->GetMonomer(varname);
auto scope = h.scope_;
auto invar = scope->FindVar(varname);
framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar,
request_.trainer_id());
if (outvar) {
SerializeToByteBuffer(varname, outvar, *h.dev_ctx_, &reply_);
}
Finish(reply_, &responder_);
}
protected:
sendrecv::VariableMessage request_;
::grpc::ByteBuffer reply_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
RPCServer* rpc_server_{nullptr};
};
class RequestGetMonomerBarrier final : public RequestBase {
public:
explicit RequestGetMonomerBarrier(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id,
RPCServer* rpc_server)
: RequestBase(service, cq, request_handler, req_id),
responder_(&ctx_),
rpc_server_(rpc_server) {
auto method_id =
static_cast<int>(distributed::GrpcMethod::kGetMonomerBarrier);
service_->RequestAsyncUnary(
method_id, &ctx_, &request_, &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
}
virtual ~RequestGetMonomerBarrier() {}
std::string GetReqName() override { return request_.varname(); }
void Process() override {
// proc request.
std::string varname = request_.varname();
VLOG(4) << "RequestGetMonomerBarrier " << varname;
rpc_server_->WaitVarCond(varname);
MonomerHandle h = rpc_server_->GetMonomer(varname);
framework::Scope* scope = nullptr;
framework::Variable* invar = nullptr;
framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar,
request_.trainer_id());
Finish(reply_, &responder_);
}
protected:
sendrecv::VariableMessage request_;
sendrecv::VoidMessage reply_;
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
RPCServer* rpc_server_{nullptr};
};
class RequestPrefetch final : public RequestBase {
public:
explicit RequestPrefetch(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id),
responder_(&ctx_),
local_scope_(nullptr) {
request_.reset(new GRPCVariableResponse(request_handler->scope(),
request_handler->dev_ctx(), true));
int method_id =
static_cast<int>(distributed::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
}
virtual ~RequestPrefetch() {}
std::string GetReqName() override { return request_->Varname(); }
void Process() override {
// prefetch process...
std::string in_var_name = request_->Varname();
std::string out_var_name = request_->OutVarname();
std::string table_name = request_->TableName();
int trainer_id = request_->GetTrainerId();
VLOG(4) << "RequestPrefetch, in_var_name: " << in_var_name
<< " out_var_name: " << out_var_name << " trainer: " << trainer_id;
auto scope = request_->GetMutableLocalScope();
auto invar = scope->FindVar(in_var_name);
// out var must be created in local scope!
framework::Variable* outvar = scope->Var(out_var_name);
request_handler_->Handle(in_var_name, scope, invar, &outvar, trainer_id,
out_var_name, table_name);
SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(),
&reply_);
Finish(reply_, &responder_);
}
protected:
std::shared_ptr<GRPCVariableResponse> request_;
::grpc::ByteBuffer reply_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
framework::Scope* local_scope_;
};
class RequestCheckpointNotify final : public RequestBase {
public:
explicit RequestCheckpointNotify(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new GRPCVariableResponse(request_handler->scope(),
request_handler->dev_ctx()));
int method_id =
static_cast<int>(distributed::GrpcMethod::kCheckpointNotify);
service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
}
virtual ~RequestCheckpointNotify() {}
std::string GetReqName() override { return request_->Varname(); }
void Process() override {
auto scope = request_->GetMutableLocalScope();
std::string checkpoint_notify = request_->Varname();
std::string checkpoint_dir = request_->OutVarname();
int trainer_id = request_->GetTrainerId();
std::string table_name = request_->TableName();
VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify
<< ", dir: " << checkpoint_dir;
request_handler_->Handle(checkpoint_notify, scope, nullptr, nullptr,
trainer_id, checkpoint_dir, table_name);
Finish(reply_, &responder_);
}
protected:
std::shared_ptr<GRPCVariableResponse> request_;
sendrecv::VoidMessage reply_;
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
};
class RequestNotify final : public RequestBase {
public:
explicit RequestNotify(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new GRPCVariableResponse(request_handler->scope(),
request_handler->dev_ctx(), true));
int method_id = static_cast<int>(distributed::GrpcMethod::kRequestNotify);
service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
}
virtual ~RequestNotify() {}
std::string GetReqName() override { return request_->Varname(); }
void Process() override {
std::string varname = GetReqName();
VLOG(4) << "RequestNotify var_name:" << varname;
auto scope = request_->GetMutableLocalScope();
auto invar = request_->GetVar();
int trainer_id = request_->GetTrainerId();
framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
Finish(reply_, &responder_);
}
protected:
sendrecv::VoidMessage reply_;
std::shared_ptr<GRPCVariableResponse> request_;
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
};
class RequestSendAndRecv final : public RequestBase {
public:
explicit RequestSendAndRecv(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new GRPCVariableResponse(request_handler->scope(),
request_handler->dev_ctx(), true));
int method_id =
static_cast<int>(distributed::GrpcMethod::kRequestSendAndRecv);
service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
}
virtual ~RequestSendAndRecv() {}
std::string GetReqName() override { return request_->Varname(); }
void Process() override {
std::string in_var_name = request_->Varname();
std::string out_var_name = request_->OutVarname();
std::string table_name = request_->TableName();
int trainer_id = request_->GetTrainerId();
VLOG(4) << "RequestSendAndRecv, in_var_name: " << in_var_name
<< " out_var_name: " << out_var_name << " trainer: " << trainer_id;
auto scope = request_->GetMutableLocalScope();
auto invar = scope->FindVar(in_var_name);
framework::Variable* outvar = nullptr;
request_handler_->Handle(in_var_name, scope, invar, &outvar, trainer_id,
out_var_name, table_name);
SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(),
&reply_);
Finish(reply_, &responder_);
}
protected:
std::shared_ptr<GRPCVariableResponse> request_;
::grpc::ByteBuffer reply_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
};
void AsyncGRPCServer::WaitServerReady() {
VLOG(4) << "AsyncGRPCServer is waiting server ready";
std::unique_lock<std::mutex> lock(this->mutex_ready_);
condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
VLOG(4) << "AsyncGRPCServer WaitSeverReady";
}
// Define an option subclass in order to disable SO_REUSEPORT for the
// server socket.
// Come from:
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
class NoReusePortOption : public ::grpc::ServerBuilderOption {
public:
void UpdateArguments(::grpc::ChannelArguments* args) override {
args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0);
}
void UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>>*
plugins) override {}
};
void AsyncGRPCServer::StartServer() {
for (int i = 0; i < FLAGS_rpc_retry_bind_port; i++) {
::grpc::ServerBuilder builder;
std::unique_ptr<GrpcService::AsyncService> service(
new GrpcService::AsyncService());
builder.AddListeningPort(bind_address_, ::grpc::InsecureServerCredentials(),
&selected_port_);
builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
if (FLAGS_rpc_disable_reuse_port) {
builder.SetOption(
std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption));
LOG(INFO) << "set FLAGS_rpc_disable_reuse_port";
}
builder.RegisterService(service.get());
for (auto t : rpc_call_map_) {
rpc_cq_[t.first].reset(builder.AddCompletionQueue().release());
}
server_ = builder.BuildAndStart();
if (selected_port_ != 0) {
LOG(INFO) << "Server listening on " << bind_address_
<< " successful, selected port: " << selected_port_;
service_.reset(service.release());
break;
}
LOG(WARNING) << "Server listening on " << bind_address_
<< " failed, selected port: " << selected_port_
<< ", retry after 3 seconds!";
sleep(3);
}
PADDLE_ENFORCE_NE(
selected_port_, 0,
platform::errors::Unavailable("can't bind to address:%s", bind_address_));
std::function<void(const std::string&, int)> f =
std::bind(&AsyncGRPCServer::TryToRegisterNewOne, this,
std::placeholders::_1, std::placeholders::_2);
for (auto& t : rpc_call_map_) {
auto& rpc_name = t.first;
auto& cq = rpc_cq_[rpc_name];
auto threadnum = rpc_thread_num_[rpc_name];
auto& reqs = rpc_reqs_[rpc_name];
reqs.reserve(kRequestBufSize);
for (int i = 0; i < kRequestBufSize; i++) {
VLOG(6) << "TryToRegisterNewOne on RPC NAME: " << rpc_name << " I: " << i;
TryToRegisterNewOne(rpc_name, i);
}
for (int i = 0; i < threadnum; i++) {
rpc_threads_[rpc_name].emplace_back(new std::thread(std::bind(
&AsyncGRPCServer::HandleRequest, this, cq.get(), rpc_name, f)));
VLOG(4) << t.first << " creates threads!";
}
}
{
std::lock_guard<std::mutex> lock(this->mutex_ready_);
ready_ = 1;
}
condition_ready_.notify_all();
// wait server
server_->Wait();
for (auto& t : rpc_threads_) {
auto& threads = t.second;
for (size_t i = 0; i < threads.size(); ++i) {
threads[i]->join();
VLOG(4) << t.first << " threads ends!";
}
}
}
void AsyncGRPCServer::ShutdownQueue() {
for (auto& t : rpc_cq_) {
t.second->Shutdown();
VLOG(4) << t.first << " queue shutdown!";
}
}
void AsyncGRPCServer::ShutDownImpl() {
std::unique_lock<std::mutex> lock(cq_mutex_);
is_shut_down_ = true;
ShutdownQueue();
VLOG(4) << "server_ shutdown!";
server_->Shutdown();
}
void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
int req_id) {
std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) {
VLOG(4) << "shutdown, do not TryToRegisterNewSendOne";
return;
}
VLOG(4) << "TryToRegisterNewOne on RPC NAME: " << rpc_name
<< " REQ ID: " << req_id;
auto& reqs = rpc_reqs_[rpc_name];
auto& handler = rpc_call_map_[rpc_name];
auto& cq = rpc_cq_[rpc_name];
RequestBase* b = nullptr;
if (rpc_name == kRequestSend) {
b = new RequestSend(service_.get(), cq.get(), handler, req_id);
} else if (rpc_name == kRequestGet) {
b = new RequestGet(service_.get(), cq.get(), handler, req_id);
} else if (rpc_name == kRequestGetNoBarrier) {
b = new RequestGetNoBarrier(service_.get(), cq.get(), handler, req_id);
} else if (rpc_name == kRequestGetMonomerVariable) {
b = new RequestGetMonomerVariable(service_.get(), cq.get(), handler, req_id,
this);
} else if (rpc_name == kRequestGetMonomerBarrier) {
b = new RequestGetMonomerBarrier(service_.get(), cq.get(), handler, req_id,
this);
} else if (rpc_name == kRequestPrefetch) {
b = new RequestPrefetch(service_.get(), cq.get(), handler, req_id);
} else if (rpc_name == kRequestCheckpoint) {
b = new RequestCheckpointNotify(service_.get(), cq.get(), handler, req_id);
} else if (rpc_name == kRequestNotify) {
b = new RequestNotify(service_.get(), cq.get(), handler, req_id);
} else if (rpc_name == kRequestSendAndRecv) {
b = new RequestSendAndRecv(service_.get(), cq.get(), handler, req_id);
} else {
PADDLE_THROW(
platform::errors::InvalidArgument("not supported rpc: %s", rpc_name));
}
reqs[req_id] = b;
VLOG(4) << "TryToRegisterNewOne status:" << b->Status();
}
void AsyncGRPCServer::HandleRequest(
::grpc::ServerCompletionQueue* cq, const std::string& rpc_name,
std::function<void(const std::string&, int)> TryToRegisterNewOne) {
void* tag = NULL;
bool ok = false;
while (true) {
VLOG(4) << "HandleRequest " << rpc_name << " wait next";
if (!cq->Next(&tag, &ok)) {
VLOG(4) << "CompletionQueue " << rpc_name << " shutdown!";
break;
}
int req_id = static_cast<int>(reinterpret_cast<intptr_t>(tag));
VLOG(4) << "HandleRequest " << rpc_name << ", req_id:" << req_id
<< " get next";
auto& reqs = rpc_reqs_[rpc_name];
RequestBase* base = nullptr;
{
PADDLE_ENFORCE_EQ(
(req_id >= 0 && req_id < kRequestBufSize), true,
platform::errors::OutOfRange("request id: %s out of bounds: [0, %s)",
req_id, kRequestBufSize));
std::unique_lock<std::mutex> lock(cq_mutex_);
base = reqs[req_id];
}
VLOG(3) << base->Status2String(rpc_name);
// reference:
// https://github.com/tensorflow/tensorflow/issues/5596
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
// https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
if (!ok) {
VLOG(4) << "completion queue:" << rpc_name << " recv no regular event"
<< " context:" << base->Status2String(rpc_name);
TryToRegisterNewOne(rpc_name, req_id);
delete base;
continue;
}
switch (base->Status()) {
case PROCESS: {
base->Process();
break;
}
case FINISH: {
TryToRegisterNewOne(rpc_name, req_id);
delete base;
break;
}
default: { assert(false); }
}
}
}
} // namespace distributed
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include <memory>
#include <set>
#include <string>
#include <thread> // NOLINT
#include <utility>
#include <vector>
#include "grpc++/grpc++.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/distributed_pb.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_service.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/platform/profiler.h"
namespace grpc {
class ServerCompletionQueue;
} // namespace grpc
namespace paddle {
namespace operators {
namespace distributed {
class RequestBase;
class AsyncGRPCServer final : public RPCServer {
public:
explicit AsyncGRPCServer(const std::string& address, int client_num)
: RPCServer(address, client_num), ready_(0) {}
virtual ~AsyncGRPCServer() {}
void WaitServerReady() override;
void StartServer() override;
private:
// HandleRequest needs to be thread-safe.
void HandleRequest(
::grpc::ServerCompletionQueue* cq, const std::string& rpc_name,
std::function<void(const std::string&, int)> TryToRegisterNewOne);
void TryToRegisterNewOne(const std::string& rpc_name, int req_id);
void ShutdownQueue();
void ShutDownImpl() override;
private:
static const int kRequestBufSize = 100;
std::mutex cq_mutex_;
volatile bool is_shut_down_ = false;
std::unique_ptr<GrpcService::AsyncService> service_;
std::unique_ptr<::grpc::Server> server_;
// condition of the sub program
std::condition_variable barrier_condition_;
std::mutex mutex_ready_;
std::condition_variable condition_ready_;
int ready_;
std::map<std::string, std::unique_ptr<::grpc::ServerCompletionQueue>> rpc_cq_;
std::map<std::string, std::vector<std::unique_ptr<std::thread>>> rpc_threads_;
std::map<std::string, std::vector<RequestBase*>> rpc_reqs_;
};
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <grpc++/impl/codegen/async_stream.h>
#include <grpc++/impl/codegen/async_unary_call.h>
#include <grpc++/impl/codegen/proto_utils.h>
#include <grpc++/impl/codegen/rpc_method.h>
#include <grpc++/impl/codegen/service_type.h>
#include <grpc++/impl/codegen/status.h>
#include <grpc++/impl/codegen/stub_options.h>
#include <grpc++/impl/codegen/sync_stream.h>
#include <grpc++/support/byte_buffer.h>
#include "paddle/fluid/operators/distributed/grpc/grpc_variable_response.h"
#include "paddle/fluid/platform/profiler.h"
// NOTE: This method was originally created by tensorflow
// (https://github.com/tensorflow/tensorflow/) we borrow this
// method and did some modifications so that we can parse gRPC
// requests without too much copying of the tensor data.
namespace grpc {
class CompletionQueue;
class Channel;
class RpcService;
class ServerCompletionQueue;
class ServerContext;
// Support parsing/unparsing of tensorflow::VariableResponse.
// Wire-format is identical to RecvVariableResponse.
template <>
class SerializationTraits<
paddle::operators::distributed::GRPCVariableResponse> {
public:
static Status Serialize(
const paddle::operators::distributed::GRPCVariableResponse& msg,
grpc_byte_buffer** bp, bool* own_buffer) {
PADDLE_THROW(paddle::platform::errors::Unimplemented(
"SerializationTraits::Serialize not implemented!"));
return Status();
}
static Status Deserialize(
grpc_byte_buffer* buffer,
paddle::operators::distributed::GRPCVariableResponse* msg,
int max_message_size = INT_MAX) {
if (buffer == nullptr) {
return Status(StatusCode::INTERNAL, "No payload");
}
Status result = g_core_codegen_interface->ok();
if (result.ok()) {
paddle::operators::distributed::GrpcByteSource source(buffer);
int ret = msg->Parse(&source);
if (ret != 0) {
result = Status(StatusCode::INTERNAL, "VariableResponse parse error");
}
}
g_core_codegen_interface->grpc_byte_buffer_destroy(buffer);
return result;
}
};
} // namespace grpc
namespace paddle {
namespace operators {
namespace distributed {
enum class GrpcMethod {
kSendVariable,
kGetVariable,
kPrefetchVariable,
kCheckpointNotify,
kGetVariableNoBarrier,
kGetMonomerVariable,
kGetMonomerBarrier,
kRequestNotify,
kRequestSendAndRecv,
// when you add new handler, change kGrpcNumMethods at the same time!
};
static const int kGrpcNumMethods =
static_cast<int>(GrpcMethod::kRequestSendAndRecv) + 1;
inline const char* GrpcMethodName(GrpcMethod id) {
switch (id) {
case GrpcMethod::kSendVariable:
return "/sendrecv.SendRecvService/SendVariable";
case GrpcMethod::kGetVariable:
return "/sendrecv.SendRecvService/GetVariable";
case GrpcMethod::kGetVariableNoBarrier:
return "/sendrecv.SendRecvService/GetVariableNoBarrier";
case GrpcMethod::kGetMonomerVariable:
return "/sendrecv.SendRecvService/GetMonomerVariable";
case GrpcMethod::kGetMonomerBarrier:
return "/sendrecv.SendRecvService/GetMonomerBarrier";
case GrpcMethod::kPrefetchVariable:
return "/sendrecv.SendRecvService/PrefetchVariable";
case GrpcMethod::kCheckpointNotify:
return "/sendrecv.SendRecvService/CheckpointNotify";
case GrpcMethod::kRequestNotify:
return "/sendrecv.SendRecvService/DistributeNotify";
case GrpcMethod::kRequestSendAndRecv:
return "/sendrecv.SendRecvService/SendAndRecvVariable";
}
// Shouldn't be reached.
PADDLE_THROW(platform::errors::InvalidArgument(
"Invalid id: not found valid method name"));
return nullptr;
}
class GrpcService final {
public:
class AsyncService : public ::grpc::Service {
public:
AsyncService() {
for (int i = 0; i < kGrpcNumMethods; ++i) {
AddMethod(new ::grpc::internal::RpcServiceMethod(
GrpcMethodName(static_cast<GrpcMethod>(i)),
::grpc::internal::RpcMethod::NORMAL_RPC, nullptr));
::grpc::Service::MarkMethodAsync(i);
}
}
virtual ~AsyncService() {}
// Make RequestAsyncUnary public for grpc_call.h
using ::grpc::Service::RequestAsyncUnary;
};
};
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdint.h>
#include <string>
#include <utility>
#include <vector>
#include "google/protobuf/io/coded_stream.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_bytebuffer_stream.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_variable_response.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/profiler.h"
namespace google {
namespace protobuf {
namespace io {
class ZeroCopyInputStream;
} // namespace io
} // namespace protobuf
} // namespace google
namespace grpc {
class ByteBuffer;
} // namespace grpc
namespace paddle {
namespace operators {
namespace distributed {
enum WireType {
WIRETYPE_VARINT = 0,
WIRETYPE_LENGTH_DELIMITED = 2,
};
inline int GetTagFieldNumber(uint32_t tag) { return tag >> 3; }
inline WireType GetTagWireType(uint32_t tag) {
return static_cast<WireType>(tag & 0x7);
}
bool ReadVarintSizeAsInt(::google::protobuf::io::CodedInputStream* input,
int* result) {
uint64_t v;
if (input->ReadVarint64(&v) && v <= static_cast<uint64_t>(INT_MAX)) {
*result = static_cast<int>(v);
return true;
} else {
return false;
}
}
int GRPCVariableResponse::Parse(const ::grpc::ByteBuffer& byte_buffer) {
GrpcByteBufferSource source;
source.Init(byte_buffer);
GrpcByteBufferSourceWrapper r(&source);
return Parse(&r);
}
bool ParseLodData(::google::protobuf::io::CodedInputStream* input,
std::vector<int64_t>* lod) {
while (true) {
auto p = input->ReadTagWithCutoff(127);
int tag = GetTagFieldNumber(p.first);
WireType wt = GetTagWireType(p.first);
if (!p.second) {
return (tag == 0);
}
switch (tag) {
case sendrecv::VariableMessage_LodData::kLodDataFieldNumber: {
uint64_t v;
if (wt == WIRETYPE_VARINT) {
if (!input->ReadVarint64(&v)) {
return false;
}
lod->push_back(v);
break;
}
if (wt == WIRETYPE_LENGTH_DELIMITED) {
int num_bytes = 0;
if (!input->ReadVarintSizeAsInt(&num_bytes)) {
return tag;
}
int start_pos = input->CurrentPosition();
while (input->CurrentPosition() - start_pos < num_bytes) {
uint64_t v;
if (!input->ReadVarint64(&v)) {
return tag;
}
lod->push_back(v);
}
break;
}
return false;
}
default: { return false; }
}
}
return true;
}
int GRPCVariableResponse::Parse(Source* source) {
::google::protobuf::io::ZeroCopyInputStream* input_stream =
source->contents();
::google::protobuf::io::CodedInputStream input(input_stream);
input.SetTotalBytesLimit(INT_MAX, INT_MAX);
while (true) {
auto p = input.ReadTagWithCutoff(127);
int tag = GetTagFieldNumber(p.first);
WireType wt = GetTagWireType(p.first);
if (!p.second) {
if (tag != 0) {
return -1;
}
return 0;
}
switch (tag) {
case sendrecv::VariableMessage::kVarnameFieldNumber: {
uint32_t length;
if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) {
return tag;
}
std::string temp;
if (!input.ReadString(&temp, length)) {
return tag;
}
meta_.set_varname(temp);
break;
}
case sendrecv::VariableMessage::kTypeFieldNumber: {
uint32_t v;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) {
return tag;
}
meta_.set_type(static_cast<::sendrecv::VarType>(v));
break;
}
case sendrecv::VariableMessage::kDataTypeFieldNumber: {
uint32_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) {
return tag;
}
meta_.set_data_type(static_cast<::sendrecv::VariableMessage_Type>(v));
break;
}
case sendrecv::VariableMessage::kDimsFieldNumber: {
// not packed
if (wt == WIRETYPE_VARINT) {
uint64_t v;
if (!input.ReadVarint64(&v)) {
return tag;
}
meta_.add_dims(v);
break;
}
// packed
if (wt == WIRETYPE_LENGTH_DELIMITED) {
int num_bytes = 0;
if (!input.ReadVarintSizeAsInt(&num_bytes)) {
return tag;
}
int start_pos = input.CurrentPosition();
while (input.CurrentPosition() - start_pos < num_bytes) {
uint64_t v;
if (!input.ReadVarint64(&v)) {
return tag;
}
meta_.add_dims(v);
}
break;
}
return tag;
}
case sendrecv::VariableMessage::kLodLevelFieldNumber: {
uint64_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) {
return tag;
}
meta_.set_lod_level(static_cast<int64_t>(v));
break;
}
case sendrecv::VariableMessage::kLodFieldNumber: {
int length = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &length)) {
return tag;
}
std::pair<::google::protobuf::io::CodedInputStream::Limit, int> p =
input.IncrementRecursionDepthAndPushLimit(length);
std::vector<int64_t> lod_data;
if (p.second < 0 || !ParseLodData(&input, &lod_data)) {
return tag;
}
if (!input.DecrementRecursionDepthAndPopLimit(p.first)) {
return tag;
}
if (lod_data.size() == 0) {
break;
}
auto lod = meta_.add_lod();
for (uint32_t i = 0; i < lod_data.size(); i++) {
lod->add_lod_data(lod_data[i]);
}
break;
}
case sendrecv::VariableMessage::kSlrHeightFieldNumber: {
uint64_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) {
return tag;
}
meta_.set_slr_height(static_cast<int64_t>(v));
break;
}
case sendrecv::VariableMessage::kSerializedFieldNumber: {
int num_bytes = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &num_bytes)) {
return tag;
}
if (!ProcSerializedField(tag, &input, num_bytes)) {
return tag;
}
break;
}
case sendrecv::VariableMessage::kRowsFieldNumber: {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
meta_.type() == sendrecv::LOD_TENSOR) &&
meta_.varname() != "",
platform::errors::PreconditionNotMet(
"meta info should be got first!"));
int num_bytes = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &num_bytes)) {
return tag;
}
if (!CopySelectRowsData(&input, *dev_ctx_, num_bytes)) {
return tag;
}
break;
}
case sendrecv::VariableMessage::kOutVarnameFieldNumber: {
uint32_t length;
if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) {
return tag;
}
std::string temp;
if (!input.ReadString(&temp, length)) {
return tag;
}
meta_.set_out_varname(temp);
break;
}
case sendrecv::VariableMessage::kProfileFieldNumber: {
uint64_t profiling = 0;
if (!input.ReadVarint64(&profiling)) {
return tag;
}
meta_.set_profile(profiling);
int64_t listener_id = platform::ListenerId();
if (listener_id <= 0) {
break;
}
if (profiling == platform::kEnableProfiler &&
!platform::IsProfileEnabled()) {
platform::EnableProfiler(platform::ProfilerState::kCPU);
} else if (profiling == platform::kDisableProfiler &&
platform::IsProfileEnabled()) {
platform::DisableProfiler(
platform::EventSortingKey::kDefault,
string::Sprintf("%s_%lld", FLAGS_rpc_server_profile_path,
listener_id));
}
break;
}
case sendrecv::VariableMessage::kTrainerIdFieldNumber: {
uint64_t trainer_id = 0;
if (!input.ReadVarint64(&trainer_id)) {
return tag;
}
meta_.set_trainer_id(trainer_id);
break;
}
case sendrecv::VariableMessage::kTableNameFieldNumber: {
uint32_t length;
if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) {
return tag;
}
std::string temp;
if (!input.ReadString(&temp, length)) {
return tag;
}
meta_.set_table_name(temp);
break;
}
default: {
// Unknown tag, return unknown error.
return -1;
}
}
}
return 0;
}
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/distributed_pb.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_bytebuffer_stream.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
namespace grpc {
class ByteBuffer;
} // namespace grpc
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
namespace distributed {
class GRPCVariableResponse : public VariableResponse {
public:
GRPCVariableResponse(const framework::Scope* scope,
const platform::DeviceContext* dev_ctx,
bool create_scope = false)
: VariableResponse(scope, dev_ctx, create_scope) {}
virtual ~GRPCVariableResponse() {}
int Parse(Source* source) override;
// return:
// 0:ok.
// -1: unkown error.
// other: number of error field.
int Parse(const ::grpc::ByteBuffer& byte_buffer);
};
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include <ctime>
namespace paddle {
namespace operators {
namespace distributed {
DEFINE_int32(worker_update_interval_secs, 900,
" the longest time interval between the worker update variables");
inline int GetCurrentUS() {
// current date/time based on current system
time_t t = std::time(0);
int now = static_cast<int>(t);
return now;
}
void HeartBeatMonitor::Update(const int worker_id, std::string be_monitored_var,
WorkerStatus status) {
if (status == UNINITED) {
LOG(WARNING) << "HeartBeatMonitor receive UNINITED status can not be used "
"in Update, something error";
}
if (!is_chief_) {
return;
}
if ((be_monitored_var == be_monitored_var_ && status == RUNNING) ||
status == COMPLETED) {
auto timestamp = GetCurrentUS();
UnderMonitoredWorker& worker = worker_status_map_.at(worker_id);
if (worker.status != COMPLETED) {
worker.status = status;
}
worker.timestamp = timestamp;
return;
}
}
void HeartBeatMonitor::LostWorkerMonitor() {
VLOG(1) << "worker heartbeat monitor start at No.0 parameter server";
while (running_) {
for (int id = 0; id < workers_; ++id) {
auto& worker = worker_status_map_.at(id);
if (worker.status == UNINITED) {
VLOG(4) << "worker " << worker.id << " is under UNINITED";
continue;
}
if (worker.status == COMPLETED) {
VLOG(4) << "worker " << worker.id << " is under COMPLETED";
continue;
}
auto timestamp = GetCurrentUS();
VLOG(4) << "worker " << worker.id << " status is " << worker.status
<< " timestamp is " << worker.timestamp << " the interval is "
<< timestamp - worker.timestamp;
if (timestamp - worker.timestamp >= FLAGS_worker_update_interval_secs) {
PADDLE_THROW(platform::errors::ExecutionTimeout(
"the latest update of worker %d is %d secs ago, we doubt the "
"the worker is not alive and this may have a bad effect on the "
"fitting result, please check",
worker.id, FLAGS_worker_update_interval_secs));
}
}
std::this_thread::sleep_for(std::chrono::milliseconds(10 * 1000));
}
VLOG(1) << "worker heartbeat monitor stopped, thread exit";
}
std::once_flag HeartBeatMonitor::init_flag_;
std::unique_ptr<HeartBeatMonitor> HeartBeatMonitor::monitor_(nullptr);
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <string>
#include <thread> // NOLINT
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace distributed {
enum WorkerStatus { UNINITED = 0, RUNNING, COMPLETED };
struct UnderMonitoredWorker {
int id;
WorkerStatus status;
int timestamp;
UnderMonitoredWorker() {}
explicit UnderMonitoredWorker(int worker_id) {
this->id = worker_id;
this->status = UNINITED;
this->timestamp = 0;
}
};
class HeartBeatMonitor {
public:
explicit HeartBeatMonitor(int workers, bool is_chief,
std::string be_monitored_var)
: workers_(workers),
is_chief_(is_chief),
be_monitored_var_(be_monitored_var),
running_(true) {
PADDLE_ENFORCE_GT(workers, 0, platform::errors::InvalidArgument(
"workers must greater than 0."));
for (auto worker_id = 0; worker_id < workers; worker_id++) {
UnderMonitoredWorker worker(worker_id);
worker_status_map_[worker_id] = std::move(worker);
}
// we define the No.0 pserver is the first parameter server
// only No.0 will check the heartbeat of all trainers
if (is_chief) {
monitor_thread_.reset(new std::thread(
std::bind(&HeartBeatMonitor::LostWorkerMonitor, this)));
}
}
~HeartBeatMonitor() {
running_ = false;
if (monitor_thread_) monitor_thread_->join();
}
static void Init(int workers, bool is_chief, std::string be_monitored_var) {
std::call_once(init_flag_, &HeartBeatMonitor::InitImpl, workers, is_chief,
be_monitored_var);
}
static HeartBeatMonitor* GetInstance() { return monitor_.get(); }
void Stop() {
running_ = false;
if (!monitor_) {
VLOG(0) << "HeartBeatMonitor is not inited, do nothing";
} else {
if (monitor_thread_) {
monitor_thread_->join();
monitor_thread_.reset(nullptr);
}
}
}
void Update(const int worker_id, std::string be_monitored_var,
WorkerStatus status);
void LostWorkerMonitor();
private:
// Init is called by GetInstance.
static void InitImpl(int workers, bool is_chief,
std::string be_monitored_var) {
if (monitor_ == nullptr) {
monitor_.reset(new HeartBeatMonitor(workers, is_chief, be_monitored_var));
}
}
static std::once_flag init_flag_;
static std::unique_ptr<HeartBeatMonitor> monitor_;
int workers_;
bool is_chief_;
std::string be_monitored_var_;
std::unordered_map<int, UnderMonitoredWorker> worker_status_map_;
std::unique_ptr<std::thread> monitor_thread_{nullptr};
std::mutex mutex_;
bool running_ = false;
};
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include "gtest/gtest.h"
namespace paddle {
namespace operators {
namespace distributed {
void run(HeartBeatMonitor* monitor) { monitor->LostWorkerMonitor(); }
TEST(HeartBeatMonitor, All) {
int trainers = 10;
int pserver_id = 0;
std::string var = "nce_w@GRAD.block0";
std::string var2 = "nce_w@GRAD.block2";
HeartBeatMonitor::Init(trainers, pserver_id == 0, var);
auto* monitor = HeartBeatMonitor::GetInstance();
std::vector<int> ids{1, 3, 5, 7};
for (auto& id : ids) {
monitor->Update(id, var, RUNNING);
}
monitor->Update(9, var2, RUNNING);
monitor->Update(2, var, COMPLETED);
std::thread t(run, monitor);
t.detach();
std::this_thread::sleep_for(std::chrono::milliseconds(15 * 1000));
monitor->Stop();
}
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
namespace paddle {
namespace operators {
namespace distributed {
std::once_flag LargeScaleKV::init_flag_;
std::shared_ptr<LargeScaleKV> LargeScaleKV::scale_kv_(nullptr);
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <string>
#include <thread> // NOLINT
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace operators {
namespace distributed {
enum Mode { training, infer };
enum InitType { uniform_random, fill_constant, gaussian_random };
inline std::vector<int> bucket(const int v_size, const int b_size) {
int remainder = v_size % b_size;
int bucket = v_size / b_size;
std::vector<int> ret_vec(b_size, bucket);
for (int i = 0; i < remainder; ++i) {
ret_vec[i] = ret_vec[i] + 1;
}
int cur_bucket = 0;
for (int &j : ret_vec) {
int tmp = j;
j = cur_bucket;
cur_bucket += tmp;
}
ret_vec.push_back(cur_bucket);
return ret_vec;
}
class Initializer {
public:
Initializer() {}
explicit Initializer(const std::vector<std::string> &attrs) {}
virtual float GetValue() = 0;
virtual ~Initializer() {}
protected:
std::string name_;
unsigned int seed_;
};
class UniformInitializer : public Initializer {
public:
explicit UniformInitializer(const std::vector<std::string> &attrs) {
name_ = attrs[0];
seed_ = static_cast<unsigned int>(std::stoi(attrs[1]));
min_ = std::stof(attrs[2]);
max_ = std::stof(attrs[3]);
dist_ = std::uniform_real_distribution<float>(min_, max_);
random_engine_ = framework::GetCPURandomEngine(seed_);
}
float GetValue() override { return dist_(*random_engine_); }
private:
float min_;
float max_;
std::shared_ptr<std::mt19937_64> random_engine_;
std::uniform_real_distribution<float> dist_;
};
template <typename T>
inline bool entry(const int count, const T threshold);
template <>
inline bool entry<std::string>(const int count, const std::string threshold) {
return true;
}
template <>
inline bool entry<int>(const int count, const int threshold) {
return count >= threshold;
}
template <>
inline bool entry<float>(const int count, const float threshold) {
UniformInitializer uniform = UniformInitializer({"0", "0", "1"});
return uniform.GetValue() >= threshold;
}
class GaussianInitializer : public Initializer {
public:
explicit GaussianInitializer(const std::vector<std::string> &attrs) {
name_ = attrs[0];
seed_ = static_cast<unsigned int>(std::stoi(attrs[1]));
mean_ = std::stof(attrs[2]);
std_ = std::stof(attrs[3]);
random_engine_ = framework::GetCPURandomEngine(seed_);
dist_ = std::normal_distribution<float>(mean_, std_);
}
float GetValue() override { return dist_(*random_engine_); }
private:
float std_;
float mean_;
std::shared_ptr<std::mt19937_64> random_engine_;
std::normal_distribution<float> dist_;
};
class FillConstantInitializer : public Initializer {
public:
explicit FillConstantInitializer(const std::vector<std::string> &attrs) {
name_ = attrs[0];
value_ = std::stof(attrs[1]);
}
float GetValue() override { return value_; }
private:
float value_;
};
struct SparseMeta {
std::string name;
std::string grad_name;
std::vector<std::string> value_names;
std::vector<int> value_dims;
std::vector<std::string> cached_varnames;
std::vector<std::string> initializer_attrs;
std::string entry;
Mode mode;
std::string ToString() {
std::stringstream ss;
ss << "name: " << name << " ";
ss << "mode: " << mode << " ";
for (int i = 0; i < static_cast<int>(value_names.size()); i++) {
ss << "value_name: " << value_names[i] << " dim: " << value_dims[i]
<< " ";
}
ss << " grad var: " << grad_name;
ss << " cached varnames: ";
for (int i = 0; i < static_cast<int>(cached_varnames.size()); i++) {
ss << cached_varnames[i] << " ";
}
ss << " initializer attrs: ";
for (int i = 0; i < static_cast<int>(initializer_attrs.size()); i++) {
ss << initializer_attrs[i] << " ";
}
ss << " entry attrs: " << entry;
return ss.str();
}
};
struct VALUE {
explicit VALUE(const std::vector<std::string> &names)
: names_(names), count_(0), unseen_days_(0) {
values_.resize(names.size());
for (int i = 0; i < static_cast<int>(names.size()); i++) {
places[names[i]] = i;
}
}
void set(std::vector<std::vector<float>> *values) {
values_ = std::move(*values);
}
void set(const std::vector<std::string> &names,
const std::vector<std::vector<float>> &values) {
for (int i = 0; i < static_cast<int>(names.size()); i++) {
auto idx = places[names[i]];
auto value = values[i];
values_[idx].assign(value.begin(), value.end());
}
}
std::vector<std::vector<float> *> get() {
auto pts = std::vector<std::vector<float> *>();
pts.reserve(values_.size());
for (auto &value : values_) {
pts.push_back(&value);
}
return pts;
}
int fetch_count() { return ++count_; }
void reset_unseen_days() { unseen_days_ = 0; }
void set_entry(bool is_entry) { is_entry_ = is_entry; }
bool get_entry() { return is_entry_; }
std::vector<std::vector<float> *> get(const std::vector<std::string> names) {
auto pts = std::vector<std::vector<float> *>();
pts.reserve(values_.size());
for (int i = 0; i < static_cast<int>(names.size()); i++) {
pts.push_back(&(values_[places[names[i]]]));
}
return pts;
}
std::vector<std::string> names_;
int count_;
bool seen_after_last_save_;
int unseen_days_;
bool is_entry_;
std::vector<std::vector<float>> values_;
std::unordered_map<std::string, int> places;
};
class ValueBlock {
public:
explicit ValueBlock(const std::vector<std::string> value_names,
const std::vector<int> value_dims, const Mode &mode,
const std::vector<std::string> &init_attrs,
const std::string &entry_attr)
: value_names_(value_names), value_dims_(value_dims), mode_(mode) {
// for Initializer
for (size_t i = 0; i < value_names.size(); i++) {
auto name = value_names[i];
auto slices = string::split_string<std::string>(init_attrs[i], "&");
if (slices[0] == "gaussian_random") {
initializers_[name] = new GaussianInitializer(slices);
} else if (slices[0] == "fill_constant") {
initializers_[name] = new FillConstantInitializer(slices);
} else if (slices[0] == "uniform_random") {
initializers_[name] = new UniformInitializer(slices);
} else {
PADDLE_THROW(
platform::errors::InvalidArgument("%s can not be supported", name));
}
}
// for Entry
{
if (entry_attr == "none") {
entry_func_ =
std::bind(entry<std::string>, std::placeholders::_1, "none");
} else {
auto slices = string::split_string<std::string>(entry_attr, "&");
if (slices[0] == "count_filter") {
int threshold = std::stoi(slices[1]);
entry_func_ = std::bind(entry<int>, std::placeholders::_1, threshold);
} else if (slices[0] == "probability") {
float threshold = std::stof(slices[1]);
entry_func_ =
std::bind(entry<float>, std::placeholders::_1, threshold);
}
}
}
rwlock_.reset(new framework::RWLock);
}
~ValueBlock() {
// for (auto init : initializers_) {
// delete init.second;
// initializers_.erase(init.first);
// }
//
// for (auto value : values_) {
// delete value.second;
// values_.erase(value.first);
// }
}
void Init(const int64_t &id, std::vector<std::vector<float>> *values,
int count) {
if (Has(id)) {
PADDLE_THROW(platform::errors::AlreadyExists("id already exist, error"));
}
if (values->size() != value_names_.size()) {
PADDLE_THROW(
platform::errors::AlreadyExists("values can not match, error"));
}
auto value = new VALUE(value_names_);
value->set(values);
value->seen_after_last_save_ = true;
value->count_ = count;
values_[id] = value;
}
std::vector<std::vector<float> *> Get(
const int64_t &id, const std::vector<std::string> &value_names) {
rwlock_->RDLock();
auto ret_values = values_.at(id)->get(value_names);
rwlock_->UNLock();
return ret_values;
}
void InitFromInitializer(const int64_t &id,
const std::vector<std::string> &value_names) {
rwlock_->WRLock();
if (Has(id)) {
Update(id);
rwlock_->UNLock();
return;
}
auto rets = std::vector<std::vector<float>>();
rets.resize(value_names_.size());
for (int i = 0; i < static_cast<int>(value_names_.size()); i++) {
auto name = value_names_[i];
auto *init = initializers_.at(name);
auto dim = value_dims_[i];
rets[i].resize(dim);
for (int j = 0; j < static_cast<int>(dim); j++) {
rets[i][j] = init->GetValue();
}
}
Init(id, &rets, 0);
Update(id);
rwlock_->UNLock();
}
bool GetEntry(const int64_t &id) {
rwlock_->RDLock();
auto value = values_.at(id);
auto entry = value->get_entry();
rwlock_->UNLock();
return entry;
}
void Set(const int64_t &id, const std::vector<std::string> &value_names,
const std::vector<std::vector<float>> &values) {
rwlock_->WRLock();
auto value = values_.at(id);
value->set(value_names, values);
rwlock_->UNLock();
}
void Update(const int64_t id) {
auto *value = values_.at(id);
value->reset_unseen_days();
auto count = value->fetch_count();
if (!value->get_entry()) {
value->set_entry(entry_func_(count));
}
}
private:
bool Has(const int64_t id) {
auto got = values_.find(id);
if (got == values_.end()) {
return false;
} else {
return true;
}
}
public:
std::unordered_map<int64_t, VALUE *> values_;
private:
std::vector<std::string> value_names_;
std::vector<int> value_dims_;
Mode mode_;
std::function<bool(int64_t)> entry_func_;
std::unordered_map<std::string, Initializer *> initializers_;
std::unique_ptr<framework::RWLock> rwlock_{nullptr};
};
class SparseVariable {
public:
explicit SparseVariable(const SparseMeta &meta) {
meta_.name = meta.name;
meta_.mode = meta.mode;
meta_.value_names = meta.value_names;
meta_.value_dims = meta.value_dims;
meta_.grad_name = meta.grad_name;
meta_.cached_varnames = meta.cached_varnames;
meta_.initializer_attrs = meta.initializer_attrs;
meta_.entry = meta.entry;
for (int i = 0; i < static_cast<int>(meta_.value_names.size()); i++) {
values_dims_[meta_.value_names[i]] = meta_.value_dims[i];
}
for (size_t i = 0; i < shard_num_; i++) {
auto block = std::make_shared<ValueBlock>(
meta.value_names, meta.value_dims, meta.mode, meta.initializer_attrs,
meta.entry);
shard_blocks_.emplace_back(block);
}
rwlock_.reset(new framework::RWLock);
}
void Init(const std::vector<int64_t> &ids) {
rwlock_->RDLock();
for (auto &id : ids) {
auto *block = GetShard(id);
block->InitFromInitializer(id, meta_.value_names);
}
rwlock_->UNLock();
}
void Get(const std::vector<int64_t> &ids,
const std::vector<std::string> &value_names,
std::vector<std::vector<std::vector<float> *>> *values) {
values->resize(ids.size());
auto buckets = bucket(ids.size(), 8);
std::vector<std::future<void>> fs;
for (int j = 0; j < 8; ++j) {
auto begin = buckets[j];
auto end = buckets[j + 1];
fs.push_back(
framework::Async([begin, end, &values, &ids, &value_names, this]() {
for (int x = begin; x < end; x++) {
auto id = ids[x];
auto *block = GetShard(id);
auto id_values = block->Get(id, value_names);
(*values)[x] = id_values;
}
}));
}
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
}
void GetEntry(const std::vector<int64_t> &ids, std::vector<int64_t> *values) {
auto buckets = bucket(ids.size(), 8);
std::vector<std::future<void>> fs;
for (int j = 0; j < 8; ++j) {
auto begin = buckets[j];
auto end = buckets[j + 1];
fs.push_back(framework::Async([begin, end, &values, &ids, this]() {
for (int x = begin; x < end; x++) {
auto id = ids[x];
auto *block = GetShard(id);
auto is_entry = block->GetEntry(id);
if (!is_entry) {
values->push_back(id);
}
}
}));
}
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
}
void Set(const std::vector<int64_t> &ids,
const std::vector<std::string> &value_names,
const std::vector<std::vector<std::vector<float>>> &values) {
for (int i = 0; i < static_cast<int>(ids.size()); i++) {
GetShard(ids[i])->Set(ids[i], value_names, values[i]);
}
}
void Dims(std::vector<std::string> value_names, std::vector<int64_t> *dims) {
for (auto &name : value_names) {
dims->push_back(values_dims_.at(name));
}
}
std::vector<std::string> CachedVarnames() const {
return meta_.cached_varnames;
}
void Load(const std::string &dirname) {
rwlock_->WRLock();
VLOG(1) << "load " << meta_.name << " from dir: " << dirname << " begin";
std::vector<std::string> filenames;
for (auto &value_name : meta_.value_names) {
auto filename = string::Sprintf("%s/%s", dirname, value_name);
filenames.push_back(filename);
}
LoadFromSelectedRows(filenames, meta_.value_names);
VLOG(1) << "load " << meta_.name << " in dir: " << dirname << " done";
rwlock_->UNLock();
}
void LoadFromSelectedRows(const std::vector<std::string> &filenames,
const std::vector<std::string> &valuenames) {
std::vector<std::shared_ptr<framework::Variable>> variables;
auto place = platform::CPUPlace();
for (int i = 0; i < static_cast<int>(filenames.size()); i++) {
auto var = std::make_shared<framework::Variable>();
variables.push_back(var);
auto &filename = filenames[i];
std::ifstream fin(filename, std::ios::binary);
auto *selectedRows = var->GetMutable<framework::SelectedRows>();
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::DeserializeFromStream(fin, selectedRows, dev_ctx);
selectedRows->SyncIndex();
}
std::vector<const float *> tensors;
for (int i = 0; i < static_cast<int>(filenames.size()); i++) {
auto &slr = variables[i]->Get<framework::SelectedRows>();
auto src_t = slr.value();
const auto *value = src_t.data<float>();
tensors.push_back(value);
}
for (int i = 1; i < static_cast<int>(filenames.size()); i++) {
auto rows_0 = variables[0]->Get<framework::SelectedRows>().rows();
auto rows_i = variables[i]->Get<framework::SelectedRows>().rows();
bool is_equal = std::equal(rows_0.begin(), rows_0.end(), rows_i.begin());
if (!is_equal) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s and %s are not equal, can not be load rightly", filenames[0],
filenames[i]));
}
}
auto rows = variables[0]->Get<framework::SelectedRows>().rows();
for (auto i = 0; i < static_cast<int64_t>(rows.size()); i++) {
auto id = rows[i];
std::vector<std::vector<float>> values;
values.resize(filenames.size());
for (int j = 0; j < static_cast<int>(filenames.size()); ++j) {
values[j].resize(meta_.value_dims[j]);
std::memcpy(values[j].data(), tensors[j] + i * meta_.value_dims[j],
sizeof(float) * meta_.value_dims[j]);
}
auto *block = GetShard(id);
block->Init(id, &values, 0);
block->Update(id);
}
}
void Save(const std::string &dirname, const int mode = 0) {
rwlock_->WRLock();
VLOG(3) << "save " << meta_.name << " in dir: " << dirname << " begin";
MkDirRecursively(dirname.c_str());
std::vector<std::string> filenames;
for (auto &value_name : meta_.value_names) {
auto filename = string::Sprintf("%s/%s", dirname, value_name);
filenames.push_back(filename);
}
SaveToSelectedRows(filenames, meta_.value_names, mode);
VLOG(3) << "save " << meta_.name << " in dir: " << dirname << " done";
rwlock_->UNLock();
}
void SaveToSelectedRows(const std::vector<std::string> &filenames,
const std::vector<std::string> &valuenames,
const int mode) {
for (auto &value_name : valuenames) {
auto it = std::find(meta_.value_names.begin(), meta_.value_names.end(),
value_name);
if (it == meta_.value_names.end()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"[%s] is invalid param for [%s]", value_name, meta_.name));
}
}
auto place = platform::CPUPlace();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
std::vector<int64_t> ids;
for (auto &block : shard_blocks_) {
for (auto value : block->values_) {
if (mode == 0) {
ids.push_back(value.first);
} else {
bool id_need_save = false;
// save all params
if (mode == 1) {
id_need_save = true;
} else {
id_need_save = value.second->seen_after_last_save_;
}
if (id_need_save) {
ids.push_back(value.first);
}
value.second->seen_after_last_save_ = false;
}
}
}
VLOG(3) << "save " << ids.size() << " feasigns for " << meta_.name
<< " with mode: " << mode;
std::vector<std::shared_ptr<framework::Variable>> variables;
std::vector<float *> tensors;
std::vector<int64_t> dims;
for (int i = 0; i < static_cast<int>(filenames.size()); i++) {
auto dim = values_dims_.at(valuenames[i]);
auto var = std::make_shared<framework::Variable>();
auto *slr = var->GetMutable<framework::SelectedRows>();
auto *src_t = slr->mutable_value();
src_t->Resize({static_cast<int64_t>(ids.size()), dim});
auto *value = src_t->mutable_data<float>(place);
dims.push_back(dim);
variables.push_back(var);
tensors.push_back(value);
}
std::vector<std::vector<std::vector<float> *>> values;
Get(ids, valuenames, &values);
int64_t offset = 0;
for (auto &vss : values) {
for (int i = 0; i < static_cast<int>(vss.size()); i++) {
auto &vs = vss[i];
std::memcpy(tensors[i] + offset * dims[i], vs->data(),
sizeof(float) * dims[i]);
}
offset += 1;
}
for (auto &var : variables) {
auto *slr = var->GetMutable<framework::SelectedRows>();
slr->set_rows(ids);
slr->set_height(ids.size());
}
for (int i = 0; i < static_cast<int>(filenames.size()); i++) {
auto &filename = filenames[i];
auto &selectedRows = variables[i]->Get<framework::SelectedRows>();
std::ofstream fout(filename, std::ios::binary);
PADDLE_ENFORCE_EQ(static_cast<bool>(fout), true,
platform::errors::Unavailable(
"Cannot open %s to save variables.", filename));
framework::SerializeToStream(fout, selectedRows, dev_ctx);
fout.close();
}
}
void SaveToText(const std::vector<std::string> &filenames,
const std::vector<std::string> &valuenames) {
for (auto &value_name : valuenames) {
auto it = std::find(meta_.value_names.begin(), meta_.value_names.end(),
value_name);
if (it == meta_.value_names.end()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"[%s] is invalid param for [%s]", value_name, meta_.name));
}
}
std::vector<std::unique_ptr<std::ofstream>> fouts;
for (auto filename : filenames) {
std::unique_ptr<std::ofstream> fout(new std::ofstream(filename));
fouts.push_back(std::move(fout));
}
for (auto &block : shard_blocks_) {
for (auto value : block->values_) {
std::vector<std::vector<float> *> vss = value.second->get(valuenames);
auto id = value.first;
for (int i = 0; i < static_cast<int>(vss.size()); i++) {
auto &vs = vss[i];
std::stringstream ss;
ss << id << "\t";
ss << vs->size() << "\t";
for (auto v : (*vs)) {
ss << v << " ";
}
ss << "\n";
fouts[i]->write(ss.str().c_str(), sizeof(char) * ss.str().size());
}
}
}
for (int i = 0; i < static_cast<int>(fouts.size()); i++) {
fouts[i]->close();
}
}
int64_t Size() {
int64_t cnt = 0;
for (auto &block : shard_blocks_) {
cnt += block->values_.size();
}
return cnt;
}
ValueBlock *GetShard(const int64_t id) {
return shard_blocks_[id & shard_mask_].get();
}
SparseMeta *GetMeta() { return &meta_; }
private:
std::unique_ptr<framework::RWLock> rwlock_{nullptr};
SparseMeta meta_;
std::unordered_map<std::string, int64_t> values_dims_;
const size_t shard_mask_ = 127;
const size_t shard_num_ = 128;
std::vector<std::shared_ptr<ValueBlock>> shard_blocks_;
};
class LargeScaleKV {
public:
LargeScaleKV() {}
explicit LargeScaleKV(const std::vector<SparseMeta> &table_metas) {
for (auto &sparse_meta : table_metas) {
auto table_name = sparse_meta.name;
auto meta = std::shared_ptr<SparseVariable>(
new SparseVariable(std::move(sparse_meta)));
sparse_variables[table_name] = meta;
grad_to_variables[sparse_meta.grad_name] = table_name;
grad_names_.push_back(sparse_meta.grad_name);
}
}
~LargeScaleKV() {}
static std::shared_ptr<LargeScaleKV> GetInstantcePtr() { return scale_kv_; }
static LargeScaleKV *GetInstance() { return scale_kv_.get(); }
static LargeScaleKV *InitInstance(
const std::vector<SparseMeta> &table_metas) {
std::call_once(init_flag_, &LargeScaleKV::Init, table_metas);
return scale_kv_.get();
}
static void Init(const std::vector<SparseMeta> &table_metas) {
if (scale_kv_.get() == nullptr) {
scale_kv_.reset(new LargeScaleKV(table_metas));
}
}
SparseVariable *Get(const std::string &name) {
auto variable = sparse_variables.at(name);
return variable.get();
}
bool ParamInLargeScale(const std::string &name) {
auto got = sparse_variables.find(name);
if (got == sparse_variables.end()) {
return false;
}
return true;
}
bool GradInLargeScale(const std::string &name) {
auto got = grad_to_variables.find(name);
if (got == grad_to_variables.end()) {
return false;
}
return true;
}
SparseVariable *GetByGrad(const std::string &name) {
return Get(grad_to_variables[name]);
}
const std::vector<std::string> &GetAllGrads() { return grad_names_; }
private:
std::unordered_map<std::string, std::shared_ptr<SparseVariable>>
sparse_variables;
std::unordered_map<std::string, std::string> grad_to_variables;
std::vector<std::string> grad_names_;
static std::shared_ptr<LargeScaleKV> scale_kv_;
static std::once_flag init_flag_;
};
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
#include <memory>
#include <set>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/distributed.h"
namespace paddle {
namespace framework {
class ExecutionContext;
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
namespace distributed {
class RPCClient;
using LoDTensor = framework::LoDTensor;
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;
static void SplitIdsIntoMultipleVarsBySection(
const std::vector<int64_t> &in_ids,
const std::vector<std::string> &in_varnames, const int tables,
const int pservers, const bool is_distibuted, framework::Scope *scope,
std::vector<std::vector<int64_t>> *splited_ids,
std::vector<std::vector<int64_t>> *origin_ids) {
PADDLE_ENFORCE_EQ(
in_varnames.size(), tables,
platform::errors::OutOfRange(
"send varnames size: %d not equal table number: %d, internal error",
in_varnames.size(), tables));
PADDLE_ENFORCE_LE(
tables, pservers,
platform::errors::OutOfRange("table number %d not equal or less than "
"pserver number: %d, internal error",
tables, pservers));
auto place = platform::CPUPlace();
std::set<int64_t> st(in_ids.begin(), in_ids.end());
std::vector<int64_t> all_ids;
all_ids.assign(st.begin(), st.end());
splited_ids->resize(tables);
origin_ids->resize(tables);
if (is_distibuted) {
for (auto &id : all_ids) {
auto pserver_id = id % pservers;
(*splited_ids)[pserver_id].push_back(id);
(*origin_ids)[pserver_id].push_back(id);
}
} else {
for (auto &id : all_ids) {
auto pserver_id = id % pservers;
(*origin_ids)[pserver_id].push_back(id);
id = id / pservers;
(*splited_ids)[pserver_id].push_back(id);
}
}
for (size_t i = 0; i < in_varnames.size(); ++i) {
auto *id_tensor =
scope->Var(in_varnames[i])->GetMutable<framework::LoDTensor>();
auto &ids = (*splited_ids)[i];
if (!ids.empty()) {
auto *id_tensor_data = id_tensor->mutable_data<int64_t>(
framework::make_ddim({static_cast<int64_t>(ids.size()), 1}), place);
memcpy(id_tensor_data, ids.data(), sizeof(int64_t) * ids.size());
}
}
}
typedef std::vector<std::pair<std::string, std::string>> TableAndEndpoints;
void prefetch_core(
const std::vector<int64_t> &ids, const TableAndEndpoints &tables,
const framework::ExecutionContext &context, const framework::Scope &scope,
const bool is_distributed,
std::unordered_map<int64_t, std::vector<float>> *recved_vec_map) {
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
context.Attr<int>("trainer_id"));
int pservers = context.Attr<int>("pserver_num");
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &actual_ctx = *pool.Get(platform::CPUPlace());
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
std::vector<std::string> in_var_names;
std::vector<std::string> out_var_names;
for (size_t i = 0; i < tables.size(); ++i) {
in_var_names.push_back("prefetch_send@" + tables[i].second);
out_var_names.push_back("prefetch_recv@" + tables[i].second);
}
std::vector<std::vector<int64_t>> split_ids;
std::vector<std::vector<int64_t>> origin_ids;
SplitIdsIntoMultipleVarsBySection(ids, in_var_names, tables.size(), pservers,
is_distributed, local_scope.get(),
&split_ids, &origin_ids);
// create output var in local scope
for (auto &name : out_var_names) {
local_scope->Var(name)->GetMutable<framework::LoDTensor>();
}
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < in_var_names.size(); i++) {
if (NeedSend(*local_scope.get(), in_var_names[i])) {
VLOG(3) << "sending " << in_var_names[i] << " to " << tables[i].second
<< " to get " << out_var_names[i] << " back";
rets.push_back(rpc_client->AsyncPrefetchVar(
tables[i].second, actual_ctx, *local_scope.get(), in_var_names[i],
out_var_names[i], tables[i].first));
} else {
VLOG(3) << "don't send no-initialied variable: " << out_var_names[i];
}
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::ExecutionTimeout(
"internal error in RPCClient"));
}
for (size_t o_idx = 0; o_idx < out_var_names.size(); ++o_idx) {
auto &ids_in_this_section = origin_ids[o_idx];
if (!ids_in_this_section.empty()) {
auto &prefetch_out_var =
local_scope->Var(out_var_names[o_idx])->Get<framework::LoDTensor>();
const auto *out_var_data = prefetch_out_var.data<float>();
auto &dims = prefetch_out_var.dims();
PADDLE_ENFORCE_EQ(dims.size(), 2,
platform::errors::InvalidArgument(
"The size of Tensor dims must be 2."));
PADDLE_ENFORCE_EQ(ids_in_this_section.size(), dims[0],
platform::errors::InvalidArgument(
"The size of ids in this section must equal to "
"dims[0]: %s, but got %s",
dims[0], ids_in_this_section.size()));
auto row_numel = dims[1];
for (int64_t i = 0; i < dims[0]; ++i) {
auto origin_id = ids_in_this_section[i];
std::vector<float> vecs(row_numel);
std::copy_n(out_var_data + i * row_numel, row_numel, vecs.begin());
(*recved_vec_map)[origin_id] = vecs;
}
} else {
VLOG(3) << "ids in this section is empty";
}
}
}
void prefetch(const std::string &id_name, const std::string &out_name,
const std::string &persistable_var_name,
const bool is_distributed,
const std::vector<std::string> &table_names,
const std::vector<std::string> &endpoints,
const framework::ExecutionContext &context,
const framework::Scope &scope) {
prefetchs({id_name}, {out_name}, persistable_var_name, is_distributed,
table_names, endpoints, context, scope);
}
void prefetchs(const std::vector<std::string> &id_var_names,
const std::vector<std::string> &out_var_names,
const std::string &persistable_var_name,
const bool is_distributed,
const std::vector<std::string> &table_names,
const std::vector<std::string> &endpoints,
const framework::ExecutionContext &context,
const framework::Scope &scope) {
auto vec_dim_1 = 0;
auto vec_dim_0 = 0;
framework::Variable *var = scope.FindVar(persistable_var_name);
if (var->IsType<SelectedRows>()) {
vec_dim_1 = var->Get<framework::SelectedRows>().value().dims()[1];
} else {
vec_dim_0 = var->Get<framework::LoDTensor>().dims()[0];
vec_dim_1 = var->Get<framework::LoDTensor>().dims()[1];
}
PADDLE_ENFORCE_GT(vec_dim_1, 0,
platform::errors::InvalidArgument(
"lookup table var's dim must gather than 0"));
const auto place =
scope.FindVar(id_var_names[0])->Get<framework::LoDTensor>().place();
std::vector<std::vector<int64_t>> ids_group;
std::vector<int64_t> ids_union;
std::vector<framework::LoD> ids_lods;
TableAndEndpoints tables;
for (auto &id_name : id_var_names) {
auto &id_tensor = scope.FindVar(id_name)->Get<framework::LoDTensor>();
std::vector<int64_t> ids;
TensorToVector(id_tensor, context.device_context(), &ids);
ids_union.insert(ids_union.end(), ids.begin(), ids.end());
ids_group.push_back(ids);
ids_lods.push_back(id_tensor.lod());
}
std::unordered_set<int64_t> s(ids_union.begin(), ids_union.end());
ids_union.assign(s.begin(), s.end());
for (auto &i : ids_union) {
PADDLE_ENFORCE_GE(
i, 0, platform::errors::OutOfRange(
"each element in embedding should be larger or equal 0"));
if (!is_distributed) {
PADDLE_ENFORCE_LT(
i, vec_dim_0,
platform::errors::OutOfRange(
"embedding id must in [0, %d) when is_distributed False",
vec_dim_0));
}
}
for (size_t i = 0; i < table_names.size(); i++) {
tables.push_back(std::make_pair(table_names[i], endpoints[i]));
}
std::unordered_map<int64_t, std::vector<float>> recved_vec_map;
prefetch_core(ids_union, tables, context, scope, is_distributed,
&recved_vec_map);
auto padding_idx = distributed::kNoPadding;
if (context.HasAttr("padding_idx")) {
padding_idx = context.Attr<int64_t>("padding_idx");
}
for (size_t i = 0; i < out_var_names.size(); i++) {
std::vector<int64_t> ids = ids_group[i];
auto ids_size = ids.size();
auto *out_t =
scope.FindVar(out_var_names[i])->GetMutable<framework::LoDTensor>();
out_t->set_lod(ids_lods[i]);
out_t->Resize(
framework::make_ddim({static_cast<int64_t>(ids_size), vec_dim_1}));
auto *out_d = out_t->mutable_data<float>(place);
if (platform::is_cpu_place(out_t->place())) {
for (auto idx = 0; idx < static_cast<int>(ids_size); idx++) {
const auto &id = ids[idx];
if (padding_idx != distributed::kNoPadding && id == padding_idx) {
memset(out_d + idx * vec_dim_1, 0, sizeof(float) * vec_dim_1);
} else {
std::copy_n(recved_vec_map[id].begin(), vec_dim_1,
out_d + idx * vec_dim_1);
}
}
} else {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
std::vector<float> ids_value_vec(ids_size * vec_dim_1);
for (auto idx = 0; idx < static_cast<int>(ids_size); idx++) {
const auto &id = ids[idx];
if (padding_idx != distributed::kNoPadding && id == padding_idx) {
memset(&ids_value_vec[idx * vec_dim_1], 0, sizeof(float) * vec_dim_1);
} else {
memcpy(&ids_value_vec[idx * vec_dim_1], &recved_vec_map[id][0],
sizeof(float) * vec_dim_1);
}
}
auto &gpu_place = BOOST_GET_CONST(platform::CUDAPlace, out_t->place());
auto &cpu_place = BOOST_GET_CONST(
platform::CPUPlace, paddle::platform::CPUDeviceContext().GetPlace());
auto stream = context.cuda_device_context().stream();
memory::Copy(gpu_place, out_d, cpu_place, &ids_value_vec[0],
sizeof(float) * ids_size * vec_dim_1, stream);
#else
PADDLE_ENFORCE(true, platform::errors::PermissionDenied(
"Paddle is not compiled with GPU!"));
#endif
}
}
}
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
class ExecutionContext;
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
namespace distributed {
constexpr int64_t kNoPadding = -1;
void prefetchs(const std::vector<std::string>& id_var_names,
const std::vector<std::string>& out_var_names,
const std::string& persistable_var_name, const bool backfill,
const std::vector<std::string>& table_names,
const std::vector<std::string>& endpoints,
const framework::ExecutionContext& context,
const framework::Scope& scope);
void prefetch(const std::string& id_name, const std::string& out_name,
const std::string& persistable_var_name, const bool backfill,
const std::vector<std::string>& table_names,
const std::vector<std::string>& endpoints,
const framework::ExecutionContext& context,
const framework::Scope& scope);
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <sys/types.h>
#include <algorithm>
#include <memory>
#include "glog/logging.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/communicator_common.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
namespace distributed {
class RPCClient;
using LoDTensor = framework::LoDTensor;
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;
template <typename T>
void RecvSparseLodTensor(const CommContext &rpc_ctx,
const framework::Scope &scope) {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto cpu_place = platform::CPUPlace();
auto &cpu_ctx = *pool.Get(cpu_place);
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
std::vector<const float *> tensors;
std::vector<distributed::VarHandlePtr> rets;
std::vector<std::string> recv_varnames;
for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_varnames[i];
VLOG(4) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i];
local_scope->Var(recv_var_name);
// sparse param in recv_scope is LoDTensor
rets.push_back(rpc_client->AsyncGetVarNoBarrier(
rpc_ctx.epmap[i], cpu_ctx, *local_scope.get(), recv_var_name,
recv_var_name));
recv_varnames.push_back(recv_var_name);
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::ExecutionTimeout(
"internal error in RPCClient"));
auto &recv_var_name = recv_varnames[i];
auto *local_var = local_scope->FindVar(recv_var_name);
const auto *value = local_var->Get<framework::LoDTensor>().data<float>();
tensors.push_back(value);
}
auto *merged_var = scope.FindVar(rpc_ctx.var_name);
if (merged_var == nullptr || !merged_var->IsInitialized()) {
PADDLE_THROW(
platform::errors::InvalidArgument("%s must initialized at first."));
}
auto dims1 = merged_var->Get<framework::LoDTensor>().dims()[1];
int64_t height = 0;
for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
auto *splited_var = local_scope->FindVar(rpc_ctx.splited_varnames[i]);
height += splited_var->Get<framework::LoDTensor>().dims()[0];
}
PADDLE_ENFORCE_EQ(
merged_var->Get<framework::LoDTensor>().dims()[0], height,
platform::errors::InvalidArgument(
"Received variable must has same dimension with local variable."));
auto *merged_t = merged_var->GetMutable<framework::LoDTensor>();
auto *merged_d = merged_t->mutable_data<float>(cpu_place);
auto pserver_num = rpc_ctx.splited_varnames.size();
for (int x = 0; x < height; ++x) {
auto id = x % pserver_num;
auto idx = x / pserver_num;
std::memcpy(merged_d + x * dims1, tensors[id] + idx * dims1,
sizeof(float) * dims1);
}
}
template <typename T>
void RecvGeoSparseRecords(const CommContext &rpc_ctx,
const framework::Scope &scope) {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto cpu_place = platform::CPUPlace();
auto &cpu_ctx = *pool.Get(cpu_place);
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_varnames[i];
local_scope->Var(recv_var_name);
VLOG(4) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i];
// sparse param in recv_scope is LoDTensor
rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx,
*local_scope.get(), recv_var_name,
recv_var_name, recv_var_name));
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::ExecutionTimeout(
"internal error in RPCClient"));
}
int64_t height = 0;
int64_t ids_num = 0;
int64_t width = 0;
std::vector<int64_t> all_ids;
auto pserver_num = rpc_ctx.splited_varnames.size();
for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_varnames[i];
auto *recv_var = local_scope->FindVar(recv_var_name);
auto &recv_t = recv_var->Get<framework::SelectedRows>();
height += recv_t.height();
ids_num += recv_t.rows().size();
width = recv_t.value().dims()[1];
if (rpc_ctx.is_distributed) {
std::copy(recv_t.rows().begin(), recv_t.rows().end(),
std::back_inserter(all_ids));
} else {
std::transform(recv_t.rows().begin(), recv_t.rows().end(),
std::back_inserter(all_ids),
[&](int64_t id) { return id * pserver_num + i; });
}
}
auto *var = scope.FindVar(rpc_ctx.var_name);
auto *t_ = var->GetMutable<framework::SelectedRows>();
T *out_data =
t_->mutable_value()->mutable_data<T>({ids_num, width}, cpu_place);
t_->set_height(height);
t_->set_rows(all_ids);
int64_t cnt = 0;
for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_varnames[i];
auto *recv_var = local_scope->FindVar(recv_var_name);
auto &recv_t = recv_var->Get<framework::SelectedRows>();
auto rows = recv_t.rows().size();
const T *in_data = recv_t.value().data<T>();
std::copy_n(in_data, rows * width, out_data + cnt);
cnt += rows * width;
}
t_->SyncIndex();
}
template <typename T>
void RecvLodTensor(const CommContext &rpc_ctx, const framework::Scope &scope) {
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
std::vector<distributed::VarHandlePtr> rets;
// variable do not spilt
if (rpc_ctx.origin_varnames.size() == 1 &&
rpc_ctx.splited_varnames.size() == 1) {
auto varname = rpc_ctx.origin_varnames[0];
const auto place =
scope.FindVar(varname)->Get<framework::LoDTensor>().place();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &ctx = *pool.Get(place);
VLOG(4) << "recv " << varname << " from " << rpc_ctx.epmap[0] << " in gpu? "
<< platform::is_gpu_place(place);
rets.push_back(rpc_client->AsyncGetVarNoBarrier(rpc_ctx.epmap[0], ctx,
scope, varname, varname));
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(
rets[i]->Wait(), 0U,
platform::errors::ExecutionTimeout("internal error in RPCClient"));
}
VLOG(3) << "ParameterRecv out " << rpc_ctx.var_name;
return;
} else {
PADDLE_ENFORCE(false, platform::errors::Unimplemented(
"ParameterRecv can not recv dense with multi "
"parts now, add it soon."));
}
}
template <typename T>
void ParameterRecv<T>::operator()(const CommContext &rpc_ctx,
const framework::Scope &scope,
bool geo_records) {
VLOG(3) << "ParameterRecv in " << rpc_ctx.var_name;
PADDLE_ENFORCE_GE(rpc_ctx.origin_varnames.size(), 1,
platform::errors::InvalidArgument(
"origin_varnames.size() >= 1 is permitted"));
if (rpc_ctx.is_sparse) {
if (geo_records) {
RecvGeoSparseRecords<T>(rpc_ctx, scope);
} else {
RecvSparseLodTensor<T>(rpc_ctx, scope);
}
} else {
RecvLodTensor<T>(rpc_ctx, scope);
}
VLOG(3) << "ParameterRecv out " << rpc_ctx.var_name;
}
template <typename T>
void ParameterRecv<T>::operator()(const CommContext &rpc_ctx,
const framework::Scope &scope) {
this->operator()(rpc_ctx, scope, false);
}
template struct ParameterRecv<float>;
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/distributed/communicator_common.h"
namespace paddle {
namespace operators {
namespace distributed {
template <typename T>
struct ParameterRecv {
void operator()(const CommContext &rpc_ctx, const framework::Scope &scope,
bool barrier);
void operator()(const CommContext &rpc_ctx, const framework::Scope &scope);
};
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/parameter_send.h"
#include <memory>
#include <utility>
#include "glog/logging.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/communicator_common.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
class Scope;
class Tensor;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
namespace distributed {
class RPCClient;
using LoDTensor = framework::LoDTensor;
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;
typedef std::vector<std::pair<std::string, std::string>> EP_SPLIT_TABLE_PAIRS;
inline EP_SPLIT_TABLE_PAIRS GetMultiFieldCommContext(
const CommContext &rpc_ctx, const framework::Scope &scope,
int multi_parts) {
EP_SPLIT_TABLE_PAIRS table_pairs;
auto *send_var = scope.FindVar(rpc_ctx.var_name);
if (send_var->IsType<framework::SelectedRows>()) {
PADDLE_ENFORCE_GE(multi_parts, 1,
platform::errors::InvalidArgument(
"multi_parts must == 1 in parameter send, now is: %d",
multi_parts));
for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
table_pairs.push_back(
std::make_pair(rpc_ctx.epmap[i], rpc_ctx.splited_varnames[i]));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"GetMultiFieldCommContext unsupported LoDTensor current!"));
}
return table_pairs;
} // namespace distributed
void SendByNotifyRPC(const CommContext &rpc_ctx,
const framework::Scope &scope) {
auto cpu_ctx = paddle::platform::CPUDeviceContext();
auto &send_var_name = rpc_ctx.var_name;
std::vector<distributed::VarHandlePtr> rets;
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
if (NeedSend(scope, send_var_name)) {
for (size_t j = 0; j < rpc_ctx.epmap.size(); j++) {
auto &endpoint = rpc_ctx.epmap[j];
VLOG(4) << "sending " << send_var_name << " to " << endpoint;
rets.push_back(rpc_client->AsyncDistributeNotify(endpoint, cpu_ctx, scope,
send_var_name));
VLOG(4) << "send var " << send_var_name << " by notify RPC done";
}
} else {
VLOG(3) << "don't send non-initialized variable: " << rpc_ctx.var_name;
}
for (auto &handle : rets) {
PADDLE_ENFORCE_NE(handle->Wait(), 0U, platform::errors::ExecutionTimeout(
"internal error in RPCClient"));
}
}
template <typename T>
void ParameterSend<T>::operator()(const CommContext &rpc_ctx,
const framework::Scope &scope, bool sync,
int multi_parts) {
if (rpc_ctx.var_name == STEP_COUNTER) {
SendByNotifyRPC(rpc_ctx, scope);
return;
}
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &cpu_ctx = *pool.Get(platform::CPUPlace());
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
std::vector<distributed::VarHandlePtr> rets;
auto *send_var = scope.FindVar(rpc_ctx.var_name);
if (send_var->IsType<framework::LoDTensor>()) {
size_t out_num = rpc_ctx.splited_varnames.size();
if (out_num > 1) {
auto &send_tensor = send_var->Get<framework::LoDTensor>();
auto &send_tensor_dims = send_tensor.dims();
std::vector<framework::DDim> outs_dims;
outs_dims.reserve(out_num);
// infer output shape
PADDLE_ENFORCE_EQ(
rpc_ctx.height_sections.size(), out_num,
platform::errors::InvalidArgument("tensor split sections size"
"should be equal to output size."));
for (size_t i = 0; i < out_num; ++i) {
auto dim = send_tensor_dims;
dim[0] = rpc_ctx.height_sections[i];
outs_dims.push_back(dim);
}
// create output var in local scope
size_t row_offset = 0;
for (size_t i = 0; i < out_num; ++i) {
framework::Tensor *out = local_scope->Var(rpc_ctx.splited_varnames[i])
->GetMutable<framework::LoDTensor>();
*out = send_tensor.Slice(row_offset, row_offset + outs_dims[i][0]);
row_offset += outs_dims[i][0];
}
} else {
auto &send_tensor = send_var->Get<framework::LoDTensor>();
framework::Tensor *out = local_scope->Var(rpc_ctx.splited_varnames[0])
->GetMutable<framework::LoDTensor>();
out->ShareDataWith(send_tensor);
}
for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
auto &send_var_name = rpc_ctx.splited_varnames[i];
auto &endpoint = rpc_ctx.epmap[i];
VLOG(4) << " send var name: " << send_var_name
<< "endpoint: " << endpoint;
if (NeedSend(*local_scope.get(), send_var_name)) {
VLOG(3) << "sending " << send_var_name << " to " << endpoint;
rets.push_back(rpc_client->AsyncSendVar(
endpoint, cpu_ctx, *local_scope.get(), send_var_name));
VLOG(4) << "send var " << send_var_name << " async handle done";
} else {
VLOG(3) << "don't send non-initialized variable: "
<< rpc_ctx.splited_varnames[i];
}
}
} else if (send_var->IsType<framework::SelectedRows>()) {
auto &send_slr = send_var->Get<framework::SelectedRows>();
auto &send_rows = send_slr.rows();
if (send_rows.size() == 0) {
LOG(WARNING)
<< "WARNING: The variable sent to pserver is empty, which "
"may cause an unknown error. Please check the state of "
"use_double_buffer in pyreader/dataloader async mode, you need to "
"turn it false.";
}
std::vector<std::vector<size_t>> outs_rows_idx;
std::vector<std::vector<size_t>> outs_dense_idx;
auto table_pairs = GetMultiFieldCommContext(rpc_ctx, scope, 1);
outs_rows_idx.resize(table_pairs.size());
outs_dense_idx.resize(table_pairs.size());
auto row_numel = send_slr.value().numel() / send_slr.value().dims()[0];
auto *src = send_slr.value().data<T>();
// create output var in local scope
std::vector<framework::SelectedRows *> outs;
for (auto &table : table_pairs) {
auto *out =
local_scope->Var(table.second)->GetMutable<framework::SelectedRows>();
outs.push_back(out);
}
if (!rpc_ctx.is_distributed) {
auto pserver_num = rpc_ctx.epmap.size();
// split rows index into output sparse vars
for (size_t i = 0; i < send_rows.size(); ++i) {
auto ep_idx = send_rows[i] % pserver_num;
auto id = send_rows[i] / pserver_num;
outs_rows_idx[ep_idx].push_back(id);
outs_dense_idx[ep_idx].push_back(i);
}
auto place = platform::CPUPlace();
for (size_t out_idx = 0; out_idx < rpc_ctx.splited_varnames.size();
out_idx++) {
auto rows_idx = outs_rows_idx[out_idx];
auto dims = send_slr.GetCompleteDims();
dims[0] = rows_idx.size();
outs[out_idx]->set_height(rpc_ctx.height_sections[out_idx]);
outs[out_idx]->mutable_rows()->clear();
outs[out_idx]->mutable_value()->mutable_data<T>(dims, send_slr.place());
if (rows_idx.size() > 0) {
for (auto idx : rows_idx) {
outs[out_idx]->mutable_rows()->push_back(idx);
}
auto dst = outs[out_idx]->mutable_value()->mutable_data<T>(place);
for (size_t j = 0; j < rows_idx.size(); j++) {
if (platform::is_cpu_place(place)) {
memory::Copy(platform::CPUPlace(), dst + j * row_numel,
platform::CPUPlace(),
src + outs_dense_idx[out_idx][j] * row_numel,
sizeof(T) * row_numel);
} else {
PADDLE_THROW(
platform::errors::Unimplemented("do not support GPU now"));
}
}
}
PADDLE_ENFORCE_EQ(
rows_idx.size(), outs[out_idx]->rows().size(),
platform::errors::InvalidArgument(
"rows should has the same size with tensor dim 0"));
}
} else {
auto pserver_num = rpc_ctx.epmap.size();
// split rows index into output sparse vars
for (size_t i = 0; i < send_rows.size(); ++i) {
auto out_idx = send_rows[i] % pserver_num;
outs_rows_idx[out_idx].push_back(send_rows[i]);
outs_dense_idx[out_idx].push_back(i);
}
auto place = platform::CPUPlace();
for (size_t out_idx = 0; out_idx < rpc_ctx.splited_varnames.size();
out_idx++) {
auto rows_idx = outs_rows_idx[out_idx];
auto dims = send_slr.GetCompleteDims();
dims[0] = rows_idx.size();
outs[out_idx]->set_height(rpc_ctx.height_sections[out_idx]);
outs[out_idx]->mutable_rows()->clear();
outs[out_idx]->mutable_value()->mutable_data<T>(dims, send_slr.place());
if (rows_idx.size() > 0) {
for (auto idx : rows_idx) {
outs[out_idx]->mutable_rows()->push_back(idx);
}
auto dst = outs[out_idx]->mutable_value()->mutable_data<T>(place);
for (size_t j = 0; j < rows_idx.size(); j++) {
if (platform::is_cpu_place(place)) {
memory::Copy(platform::CPUPlace(), dst + j * row_numel,
platform::CPUPlace(),
src + outs_dense_idx[out_idx][j] * row_numel,
sizeof(T) * row_numel);
} else {
PADDLE_THROW(
platform::errors::Unimplemented("do not support GPU now"));
}
}
}
PADDLE_ENFORCE_EQ(
rows_idx.size(), outs[out_idx]->rows().size(),
platform::errors::InvalidArgument(
"rows should has the same size with tensor dim 0"));
}
}
for (size_t i = 0; i < table_pairs.size(); i++) {
auto &send_var_name = table_pairs[i].second;
auto &endpoint = table_pairs[i].first;
auto need_send = NeedSend(*local_scope.get(), send_var_name);
VLOG(4) << "send var name: " << send_var_name
<< " send var endpoint: " << endpoint
<< " need send: " << need_send;
if (need_send) {
VLOG(4) << "sending " << send_var_name << " to " << endpoint;
rets.push_back(rpc_client->AsyncSendVar(
endpoint, cpu_ctx, *local_scope.get(), send_var_name));
VLOG(4) << "send var " << send_var_name << " async handle done";
} else {
VLOG(4) << "don't send non-initialized variable: "
<< rpc_ctx.splited_varnames[i];
}
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"unsupported var type: %s to send!", send_var->Type()));
}
VLOG(4) << "Prepare to send var " << rpc_ctx.var_name;
if (sync) {
for (auto &handle : rets) {
VLOG(4) << "Wait send var to pserver handle: " << handle;
PADDLE_ENFORCE_NE(handle->Wait(), 0U, platform::errors::ExecutionTimeout(
"internal error in RPCClient"));
}
}
}
template struct ParameterSend<float>;
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/distributed/communicator_common.h"
namespace paddle {
namespace operators {
namespace distributed {
template <typename T>
struct ParameterSend {
void operator()(const CommContext &rpc_ctx, const framework::Scope &scope,
bool sync, int multi_parts);
};
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// NOTE: This file was originally created by tensorflow
// (https://github.com/tensorflow/tensorflow/) we borrow this
// file and did some modifications so that we can send gRPC
// requests without too much copying of the tensor data.
#pragma once
#include <string>
#include "grpc++/grpc++.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace distributed {
char* EncodeVarint32(char* dst, uint32_t v) {
// Operate on characters as unsigneds
unsigned char* ptr = reinterpret_cast<unsigned char*>(dst);
static const int B = 128;
if (v < (1 << 7)) {
*(ptr++) = v;
} else if (v < (1 << 14)) {
*(ptr++) = v | B;
*(ptr++) = v >> 7;
} else if (v < (1 << 21)) {
*(ptr++) = v | B;
*(ptr++) = (v >> 7) | B;
*(ptr++) = v >> 14;
} else if (v < (1 << 28)) {
*(ptr++) = v | B;
*(ptr++) = (v >> 7) | B;
*(ptr++) = (v >> 14) | B;
*(ptr++) = v >> 21;
} else {
*(ptr++) = v | B;
*(ptr++) = (v >> 7) | B;
*(ptr++) = (v >> 14) | B;
*(ptr++) = (v >> 21) | B;
*(ptr++) = v >> 28;
}
return reinterpret_cast<char*>(ptr);
}
char* EncodeVarint64(char* dst, uint64_t v) {
static const int B = 128;
unsigned char* ptr = reinterpret_cast<unsigned char*>(dst);
while (v >= B) {
*(ptr++) = (v & (B - 1)) | B;
v >>= 7;
}
*(ptr++) = static_cast<unsigned char>(v);
return reinterpret_cast<char*>(ptr);
}
int VarintLength(uint64_t v) {
int len = 1;
while (v >= 128) {
v >>= 7;
len++;
}
return len;
}
class ProtoEncodeHelper {
public:
ProtoEncodeHelper(char* buf, int max_size)
: base_(buf), p_(buf), limit_(base_ + max_size) {}
~ProtoEncodeHelper() {}
const char* data() const { return base_; }
size_t size() const { return p_ - base_; }
void WriteUint64(int tag, uint64_t v) {
Encode32(combine(tag, WIRETYPE_VARINT));
Encode64(v);
}
void WriteBool(int tag, bool v) {
Encode32(combine(tag, WIRETYPE_VARINT));
EncodeBool(v);
}
void WriteString(int tag, const std::string& v) {
Encode32(combine(tag, WIRETYPE_LENGTH_DELIMITED));
Encode32(v.size());
EncodeBytes(v.data(), v.size());
}
void WriteVarlengthBeginning(int tag, uint32_t len) {
Encode32(combine(tag, WIRETYPE_LENGTH_DELIMITED));
Encode32(len);
}
void WriteRawBytes(const std::string& v) { EncodeBytes(v.data(), v.size()); }
private:
// Note: this module's behavior must match the protocol buffer wire encoding
// format.
enum {
WIRETYPE_VARINT = 0,
WIRETYPE_LENGTH_DELIMITED = 2,
};
static uint32_t combine(uint32_t tag, uint32_t type) {
return ((tag << 3) | type);
}
inline void Encode32(uint32_t v) {
if (v < 128) {
// Fast path for single-byte values. Many of the calls will use a
// constant value for v, so the comparison will get optimized away
// when Encode32 is inlined into the caller.
*p_ = v;
p_++;
} else {
p_ = EncodeVarint32(p_, v);
}
}
void Encode64(uint64_t v) { p_ = EncodeVarint64(p_, v); }
void EncodeBool(bool v) {
*p_ = (v ? 1 : 0); // Equal to varint32 encoding of 0 or 1
p_++;
}
void EncodeBytes(const char* bytes, int N) {
memcpy(p_, bytes, N);
p_ += N;
}
char* base_;
char* p_;
char* limit_; // Just for CHECKs
};
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <time.h>
#include <condition_variable> // NOLINT
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace operators {
namespace distributed {
constexpr char kRequestSend[] = "RequestSend";
constexpr char kRequestGet[] = "RequestGet";
constexpr char kRequestGetMonomerVariable[] = "RequestGetMonomerVariable";
constexpr char kRequestGetMonomerBarrier[] = "RequestGetMonomerBarrier";
constexpr char kRequestPrefetch[] = "RequestPrefetch";
constexpr char kRequestCheckpoint[] = "RequestCheckpoint";
constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
constexpr char kRequestGetNoBarrier[] = "GetVariableNoBarrier";
constexpr char kRequestNotify[] = "RequestNotify";
constexpr char kRequestSendAndRecv[] = "RequestSendAndRecv";
constexpr char kSendRPC[] = "SendRPC";
constexpr char kGetRPC[] = "GetRPC";
constexpr char kGetNoBarrierRPC[] = "GetNoBarrierRPC";
constexpr char kGetMonomerRPC[] = "GetMonomerRPC";
constexpr char kPrefetchRPC[] = "PrefetchRPC";
constexpr char kBatchBarrierRPC[] = "BatchBarrierRPC";
constexpr char kFetchBarrierRPC[] = "FetchBarrierRPC";
constexpr char kSendMonomerFetchBarrierRPC[] = "SendMonomerFetchBarrierRPC";
constexpr char kSendCompleteRPC[] = "SendCompleteRPC";
constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC";
constexpr char kSendAndRecvRPC[] = "SendAndRecvRPC";
constexpr int64_t kPrefetchTimeout = 60000;
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV"
#define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV"
#define LEARNING_RATE_DECAY_COUNTER "@LR_DECAY_COUNTER@"
#define STEP_COUNTER "@PS_STEP_COUNTER@"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
enum DistributedMode { kSync = 0, kAsync = 1, kHalfAsync = 2, kGeo = 3 };
class RPCServer;
class VarHandle {
public:
VarHandle(const std::string ep, const std::string& method,
const std::string& name,
const platform::DeviceContext* p_ctx = nullptr,
const framework::Scope* p_scope = nullptr)
: status_(kDefaultState) {
ep_ = ep;
ctx_ = p_ctx;
scope_ = p_scope;
name_ = name;
method_ = method;
}
virtual ~VarHandle() {}
public:
bool should_retry = false;
bool Wait() {
int ret = kDefaultState;
{
std::unique_lock<std::mutex> lk(sync_mutex_);
wait_cond_.wait(lk, [this] { return status_ != kDefaultState; });
ret = status_;
}
VLOG(7) << "VarHandle wait:" << ret;
return ret != kErrorState;
}
void Finish(bool ok) {
{
std::unique_lock<std::mutex> lk(sync_mutex_);
status_ = ok ? kFinishState : kErrorState;
}
VLOG(7) << "VarHandle finish:" << ok;
wait_cond_.notify_all();
}
std::string String() const {
std::ostringstream s;
s << method_ << " name:[" << name_ << "], ep:[" << ep_ << "], status:["
<< status_ << "]";
return s.str();
}
std::string ep() const { return ep_; }
const platform::DeviceContext* ctx() const { return ctx_; }
const framework::Scope* scope() const { return scope_; }
std::string name() const { return name_; }
std::string method() const { return method_; }
protected:
// RPC endpoint.
std::string ep_;
const platform::DeviceContext* ctx_;
const framework::Scope* scope_;
// Variable name.
std::string name_;
// RPC method name.
std::string method_;
protected:
std::mutex sync_mutex_;
std::condition_variable wait_cond_;
enum VarHandleStatus {
kDefaultState = -1,
kErrorState = 0,
kFinishState = 1,
};
VarHandleStatus status_;
private:
DISABLE_COPY_AND_ASSIGN(VarHandle);
};
typedef std::shared_ptr<VarHandle> VarHandlePtr;
class RequestHandler {
public:
explicit RequestHandler(int distributed_mode)
: distributed_mode_(distributed_mode),
dev_ctx_(nullptr),
executor_(nullptr),
scope_(nullptr),
program_(nullptr),
rpc_server_(nullptr) {}
virtual ~RequestHandler() {}
// Set attributes.
void SetScope(framework::Scope* scope) { scope_ = scope; }
void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
void SetProgram(framework::ProgramDesc* program) { program_ = program; }
void SetExecutor(framework::Executor* executor) { executor_ = executor; }
// Used for dist lookup table prefetch
void SetPrefetchPreparedCtx(
std::unordered_map<
std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
prefetch_var_name_to_prepared_ctx_ = g;
}
void SetCheckpointNotifyPreparedCtx(
std::shared_ptr<framework::ExecutorPrepareContext> g) {
checkpoint_prepared_ctx_ = g;
}
// Used for async.
void SetGradToPreparedCtx(
std::unordered_map<
std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
grad_to_prepared_ctx_ = g;
}
void SetSparseGradToParam(std::unordered_map<std::string, std::string>* g) {
sparse_grad_to_param_ = g;
}
void SetLrDecayPreparedCtx(
std::shared_ptr<framework::ExecutorPrepareContext> g) {
lr_decay_prepared_ctx_ = g;
}
void SetRPCServer(RPCServer* rpc_server) { rpc_server_ = rpc_server; }
// Get attributes.
int distributed_mode() { return distributed_mode_; }
framework::Scope* scope() { return scope_; }
const platform::DeviceContext* dev_ctx() { return dev_ctx_; }
framework::ProgramDesc* program() { return program_; }
framework::Executor* executor() { return executor_; }
// This function processes user's rpc request.
// The implemention is in request_handler_impl.
// example:
// std::string varname = request_.varname();
//
// auto scope = request_handler_->scope();
// auto invar = scope->FindVar(varname);
// framework::Variable* outvar = nullptr;
//
// request_handler_->Handle(varname, scope, invar, &outvar);
// if (outvar) {
// SerializeToByteBuffer(varname, outvar,
// *request_handler_->dev_ctx(), &reply_);
// }
virtual bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "",
const std::string& table_name = "") = 0;
protected:
const int distributed_mode_;
const platform::DeviceContext* dev_ctx_;
framework::Executor* executor_;
framework::Scope* scope_;
framework::ProgramDesc* program_;
// used for distribute lookup table prefetch
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>*
prefetch_var_name_to_prepared_ctx_;
// used for checkpoint notify
std::shared_ptr<framework::ExecutorPrepareContext> checkpoint_prepared_ctx_;
// Used for async.
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>*
grad_to_prepared_ctx_;
std::unordered_map<std::string, std::string>* sparse_grad_to_param_;
// used for lr decay
std::shared_ptr<framework::ExecutorPrepareContext> lr_decay_prepared_ctx_;
RPCServer* rpc_server_;
};
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include <iostream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/string/piece.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/split.h"
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
namespace paddle {
namespace operators {
namespace distributed {
// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables
// to directory specified.
constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath";
bool RequestSendHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
const int trainer_id,
const std::string &out_var_name,
const std::string &table_name) {
VLOG(4) << "RequestSendHandler:" << varname;
// Sync
if (varname == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv BATCH_BARRIER_MESSAGE";
rpc_server_->IncreaseBatchBarrier(kRequestSend);
} else if (varname == COMPLETE_MESSAGE) {
VLOG(3) << "sync: recv complete message";
if (HeartBeatMonitor::GetInstance() != nullptr) {
HeartBeatMonitor::GetInstance()->Update(trainer_id, "", COMPLETED);
}
rpc_server_->Complete();
} else {
// Async
if (distributed_mode_ != DistributedMode::kSync) {
VLOG(3) << "async process var: " << varname;
if (varname == BATCH_BARRIER_MESSAGE) {
PADDLE_THROW(platform::errors::InvalidArgument(
"async mode should not recv BATCH_BARRIER_MESSAGE or "
"COMPLETE_MESSAGE"));
}
HeartBeatMonitor::GetInstance()->Update(trainer_id, varname, RUNNING);
std::string run_varname = varname;
string::Piece part_piece("@PIECE");
string::Piece var_name_piece = string::Piece(varname);
if (string::Contains(var_name_piece, part_piece)) {
auto varname_splits = paddle::string::Split(varname, '@');
PADDLE_ENFORCE_EQ(
varname_splits.size(), 3,
platform::errors::InvalidArgument(
"varname: %s should be separated into 3 parts by @", varname));
run_varname = varname_splits[0];
scope->Rename(varname, run_varname);
}
auto *var = scope->FindVar(run_varname);
// for sparse ids
if (var->IsType<framework::SelectedRows>()) {
if (distributed_mode_ == DistributedMode::kAsync ||
distributed_mode_ == DistributedMode::kHalfAsync) {
auto *ins = distributed::LargeScaleKV::GetInstance();
if (ins->GradInLargeScale(run_varname)) {
auto *large_scale_var = ins->GetByGrad(run_varname);
for (auto name : large_scale_var->CachedVarnames()) {
scope->Var(name);
}
}
}
if (distributed_mode_ == DistributedMode::kGeo) {
if (AsyncSparseParamUpdateRecorder::GetInstance()->HasGrad(
run_varname)) {
auto &grad_slr =
scope->FindVar(run_varname)->Get<framework::SelectedRows>();
AsyncSparseParamUpdateRecorder::GetInstance()->Update(
run_varname, grad_slr.rows());
}
}
}
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[run_varname].get(),
scope);
return true;
} else { // sync
rpc_server_->WaitCond(kRequestSend);
VLOG(3) << "sync: processing received var: " << varname;
PADDLE_ENFORCE_NOT_NULL(
invar, platform::errors::NotFound(
"sync: Can not find server side var %s.", varname));
}
}
return true;
}
bool RequestGetHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
const int trainer_id,
const std::string &out_var_name,
const std::string &table_name) {
VLOG(3) << "RequestGetHandler:" << varname
<< " out_var_name: " << out_var_name << " trainer_id: " << trainer_id
<< " table_name: " << table_name;
if (distributed_mode_ == DistributedMode::kSync) {
if (varname == FETCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv fetch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestGet);
} else {
rpc_server_->WaitCond(kRequestGet);
*outvar = scope_->FindVar(varname);
}
} else {
if (varname != FETCH_BARRIER_MESSAGE && varname != COMPLETE_MESSAGE) {
if (enable_dc_asgd_) {
// NOTE: the format is determined by distribute_transpiler.py
std::string param_bak_name =
string::Sprintf("%s.trainer_%d_bak", varname, trainer_id);
VLOG(3) << "getting " << param_bak_name << " trainer_id " << trainer_id;
auto var = scope_->FindVar(varname);
auto t_orig = var->Get<framework::LoDTensor>();
auto param_bak = scope_->Var(param_bak_name);
auto t = param_bak->GetMutable<framework::LoDTensor>();
t->mutable_data(dev_ctx_->GetPlace(), t_orig.type());
VLOG(3) << "copying " << varname << " to " << param_bak_name;
framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t);
}
if (distributed_mode_ == DistributedMode::kGeo &&
AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) &&
!table_name.empty()) {
VLOG(3) << "AsyncSparseParamUpdateRecorder " << varname << " exist ";
std::vector<int64_t> updated_rows;
AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear(
varname, trainer_id, &updated_rows);
if (VLOG_IS_ON(3)) {
std::ostringstream sstream;
sstream << "[";
for (auto &row_id : updated_rows) {
sstream << row_id << ", ";
}
sstream << "]";
VLOG(3) << "updated_rows size: " << updated_rows.size() << " "
<< sstream.str();
}
auto &origin_tensor =
scope_->FindVar(varname)->Get<framework::LoDTensor>();
auto *origin_tensor_data = origin_tensor.data<float>();
auto &dims = origin_tensor.dims();
*outvar = scope->Var();
auto *out_slr = (*outvar)->GetMutable<framework::SelectedRows>();
out_slr->set_rows(updated_rows);
out_slr->set_height(dims[0]);
auto out_dims = framework::make_ddim(
{static_cast<int64_t>(updated_rows.size()), dims[1]});
auto *data = out_slr->mutable_value()->mutable_data<float>(
out_dims, origin_tensor.place());
auto width = dims[1];
for (size_t i = 0; i < updated_rows.size(); ++i) {
PADDLE_ENFORCE_LT(
updated_rows[i], dims[0],
platform::errors::OutOfRange(
"The value of updated_rows: %s out of Tensor %s dims[0]: %s",
updated_rows[i], varname, dims[0]));
memcpy(data + i * width, origin_tensor_data + updated_rows[i] * width,
sizeof(float) * width);
}
} else {
*outvar = scope_->FindVar(varname);
}
}
}
return true;
}
bool RequestGetNoBarrierHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
const int trainer_id,
const std::string &out_var_name,
const std::string &table_name) {
VLOG(4) << "RequestGetNoBarrierHandler:" << varname
<< " out_var_name: " << out_var_name;
// get var from pserver immediately without barriers
string::Piece without_barrier_piece(WITHOUT_BARRIER_MESSAGE);
string::Piece var_name_piece = string::Piece(varname);
if (string::Contains(var_name_piece, without_barrier_piece)) {
var_name_piece = string::TrimSuffix(var_name_piece, without_barrier_piece);
VLOG(4) << "Get var " << var_name_piece << " with "
<< WITHOUT_BARRIER_MESSAGE;
*outvar = scope_->FindVar(var_name_piece.ToString());
return true;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"GetNoBarrier must contain %s", WITHOUT_BARRIER_MESSAGE));
}
return true;
}
bool RequestPrefetchHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
const int trainer_id,
const std::string &out_var_name,
const std::string &table_name) {
VLOG(4) << "RequestPrefetchHandler " << varname;
(*outvar)->GetMutable<framework::LoDTensor>();
VLOG(1) << "Prefetch "
<< "tablename: " << table_name << " ids:" << varname
<< " out: " << out_var_name;
paddle::platform::CPUPlace cpu_place;
auto *ins = distributed::LargeScaleKV::GetInstance();
if (ins->ParamInLargeScale(table_name)) {
auto lookup_table_op = PullLargeScaleOp(table_name, varname, out_var_name);
lookup_table_op->Run(*scope, cpu_place);
} else {
auto lookup_table_op =
BuildLookupTableOp(table_name, varname, out_var_name);
lookup_table_op->Run(*scope, cpu_place);
}
return true;
}
bool RequestCheckpointHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
const int trainer_id,
const std::string &out_var_name,
const std::string &table_name) {
VLOG(4) << "receive save var " << varname << " with path " << out_var_name
<< " mode " << table_name;
int mode = std::stoi(table_name);
auto *ins = distributed::LargeScaleKV::GetInstance();
ins->Get(varname)->Save(out_var_name, mode);
return true;
}
bool RequestNotifyHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
const int trainer_id,
const std::string &out_var_name,
const std::string &table_name) {
VLOG(3) << "RequestNotifyHandler: " << varname
<< ", trainer_id: " << trainer_id;
string::Piece decay_piece(STEP_COUNTER);
string::Piece var_name_piece = string::Piece(varname);
if (string::Contains(var_name_piece, decay_piece)) {
VLOG(3) << "LearningRate Decay Counter Update";
auto *send_var = scope->FindVar(varname);
auto send_var_tensor = send_var->Get<framework::LoDTensor>();
auto *send_value =
send_var_tensor.mutable_data<int64_t>(send_var_tensor.place());
auto counter = decay_counters.at(trainer_id);
counter += send_value[0];
decay_counters.at(trainer_id) = counter;
auto *global_step_var = this->scope()->FindVar(LEARNING_RATE_DECAY_COUNTER);
if (global_step_var == nullptr) {
PADDLE_THROW(platform::errors::InvalidArgument(
"can not find LEARNING_RATE_DECAY_COUNTER "));
}
auto *tensor = global_step_var->GetMutable<framework::LoDTensor>();
auto *value = tensor->mutable_data<int64_t>(platform::CPUPlace());
auto global_counter = 0;
for (auto &trainer_counter : decay_counters) {
global_counter += trainer_counter.second;
}
value[0] = global_counter;
if (lr_decay_prepared_ctx_.get() == nullptr) {
PADDLE_THROW(platform::errors::InvalidArgument(
"can not find decay block for executor"));
}
executor_->RunPreparedContext(lr_decay_prepared_ctx_.get(), scope_);
}
return true;
}
bool RequestSendAndRecvHandler::Handle(const std::string &varname,
framework::Scope *Scope,
framework::Variable *var,
framework::Variable **outvar,
const int trainer_id,
const std::string &out_var_name,
const std::string &table_name) {
VLOG(3) << "SendAndRecvHandle: " << varname
<< " out_var_name: " << out_var_name
<< " , trainer_id: " << trainer_id;
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(), Scope);
*outvar = Scope->FindVar(out_var_name);
return true;
}
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <time.h>
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
namespace paddle {
namespace framework {
class Scope;
class Variable;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
namespace distributed {
class RequestSendHandler final : public RequestHandler {
public:
explicit RequestSendHandler(int distributed_mode, bool enable_dc_asgd = false)
: RequestHandler(distributed_mode) {
enable_dc_asgd_ = enable_dc_asgd;
}
virtual ~RequestSendHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id, const std::string& out_var_name = "",
const std::string& table_name = "") override;
private:
bool enable_dc_asgd_;
};
class RequestGetHandler final : public RequestHandler {
public:
explicit RequestGetHandler(int distributed_mode, bool enable_dc_asgd = false)
: RequestHandler(distributed_mode) {
enable_dc_asgd_ = enable_dc_asgd;
}
virtual ~RequestGetHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id, const std::string& out_var_name = "",
const std::string& table_name = "") override;
private:
bool enable_dc_asgd_;
};
class RequestGetNoBarrierHandler final : public RequestHandler {
public:
RequestGetNoBarrierHandler() : RequestHandler(false) {}
virtual ~RequestGetNoBarrierHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id, const std::string& out_var_name = "",
const std::string& table_name = "") override;
};
static inline void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments,
paddle::framework::proto::OpDesc::Var* var) {
var->set_parameter(param_name);
for (auto& arg_name : arguments) {
*var->mutable_arguments()->Add() = arg_name;
}
}
class RequestPrefetchHandler final : public RequestHandler {
public:
explicit RequestPrefetchHandler(int distributed_mode)
: RequestHandler(distributed_mode) {}
virtual ~RequestPrefetchHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id, const std::string& out_var_name = "",
const std::string& table_name = "") override;
private:
std::unique_ptr<paddle::framework::OperatorBase> PullLargeScaleOp(
const std::string& table_name, const std::string& id_name,
const std::string& out_name) {
framework::OpDesc desc;
desc.SetType("lookup_sparse_table_read");
desc.SetInput("Ids", {id_name});
desc.SetOutput("Out", std::vector<std::string>({out_name}));
desc.SetAttr("tablename", {table_name});
desc.SetAttr("init", true);
desc.SetAttr("value_names", std::vector<std::string>({"Param"}));
auto op = paddle::framework::OpRegistry::CreateOp(desc);
return op;
}
std::unique_ptr<paddle::framework::OperatorBase> BuildLookupTableOp(
const std::string& table_name, const std::string& id_name,
const std::string& out_name) {
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("lookup_table");
BuildVar("W", {table_name.data()}, op_desc.add_inputs());
BuildVar("Ids", {id_name.data()}, op_desc.add_inputs());
BuildVar("Out", {out_name.data()}, op_desc.add_outputs());
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
return op;
}
};
class RequestCheckpointHandler final : public RequestHandler {
public:
explicit RequestCheckpointHandler(int distributed_mode)
: RequestHandler(distributed_mode) {}
virtual ~RequestCheckpointHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id, const std::string& out_var_name = "",
const std::string& table_name = "") override;
private:
std::unique_ptr<paddle::framework::OperatorBase> BuildCheckpointOp(
const std::string& varname, const std::string& file_path) {
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("save");
BuildVar("X", {varname.data()}, op_desc.add_inputs());
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("file_path");
attr->set_type(paddle::framework::proto::AttrType::STRING);
attr->set_s(file_path);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
return op;
}
};
class RequestNotifyHandler final : public RequestHandler {
public:
explicit RequestNotifyHandler(int distributed_mode, int trainers)
: RequestHandler(distributed_mode) {
this->trainers = trainers;
for (int i = 0; i < trainers; i++) {
decay_counters[i] = 0;
}
}
virtual ~RequestNotifyHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id, const std::string& out_var_name = "",
const std::string& table_name = "") override;
private:
int trainers;
std::unordered_map<int, int64_t> decay_counters;
};
class RequestSendAndRecvHandler final : public RequestHandler {
public:
explicit RequestSendAndRecvHandler(int distributed_mode)
: RequestHandler(distributed_mode) {}
virtual ~RequestSendAndRecvHandler() {}
bool Handle(const std::string& varname, framework::Scope* Scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id, const std::string& out_var_name = "",
const std::string& table_name = "") override;
};
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "gflags/gflags.h"
// default to 3min to avoid temprary network failures.
DEFINE_int32(rpc_deadline, 180000, "deadline timeouts for rpc");
DEFINE_int32(rpc_retry_times, 3, "retry times for rpc");
namespace paddle {
namespace operators {
namespace distributed {
std::once_flag RPCClient::init_flag_;
std::unique_ptr<RPCClient> RPCClient::rpc_client_(nullptr);
int RPCClient::trainer_id_ = 0;
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <condition_variable> // NOLINT
#include <memory>
#include <string>
#include "gflags/gflags.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
DECLARE_int32(rpc_deadline);
DECLARE_int32(rpc_retry_times);
namespace paddle {
namespace operators {
namespace distributed {
class RPCClient {
public:
RPCClient() {}
virtual ~RPCClient() {}
virtual VarHandlePtr AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& out_varname,
const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncGetVarNoBarrier(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
const std::string& out_varname,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncGetMonomerVariable(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncPrefetchVar(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& in_var_name,
const std::string& out_var_name, const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncSendBatchBarrier(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncSendFetchBarrier(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncGetMonomerBarrier(
const std::string& ep, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dirname,
const std::string& varname, const int mode,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncDistributeNotify(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncSendAndRecv(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& send_var_name,
const std::string& recv_var_name, const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncSendComplete(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0;
// Complete tells all the pserver instances that finishe the training,
// the pserver can reduce it's barrier count, and continue to train
// with other trainers.
virtual void SendComplete() = 0;
virtual bool Wait() = 0;
template <typename T>
static RPCClient* GetInstance(int trainer_id) {
std::call_once(init_flag_, &RPCClient::Init<T>, trainer_id);
return rpc_client_.get();
}
// Init is called by GetInstance.
template <typename T>
static void Init(int trainer_id) {
VLOG(1) << "init rpc client with trainer_id " << trainer_id;
trainer_id_ = trainer_id;
if (rpc_client_.get() == nullptr) {
rpc_client_.reset(new T());
rpc_client_->InitImpl();
}
}
virtual void InitImpl() {}
protected:
// each trainer have exact one trainer id, it should be static
static int trainer_id_;
private:
static std::once_flag init_flag_;
static std::unique_ptr<RPCClient> rpc_client_;
};
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include <fstream>
#include <string>
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
namespace distributed {
class RequestHandler;
void RPCServer::ShutDown() {
VLOG(3) << "RPCServer ShutDown ";
ShutDownImpl();
exit_flag_ = true;
barrier_cond_.notify_all();
rpc_cond_.notify_all();
}
void RPCServer::SavePort() const {
auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid());
std::ofstream port_file;
port_file.open(file_path);
port_file << selected_port_;
port_file.close();
VLOG(3) << "selected port written to " << file_path;
}
void RPCServer::WaitBarrier(const std::string& rpc_name) {
VLOG(3) << "WaitBarrier in: " << rpc_name;
std::unique_lock<std::mutex> lock(this->mutex_);
barrier_cond_.wait(lock, [this, &rpc_name] {
return ((barrier_counter_[rpc_name] == client_num_ && client_num_ != 0) ||
exit_flag_.load());
});
VLOG(3) << "WaitBarrier out: " << rpc_name
<< " counter: " << barrier_counter_[rpc_name];
}
void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
VLOG(3) << "RPCServer begin IncreaseBatchBarrier " << rpc_name;
// barrier msg should make sure that it's in the right cond(send|recv)
WaitCond(rpc_name);
int b = 0;
std::unique_lock<std::mutex> lock(mutex_);
b = ++barrier_counter_[rpc_name];
VLOG(3) << rpc_name << " barrier_counter: " << b;
if (b >= client_num_) {
lock.unlock();
VLOG(3) << "BatchBarrier counter reach " << client_num_ << " for "
<< rpc_name;
barrier_cond_.notify_all();
lock.lock();
}
}
void RPCServer::Complete() {
{
std::unique_lock<std::mutex> lock(mutex_);
client_num_--;
need_reset_all_vars_ = true;
VLOG(3) << "decrease client_num to: " << client_num_;
if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) {
barrier_counter_[kRequestGet]--;
}
}
barrier_cond_.notify_all();
}
bool RPCServer::NeedResetAllVars() {
std::unique_lock<std::mutex> lock(mutex_);
return need_reset_all_vars_;
}
int RPCServer::GetClientNum() {
std::unique_lock<std::mutex> lock(mutex_);
return client_num_;
}
void RPCServer::ResetBarrierCounter() {
VLOG(3) << "RPCServer ResetBarrierCounter ";
std::unique_lock<std::mutex> lock(mutex_);
for (auto& t : barrier_counter_) {
t.second = 0;
}
need_reset_all_vars_ = false;
}
void RPCServer::RegisterRPC(const std::string& rpc_name,
RequestHandler* handler, int thread_num) {
rpc_call_map_[rpc_name] = handler;
rpc_thread_num_[rpc_name] = thread_num;
static int cond = -1;
rpc_cond_map_[rpc_name] = ++cond;
VLOG(3) << "RegisterRPC rpc_name: " << rpc_name << ", handler: " << handler
<< ", cond: " << rpc_cond_map_[rpc_name];
}
void RPCServer::SetCond(const std::string& rpc_name) {
VLOG(3) << "RPCServer SetCond " << rpc_name;
{
std::unique_lock<std::mutex> lock(mutex_);
cur_cond_ = rpc_cond_map_[rpc_name];
}
rpc_cond_.notify_all();
}
void RPCServer::WaitCond(const std::string& rpc_name) {
VLOG(3) << "RPCServer WaitCond in " << rpc_name;
int cond = 0;
{
std::unique_lock<std::mutex> lock(mutex_);
cond = rpc_cond_map_[rpc_name];
}
std::unique_lock<std::mutex> lock(mutex_);
rpc_cond_.wait(
lock, [=] { return (cur_cond_.load() == cond || exit_flag_.load()); });
VLOG(3) << "RPCServer WaitCond out " << rpc_name;
}
void RPCServer::RegisterVar(const std::string& var_name,
const std::string& rpc_name,
framework::Scope* scope,
platform::DeviceContext* dev_ctx) {
MonomerHandle h;
h.var_name_ = var_name;
h.rpc_name_ = rpc_name;
h.scope_ = scope;
h.dev_ctx_ = dev_ctx;
{
std::unique_lock<std::mutex> lock(mutex_);
PADDLE_ENFORCE_EQ(
var_map_.find(var_name), var_map_.end(),
platform::errors::AlreadyExists("%s already in var_map.", var_name));
var_map_[var_name] = h;
}
rpc_cond_.notify_all();
VLOG(3) << "RegisterVar context:" << h.String();
}
void RPCServer::IncreaseVarBarrier(const std::string& var_name) {
int b = 0;
MonomerHandle h;
{
std::unique_lock<std::mutex> lock(mutex_);
b = ++var_map_[var_name].barrier_;
h = var_map_[var_name];
}
if (b >= client_num_) {
barrier_cond_.notify_all();
}
VLOG(3) << "IncreaseVarBarrier context:" << h.String();
}
void RPCServer::WaitVarBarrier(const std::string& var_name) {
VLOG(3) << "WaitVarBarrier var_name:" << var_name;
std::unique_lock<std::mutex> lock(mutex_);
barrier_cond_.wait(lock, [&]() {
return ((var_map_[var_name].barrier_ >= client_num_ && client_num_ != 0) ||
exit_flag_.load());
});
VLOG(3) << "WaitVarBarrier context: " << var_map_[var_name].String();
}
void RPCServer::SetVarCond(const std::string& var_name) {
VLOG(3) << "SetVarCond var_name:" << var_name;
{
std::unique_lock<std::mutex> lock(mutex_);
if (var_map_.find(var_name) != var_map_.end()) {
rpc_cond_.notify_all();
}
}
}
void RPCServer::WaitVarCond(const std::string& var_name) {
VLOG(3) << "WaitVarCond var_name:" << var_name;
std::unique_lock<std::mutex> lock(mutex_);
rpc_cond_.wait(lock, [=] {
return (var_map_.find(var_name) != var_map_.end() || exit_flag_.load());
});
VLOG(3) << "WaitVarCond var_name:" << var_name << " end";
}
MonomerHandle RPCServer::GetMonomer(const std::string& var_name) {
MonomerHandle h;
{
std::unique_lock<std::mutex> lock(mutex_);
h = var_map_[var_name];
}
return h;
}
void RPCServer::ClearRegisteredVars() {
std::unique_lock<std::mutex> lock(mutex_);
var_map_.clear();
}
void RPCServer::ClearVar(const std::string& var_name) {
std::unique_lock<std::mutex> lock(mutex_);
var_map_.erase(var_name);
}
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <set>
#include <string>
#include <thread> // NOLINT
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
namespace distributed {
class RequestHandler;
struct MonomerHandle {
std::string var_name_;
std::string rpc_name_;
framework::Scope* scope_{nullptr};
platform::DeviceContext* dev_ctx_{nullptr};
int64_t barrier_{0};
std::string String() {
std::stringstream ss;
ss << "var_name:" << var_name_ << ", rpc_name:" << rpc_name_
<< ", scope:" << scope_ << ", dev_ctx:" << dev_ctx_
<< ", barrier_:" << barrier_;
return ss.str();
}
};
class RPCServer {
public:
explicit RPCServer(const std::string& address, int client_num)
: cur_cond_(0),
bind_address_(address),
exit_flag_(false),
selected_port_(0),
client_num_(client_num),
need_reset_all_vars_(false) {}
virtual ~RPCServer() {}
virtual void StartServer() = 0;
virtual void WaitServerReady() = 0;
void ShutDown();
bool IsExit() { return exit_flag_.load(); }
int GetSelectedPort() const { return selected_port_; }
int GetClientNum();
void SavePort() const;
// RegisterRPC, register the rpc method name to a handler
// class, and auto generate a condition id for this call
// to be used for the barrier.
void RegisterRPC(const std::string& rpc_name, RequestHandler* handler,
int thread_num = 1);
int GetThreadNum(const std::string& rpc_name) {
return rpc_thread_num_[rpc_name];
}
// Wait util all the clients have reached the barrier for one
// rpc method. This function should be called in the
// RequestHandler if you want to run the server/client in a
// synchronous mode.
void WaitBarrier(const std::string& rpc_name);
void SetCond(const std::string& rpc_name);
void WaitCond(const std::string& rpc_name);
void IncreaseBatchBarrier(const std::string rpc_name);
void RegisterVar(const std::string& var_name, const std::string& rpc_name,
framework::Scope* scope, platform::DeviceContext* dev_ctx);
void IncreaseVarBarrier(const std::string& var_name);
void WaitVarBarrier(const std::string& var_name);
void SetVarCond(const std::string& var_name);
void WaitVarCond(const std::string& var_name);
void ClearRegisteredVars();
void ClearVar(const std::string& var_name);
MonomerHandle GetMonomer(const std::string& var_name);
void Complete();
void ResetBarrierCounter();
bool NeedResetAllVars();
protected:
virtual void ShutDownImpl() = 0;
private:
std::mutex mutex_;
std::unordered_map<std::string, int> barrier_counter_;
std::condition_variable barrier_cond_;
std::unordered_map<std::string, int> rpc_cond_map_;
std::atomic<int> cur_cond_;
std::condition_variable rpc_cond_;
protected:
std::string bind_address_;
std::atomic<int> exit_flag_;
int selected_port_;
int client_num_;
bool need_reset_all_vars_;
std::unordered_map<std::string, RequestHandler*> rpc_call_map_;
std::unordered_map<std::string, int> rpc_thread_num_;
friend class RequestHandler;
// TODO(gongwb): use more cond to notify or wait;
std::unordered_map<std::string, MonomerHandle> var_map_;
};
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdlib.h>
#include <unistd.h>
#include <chrono> // NOLINT
#include <memory>
#include <string>
#include <thread> // NOLINT
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace distributed = paddle::operators::distributed;
USE_NO_KERNEL_OP(lookup_sparse_table_read);
USE_NO_KERNEL_OP(checkpoint_notify);
USE_OP(scale);
std::unique_ptr<distributed::RPCServer> g_rpc_service;
std::unique_ptr<distributed::RequestHandler> g_req_handler;
framework::BlockDesc* AppendSendAndRecvBlock(framework::ProgramDesc* program) {
auto root_block = program->MutableBlock(0);
auto* block = program->AppendBlock(*root_block);
framework::OpDesc* op = block->AppendOp();
op->SetType("scale");
op->SetInput("X", {"x"});
op->SetOutput("Out", {"res"});
op->SetAttr("scale", 0.5f);
auto& out = *root_block->Var("res");
out.SetType(framework::proto::VarType::LOD_TENSOR);
out.SetShape({1, 10});
return block;
}
void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) {
auto w_var = scope->Var("w");
w_var->GetMutable<framework::SelectedRows>();
auto out_var = scope->Var("out");
out_var->GetMutable<framework::LoDTensor>();
auto ids_var = scope->Var("ids");
ids_var->GetMutable<framework::LoDTensor>();
auto x_var = scope->Var("x");
x_var->GetMutable<framework::LoDTensor>();
auto res_var = scope->Var("res");
res_var->GetMutable<framework::LoDTensor>();
}
void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place,
int64_t rows_numel) {
CreateVarsOnScope(scope, place);
auto ids_var = scope->Var("ids")->GetMutable<framework::LoDTensor>();
int64_t* ids_ptr =
ids_var->mutable_data<int64_t>(framework::DDim({rows_numel, 1}), *place);
for (int64_t i = 0; i < rows_numel; ++i) ids_ptr[i] = i * 2;
auto x_var = scope->Var("x")->GetMutable<framework::LoDTensor>();
float* x_ptr =
x_var->mutable_data<float>(framework::DDim({1, rows_numel}), *place);
for (int64_t i = 0; i < rows_numel; ++i) x_ptr[i] = 1.0;
}
void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
int64_t rows_numel) {
CreateVarsOnScope(scope, place);
auto w = scope->Var("w")->GetMutable<framework::SelectedRows>();
auto w_value = w->mutable_value();
w_value->Resize({rows_numel, 10});
for (int64_t i = 0; i < rows_numel; ++i) w->AutoGrownIndex(i, true);
auto ptr = w_value->mutable_data<float>(*place);
for (int64_t i = 0; i < w_value->numel(); ++i) {
ptr[i] = static_cast<float>(i / 10);
}
}
void StartServer(const std::string& rpc_name) {
framework::ProgramDesc program;
framework::Scope scope;
platform::CPUPlace place;
framework::Executor exe(place);
platform::CPUDeviceContext ctx(place);
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
prefetch_var_name_to_prepared;
g_req_handler->SetProgram(&program);
g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared);
g_req_handler->SetDevCtx(&ctx);
g_req_handler->SetScope(&scope);
g_req_handler->SetExecutor(&exe);
g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get());
// distributed::HeartBeatMonitor::Init(1, true, "w@grad");
g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread(
std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get()));
server_thread.join();
}
void StartSendAndRecvServer(const std::string& rpc_name) {
framework::ProgramDesc program;
framework::Scope scope;
platform::CPUPlace place;
framework::Executor exe(place);
platform::CPUDeviceContext ctx(place);
auto block = AppendSendAndRecvBlock(&program);
std::string in_var_name("x");
std::vector<int> prefetch_block_ids{block->ID()};
auto prepared = exe.Prepare(program, prefetch_block_ids);
InitTensorsOnServer(&scope, &place, 10);
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
grad_to_prepared_ctx;
grad_to_prepared_ctx[in_var_name] = prepared[0];
g_req_handler->SetProgram(&program);
g_req_handler->SetGradToPreparedCtx(&grad_to_prepared_ctx);
g_req_handler->SetDevCtx(&ctx);
g_req_handler->SetScope(&scope);
g_req_handler->SetExecutor(&exe);
g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get());
g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread(
std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get()));
server_thread.join();
}
TEST(COMPLETE, CPU) {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
g_req_handler.reset(
new distributed::RequestSendHandler(distributed::DistributedMode::kSync));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 2));
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
PADDLE_ENFORCE_NE(client, nullptr,
platform::errors::InvalidArgument(
"Client Start Fail, Check Your Code & Env"));
std::thread server_thread(StartServer, distributed::kRequestSend);
g_rpc_service->WaitServerReady();
int port = g_rpc_service->GetSelectedPort();
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
client->AsyncSendComplete(ep);
client->Wait();
EXPECT_EQ(g_rpc_service->GetClientNum(), 1);
g_rpc_service->ShutDown();
server_thread.join();
g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr);
}
TEST(SENDANDRECV, CPU) {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
g_req_handler.reset(new distributed::RequestSendAndRecvHandler(
distributed::DistributedMode::kAsync));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
PADDLE_ENFORCE_NE(client, nullptr,
platform::errors::InvalidArgument(
"Client Start Fail, Check Your Code & Env"));
std::thread server_thread(StartSendAndRecvServer,
distributed::kRequestSendAndRecv);
g_rpc_service->WaitServerReady();
int port = g_rpc_service->GetSelectedPort();
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
framework::Scope scope;
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
// create var on local scope
int64_t rows_numel = 10;
InitTensorsOnClient(&scope, &place, rows_numel);
std::string in_var_name("x");
std::string out_var_name("res");
client->AsyncSendAndRecv(ep, ctx, scope, in_var_name, out_var_name);
client->Wait();
auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::LoDTensor>();
auto ptr = value->mutable_data<float>(place);
for (int64_t i = 0; i < rows_numel; ++i) {
EXPECT_EQ(ptr[i], 0.5);
}
g_rpc_service->ShutDown();
server_thread.join();
LOG(INFO) << "begin reset";
g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr);
}
void StartCheckpointServer(const std::string& rpc_name) {
framework::ProgramDesc program;
framework::Scope scope;
platform::CPUPlace place;
framework::Executor exe(place);
platform::CPUDeviceContext ctx(place);
std::vector<distributed::SparseMeta> metas;
auto meta = distributed::SparseMeta();
meta.name = "embedding.block0";
meta.value_names = {"Param"};
meta.value_dims = {64};
meta.mode = distributed::Mode::training;
meta.grad_name = "embedding@Grad";
meta.cached_varnames = {"kSparseIds"};
meta.initializer_attrs = {"fill_constant&1.0"};
meta.entry = "none";
metas.push_back(meta);
distributed::LargeScaleKV::Init(metas);
auto* ins = distributed::LargeScaleKV::GetInstance();
ins->Get("embedding.block0")->Init({0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
prefetch_var_name_to_prepared;
g_req_handler->SetProgram(&program);
g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared);
g_req_handler->SetDevCtx(&ctx);
g_req_handler->SetScope(&scope);
g_req_handler->SetExecutor(&exe);
g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get());
g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread(
std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get()));
server_thread.join();
}
TEST(LARGE_SCALE_CHECKPOINT, CPU) {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
paddle::framework::Scope scope;
paddle::platform::CPUPlace place;
g_req_handler.reset(new distributed::RequestCheckpointHandler(
distributed::DistributedMode::kAsync));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
PADDLE_ENFORCE_NE(client, nullptr,
platform::errors::InvalidArgument(
"Client Start Fail, Check Your Code & Env"));
std::thread server_thread(StartCheckpointServer,
distributed::kRequestCheckpoint);
g_rpc_service->WaitServerReady();
int port = g_rpc_service->GetSelectedPort();
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
auto save_path =
paddle::string::Sprintf("%s/%s/%s", "/tmp/large_scale_table/base",
"embedding", "embedding.block0");
int mode = 0;
client->AsyncCheckpointNotify(ep, save_path, "embedding.block0", mode);
client->Wait();
save_path =
paddle::string::Sprintf("%s/%s/%s", "/tmp/large_scale_table/delta",
"embedding", "embedding.block0");
mode = 1;
client->AsyncCheckpointNotify(ep, save_path, "embedding.block0", mode);
client->Wait();
paddle::framework::AttributeMap attrs;
std::vector<std::string> eps = {ep};
attrs["endpoints"] = eps;
attrs["dirname"] = std::string("/tmp/large_scale_table/delta1");
attrs["varname"] = std::string("embedding");
attrs["mode"] = 2;
std::vector<std::string> slices = {"embedding.block0"};
attrs["slice_varnames"] = slices;
std::vector<std::string> remotes = {"embedding.block0"};
attrs["remote_varnames"] = remotes;
auto ops =
framework::OpRegistry::CreateOp("checkpoint_notify", {}, {}, attrs, true);
ops->Run(scope, place);
g_rpc_service->ShutDown();
server_thread.join();
LOG(INFO) << "begin reset";
g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr);
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under
the Apache License, Version 2.0 (the "License"); you may not use this file
except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto3";
package sendrecv;
option cc_generic_services = @cc_generic_services@;
service SendRecvService {
// For parameter server round-robin like hashing, do not split tensors.
// Send and recv only one tensor
// TODO(typhoonzero): add streaming API
rpc SendVariable(VariableMessage) returns (VoidMessage) {}
// Argument VariableMessage for GetVariable should only contain varname.
rpc GetVariable(VariableMessage) returns (VariableMessage) {}
rpc GetVariableNoBarrier(VariableMessage) returns (VariableMessage) {}
// pre-fetch variable by given variable name and Ids
rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {}
rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {}
rpc DistributeNotify(VariableMessage) returns (VoidMessage) {}
rpc SendAndRecvVariable(VariableMessage) returns (VariableMessage) {}
rpc GetMonomerVariable(VariableMessage) returns (VariableMessage) {}
rpc GetMonomerBarrier(VariableMessage) returns (VoidMessage) {}
}
// It can be: LoDTensorSelectedRows or NCCL_ID
enum VarType {
LOD_TENSOR = 0;
SELECTED_ROWS = 1;
NCCL_ID = 2;
}
// VariableMessage is serialized paddle variable message.
// NOTICE(gongwb):don't modify this proto if you are not
// not familar with how we serialize in sendrecvop_utils.h
// and deserilize it in variable_response.h.
message VariableMessage {
enum Type {
// Pod Types
BOOL = 0;
INT16 = 1;
INT32 = 2;
INT64 = 3;
FP16 = 4;
FP32 = 5;
FP64 = 6;
}
message LodData { repeated int64 lod_data = 1; }
string varname = 1;
// TODO(Yancey1989): reference framework::proto::VarDesc::VarType
VarType type = 2;
// bool persistable is not needed for sending.
// tensor info:
Type data_type = 3;
repeated int64 dims = 4;
// lod details:
int64 lod_level = 5;
repeated LodData lod = 6;
// selected_rows height, aka. original dim0
int64 slr_height = 7;
// tensor data
bytes serialized = 8;
// selected_rows data
bytes rows = 9;
// Look up table block execution output variable name.
string out_varname = 10;
// If 1, the ps server will start profiling, the ps
// server stops profiling and generates a profile to /tmp/profile_ps_*
// when profile switches from 1 to 2.
int64 profile = 11;
int64 trainer_id = 12;
string table_name = 13;
}
message VoidMessage {}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
namespace paddle {
namespace framework {
class Variable;
} // namespace framework
namespace memory {
namespace allocation {
class Allocation;
} // namespace allocation
} // namespace memory
} // namespace paddle
DEFINE_bool(rpc_disable_reuse_port, false, "Disable SO_REUSEPORT or not.");
DEFINE_int32(rpc_retry_bind_port, 3,
"Retry to bind the address if address is already used.");
namespace paddle {
namespace operators {
namespace distributed {
using VarMsg = sendrecv::VariableMessage;
static TensorPayload GetCommunicationAllocationFromTensor(
const platform::DeviceContext& ctx, const framework::Tensor& tensor) {
if (is_gpu_place(ctx.GetPlace())) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PADDLE_ENFORCE_EQ(
is_gpu_place(tensor.place()), true,
platform::errors::PreconditionNotMet("Please run in gpu place."));
auto& gpu_dev_ctx =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type());
platform::CUDAPinnedPlace cuda_pinned;
auto result = memory::AllocShared(cuda_pinned, copy_size);
memory::Copy(cuda_pinned, result->ptr(),
BOOST_GET_CONST(platform::CUDAPlace, tensor.place()),
tensor.data<void>(), copy_size, gpu_dev_ctx.stream());
ctx.Wait();
return TensorPayload(result);
#else
PADDLE_THROW(
platform::errors::Unavailable("This situation should not be happened"));
#endif
} else {
return TensorPayload(tensor);
}
}
TensorPayload GetTensorPayload(framework::Variable* var,
const platform::DeviceContext& ctx,
VarMsg* request) {
auto tensor = var->Get<framework::LoDTensor>();
// FIXME(wuyi): data types in send_recv.proto is copied from
// framework.proto
request->set_data_type(static_cast<VarMsg::Type>(tensor.type()));
for (auto& dim : framework::vectorize(tensor.dims())) {
request->add_dims(dim);
}
const framework::LoD lod = tensor.lod();
if (lod.size() > 0) {
request->set_lod_level(lod.size());
for (auto& each : lod) {
VarMsg::LodData* lod_inner = request->add_lod();
for (auto& d : each) {
lod_inner->add_lod_data(d);
}
}
}
return GetCommunicationAllocationFromTensor(ctx, tensor);
}
TensorPayload GetSelectedRowsPayload(framework::Variable* var,
const platform::DeviceContext& ctx,
VarMsg* request) {
auto* slr = var->GetMutable<framework::SelectedRows>();
request->set_data_type(static_cast<VarMsg::Type>(slr->value().type()));
request->set_lod_level(0);
request->set_slr_height(slr->height());
for (auto& dim : framework::vectorize(slr->value().dims())) {
request->add_dims(dim);
}
auto* tensor = slr->mutable_value();
return GetCommunicationAllocationFromTensor(ctx, *tensor);
}
TensorPayload::TensorPayload(std::shared_ptr<memory::Allocation> allocation)
: allocation_(allocation), offset_(0), memory_size_(allocation->size()) {}
TensorPayload::TensorPayload(const framework::Tensor& tensor)
: allocation_(tensor.Holder()),
offset_(tensor.offset()),
memory_size_(tensor.numel() * framework::SizeOfType(tensor.type())) {}
void* TensorPayload::ptr() const {
return reinterpret_cast<void*>(
reinterpret_cast<uintptr_t>(allocation_->ptr()) + offset_);
}
size_t TensorPayload::memory_size() const { return memory_size_; }
} // namespace distributed
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <memory>
#include <string>
#include <typeindex>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/distributed_pb.h"
#include "paddle/fluid/platform/port.h"
namespace paddle {
namespace framework {
class Tensor;
class Variable;
} // namespace framework
namespace memory {
namespace allocation {
class Allocation;
} // namespace allocation
} // namespace memory
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
namespace distributed {
using VarMsg = sendrecv::VariableMessage;
class TensorPayload final {
public:
explicit TensorPayload(const framework::Tensor& tensor);
explicit TensorPayload(std::shared_ptr<memory::Allocation> allocation);
TensorPayload(const TensorPayload& o) = default;
TensorPayload& operator=(const TensorPayload& o) = default;
void* ptr() const;
size_t memory_size() const;
private:
std::shared_ptr<memory::Allocation> allocation_;
size_t offset_;
size_t memory_size_;
};
inline void SerializeDestroyCallback(void* payload) {
if (payload != nullptr) {
auto* shared_payload = reinterpret_cast<TensorPayload*>(payload);
delete shared_payload;
}
}
TensorPayload GetTensorPayload(framework::Variable* var,
const platform::DeviceContext& ctx,
VarMsg* request);
TensorPayload GetSelectedRowsPayload(framework::Variable* var,
const platform::DeviceContext& ctx,
VarMsg* request);
inline framework::proto::VarType::Type ToVarType(
sendrecv::VariableMessage::Type type) {
switch (type) {
case sendrecv::VariableMessage::FP32:
return framework::proto::VarType::FP32; // NOLINT
case sendrecv::VariableMessage::FP64:
return framework::proto::VarType::FP64; // NOLINT
case sendrecv::VariableMessage::INT32:
return framework::proto::VarType::INT32; // NOLINT
case sendrecv::VariableMessage::INT64:
return framework::proto::VarType::INT64; // NOLINT
case sendrecv::VariableMessage::BOOL:
return framework::proto::VarType::BOOL; // NOLINT
default:
PADDLE_THROW(
platform::errors::InvalidArgument("Not support type id: %d.", type));
}
}
template <template <typename> class T, typename Elem>
std::string VectorElemName(const T<Elem>& arg) {
return typeid(Elem).name();
}
} // namespace distributed
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
using paddle::operators::distributed::VarHandlePtr;
using paddle::operators::distributed::VarHandle;
void WaitTrue(VarHandlePtr s) { EXPECT_TRUE(s->Wait()); }
void WaitFalse(VarHandlePtr s) { EXPECT_FALSE(s->Wait()); }
TEST(VarHandle, Run) {
std::vector<VarHandlePtr> a;
for (int i = 0; i < 12; i++) {
VarHandlePtr s(new VarHandle("", "", "", nullptr, nullptr));
a.push_back(s);
}
std::vector<std::unique_ptr<std::thread>> t;
for (int i = 0; i < 6; i++) {
t.emplace_back(new std::thread(WaitFalse, a[i]));
}
for (int i = 0; i < 6; i++) {
a[i]->Finish(false);
t[i]->join();
}
for (int i = 6; i < 12; i++) {
t.emplace_back(new std::thread(WaitTrue, a[i]));
}
for (int i = 6; i < 12; i++) {
a[i]->Finish(true);
t[i]->join();
}
}
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/variable_response.h"
#include <vector>
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
DEFINE_string(rpc_server_profile_path, "./profile_ps",
"the profile log file path");
namespace paddle {
namespace operators {
namespace distributed {
bool VariableResponse::ReadRaw(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& dev_ctx,
platform::Place place, void* dest,
int64_t size) {
const void* data = NULL;
int size_to_write = 0;
int64_t length = size;
int total_written = 0;
if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto& gpu_dev_ctx =
static_cast<const platform::CUDADeviceContext&>(dev_ctx);
platform::CPUPlace cpu;
char* p = reinterpret_cast<char*>(dest);
while (total_written < length) {
if (!input->GetDirectBufferPointer(&data, &size_to_write)) {
return false;
}
// NOTE: if raw buffer is large and have two neighbor fields of raw
// buffers GetDirectBufferPointer can get all of them, use length to
// truncate it.
if (total_written + size_to_write > length) {
size_to_write = length - total_written;
}
// This log is useful to see how long a internal block size is of rpc.
VLOG(7) << "copy " << size_to_write << " data to CUDAPlace";
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, place),
reinterpret_cast<void*>(p), cpu, data, size_to_write,
gpu_dev_ctx.stream());
p += size_to_write;
total_written += size_to_write;
input->Skip(size_to_write);
}
gpu_dev_ctx.Wait();
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"Unexpected branch, please compile with WITH_GPU or WITH_ROCM"));
#endif
return true;
} else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
auto& xpu_dev_ctx = static_cast<const platform::XPUDeviceContext&>(dev_ctx);
platform::CPUPlace cpu;
char* p = reinterpret_cast<char*>(dest);
while (total_written < length) {
if (!input->GetDirectBufferPointer(&data, &size_to_write)) {
return false;
}
if (total_written + size_to_write > length) {
size_to_write = length - total_written;
}
memory::Copy(BOOST_GET_CONST(platform::XPUPlace, place),
reinterpret_cast<void*>(p), cpu, data, size_to_write);
p += size_to_write;
total_written += size_to_write;
input->Skip(size_to_write);
}
xpu_dev_ctx.Wait();
#else
PADDLE_ENFORCE_NOT_NULL(
nullptr,
platform::errors::Unimplemented(
"Not supported XPU, please compile with option WITH_XPU=ON."));
#endif
return true;
}
char* p = reinterpret_cast<char*>(dest);
while (total_written < length) {
if (!input->GetDirectBufferPointer(&data, &size_to_write)) {
return false;
}
// NOTE: if raw buffer is large and have two neighbor fields of raw buffers
// GetDirectBufferPointer can get all of them, use length to truncate it.
if (total_written + size_to_write > length) {
size_to_write = length - total_written;
}
// TODO(gongwb): can we avoid copy?
platform::CPUPlace cpu;
// This log is useful to see how long a internal block size is of rpc.
VLOG(7) << "copy " << size_to_write << " data to CPUPlace";
memory::Copy(cpu, reinterpret_cast<void*>(p), cpu, data, size_to_write);
p += size_to_write;
total_written += size_to_write;
input->Skip(size_to_write);
}
return true;
}
bool VariableResponse::CopyLodTensorData(
::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, const framework::DDim& dims,
int length) {
auto server_var = GetVar();
if (!server_var) {
LOG(ERROR) << "recved var should not on current server: "
<< meta_.varname();
return false;
}
auto* tensor = GetVar()->GetMutable<framework::LoDTensor>();
tensor->Resize(dims);
framework::LoD lod;
for (int i = 0; i < meta_.lod_level(); ++i) {
framework::Vector<size_t> v;
for (int j = 0; j < meta_.lod(i).lod_data_size(); ++j) {
v.push_back(meta_.lod(i).lod_data(j));
}
lod.push_back(v);
}
tensor->set_lod(lod);
void* tensor_data =
tensor->mutable_data(ctx.GetPlace(), ToVarType(meta_.data_type()));
VLOG(6) << "Tensor.memory_size = " << tensor->memory_size()
<< ", Buffer Size = " << length << ", dims:" << dims
<< ", numel:" << tensor->numel();
PADDLE_ENFORCE_GE(
tensor->memory_size(), static_cast<unsigned int>(length),
platform::errors::InvalidArgument(
"The memory size of tensor: %s should greater than length: %s",
tensor->memory_size(), length));
return ReadRaw(input, ctx, tensor->place(), tensor_data, length);
}
inline framework::DDim GetDims(
const ::google::protobuf::RepeatedField<::google::protobuf::int64>& dims) {
std::vector<int> vecdims;
for (auto& d : dims) {
vecdims.push_back(d);
}
return framework::make_ddim(vecdims);
}
bool VariableResponse::CopySelectRowsTensorData(
::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, const framework::DDim& dims,
int length) {
auto* slr = GetVar()->GetMutable<framework::SelectedRows>();
slr->set_height(meta_.slr_height());
auto* tensor = slr->mutable_value();
tensor->Resize(dims);
PADDLE_ENFORCE_EQ(
static_cast<size_t>(tensor->numel()),
length / framework::SizeOfType(paddle::operators::distributed::ToVarType(
meta_.data_type())),
platform::errors::InvalidArgument(
"length: %s should equal to memory size of tensor: %s", length,
tensor->numel() *
framework::SizeOfType(paddle::operators::distributed::ToVarType(
meta_.data_type()))));
void* tensor_data = tensor->mutable_data(
ctx.GetPlace(),
paddle::operators::distributed::ToVarType(meta_.data_type()));
if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) {
return false;
}
return true;
}
bool VariableResponse::CopySelectRowsData(
::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, int length) {
auto* slr = GetVar()->GetMutable<framework::SelectedRows>();
slr->mutable_rows()->clear();
slr->mutable_rows()->resize(length / sizeof(int64_t)); // int64
int64_t* rows_data = slr->mutable_rows()->data();
// copy rows CPU data, GPU data will be copied lazily.
platform::CPUPlace cpu;
if (!ReadRaw(input, ctx, cpu, rows_data, length)) {
return false;
}
return true;
}
bool VariableResponse::ProcSerializedField(
int tag, ::google::protobuf::io::CodedInputStream* input,
int64_t num_bytes) {
PADDLE_ENFORCE(
(meta_.type() == sendrecv::SELECTED_ROWS ||
meta_.type() == sendrecv::LOD_TENSOR ||
meta_.type() == sendrecv::NCCL_ID) &&
meta_.varname() != "",
platform::errors::PreconditionNotMet("meta info should be got first!"));
if (meta_.type() == sendrecv::NCCL_ID) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto* var = scope_->FindVar(meta_.varname());
if (var != nullptr) {
ncclUniqueId* id = var->GetMutable<ncclUniqueId>();
if (!ReadRaw(input, *dev_ctx_, platform::CPUPlace(), id->internal,
num_bytes)) {
return false;
}
}
return true;
#else
PADDLE_THROW(
platform::errors::PreconditionNotMet("Please compiled with CUDA!"));
return false;
#endif
}
VLOG(7) << "ProcSerializedField:" << meta_.varname()
<< ", type:" << meta_.type() << std::endl;
framework::DDim dims = GetDims(meta_.dims());
if (meta_.type() == sendrecv::LOD_TENSOR) {
PADDLE_ENFORCE_GE(
meta_.lod_size(), 0,
platform::errors::PreconditionNotMet("lod info should be got first!"));
if (!CopyLodTensorData(input, *dev_ctx_, dims, num_bytes)) {
return false;
}
return true;
}
if (meta_.type() == sendrecv::SELECTED_ROWS) {
if (!CopySelectRowsTensorData(input, *dev_ctx_, dims, num_bytes)) {
return false;
}
return true;
}
PADDLE_THROW(platform::errors::InvalidArgument(
"The type: %s of var: %s is not supported", meta_.type(),
meta_.varname()));
return false;
}
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/distributed_pb.h"
namespace google {
namespace protobuf {
namespace io {
class CodedInputStream;
class ZeroCopyInputStream;
} // namespace io
} // namespace protobuf
} // namespace google
namespace paddle {
namespace framework {
class Variable;
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
DECLARE_string(rpc_server_profile_path);
namespace paddle {
namespace operators {
namespace distributed {
// Source provides a way for a particular RPC implementation to provide
// received data to ParseFrom.
class Source {
public:
virtual ~Source() {}
// Return the stream that contains the data to be parsed.
// Note that this method might be invoked more than once if
// ParseFrom needs to fall back to a more expensive parsing method.
// Every call must return a stream pointing at the beginning of
// the serialized RecvTensorResponse.
//
// Note that a subsequent call to contents() invalidates previous
// results of contents().
//
// Ownership of the returned stream is retained by the Source and
// should not be deleted by the caller.
virtual ::google::protobuf::io::ZeroCopyInputStream* contents() = 0;
};
class VariableResponse {
public:
VariableResponse(const framework::Scope* scope,
const platform::DeviceContext* dev_ctx,
bool create_scope = false)
: scope_(scope), dev_ctx_(dev_ctx), create_scope_(create_scope) {
if (create_scope) {
local_scope_ = scope->NewTmpScope().release();
}
}
virtual ~VariableResponse() {
if (local_scope_) {
delete local_scope_;
local_scope_ = nullptr;
}
}
int Parse(Source* source, const sendrecv::VariableMessage& meta) {
meta_ = meta;
return Parse(source);
}
// return:
// 0:ok.
// -1: unkown error.
// other: number of error field.
virtual int Parse(Source* source) = 0;
inline const framework::Scope& GetLocalScope() const { return *local_scope_; }
inline framework::Scope* GetMutableLocalScope() const { return local_scope_; }
inline std::string Varname() const { return meta_.varname(); }
inline std::string OutVarname() const { return meta_.out_varname(); }
inline std::string TableName() const { return meta_.table_name(); }
// should call parse first.
framework::Variable* GetVar() {
if (create_scope_) {
return local_scope_->Var(meta_.varname());
}
return scope_->FindVar(meta_.varname());
}
framework::Variable* GetRecvVar() {
if (create_scope_) {
return local_scope_->Var(meta_.out_varname());
}
return scope_->FindVar(meta_.out_varname());
}
int GetTrainerId() { return static_cast<int>(meta_.trainer_id()); }
protected:
bool ReadRaw(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& dev_ctx, platform::Place place,
void* dest, int64_t size);
bool CopySelectRowsTensorData(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx,
const framework::DDim& dims, int length);
bool CopySelectRowsData(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, int length);
bool CopyLodTensorData(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx,
const framework::DDim& dims, int length);
bool ProcSerializedField(int tag,
::google::protobuf::io::CodedInputStream* input,
int64_t num_bytes);
protected:
const framework::Scope* scope_;
const platform::DeviceContext* dev_ctx_;
bool create_scope_ = false;
framework::Scope* local_scope_ = nullptr;
sendrecv::VariableMessage meta_;
};
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
include(operators)
set(DISTRIBUTE_DEPS "")
if(WITH_GRPC)
set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator async_sparse_param_update_recorder grpc++_unsecure grpc_unsecure gpr zlib protobuf node)
else()
set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator async_sparse_param_update_recorder brpc leveldb protobuf ssl crypto zlib node)
if(WITH_BRPC_RDMA)
find_library(IBVERBS_LIBRARY NAMES ibverbs)
ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET ibverbs PROPERTY IMPORTED_LOCATION ${IBVERBS_LIBRARY})
find_library(RDMACM_LIBRARY NAMES rdmacm)
ADD_LIBRARY(rdmacm SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET rdmacm PROPERTY IMPORTED_LOCATION ${RDMACM_LIBRARY})
set(DISTRIBUTE_DEPS ${DISTRIBUTE_DEPS} ibverbs rdmacm)
endif()
endif()
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
file(GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc")
list(REMOVE_DUPLICATES OPS)
foreach(src ${OPS})
set_source_files_properties(${src} PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
endforeach()
register_operators(EXCLUDES gen_nccl_id_op DEPS ${DISTRIBUTE_DEPS})
if(WITH_NCCL OR WITH_RCCL)
set(DISTRIBUTE_DEPS ${DISTRIBUTE_DEPS} nccl_common)
endif()
set(OPERATOR_DEPS ${OPERATOR_DEPS} ${DISTRIBUTE_DEPS} PARENT_SCOPE)
set(GLOB_DISTRIBUTE_DEPS ${DISTRIBUTE_DEPS} CACHE INTERNAL "distributed dependency")
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/operators/distributed_ops/allreduce_op.h"
namespace paddle {
namespace operators {
class AllReduceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
class AllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor), tensor to be allreduced.");
AddOutput("Out", "(Tensor) the result of allreduced.");
AddAttr<int>("reduce_type", "(int) determin the reduce type.")
.SetDefault(0);
AddAttr<bool>(
"sync_mode",
"(bool) whether to synchronize the CUDA stream after nccl call.")
.SetDefault(false);
AddComment(R"DOC(
***AllReduce Operator***
Call NCCL AllReduce internally. Note that this op must be used when one
thread is managing one GPU device.
For speed reasons, reduce_type should be an integer:
0: sum
1: prod
2: max
3: min
If input and output are the same variable, in-place allreduce will be used.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(allreduce, ops::AllReduceOp,
ops::AllReduceOpMaker);
REGISTER_OP_CPU_KERNEL(
allreduce, ops::AllReduceOpKernel<plat::CPUDeviceContext, float>,
ops::AllReduceOpKernel<plat::CPUDeviceContext, double>,
ops::AllReduceOpKernel<plat::CPUDeviceContext, int>,
ops::AllReduceOpKernel<plat::CPUDeviceContext, int64_t>,
ops::AllReduceOpKernel<plat::CPUDeviceContext, plat::float16>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/distributed_ops/allreduce_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
allreduce, ops::AllReduceOpKernel<plat::CUDADeviceContext, float>,
ops::AllReduceOpKernel<plat::CUDADeviceContext, double>,
ops::AllReduceOpKernel<plat::CUDADeviceContext, int>,
ops::AllReduceOpKernel<plat::CUDADeviceContext, int64_t>,
ops::AllReduceOpKernel<plat::CUDADeviceContext, plat::float16>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class AllReduceOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto place = ctx.GetPlace();
PADDLE_ENFORCE_EQ(is_gpu_place(place), true,
platform::errors::PreconditionNotMet(
"AllReduce op can run on gpu place only for now."));
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
int dtype = platform::ToNCCLDataType(in->type());
int64_t numel = in->numel();
auto* sendbuff = in->data<void>();
out->Resize(in->dims());
void* recvbuff = out->mutable_data<T>(place);
auto* comm = dev_ctx.nccl_comm();
// FIXME(typhoonzero): should use nccl stream here.
auto stream = dev_ctx.stream();
PADDLE_ENFORCE_NOT_NULL(
stream, platform::errors::NotFound("Should initialize NCCL firstly."));
int reduce_type = ctx.Attr<int>("reduce_type");
ncclRedOp_t red_type = ncclSum;
switch (reduce_type) {
case 0:
red_type = ncclSum;
break;
case 1:
red_type = ncclProd;
break;
case 2:
red_type = ncclMax;
break;
case 3:
red_type = ncclMin;
break;
}
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, static_cast<ncclDataType_t>(dtype), red_type,
comm, stream));
if (ctx.Attr<bool>("sync_mode")) {
#ifdef PADDLE_WITH_RCCL
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
#endif
}
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <ostream>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class BroadcastOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of BroadcastOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Output) of ConvOp should not be null."));
}
};
class BroadcastOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor), tensor to be broadcast.");
AddOutput("Out", "(Tensor) the result of broadcast.");
AddAttr<bool>(
"sync_mode",
"(bool) whether to synchronize the CUDA stream after nccl call.")
.SetDefault(false);
AddAttr<int>("root", "(int).").SetDefault(0).EqualGreaterThan(0);
AddComment(R"DOC(
***Broadcast Operator***
Call NCCL Broadcast internally. Note that this op must be used when one
thread is managing one GPU device.
)DOC");
}
};
template <typename T>
class BroadcastOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"Broadcast op can run on gpu place only for now."));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(broadcast, ops::BroadcastOp,
ops::BroadcastOpMaker);
REGISTER_OP_CPU_KERNEL(broadcast, ops::BroadcastOpKernel<float>,
ops::BroadcastOpKernel<double>,
ops::BroadcastOpKernel<int>,
ops::BroadcastOpKernel<int64_t>,
ops::BroadcastOpKernel<plat::float16>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace ops = paddle::operators;
namespace plat = paddle::platform;
namespace paddle {
namespace operators {
template <typename T>
class NCCLBroadcastOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet(
"The place of ExecutionContext should be CUDAPlace."));
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
int dev_id = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()).device;
int root_dev_id = ctx.Attr<int>("root");
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
PADDLE_ENFORCE_EQ(
out->IsInitialized(), true,
platform::errors::PreconditionNotMet(
"Currently, the output of broadcast op must be initialized,"
"because this op can only be an In-Place operation."));
void* send_recv_buffer = out->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE_EQ(
send_recv_buffer, in->data<void>(),
platform::errors::PreconditionNotMet("Currently, the broadcast op can "
"only be an In-Place operation."));
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto comm = dev_ctx.nccl_comm();
auto stream = dev_ctx.stream();
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast(
send_recv_buffer, static_cast<size_t>(in->numel()),
platform::ToNCCLDataType(in->type()), root_dev_id, comm, stream));
VLOG(3) << "Bcast " << ctx.InputNames("X")[0] << ", (" << in->numel() << ")"
<< " From " << root_dev_id << " to " << dev_id;
if (ctx.Attr<bool>("sync_mode")) {
#ifdef PADDLE_WITH_RCCL
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
#endif
}
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(broadcast, ops::NCCLBroadcastOpKernel<float>,
ops::NCCLBroadcastOpKernel<double>,
ops::NCCLBroadcastOpKernel<int>,
ops::NCCLBroadcastOpKernel<int64_t>,
ops::NCCLBroadcastOpKernel<plat::float16>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/distributed.h"
namespace paddle {
namespace framework {
class InferShapeContext;
class OpDesc;
class Scope;
template <typename T>
class EmptyGradOpMaker;
} // namespace framework
namespace imperative {
class OpBase;
} // namespace imperative
} // namespace paddle
namespace paddle {
namespace operators {
class CheckpointNotifyOp : public framework::OperatorBase {
public:
CheckpointNotifyOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
std::vector<std::string> epmap =
Attr<std::vector<std::string>>("endpoints");
std::string dirname = Attr<std::string>("dirname");
std::string varname = Attr<std::string>("varname");
auto mode = Attr<int>("mode");
if (mode != 0 && mode != 1 && mode != 2) {
PADDLE_THROW(platform::errors::InvalidArgument(
"mode expected in [0/1/2], but got %d", mode));
}
std::vector<std::string> slice_varnames =
Attr<std::vector<std::string>>("slice_varnames");
std::vector<std::string> remote_varnames =
Attr<std::vector<std::string>>("remote_varnames");
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
for (size_t i = 0; i < epmap.size(); i++) {
auto save_path =
string::Sprintf("%s/%s/%s", dirname, varname, slice_varnames[i]);
rpc_client->AsyncCheckpointNotify(epmap[i], save_path, remote_varnames[i],
mode);
VLOG(3) << "checkpoint notify sending with path: " << save_path
<< " and var:" << slice_varnames[i] << " to " << epmap[i]
<< " with mode " << mode;
}
PADDLE_ENFORCE_EQ(
rpc_client->Wait(), true,
platform::errors::Fatal("Fail to notify checkpoint."
" Internal error occurs in RPCClient."));
}
};
class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddAttr<std::vector<std::string>>(
"endpoints",
"(string vector)"
"Parameter Server endpoints in the order");
AddAttr<std::string>("dirname",
"(string) indicate the folder checkpoint will use");
AddAttr<std::string>("varname", "(string) the var need to be saved");
AddAttr<std::vector<std::string>>(
"slice_varnames", "(string vector) the slice vars need to be saved");
AddAttr<std::vector<std::string>>(
"remote_varnames", "(string vector) the slice vars need to be saved");
AddAttr<int>("mode", "mode=0/1/2 means nothing/save base/save delta")
.SetDefault(0);
AddComment(R"DOC(
CheckpointNotify operator
This operator will send lookup table and it's checkpoint direcoty to listen_and_serve op at
the parameter server.
)DOC");
}
};
class CheckpointNotifyOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
checkpoint_notify, ops::CheckpointNotifyOp,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::CheckpointNotifyOpMaker, ops::CheckpointNotifyOpShapeInference);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
#include "paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
class DistributedLookupTableOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInputs("Ids"), true,
platform::errors::InvalidArgument(
"Input(Ids) of LookupTableOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true,
platform::errors::InvalidArgument(
"Input(W) of LookupTableOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutputs("Outputs"), true,
platform::errors::InvalidArgument(
"Output(Outs) of LookupTableOp should not be null."));
auto ids_dims = ctx->GetInputsDim("Ids");
auto table_dims = ctx->GetInputDim("W");
PADDLE_ENFORCE_EQ(
table_dims.size(), 2,
platform::errors::InvalidArgument(
"Only 2 dimensions of the 'Embedding' is supported."));
for (auto &ids_dim : ids_dims) {
PADDLE_ENFORCE_EQ(ids_dim.size(), 2,
platform::errors::InvalidArgument(
"The dimension of the 'Ids' tensor must be 2."));
}
auto endpoints = ctx->Attrs().Get<std::vector<std::string>>("endpoints");
// for fluid.embedding
auto lookup_table_version =
ctx->Attrs().Get<std::string>("lookup_table_version");
auto outputs_dims = std::vector<framework::DDim>();
for (auto &ids_dim : ids_dims) {
if (lookup_table_version == "lookup_table") {
outputs_dims.push_back(
framework::make_ddim({ids_dim[0], table_dims[1]}));
} else if (lookup_table_version == "lookup_table_v2") {
outputs_dims.push_back(framework::make_ddim(
{static_cast<int64_t>(ids_dim[0]), static_cast<int64_t>(ids_dim[1]),
static_cast<int64_t>(table_dims[1])}));
}
}
ctx->SetOutputsDim("Outputs", outputs_dims);
ctx->ShareLoD("Ids", /*->*/ "Outputs");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
};
class DistributedLookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Ids",
"(LoDTensor) Ids's type should be LoDTensor"
"THe ids to be looked up in W.")
.AsDuplicable();
AddInput("W",
"(Tensor) The input represents embedding tensors, "
"which is a learnable parameter.");
AddOutput("Outputs",
"(LoDTensor) The lookup results, which have the same type as W.")
.AsDuplicable();
AddAttr<std::vector<std::string>>(
"table_names",
"(string vector, such as emb_block0, emb_block1)"
"Server endpoints in the order of input variables for mapping")
.SetDefault({""});
AddAttr<std::vector<std::string>>(
"endpoints",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input variables for mapping")
.SetDefault({"127.0.0.1:6164"});
AddAttr<int>("pserver_num", "the number of pserver").SetDefault(0);
AddAttr<bool>("is_distributed",
"(boolean, default false) distributed lookup table.")
.SetDefault(false);
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::string>(
"lookup_table_version",
"(string, default lookup_table) "
"To distinguish between different versions of embedding OP")
.SetDefault(std::string("lookup_table"));
AddAttr<int64_t>("padding_idx",
"(int64, default -1) "
"If the value is -1, it makes no effect to lookup. "
"Otherwise the given value indicates padding the output "
"with zeros whenever lookup encounters it in Ids.")
.SetDefault(distributed::kNoPadding);
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::proto::VarType::FP32);
AddComment(R"DOC(
Lookup Tablel Prefetch Operator.
This operator is used to perform lookup on parameter W,
then concatenated into a sparse tensor.
The type of Ids(Input) is SelectedRows, the rows of Ids contains
the ids to be looked up in W;
if the Id is not in the sparse table, this operator will return a
random value and set the value into the table for the next looking up.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(distributed_lookup_table, ops::DistributedLookupTableOp,
ops::DistributedLookupTableOpMaker);
REGISTER_OP_CPU_KERNEL(distributed_lookup_table,
ops::DistributedLookupTableKernel<
paddle::platform::CPUDeviceContext, float>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
distributed_lookup_table,
ops::DistributedLookupTableKernel<plat::CUDADeviceContext, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class DistributedLookupTableKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto ids_vars = context.MultiInputVar("Ids");
auto emb_vars = context.MultiOutput<framework::Tensor>("Embeddings");
auto id_names = context.InputNames("Ids");
auto embedding_name = context.InputNames("W").front();
auto out_names = context.OutputNames("Outputs");
auto lookup_tables = context.Attr<std::vector<std::string>>("table_names");
auto endpoints = context.Attr<std::vector<std::string>>("endpoints");
auto is_distributed = context.Attr<bool>("is_distributed");
auto lookup_table_version =
context.Attr<std::string>("lookup_table_version");
operators::distributed::prefetchs(id_names, out_names, embedding_name,
is_distributed, lookup_tables, endpoints,
context, context.scope());
if (lookup_table_version == "lookup_table_v2") {
auto &scope = context.scope();
auto emb_dim =
scope.FindVar(embedding_name)->Get<framework::LoDTensor>().dims()[1];
for (size_t i = 0; i < id_names.size(); ++i) {
auto *id_var = scope.FindVar(id_names[i]);
auto *out_var = scope.FindVar(out_names[i]);
auto *id_tensor = id_var->GetMutable<framework::LoDTensor>();
auto *out_tensor = out_var->GetMutable<framework::LoDTensor>();
auto id_dims = id_tensor->dims();
out_tensor->Resize(framework::make_ddim(
{static_cast<int64_t>(id_dims[0]), static_cast<int64_t>(id_dims[1]),
static_cast<int64_t>(emb_dim)}));
}
}
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
class FakeInitInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "FakeInit");
auto &shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
ctx->SetOutputDim("Out", framework::make_ddim(shape));
}
};
class FakeInitOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
framework::Tensor *tensor = nullptr;
auto &out_var = *scope.FindVar(Output("Out"));
if (out_var.IsType<framework::LoDTensor>()) {
tensor = out_var.GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim(Attr<std::vector<int64_t>>("shape")));
} else if (out_var.IsType<framework::SelectedRows>()) {
tensor = out_var.GetMutable<framework::SelectedRows>()->mutable_value();
tensor->Resize(framework::make_ddim(Attr<std::vector<int64_t>>("shape")));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"fake init op's output only"
"supports SelectedRows and LoDTensor"));
}
}
};
class FakeInitOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {}
};
class FakeInitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddAttr<std::vector<int64_t>>("shape",
"(vector<int64_t>) The shape of the output");
AddOutput("Out",
"(Tensor) Tensor of specified shape will be filled "
"with the specified value");
AddComment(R"DOC(
FakeInit Operator.
Init an variable but not alloc memory for it, it is used for init the
table parameter at trainer side in distributed lookup table.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
fake_init, ops::FakeInitOp, ops::FakeInitInferShape, ops::FakeInitOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::FakeInitOpVarTypeInference);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/distributed.h"
namespace paddle {
namespace framework {
class InferShapeContext;
class OpDesc;
class Scope;
template <typename T>
class EmptyGradOpMaker;
} // namespace framework
namespace imperative {
class OpBase;
} // namespace imperative
} // namespace paddle
namespace paddle {
namespace operators {
namespace distributed {
class RPCClient;
} // namespace distributed
class FetchBarrierOp : public framework::OperatorBase {
public:
FetchBarrierOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id"));
std::vector<distributed::VarHandlePtr> rets;
for (auto& ep : eps) {
VLOG(3) << "fetch barrier, ep: " << ep;
rets.push_back(rpc_client->AsyncSendFetchBarrier(ep));
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U,
platform::errors::Unavailable(
"Internal error occurred in RPCClient."));
}
}
};
class FetchBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Any) Dummy inputs, used for control dependency")
.AsDispensable()
.AsDuplicable();
AddOutput("Out", "(Any) Dummy outputs, used for control dependency")
.AsDuplicable();
AddComment(R"DOC(
SendBarrier operator
This operator will send a send barrier signal to list_and_serv op, so that
the Parameter Server would knew all variables have been sent.
)DOC");
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<std::string>>("endpoints",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to.")
.SetDefault({"127.0.0.1:6164"});
}
};
class FetchBarrierOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
fetch_barrier, ops::FetchBarrierOp,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::FetchBarrierOpMaker, ops::FetchBarrierOpShapeInference);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdio.h> // for removing the port file
#include <csignal>
#include <cstdlib>
#include <fstream>
#include <thread> // NOLINT
#include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed_ops/fl_listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_int32(flrpc_send_thread_num, 12, "number of threads for rpc send");
DEFINE_int32(flrpc_get_thread_num, 12, "number of threads for rpc get");
namespace paddle {
namespace operators {
void FlRunServer(std::shared_ptr<distributed::RPCServer> service) {
service->StartServer();
}
static void flsplit(const std::string &str, char sep,
std::vector<std::string> *pieces) {
pieces->clear();
if (str.empty()) {
return;
}
size_t pos = 0;
size_t next = str.find(sep, pos);
while (next != std::string::npos) {
pieces->push_back(str.substr(pos, next - pos));
pos = next + 1;
next = str.find(sep, pos);
}
if (!str.substr(pos).empty()) {
pieces->push_back(str.substr(pos));
}
}
static void FlParallelExecuteBlocks(
const std::vector<size_t> &parallel_blkids, framework::Executor *executor,
const std::vector<std::shared_ptr<framework::ExecutorPrepareContext>>
&prepared,
framework::ProgramDesc *program, framework::Scope *scope) {
std::vector<std::future<void>> fs;
for (size_t idx : parallel_blkids) {
fs.push_back(framework::Async([&executor, &prepared, &scope, idx]() {
int run_block = idx; // thread local
try {
VLOG(3) << "running server block: " << run_block
<< "pointer: " << prepared[run_block].get();
executor->RunPreparedContext(prepared[run_block].get(), scope);
} catch (const std::exception &e) {
PADDLE_THROW(platform::errors::Fatal(
"Run %d-th sub program failed. The exception is:\n%s.", idx,
e.what()));
}
}));
}
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
}
FlListenAndServOp::FlListenAndServOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
FlListenAndServOp::~FlListenAndServOp() {}
void FlListenAndServOp::SavePort() const {
// NOTE: default write file to /tmp/paddle.selected_port
rpc_service_->SavePort();
}
static int64_t GetTimestamp() {
struct timeval tp;
gettimeofday(&tp, NULL);
return tp.tv_sec * 1000 + tp.tv_usec / 1000;
}
void FlListenAndServOp::RunSyncLoop(framework::Executor *executor,
framework::ProgramDesc *program,
framework::Scope *recv_scope,
platform::DeviceContext *dev_ctx) const {
VLOG(2) << "RunSyncLoop";
size_t num_blocks = program->Size();
auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
PADDLE_ENFORCE_GE(num_blocks, 2,
platform::errors::InvalidArgument(
"server program should have at least 2 blocks"));
// Prepare all the server block
std::vector<int> optimize_blocks_list;
for (size_t i = 1; i < program->Size(); ++i) {
optimize_blocks_list.push_back(i);
}
auto optimize_prepared = executor->Prepare(*program, optimize_blocks_list);
// Insert placeholder for block0 which holds current op itself,
// NOTE the first block in `optimize_prepared` should never be ran.
optimize_prepared.insert(
optimize_prepared.begin(),
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));
while (true) {
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
VLOG(3) << "wait all clients to get pserver parameters back";
rpc_service_->SetCond(distributed::kRequestGet);
VLOG(3) << "wait all clients to send fetch_barrier";
rpc_service_->WaitBarrier(distributed::kRequestGet);
if (rpc_service_->IsExit()) {
rpc_service_->SetCond(distributed::kRequestGet);
break;
}
VLOG(3) << "wait all clients to send after_optimizer parameters";
rpc_service_->SetCond(distributed::kRequestSend);
VLOG(3) << "wait all clients to send send_barrier";
rpc_service_->WaitBarrier(distributed::kRequestSend);
VLOG(3) << "ResetBarrierCounter";
rpc_service_->ResetBarrierCounter();
// NOTE: if is_gpu_place, CUDA kernels are launched by multiple threads
// and this will still work.
// The optimize blocks which have the same parent ID would run parallel
// TODO(Yancey1989): need to use ParallelExecutor for future
int32_t last_parent_blkid = optimize_blocks[0]->Parent();
std::vector<size_t> parallel_blkids;
parallel_blkids.push_back(optimize_blocks[0]->ID());
double ts = GetTimestamp();
for (size_t i = 1; i < optimize_blocks.size(); ++i) {
// skip the first optimize block because it is already in the
// parallel_blkids.
int blkid = optimize_blocks[i]->ID();
if (program->Block(blkid).Parent() != last_parent_blkid) {
FlParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared,
program, recv_scope);
parallel_blkids.clear();
last_parent_blkid = program->Block(blkid).Parent();
}
parallel_blkids.push_back(blkid);
}
FlParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared,
program, recv_scope);
VLOG(3) << "run all blocks spent " << GetTimestamp() - ts << "(ms)";
} // while(true)
}
static void FillRequestCtx(distributed::RequestHandler *h,
framework::Scope *scope,
platform::DeviceContext *dev_ctx,
framework::Executor *executor,
framework::ProgramDesc *program,
distributed::RPCServer *rpc_server) {
h->SetScope(scope);
h->SetDevCtx(dev_ctx);
h->SetExecutor(executor);
h->SetProgram(program);
h->SetRPCServer(rpc_server);
}
void FlListenAndServOp::RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const {
// Mark this as PS that it should decide profiling by listening from trainer.
platform::SetProfileListener();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
framework::Scope &recv_scope = scope.NewScope();
bool sync_mode = Attr<bool>("sync_mode");
auto fan_in = Attr<int>("Fanin");
auto inputs = Inputs("X");
PADDLE_ENFORCE_EQ(!rpc_service_, true, platform::errors::InvalidArgument(
"rpc_service_ must null"));
std::string endpoint = Attr<std::string>("endpoint");
VLOG(4) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in
<< ", end_point:" << endpoint;
rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in));
request_send_handler_.reset(
new distributed::RequestSendHandler(!sync_mode, false));
request_get_handler_.reset(
new distributed::RequestGetHandler(!sync_mode, false));
rpc_service_->RegisterRPC(distributed::kRequestSend,
request_send_handler_.get(),
FLAGS_flrpc_send_thread_num);
rpc_service_->RegisterRPC(distributed::kRequestGet,
request_get_handler_.get(),
FLAGS_flrpc_get_thread_num);
auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
PADDLE_ENFORCE_GE(
optimize_blocks.size(), 1,
platform::errors::InvalidArgument(
"optimize blocks should be 1 at least on the pserver side."));
auto *program = optimize_blocks[0]->Program();
framework::Executor executor(dev_place);
auto f = std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope,
&dev_ctx, &executor, program, rpc_service_.get());
f(request_send_handler_.get());
f(request_get_handler_.get());
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(FlRunServer, rpc_service_));
VLOG(3) << "wait server thread to become ready...";
rpc_service_->WaitServerReady();
// register SIGINT(from ctrl+C) and SIGTERM(from kill) signal handlers
signal(SIGINT, FlSignalHandler::StopAndExit);
signal(SIGTERM, FlSignalHandler::StopAndExit);
// Cache the type of the received vars as `sparse_vars_` and `dense_vars_`
// so that we can reset them at the end of each iteration.
// NOTE: only used in sync update
// Write to a file of server selected port for python use.
SavePort();
RunSyncLoop(&executor, program, &recv_scope, &dev_ctx);
}
class FlListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) Variables that server recv.").AsDuplicable();
AddComment(R"DOC(" + "ListenAndServ operator" + "\n" + "This operator" +
" will start a RPC server which can receive variables from send_op and send" +
"back variables to recv_op.)DOC");
AddAttr<std::string>("endpoint",
"(string, default 127.0.0.1:6164)"
"IP address to listen on.")
.SetDefault("127.0.0.1:6164")
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true);
AddAttr<int>("Fanin", "How many clients send to this server.")
.SetDefault(1);
AddAttr<std::vector<framework::BlockDesc *>>(
kOptimizeBlocks, "Optimize blocks to run on server side.")
.SetDefault({});
}
};
void FlSignalHandler::StopAndExit(int signal_num) {
// Do not use VLOG here for the device for printing maybe already released.
// exit will release interal allocated resoureces.
auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid());
remove(file_path.c_str());
exit(0);
}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(fl_listen_and_serv, ops::FlListenAndServOp,
ops::FlListenAndServOpMaker);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdint.h>
#include <atomic>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace framework {
class Executor;
class ProgramDesc;
class Scope;
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
namespace distributed {
class RPCServer;
class RequestHandler;
} // namespace distributed
constexpr char kOptimizeBlocks[] = "optimize_blocks";
void FlRunServer(std::shared_ptr<distributed::RPCServer> service);
template <class TKey, class TValue>
class DoubleFindMap : public std::unordered_map<TKey, TValue> {
public:
typename std::unordered_map<TKey, TValue>::iterator find_value(TValue v) {
return std::find_if(this->begin(), this->end(),
[&v](const std::pair<const std::string, int> p) {
return p.second == v;
});
}
};
class FlListenAndServOp : public framework::OperatorBase {
public:
FlListenAndServOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs);
virtual ~FlListenAndServOp();
void RunSyncLoop(framework::Executor* executor,
framework::ProgramDesc* program,
framework::Scope* recv_scope,
platform::DeviceContext* dev_ctx) const;
void SavePort() const;
int GetSelectedPort() { return rpc_service_->GetSelectedPort(); }
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override;
protected:
mutable std::shared_ptr<distributed::RPCServer> rpc_service_;
mutable std::shared_ptr<distributed::RequestHandler> request_send_handler_;
mutable std::shared_ptr<distributed::RequestHandler> request_get_handler_;
mutable std::shared_ptr<std::thread> server_thread_;
mutable std::vector<std::string> sparse_vars_;
mutable std::vector<std::string> dense_vars_;
};
class FlSignalHandler {
public:
static void StopAndExit(int signal_num);
private:
DISABLE_COPY_AND_ASSIGN(FlSignalHandler);
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <ostream>
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
class GenNCCLIdOp : public framework::OperatorBase {
public:
GenNCCLIdOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
// put nccl id in CPUPlace
auto& dev_ctx = *pool.Get(platform::CPUPlace());
int trainer_id = Attr<int>("trainer_id");
std::vector<std::string> trainers =
Attr<std::vector<std::string>>("trainers");
PADDLE_ENFORCE_GE(trainer_id, 0, platform::errors::InvalidArgument(
"trainer_id %d is less than 0. Its "
"valid range is [0, trainer_size)"));
PADDLE_ENFORCE_LT(
trainer_id, static_cast<int>(trainers.size()),
platform::errors::OutOfRange("trainer_id %d is out of range. Its valid "
"range is [0, trainer_size)",
trainer_id));
std::string endpoint = trainers[trainer_id];
framework::Scope& local_scope = scope.NewScope();
int nccl_comm_num = Attr<int>("nccl_comm_num");
int use_hierarchical_allreduce = Attr<bool>("use_hierarchical_allreduce");
int inter_nranks = Attr<int>("hierarchical_allreduce_inter_nranks");
int inter_trainer_id = -1;
int exter_trainer_id = -1;
if (use_hierarchical_allreduce) {
PADDLE_ENFORCE_GT(
trainers.size(), 1,
platform::errors::PreconditionNotMet(
"The number of collective trainers %llu <= 1", trainers.size()));
PADDLE_ENFORCE_GT(
inter_nranks, 1,
platform::errors::PreconditionNotMet(
"inter_nranks %d <= 1 while in hierarchical allreduce mode",
inter_nranks));
PADDLE_ENFORCE_EQ(
trainers.size() % inter_nranks, 0,
platform::errors::PreconditionNotMet(
"The number of trainers %llu mod inter_nranks %d is not equal 0",
trainers.size(), inter_nranks));
inter_trainer_id = trainer_id % inter_nranks;
if (trainer_id % inter_nranks == 0) {
exter_trainer_id = trainer_id / inter_nranks;
}
}
if (trainer_id != 0) {
GetIdByServer(endpoint, &local_scope, dev_ctx, nccl_comm_num,
use_hierarchical_allreduce, trainer_id, inter_trainer_id,
exter_trainer_id);
}
std::ostringstream ss;
for (size_t i = 0; i < trainers.size(); i++) {
ss << trainers[i] << ",";
}
VLOG(1) << "trainer_id:" << trainer_id
<< ", use_hierarchical_allreduce:" << use_hierarchical_allreduce
<< ", inter_nranks:" << inter_nranks
<< ", inter_trainer_id:" << inter_trainer_id
<< ", exter_trainer_id:" << exter_trainer_id
<< ", trainers:" << ss.str();
// init flat
if (trainer_id == 0) {
std::vector<std::string> flat_endpoints;
flat_endpoints.insert(flat_endpoints.begin(), trainers.begin() + 1,
trainers.end());
// flat nccl_id
for (int i = 0; i < nccl_comm_num; i++) {
std::string var_name = platform::GetFlatNCCLVarName(i);
GenerateAndSend(&local_scope, dev_ctx, var_name, flat_endpoints);
}
}
if (!use_hierarchical_allreduce) {
return;
}
PADDLE_ENFORCE_EQ(
trainers.size() % inter_nranks, 0,
platform::errors::PreconditionNotMet(
"The number of trainers %llu mod inter_nranks %d is not equal 0",
trainers.size(), inter_nranks));
PADDLE_ENFORCE_GT(
inter_nranks, 1,
platform::errors::PreconditionNotMet(
"inter_nranks %d <= 1 while in hierarchical allreduce mode",
inter_nranks));
// hierarchical inter ncclid
if (inter_trainer_id == 0) {
std::ostringstream ss;
ss << endpoint;
std::vector<std::string> inter_endpoints;
for (int i = trainer_id + 1; i < trainer_id + inter_nranks &&
i < static_cast<int>(trainers.size());
i++) {
ss << ",";
inter_endpoints.push_back(trainers[i]);
ss << trainers[i];
}
VLOG(1) << "Hierarchical inter ring endpoints:" << ss.str();
for (int i = 0; i < nccl_comm_num; i++) {
std::string nccl_var_name =
platform::GetHierarchicalInterNCCLVarName(i);
GenerateAndSend(&local_scope, dev_ctx, nccl_var_name, inter_endpoints);
}
}
// hierarchical exter ncclid
if (exter_trainer_id == 0) {
std::ostringstream ss;
std::vector<std::string> exter_endpoints;
ss << endpoint;
for (size_t i = inter_nranks; i < trainers.size(); i += inter_nranks) {
ss << ",";
exter_endpoints.push_back(trainers[i]);
ss << trainers[i];
}
VLOG(1) << "Hierarchical exter ring endpoints:" << ss.str();
for (int i = 0; i < nccl_comm_num; i++) {
std::string nccl_var_name =
platform::GetHierarchicalExterNCCLVarName(i);
GenerateAndSend(&local_scope, dev_ctx, nccl_var_name, exter_endpoints);
}
}
}
private:
void GenerateAndSend(framework::Scope* scope,
const platform::DeviceContext& dev_ctx,
const std::string& nccl_id_name,
const std::vector<std::string>& endpoint_list) const {
auto var = scope->FindVar(nccl_id_name);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("Variable with name %s is not found",
nccl_id_name.c_str()));
auto id = var->GetMutable<ncclUniqueId>();
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGetUniqueId(id));
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
for (auto& ep : endpoint_list) {
VLOG(3) << "sending nccl_id_var:" << nccl_id_name << " to " << ep;
client->AsyncSendVar(ep, dev_ctx, *scope, nccl_id_name);
}
client->Wait();
for (auto& ep : endpoint_list) {
client->AsyncSendBatchBarrier(ep);
}
client->Wait();
VLOG(3) << "sending completed...";
}
void GetIdByServer(const std::string& endpoint, framework::Scope* scope,
const platform::DeviceContext& dev_ctx, int nccl_comm_num,
bool use_hierarchical_allreduce, int trainer_id,
int inter_trainer_id, int exter_trainer_id) const {
// std::string endpoint = Attr<std::string>("endpoint");
// NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash.
distributed::RequestSendHandler rpc_h(distributed::DistributedMode::kSync);
std::unique_ptr<distributed::RPCServer> rpc_service(
new RPCSERVER_T(endpoint, 1));
rpc_service->RegisterRPC(distributed::kRequestSend, &rpc_h);
rpc_h.SetRPCServer(rpc_service.get());
framework::ProgramDesc empty_program;
framework::Executor executor(dev_ctx.GetPlace());
rpc_h.SetScope(scope);
rpc_h.SetDevCtx(&dev_ctx);
rpc_h.SetProgram(&empty_program);
rpc_h.SetExecutor(&executor);
std::thread server_thread(
std::bind(&distributed::RPCServer::StartServer, rpc_service.get()));
for (int i = 0; i < nccl_comm_num; i++) {
rpc_service->SetCond(distributed::kRequestSend);
VLOG(3) << "trainer_id:" << trainer_id
<< " start getting nccl id from trainer 0, nccl_comm_no:" << i;
rpc_service->WaitBarrier(distributed::kRequestSend);
rpc_service->ResetBarrierCounter();
}
if (use_hierarchical_allreduce) {
if (inter_trainer_id > 0) {
for (int i = 0; i < nccl_comm_num; i++) {
rpc_service->SetCond(distributed::kRequestSend);
VLOG(3) << "trainer_id:" << trainer_id
<< ", inter_trainer_id:" << inter_trainer_id
<< " start getting nccl id from inter_trainer:" << i;
rpc_service->WaitBarrier(distributed::kRequestSend);
rpc_service->ResetBarrierCounter();
}
}
if (exter_trainer_id > 0) {
for (int i = 0; i < nccl_comm_num; i++) {
rpc_service->SetCond(distributed::kRequestSend);
VLOG(3)
<< "trainer_id:" << trainer_id
<< ", exter_trainer_id:" << exter_trainer_id
<< " start getting nccl id from exter_trainer 0, nccl_comm_no:"
<< i;
rpc_service->WaitBarrier(distributed::kRequestSend);
rpc_service->ResetBarrierCounter();
}
}
}
VLOG(3) << "traier_id:" << trainer_id
<< ", inter_trainer_id:" << inter_trainer_id
<< ", exter_trainer_id:" << exter_trainer_id
<< " got nccl id and stop server...";
rpc_service->ShutDown();
VLOG(3) << "rpc server stopped";
server_thread.join();
}
};
class GenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddOutput("NCCLID", "Raw variable contains a NCCL UniqueId instaces.");
AddComment(R"DOC(
GenNCCLId operator
For trainer 0: generate a new UniqueId and send it to all the other trainers.
For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server.
)DOC");
AddAttr<std::vector<std::string>>(
"trainers",
"['trainer0_ip:port', 'trainer1_ip:port', ...] "
"list of all trainer endpoints")
.SetDefault({});
AddAttr<int>("trainer_id",
"(int) "
"The index of the trainer in distributed training.");
AddAttr<int>("nccl_comm_num",
"(int default 1) "
"The number of nccl communicator num.")
.SetDefault(1);
AddAttr<bool>("use_hierarchical_allreduce",
"(bool default false) "
"Wheter to use hierarchical allreduce.")
.SetDefault(false);
AddAttr<int>("hierarchical_allreduce_inter_nranks",
"(int default 1) "
"Wheter to use hierarchical allreduce.")
.SetDefault(-1);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(gen_nccl_id, ops::GenNCCLIdOp, ops::GenNCCLIdOpMaker);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdio.h> // for removing the port file
#include <csignal>
#include <cstdlib>
#include <fstream>
#include <thread> // NOLINT
#include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_int32(rpc_send_thread_num, 12, "number of threads for rpc send");
DEFINE_int32(rpc_get_thread_num, 12, "number of threads for rpc get");
DEFINE_int32(rpc_prefetch_thread_num, 12, "number of threads for rpc prefetch");
namespace paddle {
namespace operators {
void RunServer(std::shared_ptr<distributed::RPCServer> service) {
service->StartServer();
VLOG(4) << "RunServer thread end";
}
static void split(const std::string &str, char sep,
std::vector<std::string> *pieces) {
pieces->clear();
if (str.empty()) {
return;
}
size_t pos = 0;
size_t next = str.find(sep, pos);
while (next != std::string::npos) {
pieces->push_back(str.substr(pos, next - pos));
pos = next + 1;
next = str.find(sep, pos);
}
if (!str.substr(pos).empty()) {
pieces->push_back(str.substr(pos));
}
}
static void ParallelExecuteBlocks(
const std::vector<size_t> &parallel_blkids, framework::Executor *executor,
const std::vector<std::shared_ptr<framework::ExecutorPrepareContext>>
&prepared,
framework::ProgramDesc *program, framework::Scope *scope) {
std::vector<std::future<void>> fs;
for (size_t idx : parallel_blkids) {
fs.push_back(framework::Async([&executor, &prepared, &scope, idx]() {
int run_block = idx; // thread local
try {
VLOG(3) << "running server block: " << run_block
<< "pointer: " << prepared[run_block].get();
executor->RunPreparedContext(prepared[run_block].get(), scope);
} catch (const std::exception &e) {
PADDLE_THROW(platform::errors::Fatal(
"Run %d-th sub program failed. The exception is:\n%s.", idx,
e.what()));
}
}));
}
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
}
ListenAndServOp::ListenAndServOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
ListenAndServOp::~ListenAndServOp() { Stop(); }
void ListenAndServOp::Stop() {
rpc_service_->ShutDown();
server_thread_->join();
auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid());
remove(file_path.c_str());
}
void ListenAndServOp::SavePort() const {
// NOTE: default write file to /tmp/paddle.selected_port
rpc_service_->SavePort();
}
static int64_t GetTimestamp() {
struct timeval tp;
gettimeofday(&tp, NULL);
return tp.tv_sec * 1000 + tp.tv_usec / 1000;
}
// For sync, sparse variables need recover grad type from LodTensor to
// SelectedRows
void ResetSparseVarsType(framework::Scope *recv_scope) {
auto *ins = distributed::LargeScaleKV::GetInstance();
auto grads = ins->GetAllGrads();
for (auto &grad : grads) {
auto *v = recv_scope->FindVar(grad);
v->Clear();
v->GetMutable<framework::SelectedRows>();
}
}
void ListenAndServOp::RunSyncLoop(
framework::Executor *executor, framework::ProgramDesc *program,
framework::Scope *recv_scope, platform::DeviceContext *dev_ctx,
const std::vector<int> &prefetch_block_id_list,
const int checkpoint_point_block_id) const {
VLOG(2) << "RunSyncLoop";
size_t num_blocks = program->Size();
auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
PADDLE_ENFORCE_GE(num_blocks, 2,
platform::errors::PreconditionNotMet(
"Invalid number of blocks in server program. Expected "
"equal or greater than 2. Recieved %zu",
num_blocks));
// Prepare all the server block
std::vector<int> optimize_blocks_list;
for (size_t i = 1; i < program->Size(); ++i) {
optimize_blocks_list.push_back(i);
}
auto optimize_prepared = executor->Prepare(*program, optimize_blocks_list);
// Insert placeholder for block0 which holds current op itself,
// NOTE the first block in `optimize_prepared` should never be ran.
optimize_prepared.insert(
optimize_prepared.begin(),
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));
// Trainers will get all parameters from pserver in the
// startup program, so we will wait RequestGet first
rpc_service_->SetCond(distributed::kRequestGet);
rpc_service_->WaitBarrier(distributed::kRequestGet);
rpc_service_->ResetBarrierCounter();
while (true) {
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
VLOG(3) << "wait all clients to send gradient";
rpc_service_->SetCond(distributed::kRequestSend);
VLOG(3) << "wait all clients to send send_barrier";
rpc_service_->WaitBarrier(distributed::kRequestSend);
if (rpc_service_->IsExit()) {
LOG(WARNING) << "get exit!rpc_processor break!";
rpc_service_->SetCond(distributed::kRequestGet);
break;
}
// NOTE: if is_gpu_place, CUDA kernels are launched by multiple threads
// and this will still work.
// The optimize blocks which have the same parent ID would run parallel
// TODO(Yancey1989): need to use ParallelExecutor for future
int32_t last_parent_blkid = optimize_blocks[0]->Parent();
std::vector<size_t> parallel_blkids;
parallel_blkids.push_back(optimize_blocks[0]->ID());
double ts = GetTimestamp();
for (size_t i = 1; i < optimize_blocks.size(); ++i) {
// skip the first optimize block because it is already in the
// parallel_blkids.
int blkid = optimize_blocks[i]->ID();
if (program->Block(blkid).Parent() != last_parent_blkid) {
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared,
program, recv_scope);
parallel_blkids.clear();
last_parent_blkid = program->Block(blkid).Parent();
}
parallel_blkids.push_back(blkid);
}
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, program,
recv_scope);
VLOG(3) << "run all blocks spent " << GetTimestamp() - ts << "(ms)";
VLOG(3) << "ResetReceivedVars";
ResetReceivedVars(recv_scope, dev_ctx, rpc_service_->NeedResetAllVars());
ResetSparseVarsType(recv_scope);
VLOG(3) << "wait all clients to get parameters back";
rpc_service_->SetCond(distributed::kRequestGet);
VLOG(3) << "wait all clients to send fetch_barrier";
rpc_service_->WaitBarrier(distributed::kRequestGet);
VLOG(3) << "ResetBarrierCounter";
rpc_service_->ResetBarrierCounter();
} // while(true)
}
void ListenAndServOp::ResetReceivedVars(framework::Scope *recv_scope,
platform::DeviceContext *dev_ctx,
bool reset_all) const {
for (auto &varname : sparse_vars_) {
auto var = recv_scope->FindVar(varname);
if (var == nullptr) {
VLOG(2) << "can not find var " << varname << " in received scope";
continue;
}
if (var->IsType<framework::SelectedRows>()) {
VLOG(3) << "reset sparse var: " << varname;
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
} else {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"The type of sparse var should be SelectedRows"));
}
}
if (UNLIKELY(reset_all)) {
for (auto &varname : dense_vars_) {
auto var = recv_scope->FindVar(varname);
if (var == nullptr) {
VLOG(2) << "can not find var " << varname << " in received scope";
continue;
}
if (var->IsType<framework::LoDTensor>()) {
math::set_constant(*dev_ctx, var->GetMutable<framework::LoDTensor>(),
static_cast<float>(0));
} else if (var->IsType<framework::Tensor>()) {
math::set_constant(*dev_ctx, var->GetMutable<framework::Tensor>(),
static_cast<float>(0));
} else {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"The type of dense var should be in [LoDTensor, Tensor]"));
}
}
}
}
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
framework::ProgramDesc *program,
framework::Scope *recv_scope) const {
VLOG(2) << "RunAsyncLoop";
auto grad_to_block_id_str =
Attr<std::vector<std::string>>("grad_to_block_id");
DoubleFindMap<std::string, int32_t> grad_to_block_id;
auto append_block_maps = [](DoubleFindMap<std::string, int32_t> *out_map,
const std::string &grad_and_id) {
std::vector<std::string> pieces;
split(grad_and_id, ':', &pieces);
VLOG(3) << "after split, key = " << pieces[0] << ", id=" << pieces[1];
PADDLE_ENFORCE_EQ(pieces.size(), 2,
platform::errors::PreconditionNotMet(
"Invalid format of grad_and_id argument. "
"Expected \"grad:block_id\". Recieved %s",
grad_and_id.c_str()));
PADDLE_ENFORCE_EQ(out_map->count(pieces[0]), 0,
platform::errors::AlreadyExists(
"The gradient name %s has already existed in out_map",
pieces[0].c_str()));
int block_id = std::stoi(pieces[1]);
(*out_map)[pieces[0]] = block_id;
};
for (const auto &grad_and_id : grad_to_block_id_str) {
append_block_maps(&grad_to_block_id, grad_and_id);
}
size_t num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2,
platform::errors::PreconditionNotMet(
"Invalid number of blocks in server program. Expected "
"equal or greater than 2. Recieved %zu",
num_blocks));
std::vector<int> block_list;
for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
block_list.push_back(blkid);
}
auto optimize_prepared = executor->Prepare(*program, block_list);
// execute global block if needed, block id 1 in the program is global
// block if it's not bind to a grad var for it's update.
if (block_list[0] == 1 &&
grad_to_block_id.find_value(static_cast<int32_t>(1)) ==
grad_to_block_id.end()) {
executor->RunPreparedContext(optimize_prepared[0].get(), recv_scope);
}
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
grad_to_prepared_ctx, param_to_prepared_ctx;
for (size_t i = 0; i < block_list.size(); ++i) {
auto blkid = block_list[i];
auto it = grad_to_block_id.find_value(blkid);
if (it != grad_to_block_id.end()) {
grad_to_prepared_ctx[it->first] = optimize_prepared[i];
}
}
request_send_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
request_get_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
request_prefetch_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
request_send_and_recv_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
while (true) {
if (rpc_service_->IsExit()) {
VLOG(4) << "get exit!rpc_processor break!";
break;
}
sleep(1);
} // while(true)
}
static void FillRequestCtx(
distributed::RequestHandler *h, framework::Scope *scope,
platform::DeviceContext *dev_ctx, framework::Executor *executor,
framework::ProgramDesc *program,
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
*prefetch_ctx,
std::unordered_map<std::string, std::string>
*sparse_grad_name_to_param_name,
std::shared_ptr<framework::ExecutorPrepareContext> checkpoint_ctx,
std::shared_ptr<framework::ExecutorPrepareContext> lr_decay_ctx,
distributed::RPCServer *rpc_server) {
h->SetScope(scope);
h->SetDevCtx(dev_ctx);
h->SetExecutor(executor);
h->SetProgram(program);
h->SetPrefetchPreparedCtx(prefetch_ctx);
h->SetSparseGradToParam(sparse_grad_name_to_param_name);
h->SetRPCServer(rpc_server);
h->SetCheckpointNotifyPreparedCtx(checkpoint_ctx);
h->SetLrDecayPreparedCtx(lr_decay_ctx);
}
void ListenAndServOp::CacheVarsType(const std::vector<std::string> &varnames,
const framework::Scope &scope) const {
for (const auto &varname : varnames) {
auto var = scope.FindVar(varname);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::PreconditionNotMet(
"Received var is not initialized in the received scope."));
if (var->IsType<framework::SelectedRows>()) {
sparse_vars_.push_back(varname);
} else if (var->IsType<framework::LoDTensor>() ||
var->IsType<framework::Tensor>()) {
dense_vars_.push_back(varname);
} else {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"The type of received var should be in [SelectedRows, LoDTensor, "
"Tensor]."));
}
}
}
void ListenAndServOp::RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const {
// Mark this as PS that it should decide profiling by listening from trainer.
platform::SetProfileListener();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
framework::Scope &recv_scope = scope.NewScope();
int distributed_mode = Attr<int>("distributed_mode");
bool dc_sgd = Attr<bool>("dc_asgd");
auto fan_in = Attr<int>("Fanin");
auto pserver_id = Attr<int>("pserver_id");
auto inputs = Inputs("X");
PADDLE_ENFORCE_EQ(rpc_service_, nullptr,
platform::errors::PreconditionNotMet(
"RPC service has been created unexpectedly."));
std::string endpoint = Attr<std::string>("endpoint");
int checkpoint_block_id = Attr<int>(kCheckpointBlockId);
int lr_decay_block_id = Attr<int>(kLRDecayBlockId);
VLOG(4) << "pserver_id: " << pserver_id
<< ", distributed_mode:" << distributed_mode << ", fan_in:" << fan_in
<< ", end_point:" << endpoint
<< ", checkpoint_block_id: " << checkpoint_block_id
<< ", lr_decay_block_id: " << lr_decay_block_id;
rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in));
auto rpc_get_thread_num = Attr<int>("rpc_get_thread_num");
auto rpc_send_thread_num = Attr<int>("rpc_send_thread_num");
auto rpc_prefetch_thread_num = Attr<int>("rpc_prefetch_thread_num");
request_send_handler_.reset(
new distributed::RequestSendHandler(distributed_mode, dc_sgd));
request_get_handler_.reset(
new distributed::RequestGetHandler(distributed_mode, dc_sgd));
request_prefetch_handler_.reset(
new distributed::RequestPrefetchHandler(distributed_mode));
request_checkpoint_handler_.reset(
new distributed::RequestCheckpointHandler(distributed_mode));
request_get_no_barrier_handler_.reset(
new distributed::RequestGetNoBarrierHandler());
request_notify_handler_.reset(
new distributed::RequestNotifyHandler(distributed_mode, fan_in));
request_send_and_recv_handler_.reset(
new distributed::RequestSendAndRecvHandler(distributed_mode));
rpc_service_->RegisterRPC(distributed::kRequestSend,
request_send_handler_.get(), rpc_send_thread_num);
rpc_service_->RegisterRPC(distributed::kRequestGet,
request_get_handler_.get(), rpc_get_thread_num);
rpc_service_->RegisterRPC(distributed::kRequestPrefetch,
request_prefetch_handler_.get(),
rpc_prefetch_thread_num);
rpc_service_->RegisterRPC(distributed::kRequestCheckpoint,
request_checkpoint_handler_.get());
rpc_service_->RegisterRPC(distributed::kRequestGetNoBarrier,
request_get_no_barrier_handler_.get());
rpc_service_->RegisterRPC(distributed::kRequestNotify,
request_notify_handler_.get(), rpc_send_thread_num);
rpc_service_->RegisterRPC(distributed::kRequestSendAndRecv,
request_send_and_recv_handler_.get(),
rpc_get_thread_num);
auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
PADDLE_ENFORCE_GE(optimize_blocks.size(), 1,
platform::errors::PreconditionNotMet(
"optimize blocks is less than 1. Optimize blocks "
"should be 1 at least on the pserver side."));
auto *program = optimize_blocks[0]->Program();
framework::Executor executor(dev_place);
std::shared_ptr<framework::ExecutorPrepareContext> ckpt_pre_context = nullptr;
if (checkpoint_block_id != -1) {
auto ctx = executor.Prepare(*program, checkpoint_block_id);
// see: https://stackoverflow.com/a/14856553
ckpt_pre_context = std::move(ctx);
}
std::shared_ptr<framework::ExecutorPrepareContext> lr_decay_context = nullptr;
if (lr_decay_block_id != -1) {
auto ctx = executor.Prepare(*program, lr_decay_block_id);
// see: https://stackoverflow.com/a/14856553
lr_decay_context = std::move(ctx);
}
// prepare for prefetch
std::vector<int> prefetch_block_id_list;
std::unordered_map<int, std::string> block_id_to_prefetch_var_name;
auto prefetch_var_name_to_block_id_str =
Attr<std::vector<std::string>>(kPrefetchVarNameToBlockId);
for (const auto &prefetch_var_name_and_id :
prefetch_var_name_to_block_id_str) {
std::vector<std::string> pieces;
split(prefetch_var_name_and_id, ':', &pieces);
VLOG(3) << "after split, prefetch_var = " << pieces[0]
<< ", id=" << pieces[1];
PADDLE_ENFORCE_EQ(
pieces.size(), 2,
platform::errors::PreconditionNotMet(
"Invalid format of prefetch_var_name_and_id argument. "
"Expected \"xxx:xxx\". Recieved %s",
prefetch_var_name_and_id.c_str()));
int block_id = std::stoi(pieces[1]);
prefetch_block_id_list.push_back(block_id);
block_id_to_prefetch_var_name[block_id] = pieces[0];
}
auto prefetch_prepared = executor.Prepare(*program, prefetch_block_id_list);
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
prefetch_var_name_to_prepared_ctx;
for (size_t i = 0; i < prefetch_block_id_list.size(); ++i) {
auto block_id = prefetch_block_id_list[i];
auto prefetch_var_name = block_id_to_prefetch_var_name[block_id];
prefetch_var_name_to_prepared_ctx[prefetch_var_name] = prefetch_prepared[i];
}
// parse attr of kSparseGradToParam sparse_grad_name -> param_name
std::unordered_map<std::string, std::string> sparse_grad_name_to_param_name;
auto sparse_grad_name_to_param_name_str =
Attr<std::vector<std::string>>(kSparseGradToParam);
for (const auto &sparse_grad_name_and_param_name :
sparse_grad_name_to_param_name_str) {
std::vector<std::string> pieces;
split(sparse_grad_name_and_param_name, ':', &pieces);
PADDLE_ENFORCE_EQ(
pieces.size(), 2,
platform::errors::PreconditionNotMet(
"Invalid format of sparse_grad_name_and_param_name argument. "
"Expected \"xxx:xxx\". Recieved %s",
sparse_grad_name_and_param_name.c_str()));
VLOG(3) << "after split, sparse_grad_name = " << pieces[0]
<< ", param_name = " << pieces[1];
sparse_grad_name_to_param_name[pieces[0]] = pieces[1];
}
auto f =
std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, &dev_ctx,
&executor, program, &prefetch_var_name_to_prepared_ctx,
&sparse_grad_name_to_param_name, ckpt_pre_context,
lr_decay_context, rpc_service_.get());
f(request_send_handler_.get());
f(request_get_handler_.get());
f(request_prefetch_handler_.get());
f(request_checkpoint_handler_.get());
f(request_get_no_barrier_handler_.get());
f(request_notify_handler_.get());
f(request_send_and_recv_handler_.get());
// register SIGINT(from ctrl+C) and SIGTERM(from kill) signal handlers
signal(SIGINT, SignalHandler::StopAndExit);
signal(SIGTERM, SignalHandler::StopAndExit);
if (distributed_mode == distributed::DistributedMode::kSync) {
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_));
VLOG(3) << "wait server thread to become ready...";
rpc_service_->WaitServerReady();
CacheVarsType(inputs, recv_scope);
// Write to a file of server selected port for python use.
SavePort();
RunSyncLoop(&executor, program, &recv_scope, &dev_ctx,
prefetch_block_id_list, checkpoint_block_id);
} else {
if (distributed_mode == distributed::DistributedMode::kGeo) {
distributed::AsyncSparseParamUpdateRecorder::Init(
fan_in, sparse_grad_name_to_param_name);
}
VLOG(2) << "RunAsyncLoop";
auto grad_to_block_id_str =
Attr<std::vector<std::string>>("grad_to_block_id");
if (grad_to_block_id_str.size() == 0) {
VLOG(0) << "there are no gradients on this parameter server";
} else {
std::vector<std::string> pieces;
split(grad_to_block_id_str[0], ':', &pieces);
distributed::HeartBeatMonitor::Init(fan_in, pserver_id == 0, pieces[0]);
}
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_));
VLOG(3) << "wait server thread to become ready...";
rpc_service_->WaitServerReady();
// Write to a file of server selected port for python use.
SavePort();
RunAsyncLoop(&executor, program, &recv_scope);
}
}
class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) Variables that server recv.").AsDuplicable();
AddComment(R"DOC(" + "ListenAndServ operator" + "\n" + "This operator" +
" will start a RPC server which can receive variables from send_op and send" +
"back variables to recv_op.)DOC");
AddAttr<std::string>("endpoint",
"(string, default 127.0.0.1:6164)"
"IP address to listen on.")
.SetDefault("127.0.0.1:6164")
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<int>("pserver_id",
"(int, default -1), the parameter server index id")
.SetDefault(-1);
AddAttr<std::vector<std::string>>(
"grad_to_block_id",
"['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'] "
"a map from grad name to it's optimize block id")
.SetDefault({});
AddAttr<int>("distributed_mode",
"indicate distriubte training mode, 0 is sync, 1 is "
"fully-async, 2 is half-async, 3 is geo")
.SetDefault(0);
AddAttr<bool>("dc_asgd", "set to true will enable DC-ASGD training.")
.SetDefault(false);
AddAttr<std::vector<framework::BlockDesc *>>(
kOptimizeBlocks, "Optimize blocks to run on server side.")
.SetDefault({});
AddAttr<std::vector<std::string>>(kPrefetchVarNameToBlockId,
"prefetch blocks to run on server side.")
.SetDefault({});
AddAttr<std::vector<std::string>>(
kSparseGradToParam,
"sparse grad name to param name. like: 'emb@Grad:emb'")
.SetDefault({});
AddAttr<int>("Fanin", "How many clients send to this server.")
.SetDefault(1);
AddAttr<int>(kCheckpointBlockId,
"BolckID to run save checkpoint on pserer.")
.SetDefault(-1);
AddAttr<int>(kLRDecayBlockId, "BolckID to run lr decay on pserer.")
.SetDefault(-1);
AddAttr<int>("rpc_get_thread_num", "pserver get thread num.").SetDefault(1);
AddAttr<int>("rpc_send_thread_num", "pserver send thread num.")
.SetDefault(1);
AddAttr<int>("rpc_prefetch_thread_num", "pserver prefetch thread num.")
.SetDefault(1);
}
};
void SignalHandler::StopAndExit(int signal_num) {
// Do not use VLOG here for the device for printing maybe already released.
// exit will release interal allocated resoureces.
auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid());
remove(file_path.c_str());
exit(0);
}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(listen_and_serv, ops::ListenAndServOp,
ops::ListenAndServOpMaker);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdint.h>
#include <atomic>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace framework {
class Executor;
class ProgramDesc;
class Scope;
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
namespace distributed {
class RPCServer;
class RequestHandler;
} // namespace distributed
constexpr char kOptimizeBlocks[] = "optimize_blocks";
constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id";
constexpr char kCheckpointBlockId[] = "checkpint_block_id";
constexpr char kLRDecayBlockId[] = "lr_decay_block_id";
constexpr char kSparseGradToParam[] = "sparse_grad_to_param";
void RunServer(std::shared_ptr<distributed::RPCServer> service);
template <class TKey, class TValue>
class DoubleFindMap : public std::unordered_map<TKey, TValue> {
public:
typename std::unordered_map<TKey, TValue>::iterator find_value(TValue v) {
return std::find_if(this->begin(), this->end(),
[&v](const std::pair<const std::string, int> p) {
return p.second == v;
});
}
};
class ListenAndServOp : public framework::OperatorBase {
public:
ListenAndServOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs);
virtual ~ListenAndServOp();
void RunSyncLoop(framework::Executor* executor,
framework::ProgramDesc* program,
framework::Scope* recv_scope,
platform::DeviceContext* dev_ctx,
const std::vector<int>& prefetch_block_id_list,
const int checkpoint_point_block_id) const;
void RunAsyncLoop(framework::Executor* executor,
framework::ProgramDesc* program,
framework::Scope* recv_scope) const;
void SavePort() const;
int GetSelectedPort() { return rpc_service_->GetSelectedPort(); }
void Stop() override;
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override;
void ResetReceivedVars(framework::Scope* recv_scope,
platform::DeviceContext* dev_ctx,
bool reset_all = false) const;
void CacheVarsType(const std::vector<std::string>& varnames,
const framework::Scope& scope) const;
protected:
mutable std::shared_ptr<distributed::RPCServer> rpc_service_;
mutable std::shared_ptr<distributed::RequestHandler> request_send_handler_;
mutable std::shared_ptr<distributed::RequestHandler> request_get_handler_;
mutable std::shared_ptr<distributed::RequestHandler>
request_get_no_barrier_handler_;
mutable std::shared_ptr<distributed::RequestHandler>
request_prefetch_handler_;
mutable std::shared_ptr<distributed::RequestHandler>
request_checkpoint_handler_;
mutable std::shared_ptr<distributed::RequestHandler> request_notify_handler_;
mutable std::shared_ptr<distributed::RequestHandler>
request_send_and_recv_handler_;
mutable std::shared_ptr<std::thread> server_thread_;
mutable std::vector<std::string> sparse_vars_;
mutable std::vector<std::string> dense_vars_;
};
class SignalHandler {
public:
static void StopAndExit(int signal_num);
private:
DISABLE_COPY_AND_ASSIGN(SignalHandler);
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_adam_op.h"
#include <string>
namespace paddle {
namespace operators {
class LargeScaleFuseAdamOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(
ctx->HasInput("Grad"),
platform::errors::InvalidArgument(
"Input(Grad) of LargeScaleFuseAdamOp should not be null."));
PADDLE_ENFORCE(
ctx->HasInput("LearningRate"),
platform::errors::InvalidArgument(
"Input(LearningRate) of LargeScaleFuseAdamOp should not be null."));
auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_NE(framework::product(lr_dims), 0,
platform::errors::InvalidArgument(
"Maybe the Input variable LearningRate has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function."));
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
platform::errors::InvalidArgument(
"Learning rate should have 1 element"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Grad");
return framework::OpKernelType(data_type, ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const {
if (var_name == "LearningRate") {
return framework::OpKernelType(tensor.type(), tensor.place(),
tensor.layout());
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
class LargeScaleFuseAdamOpInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto in_var_type = ctx->GetInputType("Grad");
PADDLE_ENFORCE_EQ(in_var_type == framework::proto::VarType::SELECTED_ROWS ||
in_var_type == framework::proto::VarType::LOD_TENSOR,
true, platform::errors::InvalidArgument(
"The input Var's type should be LoDtensor or "
"SelectedRows, but the received type is %s",
in_var_type));
}
};
class LargeScaleFuseAdamOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Grad",
"(SelectedRows) Ids's type should be SelectedRows"
"THe ids to be looked up in W.");
AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator");
AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator");
AddInput("LearningRate", "(Tensor) Learning rate of SGD");
AddOutput("Beta1PowOut", "(Tensor) Output beta1 power accumulator");
AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator");
AddAttr<float>("beta1",
"(float, default 0.9) "
"Exponential decay rate for the "
"first moment estimates.")
.SetDefault(0.9f);
AddAttr<float>("beta2",
"(float, default 0.999) "
"exponential decay rate for the "
"second moment estimates.")
.SetDefault(0.999f);
AddAttr<float>("epsilon",
"(float, default 1.0e-8) "
"Constant for numerical stability")
.SetDefault(1.0e-8f);
AddAttr<bool>("is_entry",
"(bool)"
"sparse table need entry");
AddAttr<std::string>("tablename",
"(string)"
"sparse table name");
AddAttr<std::vector<std::string>>("value_names",
"(strings)"
"sparse table name");
AddComment(R"DOC(
Adam Optimizer.
This implements the Adam optimizer from Section 2 of the Adam
paper : https://arxiv.org/abs/1412.6980.
Adam is a first-order gradient-based optimization method based on
adaptive estimates of lower-order moments.
Adam updates:
$$
moment\_1\_out = \beta_1 * moment\_1 + (1 - \beta_1) * grad \\
moment\_2_\out = \beta_2 * moment\_2 + (1 - \beta_2) * grad * grad \\
learning\_rate = learning\_rate *
\frac{\sqrt{1 - \beta_{2\_pow}}}{1 - \beta_{1\_pow}} \\
param\_out = param - learning\_rate * \frac{moment\_1}{\sqrt{moment\_2} + \epsilon}
$$
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
lookup_sparse_table_fuse_adam, ops::LargeScaleFuseAdamOp,
ops::LargeScaleFuseAdamOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::LargeScaleFuseAdamOpInferVarType);
REGISTER_OP_CPU_KERNEL(
lookup_sparse_table_fuse_adam,
ops::LargeScaleFuseAdamOpKernel<paddle::platform::CPUDeviceContext, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <math.h> // for sqrt in CPU and CUDA
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class LargeScaleFuseAdamOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override;
};
template <typename T>
class LargeScaleFuseAdamOpKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
using paddle::framework::LoDTensor;
const auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");
const auto *grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE(
grad_var->IsType<framework::SelectedRows>(),
platform::errors::InvalidArgument(
"in large scale optimize, gradient should only be SelectedRows"));
const auto &grad = grad_var->Get<framework::SelectedRows>();
// for distributed training, a sparse var may be empty,
// just skip updating.
if (grad.rows().size() == 0) {
return;
}
framework::SelectedRows tmp_grad_merge;
const framework::SelectedRows *grad_merge_ptr;
math::scatter::MergeAdd<platform::CPUDeviceContext, T> merge_func;
merge_func(ctx.template device_context<platform::CPUDeviceContext>(), grad,
&tmp_grad_merge, true);
grad_merge_ptr = &tmp_grad_merge;
std::vector<int64_t> in_rows;
in_rows.reserve(grad_merge_ptr->rows().size());
std::copy(grad_merge_ptr->rows().begin(), grad_merge_ptr->rows().end(),
std::back_inserter(in_rows));
const auto *lr = learning_rate->data<T>();
auto grad_v = grad_merge_ptr->value();
auto grad_width = grad_v.dims()[1];
// auto is_entry = context.Attr<bool>("is_entry");
auto tablename = ctx.Attr<std::string>("tablename");
auto value_names = ctx.Attr<std::vector<std::string>>("value_names");
auto *beta1_pow = ctx.Input<LoDTensor>("Beta1Pow");
auto *beta2_pow = ctx.Input<LoDTensor>("Beta2Pow");
auto *beta1_pow_out = ctx.Output<LoDTensor>("Beta1PowOut");
auto *beta2_pow_out = ctx.Output<LoDTensor>("Beta2PowOut");
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
PADDLE_ENFORCE_EQ(beta1_pow_out->numel(), 1,
platform::errors::InvalidArgument(
"beta1 pow output size should be 1, but received "
"value is:%d.",
beta1_pow_out->numel()));
PADDLE_ENFORCE_EQ(beta2_pow_out->numel(), 1,
platform::errors::InvalidArgument(
"beta2 pow output size should be 1, but received "
"value is:%d.",
beta2_pow_out->numel()));
// update beta1 and beta2
beta1_pow_out->mutable_data<T>(ctx.GetPlace())[0] =
beta1 * beta1_pow->data<T>()[0];
beta2_pow_out->mutable_data<T>(ctx.GetPlace())[0] =
beta2 * beta2_pow->data<T>()[0];
std::vector<std::vector<std::vector<float> *>> values;
std::vector<int64_t> dims;
auto *ins = distributed::LargeScaleKV::GetInstance();
auto *table = ins->Get(tablename);
table->Get(in_rows, value_names, &values);
table->Dims({"Param"}, &dims);
PADDLE_ENFORCE_EQ(dims[0], grad_width,
platform::errors::InvalidArgument(
"param_row should have the same size with grad_row"));
T lr_ = lr[0];
T beta1_pow_ = beta1_pow->data<T>()[0];
T beta2_pow_ = beta2_pow->data<T>()[0];
lr_ *= sqrt(1 - beta2_pow_) / (1 - beta1_pow_);
for (size_t i = 0; i < in_rows.size(); i++) {
auto &params = values[i][0];
auto &moment_1 = values[i][1];
auto &moment_2 = values[i][2];
auto *p_data = params->data();
auto *m1_data = moment_1->data();
auto *m2_data = moment_2->data();
for (int x = 0; x < grad_width; ++x) {
auto g = grad_v.data<T>()[grad_width * i + x];
m1_data[x] = beta1 * m1_data[x] + (1 - beta1) * g;
m2_data[x] = beta2 * m2_data[x] + (1 - beta2) * g * g;
p_data[x] -= lr_ * (m1_data[x] / (sqrt(m2_data[x]) + epsilon));
}
}
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_sgd_op.h"
#include <string>
namespace paddle {
namespace operators {
class LargeScaleFuseSGDOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(
ctx->HasInput("Grad"),
platform::errors::InvalidArgument(
"Input(Grad) of LargeScaleFuseSGDOp should not be null."));
PADDLE_ENFORCE(
ctx->HasInput("LearningRate"),
platform::errors::InvalidArgument(
"Input(LearningRate) of LargeScaleFuseSGDOp should not be null."));
auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_NE(framework::product(lr_dims), 0,
platform::errors::InvalidArgument(
"Maybe the Input variable LearningRate has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function."));
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
platform::errors::InvalidArgument(
"Learning rate should have 1 element"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Grad");
return framework::OpKernelType(data_type, ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const {
if (var_name == "LearningRate") {
return framework::OpKernelType(tensor.type(), tensor.place(),
tensor.layout());
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
class LargeScaleFuseSGDOpInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto in_var_type = ctx->GetInputType("Grad");
PADDLE_ENFORCE_EQ(in_var_type == framework::proto::VarType::SELECTED_ROWS ||
in_var_type == framework::proto::VarType::LOD_TENSOR,
true, platform::errors::InvalidArgument(
"The input Var's type should be LoDtensor or "
"SelectedRows, but the received type is %s",
in_var_type));
}
};
class LargeScaleFuseSGDOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Grad",
"(SelectedRows) Ids's type should be SelectedRows"
"THe ids to be looked up in W.");
AddInput("LearningRate", "(Tensor) Learning rate of SGD");
AddAttr<bool>("is_entry",
"(bool)"
"sparse table need entry");
AddAttr<std::string>("tablename",
"(string)"
"sparse table name");
AddAttr<std::vector<std::string>>("value_names",
"(strings)"
"sparse table name");
AddComment(R"DOC(
LargeScaleFuseSGD operator
This operator implements one step of the stochastic gradient descent algorithm.
$$param\_out = param - learning\_rate * grad$$
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
lookup_sparse_table_fuse_sgd, ops::LargeScaleFuseSGDOp,
ops::LargeScaleFuseSGDOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::LargeScaleFuseSGDOpInferVarType);
REGISTER_OP_CPU_KERNEL(
lookup_sparse_table_fuse_sgd,
ops::LargeScaleFuseSGDOpKernel<paddle::platform::CPUDeviceContext, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class LargeScaleFuseSGDOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override;
};
template <typename T>
class LargeScaleFuseSGDOpKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");
const auto *grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE(
grad_var->IsType<framework::SelectedRows>(),
platform::errors::InvalidArgument(
"in large scale optimize, gradient should only be SelectedRows"));
const auto &grad = grad_var->Get<framework::SelectedRows>();
// for distributed training, a sparse var may be empty,
// just skip updating.
if (grad.rows().size() == 0) {
return;
}
framework::SelectedRows tmp_grad_merge;
const framework::SelectedRows *grad_merge_ptr;
math::scatter::MergeAdd<platform::CPUDeviceContext, T> merge_func;
merge_func(ctx.template device_context<platform::CPUDeviceContext>(), grad,
&tmp_grad_merge, true);
grad_merge_ptr = &tmp_grad_merge;
std::vector<int64_t> in_rows;
in_rows.reserve(grad_merge_ptr->rows().size());
std::copy(grad_merge_ptr->rows().begin(), grad_merge_ptr->rows().end(),
std::back_inserter(in_rows));
const auto *lr = learning_rate->data<T>();
auto grad_v = grad_merge_ptr->value();
auto grad_width = grad_v.dims()[1];
// auto is_entry = context.Attr<bool>("is_entry");
auto tablename = ctx.Attr<std::string>("tablename");
auto value_names = ctx.Attr<std::vector<std::string>>("value_names");
std::vector<std::vector<std::vector<float> *>> values;
std::vector<int64_t> dims;
auto *ins = distributed::LargeScaleKV::GetInstance();
auto *table = ins->Get(tablename);
table->Get(in_rows, value_names, &values);
table->Dims({"Param"}, &dims);
PADDLE_ENFORCE_EQ(dims[0], grad_width,
platform::errors::InvalidArgument(
"param_row should have the same size with grad_row"));
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
std::vector<T> grads;
framework::TensorToVector(grad_v, ctx.device_context(), &grads);
blas.SCAL(grads.size(), lr[0], grads.data());
for (int x = 0; x < static_cast<int>(in_rows.size()); ++x) {
auto &params = values[x][0];
blas.VSUB(grad_width, params->data(), grads.data() + grad_width * x,
params->data());
}
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/distributed_ops/lookup_sparse_table_grad_split_op.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
class LookupSparseTableGradSplitOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {}
};
class LookupSparseTableGradSplitOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Grad",
"(SelectedRows) Ids's type should be SelectedRows"
"THe ids to be looked up in W.");
AddAttr<bool>("is_entry",
"(bool)"
"sparse table need entry");
AddAttr<std::string>("tablename",
"(string)"
"sparse table name");
AddOutput("Row",
"(LoDTensor) The lookup results, which have the "
"same type as W.");
AddOutput("Value",
"(LoDTensor) The lookup results, which have the "
"same type as W.");
AddComment(R"DOC(
Lookup Sprase Tablel Operator.
This operator is used to perform lookup on parameter W,
then concatenated into a sparse tensor.
The type of Ids(Input) is SelectedRows, the rows of Ids contains
the ids to be looked up in W;
if the Id is not in the sparse table, this operator will return a
random value and set the value into the table for the next looking up.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
lookup_sparse_table_grad_split, ops::LookupSparseTableGradSplitOp,
ops::LookupSparseTableGradSplitOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
lookup_sparse_table_grad_split,
ops::LookupSparseTableGradSplitKernel<paddle::platform::CPUDeviceContext,
float>,
ops::LookupSparseTableGradSplitKernel<paddle::platform::CPUDeviceContext,
double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <iostream>
#include <iterator>
#include <random>
#include <sstream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using SelectedRows = framework::SelectedRows;
template <typename DeviceContext, typename T>
class LookupSparseTableGradSplitKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const SelectedRows* in_grad = context.Input<SelectedRows>("Grad");
// merge duplicated rows if any.
// The rows of grad_merge_ptr have been sorted inside MergeAdd functor
framework::SelectedRows tmp_grad_merge;
const framework::SelectedRows* grad_merge_ptr;
math::scatter::MergeAdd<DeviceContext, T> merge_func;
merge_func(context.template device_context<DeviceContext>(), *in_grad,
&tmp_grad_merge, true);
grad_merge_ptr = &tmp_grad_merge;
std::vector<int64_t> in_rows;
in_rows.reserve(grad_merge_ptr->rows().size());
std::copy(grad_merge_ptr->rows().begin(), grad_merge_ptr->rows().end(),
std::back_inserter(in_rows));
auto* out_row = context.Output<Tensor>("Row");
out_row->Resize(
framework::make_ddim({static_cast<int64_t>(in_rows.size()), 1}));
out_row->mutable_data<int64_t>(context.GetPlace());
framework::TensorFromVector(in_rows, context.device_context(), out_row);
auto in_value = grad_merge_ptr->value();
std::vector<T> ins_vector;
framework::TensorToVector(in_value, context.device_context(), &ins_vector);
auto dims = in_value.dims();
auto is_entry = context.Attr<bool>("is_entry");
auto tablename = context.Attr<std::string>("tablename");
if (is_entry) {
auto* ins = distributed::LargeScaleKV::GetInstance();
std::vector<int64_t> ids;
ins->Get(tablename)->GetEntry(in_rows, &ids);
for (auto& id : ids) {
auto it = std::find(in_rows.begin(), in_rows.end(), id);
if (it == in_rows.end()) {
PADDLE_THROW(platform::errors::OutOfRange(
"the input key should be exists. But received %d.", id));
}
auto distance =
static_cast<int64_t>(std::distance(in_rows.begin(), it));
std::fill(ins_vector.data() + distance * dims[1],
ins_vector.data() + dims[1], 0.0);
}
}
auto* out_v = context.OutputVar("Value");
out_v->Clear();
auto* out_t = out_v->GetMutable<framework::LoDTensor>();
out_t->mutable_data<T>(context.GetPlace());
framework::TensorFromVector(ins_vector, context.device_context(), out_t);
out_t->Resize(dims);
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
// examples: embedding:Param,Moment1,Moment2:64,64,64:0
constexpr char kLargeScaleKV[] = "large_scale_metas";
constexpr int64_t kNoPadding = -1;
static void split(const std::string &str, char sep,
std::vector<std::string> *pieces) {
pieces->clear();
if (str.empty()) {
return;
}
size_t pos = 0;
size_t next = str.find(sep, pos);
while (next != std::string::npos) {
pieces->push_back(str.substr(pos, next - pos));
pos = next + 1;
next = str.find(sep, pos);
}
if (!str.substr(pos).empty()) {
pieces->push_back(str.substr(pos));
}
}
class LookupSparseTableInitInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {}
};
void InitLargeScaleKV(std::vector<std::string> kv_attrs) {
std::vector<distributed::SparseMeta> metas;
for (auto attrs : kv_attrs) {
std::vector<std::string> pieces;
split(attrs, ':', &pieces);
PADDLE_ENFORCE_EQ(
pieces.size(), 8,
platform::errors::InvalidArgument(
"param, names, dims, mode, grad, cached_var, init_attrs"));
std::string name;
std::string grad_name;
std::vector<std::string> value_names;
std::vector<int> value_dims;
distributed::Mode mode;
std::vector<std::string> cached_names;
std::vector<std::string> init_attrs;
std::string entry_attr;
name = pieces[0];
split(pieces[1], ',', &value_names);
std::vector<std::string> value_dims_str;
split(pieces[2], ',', &value_dims_str);
for (auto &str : value_dims_str) {
value_dims.push_back(std::stoi(str));
}
mode = pieces[3] == "0" ? distributed::Mode::training
: distributed::Mode::infer;
grad_name = pieces[4];
split(pieces[5], ',', &cached_names);
split(pieces[6], ',', &init_attrs);
entry_attr = pieces[7];
auto meta = distributed::SparseMeta();
meta.name = name;
meta.value_names = value_names;
meta.value_dims = value_dims;
meta.mode = mode;
meta.grad_name = grad_name;
meta.cached_varnames = cached_names;
meta.initializer_attrs = init_attrs;
meta.entry = entry_attr;
VLOG(3) << "add sparse meta: " << meta.ToString();
metas.push_back(meta);
}
distributed::LargeScaleKV::Init(metas);
VLOG(3) << "init large scale kv with " << metas.size() << " params";
}
class LookupSparseTableInitOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto kv_attrs = Attr<std::vector<std::string>>(kLargeScaleKV);
InitLargeScaleKV(kv_attrs);
}
};
class LookupSparseTableInitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddAttr<std::vector<std::string>>(kLargeScaleKV,
"(string)"
"sparse table name");
AddComment(R"DOC(
Lookup Sprase Tablel Operator.
This operator is used to perform lookup on parameter W,
then concatenated into a sparse tensor.
The type of Ids(Input) is SelectedRows, the rows of Ids contains
the ids to be looked up in W;
if the Id is not in the sparse table, this operator will return a
random value and set the value into the table for the next looking up.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
lookup_sparse_table_init, ops::LookupSparseTableInitOp,
ops::LookupSparseTableInitInferShape, ops::LookupSparseTableInitOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/distributed_ops/lookup_sparse_table_merge_op.h"
namespace paddle {
namespace operators {
class LookupSparseTableMergeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInputs("X"), true,
platform::errors::InvalidArgument("Input(X) should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument("Output(Out) should not be null."));
PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("X").front(),
framework::proto::VarType::SELECTED_ROWS,
platform::errors::InvalidArgument(
"Input X only should be SelectedRows."));
PADDLE_ENFORCE_EQ(ctx->GetOutputsVarType("Out").front(),
framework::proto::VarType::SELECTED_ROWS,
platform::errors::InvalidArgument(
"Output Y only should be SelectedRows."));
ctx->ShareDim("X", /*->*/ "Out");
}
};
class LookupSparseTableMergeMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"The input type is SelectedRows, and the selected rows may be "
"duplicated.")
.AsDuplicable();
AddOutput("Out",
"The output type is SelectedRows, and the selected rows are not "
"duplicated.");
AddComment(
R"DOC(
Merge sparse lookup table(selected rows as parameter).
)DOC");
}
};
class LookupSparseTableMergeOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(lookup_sparse_table_merge, ops::LookupSparseTableMergeOp,
ops::LookupSparseTableMergeMaker,
ops::LookupSparseTableMergeOpInferVarType);
REGISTER_OP_CPU_KERNEL(
lookup_sparse_table_merge,
ops::LookupSparseTableMergeKernel<plat::CPUDeviceContext, float>,
ops::LookupSparseTableMergeKernel<plat::CPUDeviceContext, double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
int64_t GetDelimiterForShard(const std::vector<int64_t>& rows, int start_idx,
int shard_id, int shard_num) {
int64_t rows_num = rows.size() / 2;
for (int64_t i = start_idx; i < rows_num; ++i) {
if (rows[i] % shard_num != shard_id) {
return i;
}
}
return rows_num;
}
template <typename DeviceContext, typename T>
class LookupSparseTableMergeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto inputs = ctx.MultiInput<framework::SelectedRows>("X");
auto* out = ctx.Output<framework::SelectedRows>("Out");
int64_t height = 0;
int64_t ids_num = 0;
int64_t width = 0;
height = inputs[0]->height();
width = inputs[0]->value().dims()[1];
for (auto& in : inputs) {
ids_num += in->rows().size();
height += in->height();
}
T* out_data = out->mutable_value()->mutable_data<T>({ids_num, width},
platform::CPUPlace());
out->set_height(height);
std::vector<int64_t> all_ids;
all_ids.reserve(ids_num);
for (auto& in : inputs) {
all_ids.insert(all_ids.end(), in->rows().begin(), in->rows().end());
}
out->set_rows(all_ids);
int64_t cnt = 0;
for (auto& in : inputs) {
auto rows = in->rows().size();
const T* in_data = in->value().data<T>();
std::copy_n(in_data, rows * width, out_data + cnt);
cnt += rows * width;
}
out->SyncIndex();
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
constexpr int64_t kNoPadding = -1;
class LookupSparseTableReadInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {}
};
class LookupSparseTableReadOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto init = Attr<bool>("init");
auto &id_tensor = scope.FindVar(Input("Ids"))->Get<framework::LoDTensor>();
auto *id_data = id_tensor.data<int64_t>();
auto tablename = Attr<std::string>("tablename");
auto value_names = Attr<std::vector<std::string>>("value_names");
auto out_names = Outputs("Out");
std::vector<int64_t> ids;
for (int64_t i = 0; i < id_tensor.numel(); ++i) {
ids.push_back(id_data[i]);
}
std::vector<std::vector<std::vector<float> *>> values;
std::vector<int64_t> dims;
auto *ins = distributed::LargeScaleKV::GetInstance();
if (init) {
ins->Get(tablename)->Init(ids);
ins->Get(tablename)->Get(ids, value_names, &values);
} else {
ins->Get(tablename)->Get(ids, value_names, &values);
}
ins->Get(tablename)->Dims(value_names, &dims);
platform::CPUPlace cpu;
std::vector<float *> tensors;
for (int i = 0; i < static_cast<int>(value_names.size()); i++) {
auto out_var = scope.FindVar(out_names[i]);
auto out_t = out_var->GetMutable<framework::LoDTensor>();
std::vector<int64_t> o_dims;
o_dims.push_back(static_cast<int64_t>(ids.size()));
o_dims.push_back(dims[i]);
out_t->Resize(framework::make_ddim(o_dims));
auto *out_d = out_t->mutable_data<float>(cpu);
tensors.push_back(out_d);
}
for (int i = 0; i < static_cast<int>(values.size()); i++) {
for (int j = 0; j < static_cast<int>(tensors.size()); j++) {
std::memcpy(tensors[j] + i * dims[j], values[i][j]->data(),
sizeof(float) * dims[j]);
}
}
}
};
class LookupSparseTableReadOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Ids",
"(LoDTensor) Ids's type should be LoDTensor"
"THe ids to be looked up in W.");
AddOutput("Out",
"(LoDTensor) The lookup results, which have the "
"same type as W.")
.AsDuplicable();
AddAttr<std::string>("tablename",
"(string)"
"sparse table name");
AddAttr<std::vector<std::string>>("value_names",
"(strings)"
"sparse table name");
AddAttr<bool>("init", " for test init large scale kv").SetDefault(false);
AddComment(R"DOC(
Lookup Sprase Tablel Operator.
This operator is used to perform lookup on parameter W,
then concatenated into a sparse tensor.
The type of Ids(Input) is SelectedRows, the rows of Ids contains
the ids to be looked up in W;
if the Id is not in the sparse table, this operator will return a
random value and set the value into the table for the next looking up.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
lookup_sparse_table_read, ops::LookupSparseTableReadOp,
ops::LookupSparseTableReadInferShape, ops::LookupSparseTableReadOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
constexpr int64_t kNoPadding = -1;
class LookupSparseTableWriteInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {}
};
class LookupSparseTableWriteOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto &id_tensor = scope.FindVar(Input("Ids"))->Get<framework::LoDTensor>();
auto *id_data = id_tensor.data<int64_t>();
std::vector<int64_t> ids;
for (int64_t i = 0; i < id_tensor.numel(); ++i) {
ids.push_back(id_data[i]);
}
auto tablename = Attr<std::string>("tablename");
auto value_names = Attr<std::vector<std::string>>("value_names");
std::vector<const float *> tensors;
std::vector<int64_t> dims;
std::vector<std::vector<std::vector<float>>> values;
values.resize(ids.size());
auto in_names = Inputs("In");
for (int i = 0; i < static_cast<int>(in_names.size()); i++) {
auto *in = scope.FindVar(in_names[i]);
auto in_t = in->Get<framework::LoDTensor>();
dims.push_back(in_t.dims()[1]);
tensors.push_back(in_t.data<float>());
}
for (int i = 0; i < static_cast<int>(ids.size()); i++) {
values[i].resize(tensors.size());
for (int j = 0; j < static_cast<int>(tensors.size()); j++) {
values[i][j].resize(dims[j]);
std::memcpy(values[i][j].data(), tensors[j] + i * dims[j],
sizeof(float) * dims[j]);
}
}
auto *ins = distributed::LargeScaleKV::GetInstance();
ins->Get(tablename)->Set(ids, value_names, values);
}
};
class LookupSparseTableWriteOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Ids",
"(LoDTensor) Ids's type should be LoDTensor"
"THe ids to be looked up in W.");
AddInput("In",
"(LoDTensor) The lookup results, which have the "
"same type as W.")
.AsDuplicable();
AddAttr<std::string>("tablename",
"(string)"
"sparse table name");
AddAttr<std::vector<std::string>>("value_names",
"(strings)"
"sparse table name");
AddComment(R"DOC(
Lookup Sprase Tablel Operator.
This operator is used to perform lookup on parameter W,
then concatenated into a sparse tensor.
The type of Ids(Input) is SelectedRows, the rows of Ids contains
the ids to be looked up in W;
if the Id is not in the sparse table, this operator will return a
random value and set the value into the table for the next looking up.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
lookup_sparse_table_write, ops::LookupSparseTableWriteOp,
ops::LookupSparseTableWriteInferShape, ops::LookupSparseTableWriteOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/distributed_ops/merge_ids_op.h"
namespace paddle {
namespace operators {
class MergeIdsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}")
.AsDuplicable();
AddInput("Rows", "(LoDTensor) the input ids with shape{row_size, 1}, ")
.AsDuplicable();
AddInput("X",
"(LoDTensors) multi input tensor with shape{Rows, N}, N is the "
"size of embedding table")
.AsDuplicable();
AddOutput("Out", "(LoDTensor) The merged outputs of the input tensors.")
.AsDuplicable();
AddComment(R"DOC(
Merge multi LoDTensor's into one according to Ids's shard num.
split_ids_op -> prefetch_op -> merge_ids_op
merge_ids_op should be used after split_ids_op and prefetch_op, split_ids_op
will split input Ids into multiple tensors according to Id's shard number.
prefetch_op will send them to parameter server to prefetch embedding value
back. During split, the order of ids is disordered. In merge_ids_op we use
the original Ids to restore the order of the fetched embedding value and
also pass the lod information to the merged output.
Example:
Ids = [1,2,3,4,5,6] # 3 shared
split_ids_op ->
Id0 = [3, 6] # id % 3 == 0
Id1 = [1, 4] # id % 3 == 1
Id2 = [2, 5] # id % 3 == 2
prefetch_op ->
X0 = [[0.3 0.3] # 3
[0.6 0.6]] # 6
X1 = [[0.1 0.1] # 1
[0.4 0.4]] # 4
X2 = [[0.2 0.2] # 2
[0.5 0.5]] # 5
merge_ids_op ->
Out = [[0.1 0.1] # 1
[0.2 0.2] # 2
[0.3 0.3] # 3
[0.4 0.4] # 4
[0.5 0.5] # 5
[0.6 0.6]] # 6
)DOC");
}
};
class MergeIdsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs("Ids"), "Input", "Ids", "MergeIds");
OP_INOUT_CHECK(ctx->HasInputs("Rows"), "Input", "Rows", "MergeIds");
OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "MergeIds");
OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out", "MergeIds");
auto ids_var_type = ctx->GetInputsVarType("Ids").front();
auto ids_dims = ctx->GetInputsDim("Ids");
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(
ids_dims[0].size(), 2,
platform::errors::InvalidArgument(
"the ids size must be 2, but received %d", ids_dims[0].size()));
PADDLE_ENFORCE_EQ(
ids_dims[0][1], 1,
platform::errors::InvalidArgument(
"the ids dim must be 1, but received %d", ids_dims[0][1]));
}
auto x_var_type = ctx->GetInputsVarType("X");
for (auto &var_type : x_var_type) {
PADDLE_ENFORCE_EQ(var_type, framework::proto::VarType::LOD_TENSOR,
platform::errors::InvalidArgument(
"input X only support lod tensors"));
}
ctx->ShareLoD("Ids", "Out");
}
private:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
class MergeIdsOpInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto input_type = ctx->GetInputType("Ids");
ctx->SetOutputType("Out", input_type, framework::ALL_ELEMENTS);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(merge_ids, ops::MergeIdsOp, ops::MergeIdsOpMaker,
ops::MergeIdsOpInferVarType);
REGISTER_OP_CPU_KERNEL(
merge_ids, ops::MergeIdsOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <tuple>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class MergeIdsOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace();
if (!platform::is_cpu_place(place)) {
PADDLE_THROW(platform::errors::InvalidArgument(
"MergeIds do not support GPU kernel"));
}
const auto ids = ctx.MultiInput<framework::LoDTensor>("Ids");
const auto row_ids = ctx.MultiInput<framework::LoDTensor>("Rows");
const auto x_tensors = ctx.MultiInput<framework::LoDTensor>("X");
auto outs = ctx.MultiOutput<framework::LoDTensor>("Out");
PADDLE_ENFORCE_EQ(row_ids.size(), x_tensors.size(),
platform::errors::InvalidArgument(
"the number of Rows and X should be the same"));
PADDLE_ENFORCE_EQ(ids.size(), outs.size(),
platform::errors::InvalidArgument(
"the number of Ids and Out should be the same"));
int64_t row_ids_size = 0;
int64_t row_size = 0;
int64_t embedding_size = 0;
for (size_t i = 0; i < x_tensors.size(); ++i) {
const auto *x_tensor = x_tensors[i];
const auto *row_id = row_ids[i];
if (embedding_size == 0) {
embedding_size = x_tensor->dims()[1];
}
PADDLE_ENFORCE_EQ(embedding_size, x_tensor->dims()[1],
platform::errors::InvalidArgument(
"embedding size of all input should be the same"));
row_size += x_tensor->dims()[0];
row_ids_size += row_id->dims()[0];
}
PADDLE_ENFORCE_EQ(
row_size, row_ids_size,
platform::errors::InvalidArgument(
"the merged X dim[0] and merged Rows dim[0] should be the same"));
std::unordered_map<int64_t, std::tuple<int64_t, int64_t>>
selected_rows_idx_map;
for (size_t i = 0; i < x_tensors.size(); ++i) {
const auto *row_id = row_ids[i];
for (auto j = 0; j < row_id->numel(); ++j) {
int64_t key = row_id->data<int64_t>()[j];
std::tuple<int64_t, int64_t> val = std::make_tuple(i, j);
selected_rows_idx_map.insert(std::make_pair(key, val));
}
}
PADDLE_ENFORCE_EQ(row_ids_size, selected_rows_idx_map.size(),
platform::errors::InvalidArgument(
"the rows and tensor map size should be the same"));
for (size_t i = 0; i < outs.size(); ++i) {
auto *out_ids = ids[i];
auto *out = outs[i];
out->set_lod(out_ids->lod());
auto nums = out_ids->dims()[0];
auto *out_data = out->mutable_data<T>(
framework::make_ddim({nums, embedding_size}), place);
for (auto j = 0; j < nums; ++j) {
auto id = out_ids->data<int64_t>()[j];
auto row_tuple = selected_rows_idx_map.at(id);
auto row_idx = std::get<1>(row_tuple);
const auto *x_tensor = x_tensors[std::get<0>(row_tuple)];
memcpy(out_data + embedding_size * j,
x_tensor->data<T>() + row_idx * embedding_size,
sizeof(T) * embedding_size);
}
}
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/distributed.h"
namespace paddle {
namespace framework {
class InferShapeContext;
class OpDesc;
class Scope;
template <typename T>
class EmptyGradOpMaker;
} // namespace framework
namespace imperative {
class OpBase;
} // namespace imperative
} // namespace paddle
namespace paddle {
namespace operators {
namespace distributed {
class RPCClient;
} // namespace distributed
class PrefetchOp : public framework::OperatorBase {
public:
PrefetchOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
auto ins = Inputs("X");
auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id"));
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i] << " to get "
<< outs[i] << " back";
rets.push_back(rpc_client->AsyncPrefetchVar(epmap[i], ctx, scope,
ins[i], outs[i]));
} else {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
}
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_EQ(
rets[i]->Wait(), true,
platform::errors::Fatal(
"It's a fatal error of RPCClient that RPCClient can't "
"get the wait result. It may happen when trainers or "
"parameter servers exit un normally or the network "
"issue!"));
}
}
};
class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(LoDTensor) Input Id variables to be sent").AsDuplicable();
AddOutput("Out",
"(LoDTensor) result "
"to be fetched from parameter server")
.AsDuplicable();
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<std::string>>(
"epmap",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input variables for mapping")
.SetDefault({"127.0.0.1:6164"});
AddComment(R"DOC(
Prefetch operator
This operator will send Ids variables to listen_and_serve op at
the parameter server and fetch result back.
)DOC");
}
};
class PrefetchOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
prefetch, ops::PrefetchOp,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::PrefetchOpMaker, ops::PrefetchOpShapeInference);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/communicator.h"
#include "paddle/fluid/operators/distributed/distributed.h"
namespace paddle {
namespace framework {
class InferShapeContext;
class OpDesc;
class Scope;
template <typename T>
class EmptyGradOpMaker;
} // namespace framework
namespace imperative {
class OpBase;
} // namespace imperative
} // namespace paddle
namespace paddle {
namespace operators {
namespace distributed {
class RPCClient;
} // namespace distributed
class RecvOp : public framework::OperatorBase {
public:
RecvOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::vector<std::string> varnames =
Attr<std::vector<std::string>>("varnames");
auto outs = Outputs("Out");
bool with_barrier = Attr<bool>("with_barrier");
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &ctx = *pool.Get(place);
auto trainer_id = Attr<int>("trainer_id");
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
std::vector<std::string> recv_varnames =
Attr<std::vector<std::string>>("recv_varnames");
if (recv_varnames.size() > 0) {
auto *communicator = distributed::Communicator::GetInstance();
if (communicator != nullptr) {
PADDLE_THROW(platform::errors::InvalidArgument(
"execute startup program must before fleet.init_worker"));
}
} else {
std::vector<distributed::VarHandlePtr> rets;
if (with_barrier) {
for (size_t i = 0; i < outs.size(); i++) {
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
<< varname << " and with AsyncGetVar";
rets.push_back(
rpc_client->AsyncGetVar(epmap[i], ctx, scope, varname, outs[i]));
}
} else {
for (size_t i = 0; i < outs.size(); i++) {
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
<< varname << " and with AsyncGetVarNoBarrier";
rets.push_back(rpc_client->AsyncGetVarNoBarrier(epmap[i], ctx, scope,
varname, outs[i]));
}
}
for (size_t i = 0; i < rets.size(); i++) {
VLOG(7) << "before sync_recv " << outs[i] << "from " << epmap[i];
PADDLE_ENFORCE_NE(
rets[i]->Wait(), 0U,
platform::errors::ExecutionTimeout("internal error in RPCClient"));
VLOG(7) << "after sync_recv " << outs[i] << "from " << epmap[i];
}
}
}
};
class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Any) Dummy inputs, used for control dependency")
.AsDuplicable();
AddOutput("Out", "(Tensor) Variables to get from server.").AsDuplicable();
AddComment(R"DOC(
Recv operator
This operator can get variables from server side.
)DOC");
AddAttr<std::vector<std::string>>("epmap",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input "
"variables for mapping")
.SetDefault({});
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<bool>("with_barrier",
"(bool, default True) if with_barrier=False, will use "
"AsyncGetVarNoBarrier get variable from pserver immediately")
.SetDefault(true);
AddAttr<std::vector<std::string>>(
"varnames",
"(string vector, default {}) "
"sometimes we need to put received var in another name "
"for example: we need var named 'moment_1@127.0.0.1:1001', "
"and it real name on parameter server is 'moment_1'. ")
.SetDefault({});
AddAttr<std::vector<std::string>>(
"recv_varnames",
"(vector<string>) "
"the split parameter varnames to be recved from pserver")
.SetDefault(std::vector<std::string>{});
AddAttr<int>("do_not_run", "if recv need to really run").SetDefault(0);
}
};
class RecvOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
recv, ops::RecvOp,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::RecvOpMaker, ops::RecvOpShapeInference);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdint.h>
#include <fstream>
#include <numeric>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/version.h"
#include "paddle/fluid/operators/distributed/communicator_common.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace operators {
class RecvSaveOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
platform::CPUPlace());
}
};
class RecvSaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddComment(R"DOC(
Recv Save operator
This operator will serialize and write LoDTensor variable to file on disk.
)DOC");
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::proto::VarType::FP32);
AddAttr<bool>("overwrite",
"(boolean, default true)"
"Overwrite the output file if exist")
.SetDefault(true);
AddAttr<std::string>("file_path",
"(string)"
"The \"file_path\" where the variable will be saved.")
.AddCustomChecker(
[](const std::string &path) { return !path.empty(); });
AddAttr<std::vector<int64_t>>("shape",
"(vector<int64_t>) The shape of the output")
.SetDefault({});
AddAttr<std::vector<std::string>>(
"slice_varnames",
"(string vector, default {}) "
"sometimes we need to put received var in another name "
"for example: we need var named 'moment_1@127.0.0.1:1001', "
"and it real name on parameter server is 'moment_1'. ")
.SetDefault({});
AddAttr<std::vector<std::string>>(
"remote_varnames",
"(string vector, default {}) "
"sometimes we need to put received var in another name "
"for example: we need var named 'moment_1@127.0.0.1:1001', "
"and it real name on parameter server is 'moment_1'. ")
.SetDefault({});
AddAttr<std::vector<std::string>>("slice_shapes",
"(vector<int>) "
"the length of each output along the "
"specified axis.")
.SetDefault({});
AddAttr<std::vector<std::string>>("endpoints",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input "
"variables for mapping")
.SetDefault({});
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<bool>("is_sparse", "sparse or dense param");
AddAttr<int>("pserver_num", "the number of pserver").SetDefault(0);
AddAttr<bool>("is_distributed", "sparse id range [0, N) or [0, INT64]")
.SetDefault(false);
}
};
template <typename DeviceContext, typename T>
class RecvSaveOpKernel : public framework::OpKernel<T> {
private:
void SerializeVersionToStream(std::ostream &os) const {
{ // the 1st field, uint32_t version for LoDTensor
os.write(reinterpret_cast<const char *>(&framework::kCurTensorVersion),
sizeof(framework::kCurTensorVersion));
}
// the 2st field, LoD information
// in this scene, skip LoD information.
uint64_t size = 0;
os.write(reinterpret_cast<const char *>(&size), sizeof(size));
}
void SerializeTensorHeaderToStream(
std::ostream &os, const framework::proto::VarType::Type &type,
const framework::DDim &dims) const {
{ // the 1st field, uint32_t version
constexpr uint32_t version = 0;
os.write(reinterpret_cast<const char *>(&version), sizeof(version));
}
{ // the 2nd field, tensor description
// int32_t size
// void* protobuf message
framework::proto::VarType::TensorDesc desc;
desc.set_data_type(type);
auto tensor_dims = framework::vectorize(dims);
auto *pb_dims = desc.mutable_dims();
pb_dims->Resize(static_cast<int>(tensor_dims.size()), 0);
std::copy(tensor_dims.begin(), tensor_dims.end(), pb_dims->begin());
int32_t size = desc.ByteSize();
os.write(reinterpret_cast<const char *>(&size), sizeof(size));
auto out = desc.SerializeAsString();
os.write(out.data(), size);
}
}
void SerializeTensorAppendToStream(std::ostream &os,
const framework::Tensor &tensor) const {
uint64_t size = tensor.numel() * framework::SizeOfType(tensor.type());
auto *data_ptr = tensor.data<void>();
PADDLE_ENFORCE_LT(size, std::numeric_limits<std::streamsize>::max(),
platform::errors::ResourceExhausted(
"tensor size %d overflow when writing tensor", size));
os.write(static_cast<const char *>(data_ptr),
static_cast<std::streamsize>(size));
}
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto filename = ctx.Attr<std::string>("file_path");
auto overwrite = ctx.Attr<bool>("overwrite");
if (FileExists(filename) && !overwrite) {
PADDLE_THROW(platform::errors::AlreadyExists(
"%s is existed, cannot save to it when overwrite=false", filename));
}
MkDirRecursively(DirName(filename).c_str());
auto origin_shape = ctx.Attr<std::vector<int64_t>>("shape");
auto slice_shapes = ctx.Attr<std::vector<std::string>>("slice_shapes");
auto slice_varnames = ctx.Attr<std::vector<std::string>>("slice_varnames");
auto remote_varnames =
ctx.Attr<std::vector<std::string>>("remote_varnames");
auto endpoints = ctx.Attr<std::vector<std::string>>("endpoints");
auto trainer_id = ctx.Attr<int>("trainer_id");
auto is_sparse = ctx.Attr<bool>("is_sparse");
auto pserver_num = ctx.Attr<int>("pserver_num");
// auto is_distributed = ctx.Attr<int>("is_distributed");
PADDLE_ENFORCE_EQ(slice_shapes.size(), slice_varnames.size(),
platform::errors::InvalidArgument(
"Expected attr len(slice_shapes) must be equal to "
"len(slice_varnames)"));
PADDLE_ENFORCE_EQ(
slice_shapes.size(), endpoints.size(),
platform::errors::InvalidArgument(
"Expected attr len(slice_shapes) must be equal to len(endpoints)"));
auto data_type =
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));
// it to save an output stream.
std::ofstream fout(filename, std::ios::binary);
PADDLE_ENFORCE_EQ(
static_cast<bool>(fout), true,
platform::errors::NotFound("Cannot open %s to write", filename));
SerializeVersionToStream(fout);
SerializeTensorHeaderToStream(fout, data_type,
framework::make_ddim(origin_shape));
framework::Scope &local_scope = ctx.scope().NewScope();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto place = ctx.GetPlace();
auto &device_ctx = *pool.Get(place);
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
if (!is_sparse) {
for (size_t i = 0; i < slice_varnames.size(); i++) {
auto &varname = slice_varnames[i];
auto *var = local_scope.Var(varname);
auto *tensor = var->GetMutable<framework::LoDTensor>();
auto slice_string =
string::split_string<std::string>(slice_shapes[i], ",");
std::vector<int64_t> slice_shape;
for (auto &dim : slice_string) {
slice_shape.push_back(static_cast<int64_t>(std::stoull(dim)));
}
tensor->Resize(framework::make_ddim(slice_shape));
distributed::VarHandlePtr ret;
ret = rpc_client->AsyncGetVarNoBarrier(
endpoints[i], device_ctx, local_scope, remote_varnames[i], varname);
PADDLE_ENFORCE_NE(
ret->Wait(), 0U,
platform::errors::ExecutionTimeout(
"rpc error when communication with %s", endpoints[i]));
auto &c_tensor = var->Get<framework::LoDTensor>();
SerializeTensorAppendToStream(fout, c_tensor);
local_scope.EraseVars({varname});
}
} else {
PADDLE_ENFORCE_GT(
pserver_num, 0,
platform::errors::InvalidArgument(
"Expected attr len(pserver_num) must gather than 0"));
std::vector<std::string> varnames;
auto *var = local_scope.Var("tmp_for_sparse_merge");
auto *o_t = var->GetMutable<framework::LoDTensor>();
o_t->Resize(framework::make_ddim(origin_shape));
auto *out_d = o_t->mutable_data<float>(place);
varnames.push_back("tmp_for_sparse_merge");
for (size_t i = 0; i < slice_varnames.size(); i++) {
varnames.push_back(slice_varnames[i]);
}
std::vector<const float *> tensors;
for (size_t i = 0; i < slice_varnames.size(); i++) {
auto &varname = slice_varnames[i];
auto *local_var = local_scope.Var(varname);
auto *tensor = local_var->GetMutable<framework::LoDTensor>();
auto slice_string =
string::split_string<std::string>(slice_shapes[i], ",");
std::vector<int64_t> slice_shape;
for (auto &dim : slice_string) {
slice_shape.push_back(static_cast<int64_t>(std::stoull(dim)));
}
tensor->Resize(framework::make_ddim(slice_shape));
distributed::VarHandlePtr ret;
ret = rpc_client->AsyncGetVarNoBarrier(
endpoints[i], device_ctx, local_scope, remote_varnames[i], varname);
PADDLE_ENFORCE_NE(
ret->Wait(), 0U,
platform::errors::ExecutionTimeout(
"rpc error when communication with %s", endpoints[i]));
const auto *value =
local_var->Get<framework::LoDTensor>().data<float>();
tensors.push_back(value);
}
auto dims1 = origin_shape[1];
for (int j = 0; j < origin_shape[0]; ++j) {
auto id = j % pserver_num;
auto idx = j / pserver_num;
std::memcpy(out_d + j * dims1, tensors[id] + idx * dims1,
sizeof(float) * dims1);
}
auto &c_tensor = var->Get<framework::LoDTensor>();
SerializeTensorAppendToStream(fout, c_tensor);
local_scope.EraseVars(varnames);
}
fout.close();
ctx.scope().DeleteScope(&local_scope);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(recv_save, ops::RecvSaveOp, ops::RecvSaveOpProtoMaker);
REGISTER_OP_CPU_KERNEL(
recv_save, ops::RecvSaveOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::RecvSaveOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::RecvSaveOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::RecvSaveOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.h"
#include <string>
namespace paddle {
namespace framework {
class InferShapeContext;
class OpDesc;
template <typename T>
class EmptyGradOpMaker;
} // namespace framework
namespace imperative {
class OpBase;
} // namespace imperative
} // namespace paddle
namespace paddle {
namespace operators {
class RefByTrainerIdOp : public framework::OperatorWithKernel {
public:
RefByTrainerIdOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInputs("X"), true,
platform::errors::InvalidArgument(
"Input(X) of RefByTrainerIdOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("TrainerId"), true,
platform::errors::InvalidArgument(
"Input(TrainerId) of RefByTrainerIdOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of RefByTrainerIdOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("TrainerId").size(), 1,
platform::errors::InvalidArgument("TrainerId should be a scalar."));
// Out's shape is determined at runtime.
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
class RefByTrainerIdOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) Input tensor list.").AsDuplicable();
AddInput("TrainerId", "(Tensor) Scalar int, the trainer id runtime value.");
AddOutput("Out", "(Tensor) Return one tensor reference of X[trainer_id]");
AddComment(R"DOC(
**RefByTrainerId operator**
Return a reference of a tensor, using trainer_id as the index to find from the input.
$$Out = X[TrainerId]$$
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(ref_by_trainer_id, ops::RefByTrainerIdOp,
ops::RefByTrainerIdOpMaker);
REGISTER_OP_CPU_KERNEL(
ref_by_trainer_id,
ops::RefByTrainerIdKernel<paddle::platform::CPUDeviceContext, float>,
ops::RefByTrainerIdKernel<paddle::platform::CPUDeviceContext, double>,
ops::RefByTrainerIdKernel<paddle::platform::CPUDeviceContext, int>,
ops::RefByTrainerIdKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.h"
REGISTER_OP_CUDA_KERNEL(
ref_by_trainer_id,
paddle::operators::RefByTrainerIdKernel<paddle::platform::CUDADeviceContext,
float>,
paddle::operators::RefByTrainerIdKernel<paddle::platform::CUDADeviceContext,
double>,
paddle::operators::RefByTrainerIdKernel<paddle::platform::CUDADeviceContext,
int>,
paddle::operators::RefByTrainerIdKernel<paddle::platform::CUDADeviceContext,
int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdio.h>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class RefByTrainerIdKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& context) const {
auto* out = context.Output<framework::Tensor>("Out");
auto in_list = context.MultiInput<framework::Tensor>("X");
auto* trainer_id_t = context.Input<framework::Tensor>("TrainerId");
int64_t trainer_id = 0;
auto* trainer_id_data = trainer_id_t->data<int64_t>();
if (platform::is_gpu_place(context.GetPlace())) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto stream = context.cuda_device_context().stream();
memory::Copy<>(platform::CPUPlace(), &trainer_id,
BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
trainer_id_data, sizeof(int64_t), stream);
#endif
} else {
trainer_id = *trainer_id_data;
}
PADDLE_ENFORCE_LT((size_t)trainer_id, in_list.size(),
platform::errors::InvalidArgument(
"X' size must >= TrainerId: [%s], but received [%s]",
trainer_id, in_list.size()));
out->mutable_data<T>(context.GetPlace());
framework::TensorCopy(*(in_list[trainer_id]), in_list[trainer_id]->place(),
out);
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/communicator.h"
#include "paddle/fluid/operators/distributed/communicator_common.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/parameter_send.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class SendAndRecvKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& scope = ctx.scope();
const auto& place = ctx.GetPlace();
auto send_var_name = ctx.Attr<std::string>("send_var_name");
auto recv_var_name = ctx.Attr<std::string>("recv_var_name");
auto epmap = ctx.Attr<std::string>("endpoint");
auto trainer_id = ctx.Attr<int>("trainer_id");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& context = *pool.Get(place);
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
VLOG(3) << "SendAndRecvOp Send_var_name: " << send_var_name
<< " Recv_var_name: " << recv_var_name;
distributed::VarHandlePtr rets = rpc_client->AsyncSendAndRecv(
epmap, context, scope, send_var_name, recv_var_name);
rets->Wait();
}
};
class SendAndRecvOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, platform::CPUPlace());
}
};
class SendAndRecvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "Tensor Input variable to be sent").AsDuplicable();
AddOutput("Out", "Tensor Output varibale to be recv").AsDuplicable();
AddAttr<std::string>("send_var_name", "Send Tensor's name")
.SetDefault(std::string(""));
AddAttr<std::string>("recv_var_name", "Recv Tensor's name")
.SetDefault(std::string(""));
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::string>("endpoint", "Server endpoint")
.SetDefault({"127.0.0.1:6164"});
AddComment(R"DOC(
SendAndRecv operator
This operator will send variables to listen_and_serve op at the parameter server.
And recv variable from parameter server of send variable's scope.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(send_and_recv, ops::SendAndRecvOp, ops::SendAndRecvOpMaker);
REGISTER_OP_CPU_KERNEL(
send_and_recv,
ops::SendAndRecvKernel<paddle::platform::CPUDeviceContext, float>)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/distributed.h"
namespace paddle {
namespace framework {
class InferShapeContext;
class OpDesc;
class Scope;
template <typename T>
class EmptyGradOpMaker;
} // namespace framework
namespace imperative {
class OpBase;
} // namespace imperative
} // namespace paddle
namespace paddle {
namespace operators {
namespace distributed {
class RPCClient;
} // namespace distributed
class SendBarrierOp : public framework::OperatorBase {
public:
SendBarrierOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
auto is_half_async = Attr<bool>("half_async");
if (is_half_async) {
distributed::Communicator::GetInstance()->Barrier();
return;
}
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id"));
VLOG(3) << "SendBarrierOp sync";
std::vector<distributed::VarHandlePtr> rets;
for (auto& ep : eps) {
VLOG(3) << "send barrier, ep: " << ep;
rets.push_back(rpc_client->AsyncSendBatchBarrier(ep));
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(
rets[i]->Wait(), 0U,
platform::errors::ExecutionTimeout("internal error in RPCClient"));
}
}
};
class SendBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Any) Dummy inputs, used for control dependency")
.AsDuplicable();
AddOutput("Out", "(Any) Dummy outputs, used for control dependency")
.AsDuplicable();
AddComment(R"DOC(
SendBarrier operator
This operator will send a send barrier signal to list_and_serv op, so that
the Parameter Server would knew all variables have been sent.
)DOC");
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<std::string>>("endpoints",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to.")
.SetDefault({"127.0.0.1:6164"});
AddAttr<bool>(
"half_async",
"(bool, default false)"
"half_async=True is for half_async mode, this will send signal "
"to HalfAsyncCommunicator Instance")
.SetDefault(false);
}
};
class SendBarrierOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
send_barrier, ops::SendBarrierOp,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::SendBarrierOpMaker, ops::SendBarrierOpShapeInference);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/distributed.h"
namespace paddle {
namespace framework {
class InferShapeContext;
class OpDesc;
class Scope;
template <typename T>
class EmptyGradOpMaker;
} // namespace framework
namespace imperative {
class OpBase;
} // namespace imperative
} // namespace paddle
namespace paddle {
namespace operators {
namespace distributed {
class RPCClient;
} // namespace distributed
class SendOp : public framework::OperatorBase {
public:
SendOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
auto ins = Inputs("X");
auto epmap = Attr<std::vector<std::string>>("endpoints");
auto trainer_id = Attr<int>("trainer_id");
auto send_varnames = Attr<std::vector<std::string>>("send_varnames");
auto height_sections = Attr<std::vector<int64_t>>("sections");
auto use_send_handler = Attr<bool>("use_send_handler");
if (send_varnames.size() > 0) {
distributed::Communicator::GetInstance()->Send(ins, send_varnames, scope);
} else {
platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
std::vector<distributed::VarHandlePtr> rets;
if (use_send_handler) {
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
rets.push_back(
rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]));
} else {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
}
}
} else {
for (size_t i = 0; i < ins.size(); i++) {
for (size_t j = 0; j < epmap.size(); j++) {
if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[j];
rets.push_back(rpc_client->AsyncDistributeNotify(epmap[j], ctx,
scope, ins[i]));
} else {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
}
}
}
}
for (size_t i = 0; i < rets.size(); i++) {
VLOG(7) << "before sync_send " << ins[i] << "from " << epmap[i];
PADDLE_ENFORCE_NE(
rets[i]->Wait(), 0U,
platform::errors::ExecutionTimeout("internal error in RPCClient"));
VLOG(7) << "after sync_send " << ins[i] << "from " << epmap[i];
}
}
}
};
class SendOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor, SelectedRows) Input variables to be sent")
.AsDuplicable();
AddOutput("Out", "(Any) Dummy outputs, used for control dependency")
.AsDuplicable();
AddComment(R"DOC(
Send operator
This operator will send variables to listen_and_serve op at the parameter server.
)DOC");
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<std::string>>("endpoints",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input "
"variables for mapping")
.SetDefault({"127.0.0.1:6164"});
AddAttr<std::vector<int64_t>>("sections",
"(vector<int>) "
"the length of each output along the "
"specified axis.")
.SetDefault(std::vector<int64_t>{});
AddAttr<std::vector<std::string>>(
"send_varnames",
"(vector<string>) "
"the split output varnames to send to pserver")
.SetDefault(std::vector<std::string>{});
AddAttr<int>("num",
"(int, default 0)"
"Number of sub-tensors. This must evenly divide "
"Input.dims()[axis]")
.SetDefault(0);
AddAttr<bool>("merge_add",
"(bool, default 0)"
"merge method, true represent add, false represent average")
.SetDefault(false);
AddAttr<bool>(
"use_send_handler",
"(bool, default 1)"
"if it's true, use send handler, other wise, use notify handler")
.SetDefault(true);
}
};
class SendOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
send, ops::SendOp,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::SendOpMaker, ops::SendOpShapeInference);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <string>
#include <thread> // NOLINT
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/string/printf.h"
USE_NO_KERNEL_OP(send);
USE_NO_KERNEL_OP(listen_and_serv);
USE_OP(sum);
namespace f = paddle::framework;
namespace p = paddle::platform;
namespace m = paddle::operators::math;
namespace d = paddle::operators::distributed
// global for simplicity.
std::unique_ptr<f::OperatorBase>
listen_and_serv_op;
int selected_port;
void InitTensorsInScope(const p::CPUPlace &place, f::Scope *scope) {
p::CPUDeviceContext ctx(place);
for (int i = 0; i < 2; ++i) {
auto var_name = paddle::string::Sprintf("x%d", i);
auto var = scope->Var(var_name);
auto tensor = var->GetMutable<f::LoDTensor>();
tensor->Resize({10, 10});
float *expect = tensor->mutable_data<float>(place);
for (int64_t i = 0; i < tensor->numel(); ++i) {
expect[i] = static_cast<float>(i);
}
}
auto out_var = scope->Var("Out");
auto out_tensor = out_var->GetMutable<f::LoDTensor>();
out_tensor->Resize({10, 10});
out_tensor->mutable_data<float>(place); // allocate
}
void InitSelectedRowsInScope(const p::CPUPlace &place, f::Scope *scope) {
p::CPUDeviceContext ctx(place);
int64_t height = 10;
int64_t row_numel = 10;
m::SetConstant<p::CPUDeviceContext, float> set_one;
// init x0
std::vector<int64_t> rows0{0, 4, 7};
auto x0_var = scope->Var("x0");
auto x0 = x0_var->GetMutable<f::SelectedRows>();
x0->set_rows(rows0);
x0->set_height(height);
auto x0_value = x0->mutable_value();
x0_value->mutable_data<float>(
f::make_ddim({static_cast<int64_t>(rows0.size()), row_numel}), place);
set_one(ctx, x0_value, 1.0);
// init x1
std::vector<int64_t> rows1{2, 9};
auto x1_var = scope->Var("x1");
auto x1 = x1_var->GetMutable<f::SelectedRows>();
x1->set_rows(rows1);
x1->set_height(height);
auto x1_value = x1->mutable_value();
x1_value->mutable_data<float>(
f::make_ddim({static_cast<int64_t>(rows1.size()), row_numel}), place);
set_one(ctx, x1_value, 1.0);
auto out_var = scope->Var("Out");
auto out = out_var->GetMutable<f::SelectedRows>();
auto out_value = out->mutable_value();
out->set_height(height);
out_value->mutable_data<float>(f::make_ddim({5, 10}), place);
}
void AddOp(const std::string &type, const f::VariableNameMap &inputs,
const f::VariableNameMap &outputs, f::AttributeMap attrs,
f::BlockDesc *block, bool is_sparse) {
// insert output
for (auto kv : outputs) {
for (auto v : kv.second) {
auto var = block->Var(v);
var->SetDataType(f::proto::VarType::FP32);
var->SetPersistable(true);
if (is_sparse) {
var->SetType(f::proto::VarType::SELECTED_ROWS);
}
}
}
// insert op
auto op = block->AppendOp();
op->SetType(type);
for (auto &kv : inputs) {
op->SetInput(kv.first, kv.second);
}
for (auto &kv : outputs) {
op->SetOutput(kv.first, kv.second);
}
op->SetAttrMap(attrs);
}
void StartServerNet(bool is_sparse, std::atomic<bool> *initialized) {
f::Scope scope;
p::CPUPlace place;
VLOG(4) << "before init tensor";
if (is_sparse) {
InitSelectedRowsInScope(place, &scope);
} else {
InitTensorsInScope(place, &scope);
}
// sub program run in listen_and_serv_op, for simple test we use sum
f::ProgramDesc program;
const auto &root_block = program.Block(0);
std::vector<framework::BlockDesc *> optimize_blocks;
auto *optimize_block = program.AppendBlock(root_block);
optimize_blocks.push_back(optimize_block);
auto *prefetch_block = program.AppendBlock(root_block);
// X for server side tensors, RX for received tensors, must be of same shape.
AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block,
is_sparse);
f::AttributeMap attrs;
attrs.insert({"endpoint", std::string("127.0.0.1:0")});
attrs.insert({"Fanin", 1});
attrs.insert({"ParamList", std::vector<std::string>({"Out"})});
attrs.insert({"GradList", std::vector<std::string>({"x1"})});
attrs.insert({"optimize_blocks", optimize_blocks});
attrs.insert({"PrefetchBlock", prefetch_block});
attrs.insert({"grad_to_block_id", std::vector<std::string>({""})});
attrs.insert({"distributed_mode", d::DistributedMode::kSync});
VLOG(4) << "before init op";
listen_and_serv_op =
f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs);
*initialized = true;
listen_and_serv_op->Run(scope, place);
LOG(INFO) << "server exit";
}
TEST(SendRecvOp, CPUDense) {
std::atomic<bool> initialized{false};
std::thread server_thread(StartServerNet, false, &initialized);
while (!initialized) {
}
static_cast<paddle::operators::ListenAndServOp *>(listen_and_serv_op.get())
->WaitServerReady();
// local net
f::Scope scope;
p::CPUPlace place;
InitTensorsInScope(place, &scope);
// create rpc client var
scope.Var("RPC_CLIENT_VAR");
f::AttributeMap attrs;
auto *listen_and_serv_op_ptr =
static_cast<paddle::operators::ListenAndServOp *>(
listen_and_serv_op.get());
ASSERT_TRUE(listen_and_serv_op_ptr != nullptr);
selected_port = listen_and_serv_op_ptr->GetSelectedPort();
std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port);
attrs.insert({"endpoints", std::vector<std::string>({endpoint})});
attrs.insert({"epmap", std::vector<std::string>({endpoint})});
const f::VariableNameMap &inputs = {{"X", {"x1"}}};
const f::VariableNameMap &outputs = {{"Out", {"Out"}}};
auto send_op = f::OpRegistry::CreateOp("send", inputs, outputs, attrs);
send_op->Run(scope, place);
auto in_var = scope.Var("x1");
auto tensor = in_var->GetMutable<f::LoDTensor>();
float *expected = tensor->data<float>();
auto out_var = scope.Var("Out");
auto target = out_var->GetMutable<f::LoDTensor>();
// x1 * 2 == x0
EXPECT_NE(target->memory_size(), size_t(0));
float *actual = target->data<float>();
for (int64_t i = 0; i < target->numel(); ++i) {
EXPECT_EQ(expected[i] * 2, actual[i]);
}
listen_and_serv_op->Stop();
server_thread.join();
listen_and_serv_op.reset(nullptr);
paddle::operators::ListenAndServOp::ResetPort();
}
TEST(SendRecvOp, CPUSparse) {
std::atomic<bool> initialized;
initialized = false;
std::thread server_thread(StartServerNet, true, &initialized);
while (!initialized) {
}
auto *listen_and_serv_op_ptr =
static_cast<paddle::operators::ListenAndServOp *>(
listen_and_serv_op.get());
ASSERT_TRUE(listen_and_serv_op_ptr != nullptr);
listen_and_serv_op_ptr->WaitServerReady();
// local net
f::Scope scope;
p::CPUPlace place;
p::CPUDeviceContext ctx(place);
InitSelectedRowsInScope(place, &scope);
scope.Var("RPC_CLIENT_VAR");
f::AttributeMap attrs;
selected_port = listen_and_serv_op_ptr->GetSelectedPort();
std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port);
attrs.insert({"endpoints", std::vector<std::string>({endpoint})});
attrs.insert({"epmap", std::vector<std::string>({endpoint})});
auto send_op = f::OpRegistry::CreateOp("send", {{"X", {"x1"}}},
{{"Out", {"Out"}}}, attrs);
send_op->Run(scope, place);
auto x0 = scope.Var("x0")->GetMutable<f::SelectedRows>();
auto x1 = scope.Var("x1")->GetMutable<f::SelectedRows>();
auto out = scope.Var("Out")->GetMutable<f::SelectedRows>();
auto actual = out->mutable_value();
std::unique_ptr<f::SelectedRows> expect{new f::SelectedRows()};
auto expect_value = expect->mutable_value();
expect_value->mutable_data<float>(f::make_ddim({5, 10}), place);
m::SelectedRowsAdd<p::CPUDeviceContext, float> add_functor;
add_functor(ctx, *x0, *x1, expect.get());
EXPECT_EQ(actual->numel(), expect_value->numel());
EXPECT_EQ(out->rows().size(), x0->rows().size() + x1->rows().size());
for (int64_t i = 0; i < expect_value->numel(); ++i) {
EXPECT_EQ(expect_value->mutable_data<float>(place)[i],
actual->mutable_data<float>(place)[i]);
}
listen_and_serv_op->Stop();
server_thread.join();
listen_and_serv_op.reset();
paddle::operators::ListenAndServOp::ResetPort();
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
namespace paddle {
namespace operators {
inline bool NeedSend(const framework::Scope& scope,
const std::string& varname) {
// dummy variable is only used in parallel executor to represent
// some dependency relationship, we don't need to send/recv it.
// TODO(paddle-dev): Why would parallel executor logic leaked into here?
if (varname.find(framework::ir::Node::kControlDepVarName) !=
std::string::npos)
return false;
auto* var = scope.FindVar(varname);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound(
"Can not find variable '%s' in the send side.", varname));
if (var->IsType<framework::LoDTensor>()) {
return var->Get<framework::LoDTensor>().IsInitialized();
} else if (var->IsType<framework::SelectedRows>()) {
return var->Get<framework::SelectedRows>().rows().size() > 0UL;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Variable type in send side should be LodTensor or SelectedRows."));
}
return false;
}
inline std::vector<int64_t> ToAbsoluteSection(
const std::vector<int64_t>& height_sections) {
std::vector<int64_t> abs_sections;
abs_sections.resize(height_sections.size());
abs_sections[0] = 0;
for (size_t i = 1; i < height_sections.size(); ++i) {
abs_sections[i] = height_sections[i - 1] + abs_sections[i - 1];
}
return abs_sections;
}
inline size_t GetSectionIndex(int64_t id,
const std::vector<int64_t>& abs_sections) {
for (size_t i = 1; i < abs_sections.size(); ++i) {
if (id < abs_sections[i]) {
return i - 1;
}
}
return abs_sections.size() - 1;
}
} // namespace operators
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/version.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
struct DeserializedDataFunctor {
DeserializedDataFunctor(void **buf, Tensor *tensor,
const platform::Place &place)
: buf_(buf), tensor_(tensor), place_(place) {}
template <typename T>
void apply() {
*buf_ = tensor_->mutable_data<T>(place_);
}
void **buf_;
Tensor *tensor_;
platform::Place place_;
};
template <typename DeviceContext, typename T>
class SparseTensorLoadKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace();
auto filename = ctx.Attr<std::string>("file_path");
std::ifstream fin(filename, std::ios::binary);
PADDLE_ENFORCE_EQ(static_cast<bool>(fin), true,
platform::errors::Unavailable(
"Load operator fail to open file %s, please check "
"whether the model file is complete or damaged.",
filename));
auto name = ctx.OutputNames("Out")[0];
VLOG(4) << "Sparse Load Var name: " << name;
auto *out_var = ctx.OutputVar("Out");
PADDLE_ENFORCE_NOT_NULL(
out_var, platform::errors::InvalidArgument(
"The variable %s to be loaded cannot be found.", name));
PADDLE_ENFORCE_EQ(out_var->IsType<paddle::framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"SparseLoad OP only support LoDTensor"));
LoadLodTensor(fin, place, out_var, ctx);
}
void LoadLodTensor(std::istream &is, const platform::Place &place,
paddle::framework::Variable *var,
const paddle::framework::ExecutionContext &ctx) const {
auto *tensor = var->GetMutable<paddle::framework::LoDTensor>();
auto node_index = ctx.Attr<int64_t>("node_index");
auto node_num = ctx.Attr<int64_t>("node_num");
auto shape = ctx.Attr<std::vector<int64_t>>("shape");
VLOG(4) << "Sparse LoadLodTensor node_num" << node_num;
VLOG(4) << "Sparse LoadLodTensor node_index" << node_index;
VLOG(4) << "Sparse LoadLodTensor shape[0]" << shape[0];
PADDLE_ENFORCE_GE(node_index, 0, platform::errors::InvalidArgument(
"node_num great than or equal to 0"));
PADDLE_ENFORCE_GE(node_num, 1, platform::errors::InvalidArgument(
"node_num great than or equal to 1"));
{
// the 1st field, unit32_t version for LoDTensor
uint32_t version;
is.read(reinterpret_cast<char *>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(paddle::framework::IsTensorVersionSupported(version),
true,
platform::errors::InvalidArgument(
"Tensor version %u is not supported.", version));
PADDLE_ENFORCE_EQ(version, 0U, platform::errors::InvalidArgument(
"Tensor version %u is not supported, "
"only version 0 is supported.",
version));
}
{
// the 2st field, LoD information
// Todo sparse load need change LoDTensor's lod level
uint64_t lod_level;
is.read(reinterpret_cast<char *>(&lod_level), sizeof(lod_level));
auto &lod = *tensor->mutable_lod();
lod.resize(lod_level);
}
// the 3st filed, Tensor
uint32_t version;
is.read(reinterpret_cast<char *>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(
version, 0U,
platform::errors::InvalidArgument(
"tensor version %u is not supported, Only version 0 is supported",
version));
paddle::framework::proto::VarType::TensorDesc desc;
{ // int32_t size
// proto buffer
int32_t size;
is.read(reinterpret_cast<char *>(&size), sizeof(size));
std::unique_ptr<char[]> buf(new char[size]);
is.read(reinterpret_cast<char *>(buf.get()), size);
PADDLE_ENFORCE_EQ(
desc.ParseFromArray(buf.get(), size), true,
platform::errors::InvalidArgument("Cannot parse tensor desc"));
}
{ // read tensor
std::vector<int64_t> dims;
dims.reserve(static_cast<size_t>(desc.dims().size()));
std::copy(desc.dims().begin(), desc.dims().end(),
std::back_inserter(dims));
int64_t line_numel = 1;
for (size_t dim = 1; dim < dims.size(); dim++) {
line_numel *= dims[dim];
}
auto total_line = dims[0];
tensor->Resize(paddle::framework::make_ddim(shape));
void *buf;
auto ctx = platform::CPUDeviceContext();
paddle::framework::VisitDataType(
desc.data_type(),
DeserializedDataFunctor(&buf, tensor, ctx.GetPlace()));
auto line_size =
line_numel * paddle::framework::SizeOfType(desc.data_type());
char *cur_buf = static_cast<char *>(buf);
char *temp_row = new char[line_size];
VLOG(4) << "TensorFromStream: line_size " << line_size;
VLOG(4) << "TensorFromStream: total_line " << total_line;
for (size_t line_index = 0; line_index < static_cast<size_t>(total_line);
++line_index) {
is.read(temp_row, line_size);
if (static_cast<int64_t>(line_index) % node_num == node_index) {
memcpy(cur_buf, temp_row, line_size);
cur_buf += line_size;
}
}
}
}
};
class SparseTensorLoadOp : public paddle::framework::OperatorWithKernel {
public:
using paddle::framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(paddle::framework::InferShapeContext *ctx) const override {}
protected:
paddle::framework::OpKernelType GetExpectedKernelType(
const paddle::framework::ExecutionContext &ctx) const override {
paddle::framework::OpKernelType kt = paddle::framework::OpKernelType(
paddle::framework::proto::VarType::FP32, ctx.GetPlace());
return kt;
}
};
class SparseTensorLoadOpMaker
: public paddle::framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddOutput("Out", "The LoDTensor / SelectedRows need to be loaded");
AddAttr<std::string>("file_path",
R"(Variable will be loaded from "file_path")")
.AddCustomChecker(
[](const std::string &path) { return !path.empty(); });
AddAttr<int64_t>("node_index", "role id from 0 ~ node_num.").SetDefault(0);
AddAttr<int64_t>("node_num", "role nums which need load current varibale.")
.SetDefault(0);
AddAttr<std::vector<int64_t>>("shape",
"(vector<int64_t>) The shape of the output")
.SetDefault({});
AddComment(R"DOC(
SparseTensorLoad OP, Load sprase tensor on parameter server
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(sparse_tensor_load, ops::SparseTensorLoadOp,
ops::SparseTensorLoadOpMaker);
REGISTER_OP_CPU_KERNEL(
sparse_tensor_load,
ops::SparseTensorLoadKernel<paddle::platform::CPUDeviceContext, float>)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/distributed_ops/split_byref_op.h"
#include "paddle/fluid/operators/split_op.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class SplitByrefOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "Ids", "SplitByrefOp");
OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out", "SplitByrefOp");
auto in_dims = ctx->GetInputDim("X");
auto outs_names = ctx->Outputs("Out");
size_t num = static_cast<size_t>(ctx->Attrs().Get<int>("num"));
auto sections = ctx->Attrs().Get<std::vector<int>>("sections");
const size_t outs_number = outs_names.size();
std::vector<framework::DDim> outs_dims;
outs_dims.reserve(outs_number);
if (num > 0) {
int64_t in_axis_dim = 0;
if (ctx->IsRuntime()) {
in_axis_dim = in_dims[0];
}
PADDLE_ENFORCE_EQ(in_axis_dim % num, 0, platform::errors::InvalidArgument(
"tensor split does not result"
" in an equal division"));
size_t out_axis_dim = in_axis_dim / num;
for (size_t i = 0; i < outs_number; ++i) {
auto dim = in_dims;
dim[0] = out_axis_dim;
outs_dims.push_back(dim);
}
} else if (sections.size() > 0) {
PADDLE_ENFORCE_EQ(
sections.size(), outs_number,
platform::errors::InvalidArgument("tensor split sections size"
"should be equal to output size"));
for (size_t i = 0; i < outs_number; ++i) {
auto dim = in_dims;
dim[0] = sections[i];
outs_dims.push_back(dim);
}
}
ctx->SetOutputsDim("Out", outs_dims);
}
};
class SplitByrefOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) Input tensor of the split operator.");
AddOutput("Out", "(Tensor) Output tensors of the split operator.")
.AsDuplicable();
AddComment(R"DOC(
SplitByref operator
Split source tensor to sevaral tensors by axis 0. No copy in this operator
is performed, output tensor shares the same blocks of memory.
)DOC");
AddAttr<std::vector<int>>("sections",
"(vector<int>) "
"the length of each output along the "
"specified axis.")
.SetDefault(std::vector<int>{});
AddAttr<int>("num",
"(int, default 0)"
"Number of sub-tensors. This must evenly divide "
"Input.dims()[axis]")
.SetDefault(0);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
// NOTE: concat op default axis must be 0!
USE_CPU_ONLY_OP(concat);
REGISTER_OPERATOR(split_byref, ops::SplitByrefOp, ops::SplitByrefOpMaker,
ops::SplitGradMaker<paddle::framework::OpDesc>,
ops::SplitGradMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
split_byref, ops::SplitByrefOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/distributed_ops/split_byref_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
split_byref,
ops::SplitByrefOpKernel<paddle::platform::CUDADeviceContext, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class SplitByrefOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
auto place = ctx.GetPlace();
size_t row_offset = 0;
for (size_t i = 0; i < outs.size(); ++i) {
// NOTE: no need to call mutable_data here to allocate memory.
auto* out = outs[i];
VLOG(3) << "spliting by ref: " << row_offset << " " << out->dims()[0];
*out = in->Slice(row_offset, row_offset + out->dims()[0]);
row_offset += out->dims()[0];
}
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/distributed_ops/split_ids_op.h"
#include <memory>
namespace paddle {
namespace operators {
class SplitIdsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}")
.AsDuplicable();
AddOutput("Out", "(LoDTensors) The outputs of the input Ids.")
.AsDuplicable();
AddComment(R"DOC(
Split a LoDTensor of Ids into multi LoDTensors, the number is pserver's number
Example:
Input:
X = [[1,2,3,4,5,6],[2,3]]
Out(3 output):
if compress is True:
out0 = [3, 3, 6]
out1 = [1, 4]
out2 = [2, 2, 5]
else:
out0 = [3, 6]
out1 = [1, 4]
out2 = [2, 5]
)DOC");
}
};
class SplitIdsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs("Ids"), "Input", "Ids", "SplitIdsOp");
OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out", "SplitIdsOp");
auto ids_var_type = ctx->GetInputsVarType("Ids").front();
auto ids_dims = ctx->GetInputsDim("Ids");
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(
ids_dims[0].size(), 2,
platform::errors::InvalidArgument(
"ShapeError: The dimensions of the 'split_ids' must be 2. "
"But received split_ids's dimensions = %d, "
"split_ids's shape = [%s].",
ids_dims[0].size(), ids_dims[0]));
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Ids"), ctx.GetPlace());
}
};
class SplitIdsOpInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto input_type = ctx->GetInputType("Ids");
ctx->SetOutputType("Out", input_type, framework::ALL_ELEMENTS);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(split_ids, ops::SplitIdsOp, ops::SplitIdsOpMaker,
ops::SplitIdsOpInferVarType);
REGISTER_OP_CPU_KERNEL(
split_ids, ops::SplitIdsOpKernel<paddle::platform::CPUPlace, int64_t>,
ops::SplitIdsOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iterator>
#include <set>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class SplitIdsOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace();
if (!platform::is_cpu_place(place)) {
PADDLE_THROW(platform::errors::Unimplemented(
"SplitIds do not support GPU kernel"));
}
const auto ids_vars = ctx.MultiInputVar("Ids");
PADDLE_ENFORCE_GT(
ids_vars.size(), 0,
platform::errors::InvalidArgument(
ids_vars.size(), 0, "The number of Ids expected > 0, but got %d",
ids_vars.size()));
auto *ids_var = ids_vars[0];
if (ids_var->IsType<framework::LoDTensor>()) {
int batch_size = 0;
const auto ids_tensors = ctx.MultiInput<framework::LoDTensor>("Ids");
for (size_t i = 0; i < ids_tensors.size(); ++i) {
batch_size += ids_tensors[i]->dims()[0];
}
VLOG(4) << "Get Total BatchSize is: " << batch_size;
std::vector<T> all_ids(batch_size);
int offset = 0;
for (size_t i = 0; i < ids_tensors.size(); ++i) {
const auto *ids = ids_tensors[i];
std::memcpy(all_ids.data() + offset, ids->data<T>(),
ids->numel() * sizeof(T));
offset += ids->numel();
}
std::set<T> st(all_ids.begin(), all_ids.end());
all_ids.assign(st.begin(), st.end());
auto outs = ctx.MultiOutput<framework::LoDTensor>("Out");
const size_t shard_num = outs.size();
std::vector<std::vector<T>> out_ids;
out_ids.resize(outs.size());
// split id by their shard_num.
for (size_t i = 0; i < all_ids.size(); ++i) {
T id = all_ids[i];
size_t shard_id = static_cast<size_t>(id) % shard_num;
out_ids[shard_id].push_back(id);
}
// create tensor for each shard and send to parameter server
for (size_t i = 0; i < out_ids.size(); ++i) {
auto *shard_t = outs[i];
std::vector<T> ids = out_ids[i];
auto *shard_data = shard_t->mutable_data<T>(
framework::make_ddim({static_cast<int64_t>(ids.size()), 1}), place);
for (size_t i = 0; i < ids.size(); ++i) {
shard_data[i] = ids[i];
}
}
} else if (ids_var->IsType<framework::SelectedRows>()) {
const auto *ids_selected_rows = ctx.Input<framework::SelectedRows>("Ids");
auto &ids_dims = ids_selected_rows->value().dims();
const T *ids_data = ids_selected_rows->value().data<T>();
const auto &ids_rows = ids_selected_rows->rows();
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
const size_t shard_num = outs.size();
for (auto &out : outs) {
out->mutable_rows()->clear();
}
// get rows for outputs
std::unordered_map<int64_t, size_t> id_to_index;
for (size_t i = 0; i < ids_rows.size(); ++i) {
id_to_index[ids_rows[i]] = i;
size_t shard_id = static_cast<size_t>(ids_rows[i]) % shard_num;
outs[shard_id]->mutable_rows()->push_back(ids_rows[i]);
}
int64_t row_width = ids_dims[1];
for (auto &out : outs) {
out->set_height(ids_selected_rows->height());
framework::DDim ddim = framework::make_ddim(
{static_cast<int64_t>(out->rows().size()), row_width});
T *output = out->mutable_value()->mutable_data<T>(ddim, place);
for (int64_t i = 0; i < ddim[0]; ++i) {
memcpy(output + i * row_width,
ids_data + id_to_index[out->rows()[i]] * row_width,
row_width * sizeof(T));
}
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"% should be LoDTensor or SelectedRows, but the received type is %s",
ctx.InputNames("Ids")[0], framework::ToTypeName(ids_var->Type())));
}
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <string>
#include <thread> // NOLINT
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/nccl_helper.h"
#include "paddle/fluid/string/printf.h"
#ifdef PADDLE_WITH_GRPC
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#endif
USE_NO_KERNEL_OP(listen_and_serv);
namespace f = paddle::framework;
namespace p = paddle::platform;
namespace m = paddle::operators::math;
namespace distributed = paddle::operators::distributed;
namespace string = paddle::string;
std::unique_ptr<distributed::RPCServer> g_rpc_service;
std::unique_ptr<distributed::RequestHandler> g_req_handler;
void StartServer() {
f::Scope scope;
p::CPUPlace place;
scope.Var(NCCL_ID_VARNAME);
p::DeviceContextPool& pool = p::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(p::CPUPlace());
f::ProgramDesc empty_program;
f::Executor executor(dev_ctx.GetPlace());
g_req_handler->SetScope(&scope);
g_req_handler->SetDevCtx(&dev_ctx);
g_req_handler->SetProgram(&empty_program);
g_req_handler->SetExecutor(&executor);
g_rpc_service->RegisterRPC(distributed::kRequestSend, g_req_handler.get());
g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread(
std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get()));
g_rpc_service->SetCond(distributed::kRequestSend);
g_rpc_service->WaitBarrier(distributed::kRequestSend);
LOG(INFO) << "got nccl id and stop server...";
g_rpc_service->ShutDown();
server_thread.join();
}
TEST(SendNcclId, RPCServer) {
g_req_handler.reset(
new distributed::RequestSendHandler(distributed::DistributedMode::kSync));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
std::thread server_thread(StartServer);
g_rpc_service->WaitServerReady();
f::Scope scope;
p::CPUPlace place;
p::DeviceContextPool& pool = p::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(p::CPUPlace());
auto var = scope.Var(NCCL_ID_VARNAME);
auto id = var->GetMutable<ncclUniqueId>();
p::dynload::ncclGetUniqueId(id);
int port = g_rpc_service->GetSelectedPort();
std::string ep = string::Sprintf("127.0.0.1:%d", port);
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
LOG(INFO) << "connect to server" << ep;
client->AsyncSendVar(ep, dev_ctx, scope, NCCL_ID_VARNAME);
client->Wait();
client->AsyncSendBatchBarrier(ep);
client->Wait();
server_thread.join();
g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr);
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/split_selected_rows_op.h"
#include <memory>
namespace paddle {
namespace operators {
class SplitSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input SelectedRows.");
AddOutput("Out", "The outputs of the input SelectedRows.").AsDuplicable();
AddAttr<std::vector<int64_t>>("height_sections",
"Height for each output SelectedRows.")
.SetDefault(std::vector<int64_t>({}));
AddComment(R"DOC(
Split a SelectedRows with a specified rows section.
height_sections is only needed when need to split the dims of the original tensor.
Example:
Input:
X.rows = {7, 5}
X.height = 12
Attr:
height_sections = {4, 8}
Out:
out0.rows = {}
out0.height = 4
out1.rows = {5, 7}
out2.height = 8
)DOC");
}
};
class SplitSelectedRowsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"SplitSelectedRowsOp must have input X."));
PADDLE_ENFORCE_EQ(ctx->HasOutputs("Out"), true,
platform::errors::InvalidArgument(
"SplitSelectedRowsOp must have output Out."));
}
};
class SplitSelectedRowsOpInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
ctx->SetOutputType("Out", framework::proto::VarType::SELECTED_ROWS,
framework::ALL_ELEMENTS);
}
};
template <typename T>
class SplitSelectedRowsGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("sum");
grad_op->SetInput("X", this->OutputGrad("Out"));
grad_op->SetOutput("Out", this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(split_selected_rows, ops::SplitSelectedRowsOp,
ops::SplitSelectedRowsOpMaker,
ops::SplitSelectedRowsGradMaker<paddle::framework::OpDesc>,
ops::SplitSelectedRowsGradMaker<paddle::imperative::OpBase>,
ops::SplitSelectedRowsOpInferVarType);
REGISTER_OP_CPU_KERNEL(
split_selected_rows,
ops::SplitSelectedRowsOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/split_selected_rows_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
split_selected_rows,
ops::SplitSelectedRowsOpKernel<paddle::platform::CUDADeviceContext, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class SplitSelectedRowsOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::SelectedRows>("X");
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
auto height_sections = ctx.Attr<std::vector<int64_t>>("height_sections");
auto abs_sections = ToAbsoluteSection(height_sections);
auto& x_rows = x->rows();
auto height = x->height();
std::vector<std::vector<int>> outs_rows_idx;
std::vector<std::vector<int>> outs_dense_idx;
outs_rows_idx.resize(outs.size());
outs_dense_idx.resize(outs.size());
auto row_numel = x->value().numel() / x->value().dims()[0];
auto src = x->value().data<T>();
// split rows index into output sparse vars
for (size_t i = 0; i < x_rows.size(); ++i) {
auto& id = x_rows[i];
PADDLE_ENFORCE_LT(id, height,
platform::errors::OutOfRange(
"Each row_id in x.rows must be less than x.height. "
"But received x.rows[%d] = %d, x.height = %d",
i, id, height));
int out_idx = GetSectionIndex(id, abs_sections);
outs_rows_idx[out_idx].push_back(id);
outs_dense_idx[out_idx].push_back(i);
}
auto place = ctx.GetPlace();
for (size_t i = 0; i < outs_rows_idx.size(); ++i) {
auto rows_idx = outs_rows_idx[i];
outs[i]->set_height(height_sections[i]);
auto dims = x->GetCompleteDims();
dims[0] = rows_idx.size();
outs[i]->mutable_value()->mutable_data<T>(dims, x->place());
outs[i]->mutable_rows()->clear();
if (rows_idx.size() > 0) {
for (auto idx : rows_idx) {
auto id_offset = idx - abs_sections[i];
PADDLE_ENFORCE_LT(
id_offset, height_sections[i],
platform::errors::OutOfRange("Each row_id in out.rows must be "
"less than out.height. But recived "
"out.rows = [%d], out.height = [%d]",
id_offset, height_sections[i]));
outs[i]->mutable_rows()->push_back(id_offset);
}
auto dst = outs[i]->mutable_value()->mutable_data<T>(ctx.GetPlace());
for (size_t j = 0; j < rows_idx.size(); j++) {
if (platform::is_cpu_place(place)) {
memory::Copy(
platform::CPUPlace(), dst + j * row_numel, platform::CPUPlace(),
src + outs_dense_idx[i][j] * row_numel, sizeof(T) * row_numel);
} else {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto stream = ctx.cuda_device_context().stream();
memory::Copy(platform::CUDAPlace(), dst + j * row_numel,
platform::CUDAPlace(),
src + outs_dense_idx[i][j] * row_numel,
sizeof(T) * row_numel, stream);
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CUDA. Cannot visit cuda device"));
#endif
}
}
}
PADDLE_ENFORCE_EQ(rows_idx.size(), outs[i]->rows().size(),
platform::errors::InvalidArgument(
"rows should has the same size with tensor dim 0. "
"But received rows = %d, tensor's dim[0] = %d.",
rows_idx.size(), outs[i]->rows().size()));
}
}
};
} // namespace operators
} // namespace paddle
......@@ -28,7 +28,6 @@ list(APPEND MIXED_DIST_TEST_OPS test_dgc_momentum_op)
list(APPEND MIXED_DIST_TEST_OPS test_dgc_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_simple_dist_transpiler)
list(APPEND MIXED_DIST_TEST_OPS test_recv_save_op)
list(APPEND MIXED_DIST_TEST_OPS test_transpiler_ops)
list(APPEND MIXED_DIST_TEST_OPS test_c_comm_init_op)
list(APPEND MIXED_DIST_TEST_OPS test_communicator_async)
list(APPEND MIXED_DIST_TEST_OPS test_communicator_geo)
......@@ -483,7 +482,6 @@ if(WITH_DISTRIBUTE)
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_fleet_gloo")
py_test_modules(test_recv_save_op MODULES test_recv_save_op ENVS ${dist_ENVS})
py_test_modules(test_transpiler_ops MODULES test_transpiler_ops ENVS ${dist_ENVS})
py_test_modules(test_communicator_async MODULES test_communicator_async ENVS ${dist_ENVS})
py_test_modules(test_communicator_geo MODULES test_communicator_geo ENVS ${dist_ENVS})
py_test_modules(test_communicator_half_async MODULES test_communicator_half_async ENVS ${dist_ENVS} FLAGS_communicator_send_queue_size=1 FLAGS_communicator_max_merge_var_num=1)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid.core as core
import numpy as np
from paddle.fluid.op import Operator
class TestSpliteSelectedRows(unittest.TestCase):
def get_places(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
return places
def test_check_output(self):
for place in self.get_places():
self.check_with_place(place)
def test_check_grad(self):
for place in self.get_places():
self.check_grad_with_place(place)
def check_with_place(self, place):
scope = core.Scope()
rows = [0, 5, 7, 4, 20]
height = 21
row_numel = 2
# initialize input variable X
x = scope.var('X').get_selected_rows()
x.set_rows(rows)
x.set_height(height)
np_array = np.ones((len(rows), row_numel)).astype("float32")
np_array[0, 0] = 2.0
np_array[2, 1] = 4.0
np_array[4, 1] = 8.0
x_tensor = x.get_tensor()
x_tensor.set(np_array, place)
height_sections = [5, 5, 5, 5, 3]
# initialize output variables [out0, out1]
outs_name = ["out%d" % i for i in range(len(height_sections))]
outs = [
scope.var(var_name).get_selected_rows() for var_name in outs_name
]
# expected output selected rows
expected_out0_rows = [0, 4]
expected_out1_rows = [0, 2]
expected_out2_rows = []
expected_out4_rows = [0]
op = Operator(
"split_selected_rows",
X="X",
Out=outs_name,
height_sections=height_sections)
op.run(scope, place)
self.assertEqual(outs[0].rows(), expected_out0_rows)
self.assertEqual(outs[1].rows(), expected_out1_rows)
self.assertEqual(outs[2].rows(), expected_out2_rows)
self.assertEqual(outs[4].rows(), expected_out4_rows)
self.assertEqual(outs[0].height(), height_sections[0])
self.assertEqual(outs[4].height(), height_sections[4])
self.assertAlmostEqual(2.0, np.array(outs[0].get_tensor())[0, 0])
self.assertAlmostEqual(4.0, np.array(outs[1].get_tensor())[1, 1])
self.assertAlmostEqual(8.0, np.array(outs[4].get_tensor())[0, 1])
self.assertEqual(outs[2].numel(), 0)
self.assertEqual(outs[3].numel(), 0)
def check_grad_with_place(self, place):
scope = core.Scope()
height = 10
row_numel = 2
# attr
height_sections = [5, 5]
# initialize input variable X
out0_grad = scope.var("out0@GRAD").get_selected_rows()
rows0 = [0, 5]
out0_grad.set_rows(rows0)
out0_grad.set_height(height)
out0_grad_tensor = out0_grad.get_tensor()
np_array = np.ones((len(rows0), row_numel)).astype("float32")
out0_grad_tensor.set(np_array, place)
out1_grad = scope.var("out1@GRAD").get_selected_rows()
rows1 = [2, 0]
out1_grad.set_rows(rows1)
out1_grad.set_height(height)
out1_grad_tensor = out1_grad.get_tensor()
np_array = np.ones((len(rows1), row_numel)).astype("float32")
out1_grad_tensor.set(np_array, place)
x_grad = scope.var("X@GRAD").get_selected_rows()
grad_op = Operator(
"sum",
X=["out0@GRAD", "out1@GRAD"],
Out="X@GRAD",
height_sections=height_sections)
grad_op.run(scope, place)
merged_rows = set(rows0 + rows1)
self.assertEqual(set(x_grad.rows()), set(rows0 + rows1))
self.assertEqual(x_grad.height(), height)
print(np.array(x_grad.get_tensor()))
self.assertAlmostEqual(2.0, np.array(x_grad.get_tensor())[0, 0])
self.assertAlmostEqual(1.0, np.array(x_grad.get_tensor())[2, 1])
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import traceback
import math
import collections
import six
import unittest
import numpy as np
import gc
gc.set_debug(gc.DEBUG_COLLECTABLE)
import paddle.fluid as fluid
from test_dist_transpiler import TranspilerTest
class TestFakeInit(TranspilerTest):
def net_conf(self):
dict_size, embedding_size, neg_num = 10000, 8, 5
input_word = fluid.layers.data(
name="input_word", shape=[1], dtype='int64', lod_level=1)
true_word = fluid.layers.data(
name='true_label', shape=[1], dtype='int64', lod_level=1)
neg_word = fluid.layers.data(
name="neg_label", shape=[1], dtype='int64', lod_level=1)
inputs = [input_word, true_word, neg_word]
init_width = 0.5 / embedding_size
input_emb = fluid.layers.embedding(
input=inputs[0],
is_sparse=True,
size=[dict_size, embedding_size],
param_attr=fluid.ParamAttr(
name='emb',
initializer=fluid.initializer.Uniform(-init_width, init_width)))
true_emb_w = fluid.layers.embedding(
input=inputs[1],
is_sparse=True,
size=[dict_size, embedding_size],
param_attr=fluid.ParamAttr(
name='emb_w',
initializer=fluid.initializer.Constant(value=0.0)))
true_emb_b = fluid.layers.embedding(
input=inputs[1],
is_sparse=True,
size=[dict_size, 1],
param_attr=fluid.ParamAttr(
name='emb_b',
initializer=fluid.initializer.Constant(value=0.0)))
neg_word_reshape = fluid.layers.reshape(inputs[2], shape=[-1, 1])
neg_word_reshape.stop_gradient = True
neg_emb_w = fluid.layers.embedding(
input=neg_word_reshape,
is_sparse=True,
size=[dict_size, embedding_size],
param_attr=fluid.ParamAttr(
name='emb_w', learning_rate=1.0))
neg_emb_w_re = fluid.layers.reshape(
neg_emb_w, shape=[-1, neg_num, embedding_size])
neg_emb_b = fluid.layers.embedding(
input=neg_word_reshape,
is_sparse=True,
size=[dict_size, 1],
param_attr=fluid.ParamAttr(
name='emb_b', learning_rate=1.0))
neg_emb_b_vec = fluid.layers.reshape(neg_emb_b, shape=[-1, neg_num])
true_logits = fluid.layers.elementwise_add(
fluid.layers.reduce_sum(
fluid.layers.elementwise_mul(input_emb, true_emb_w),
dim=1,
keep_dim=True),
true_emb_b)
input_emb_re = fluid.layers.reshape(
input_emb, shape=[-1, 1, embedding_size])
neg_matmul = fluid.layers.matmul(
input_emb_re, neg_emb_w_re, transpose_y=True)
neg_matmul_re = fluid.layers.reshape(neg_matmul, shape=[-1, neg_num])
neg_logits = fluid.layers.elementwise_add(neg_matmul_re, neg_emb_b_vec)
# nce loss
label_ones = fluid.layers.fill_constant_batch_size_like(
true_logits, shape=[-1, 1], value=1.0, dtype='float32')
label_zeros = fluid.layers.fill_constant_batch_size_like(
true_logits, shape=[-1, neg_num], value=0.0, dtype='float32')
true_xent = fluid.layers.sigmoid_cross_entropy_with_logits(true_logits,
label_ones)
neg_xent = fluid.layers.sigmoid_cross_entropy_with_logits(neg_logits,
label_zeros)
cost = fluid.layers.elementwise_add(
fluid.layers.reduce_sum(
true_xent, dim=1),
fluid.layers.reduce_sum(
neg_xent, dim=1))
avg_cost = fluid.layers.reduce_mean(cost)
sgd_optimizer = fluid.optimizer.SGD(
learning_rate=fluid.layers.exponential_decay(
learning_rate=1.0,
decay_steps=2100,
decay_rate=0.1,
staircase=True))
sgd_optimizer.minimize(avg_cost)
def transpiler_test_impl(self):
trainer, startup = self.get_trainer()
fake_init_ops = []
for op in startup.global_block().ops:
if op.type == "fake_init":
fake_init_ops.append(op)
self.assertEqual(len(fake_init_ops), 3)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册