未验证 提交 2ef4ec71 编写于 作者: H huangjiyi 提交者: GitHub

Add paddle custom flags support (#56256)

* update

* repalce gflags header

* replace DEFINE_<type> with PD_DEFINE_<type>

* fix bug

* fix bug

* fix bug

* update cmake

* add :: before some paddle namespace

* fix link error

* fix CI-Py3

* allow commandline parse

* fix SetFlagsFromEnv

* fix bug

* fix bug

* fix CI-CINN

* fix CI-Coverage-build

* fix CI-Windows-build

* fix CI-Inference

* fix bug

* fix bug

* fix CI-CINN

* fix inference api test

* fix infer_ut test

* revert infer_ut gflags usage

* update

* fix inference

* remove flags export macro

* revert inference demo_ci gflags usage

* update

* update

* update

* update

* update

* update

* update

* update

* fix bug when turn on WITH_GFLAGS

* turn on WITH_GFLAGS

* fix bug when turn on WITH_GFLAGS

* fix bug when turn on WITH_GFLAGS

* update

* update and add unittest

* add unittest

* fix conflict

* rerun ci

* update

* resolve conflict
上级 1c858591
...@@ -254,6 +254,7 @@ option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF) ...@@ -254,6 +254,7 @@ option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF)
option(ON_INFER "Turn on inference optimization and inference-lib generation" option(ON_INFER "Turn on inference optimization and inference-lib generation"
ON) ON)
option(WITH_CPP_DIST "Install PaddlePaddle C++ distribution" OFF) option(WITH_CPP_DIST "Install PaddlePaddle C++ distribution" OFF)
option(WITH_GFLAGS "Compile PaddlePaddle with gflags support" OFF)
################################ Internal Configurations ####################################### ################################ Internal Configurations #######################################
option(WITH_NV_JETSON "Compile PaddlePaddle with NV JETSON" OFF) option(WITH_NV_JETSON "Compile PaddlePaddle with NV JETSON" OFF)
option(WITH_PROFILER "Compile PaddlePaddle with GPU profiler and gperftools" option(WITH_PROFILER "Compile PaddlePaddle with GPU profiler and gperftools"
......
...@@ -183,6 +183,11 @@ if(WITH_MKL) ...@@ -183,6 +183,11 @@ if(WITH_MKL)
endif() endif()
endif() endif()
if(NOT WITH_GFLAGS)
target_link_libraries(cinnapi gflags)
add_dependencies(cinnapi gflags)
endif()
if(WITH_GPU) if(WITH_GPU)
target_link_libraries( target_link_libraries(
cinnapi cinnapi
...@@ -237,6 +242,11 @@ function(gen_cinncore LINKTYPE) ...@@ -237,6 +242,11 @@ function(gen_cinncore LINKTYPE)
endif() endif()
endif() endif()
if(NOT WITH_GFLAGS)
target_link_libraries(${CINNCORE_TARGET} gflags)
add_dependencies(${CINNCORE_TARGET} gflags)
endif()
if(WITH_GPU) if(WITH_GPU)
target_link_libraries( target_link_libraries(
${CINNCORE_TARGET} ${CINNCORE_TARGET}
......
...@@ -201,6 +201,10 @@ if(WITH_DISTRIBUTE) ...@@ -201,6 +201,10 @@ if(WITH_DISTRIBUTE)
add_definitions(-DPADDLE_WITH_DISTRIBUTE) add_definitions(-DPADDLE_WITH_DISTRIBUTE)
endif() endif()
if(WITH_GFLAGS)
add_definitions(-DPADDLE_WITH_GFLAGS)
endif()
if(WITH_PSCORE) if(WITH_PSCORE)
add_definitions(-DPADDLE_WITH_PSCORE) add_definitions(-DPADDLE_WITH_PSCORE)
endif() endif()
......
...@@ -91,3 +91,16 @@ add_dependencies(brpc extern_brpc) ...@@ -91,3 +91,16 @@ add_dependencies(brpc extern_brpc)
add_definitions(-DBRPC_WITH_GLOG) add_definitions(-DBRPC_WITH_GLOG)
list(APPEND external_project_dependencies brpc) list(APPEND external_project_dependencies brpc)
set(EXTERNAL_BRPC_DEPS
brpc
protobuf
ssl
crypto
leveldb
glog
snappy)
if(NOT WITH_GFLAGS)
set(EXTERNAL_BRPC_DEPS ${EXTERNAL_BRPC_DEPS} gflags)
endif()
...@@ -102,3 +102,14 @@ if(WIN32) ...@@ -102,3 +102,14 @@ if(WIN32)
set_property(GLOBAL PROPERTY OS_DEPENDENCY_MODULES shlwapi.lib) set_property(GLOBAL PROPERTY OS_DEPENDENCY_MODULES shlwapi.lib)
endif() endif()
endif() endif()
# We have implemented a custom flags tool paddle_flags to replace gflags.
# User can also choose to use gflags by setting WITH_GFLAGS=ON. But when
# using paddle_flags, gflags is also needed for other third party libraries
# including glog and brpc. So we can not remove gflags completely.
set(flags_dep)
if(WITH_GFLAGS)
list(APPEND flags_dep gflags)
else()
list(APPEND flags_dep paddle_flags)
endif()
...@@ -336,11 +336,22 @@ copy( ...@@ -336,11 +336,22 @@ copy(
inference_lib_dist inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/utils/flat_hash_map.h SRCS ${PADDLE_SOURCE_DIR}/paddle/utils/flat_hash_map.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/utils/) DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/utils/)
copy(
inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/utils/flags.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/utils/)
copy( copy(
inference_lib_dist inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/extension.h SRCS ${PADDLE_SOURCE_DIR}/paddle/extension.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/) DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/)
if(NOT WITH_GFLAGS)
copy(
inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/utils/flags_native.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/utils/)
endif()
# the include path of phi needs to be changed to adapt to inference api path # the include path of phi needs to be changed to adapt to inference api path
add_custom_command( add_custom_command(
TARGET inference_lib_dist TARGET inference_lib_dist
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
constexpr int64_t kWaitBlockTImeout = 10; constexpr int64_t kWaitBlockTImeout = 10;
DECLARE_bool(use_stream_safe_cuda_allocator); PD_DECLARE_bool(use_stream_safe_cuda_allocator);
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
......
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
#include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/distributed/comm_context_manager.h"
PHI_DECLARE_bool(nccl_blocking_wait); PHI_DECLARE_bool(nccl_blocking_wait);
DECLARE_bool(use_stream_safe_cuda_allocator); PD_DECLARE_bool(use_stream_safe_cuda_allocator);
// set this flag to `true` and recompile to enable dynamic checks // set this flag to `true` and recompile to enable dynamic checks
constexpr bool FLAGS_enable_nccl_dynamic_check = false; constexpr bool FLAGS_enable_nccl_dynamic_check = false;
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include "paddle/phi/backends/device_manager.h" #include "paddle/phi/backends/device_manager.h"
#include "paddle/phi/core/flags.h" #include "paddle/phi/core/flags.h"
DECLARE_bool(use_stream_safe_cuda_allocator); PD_DECLARE_bool(use_stream_safe_cuda_allocator);
PHI_DECLARE_string(allocator_strategy); PHI_DECLARE_string(allocator_strategy);
namespace paddle { namespace paddle {
......
...@@ -7,16 +7,7 @@ proto_library(interceptor_message_proto SRCS interceptor_message.proto) ...@@ -7,16 +7,7 @@ proto_library(interceptor_message_proto SRCS interceptor_message.proto)
if(WITH_ARM_BRPC) if(WITH_ARM_BRPC)
set(BRPC_DEPS arm_brpc snappy phi glog) set(BRPC_DEPS arm_brpc snappy phi glog)
elseif(WITH_DISTRIBUTE AND NOT WITH_PSLIB) elseif(WITH_DISTRIBUTE AND NOT WITH_PSLIB)
set(BRPC_DEPS set(BRPC_DEPS ${EXTERNAL_BRPC_DEPS} zlib phi)
brpc
ssl
crypto
protobuf
zlib
leveldb
snappy
phi
glog)
else() else()
set(BRPC_DEPS "") set(BRPC_DEPS "")
endif() endif()
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/fleet_executor/global.h" #include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...@@ -29,6 +28,7 @@ ...@@ -29,6 +28,7 @@
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/flags.h" #include "paddle/fluid/platform/flags.h"
#include "paddle/utils/flags.h"
PADDLE_DEFINE_EXPORTED_bool( PADDLE_DEFINE_EXPORTED_bool(
fleet_executor_with_standalone, fleet_executor_with_standalone,
false, false,
......
...@@ -3,34 +3,11 @@ set_source_files_properties(${BRPC_SRCS}) ...@@ -3,34 +3,11 @@ set_source_files_properties(${BRPC_SRCS})
if(WITH_HETERPS) if(WITH_HETERPS)
set(BRPC_DEPS set(BRPC_DEPS ${EXTERNAL_BRPC_DEPS} phi zlib device_context rocksdb)
brpc
ssl
crypto
protobuf
phi
glog
zlib
leveldb
snappy
glog
device_context
rocksdb)
else() else()
set(BRPC_DEPS set(BRPC_DEPS ${EXTERNAL_BRPC_DEPS} phi zlib device_context)
brpc
ssl
crypto
protobuf
phi
glog
zlib
leveldb
snappy
glog
device_context)
endif() endif()
......
...@@ -34,49 +34,53 @@ class Variable; ...@@ -34,49 +34,53 @@ class Variable;
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
DEFINE_int32(pserver_push_dense_merge_limit, PD_DEFINE_int32(pserver_push_dense_merge_limit,
12, 12,
"limit max push_dense local merge requests"); "limit max push_dense local merge requests");
DEFINE_int32(pserver_push_sparse_merge_limit, PD_DEFINE_int32(pserver_push_sparse_merge_limit,
12, 12,
"limit max push_sparse local merge requests"); "limit max push_sparse local merge requests");
DEFINE_int32(pserver_pull_dense_limit, PD_DEFINE_int32(pserver_pull_dense_limit,
12, 12,
"limit max push_sparse local merge requests"); "limit max push_sparse local merge requests");
DEFINE_int32(pserver_async_push_dense_interval_ms, PD_DEFINE_int32(pserver_async_push_dense_interval_ms,
10, 10,
"async push_dense to server interval"); "async push_dense to server interval");
DEFINE_int32(pserver_async_push_sparse_interval_ms, PD_DEFINE_int32(pserver_async_push_sparse_interval_ms,
10, 10,
"async push_sparse to server interval"); "async push_sparse to server interval");
DEFINE_bool(pserver_scale_gradient_by_merge, PD_DEFINE_bool(pserver_scale_gradient_by_merge,
false, false,
"scale dense gradient when merged"); "scale dense gradient when merged");
DEFINE_int32(pserver_communicate_compress_type, PD_DEFINE_int32(pserver_communicate_compress_type,
0, 0,
"none:0 snappy:1 gzip:2 zlib:3 lz4:4"); "none:0 snappy:1 gzip:2 zlib:3 lz4:4");
DEFINE_int32(pserver_max_async_call_num, PD_DEFINE_int32(pserver_max_async_call_num,
13, 13,
"max task num in async_call_server"); "max task num in async_call_server");
DEFINE_int32(pserver_timeout_ms, 500000, "pserver request server timeout_ms"); PD_DEFINE_int32(pserver_timeout_ms,
500000,
"pserver request server timeout_ms");
DEFINE_int32(pserver_connect_timeout_ms, PD_DEFINE_int32(pserver_connect_timeout_ms,
10000, 10000,
"pserver connect server timeout_ms"); "pserver connect server timeout_ms");
DEFINE_int32(pserver_sparse_merge_thread, 1, "pserver sparse merge thread num"); PD_DEFINE_int32(pserver_sparse_merge_thread,
1,
"pserver sparse merge thread num");
DEFINE_int32(pserver_sparse_table_shard_num, PD_DEFINE_int32(pserver_sparse_table_shard_num,
1000, 1000,
"sparse table shard for save & load"); "sparse table shard for save & load");
inline size_t get_sparse_shard(uint32_t shard_num, inline size_t get_sparse_shard(uint32_t shard_num,
uint32_t server_num, uint32_t server_num,
...@@ -140,7 +144,7 @@ int32_t BrpcPsClient::StartFlClientService(const std::string &self_endpoint) { ...@@ -140,7 +144,7 @@ int32_t BrpcPsClient::StartFlClientService(const std::string &self_endpoint) {
if (_fl_server.Start(self_endpoint.c_str(), &options) != 0) { if (_fl_server.Start(self_endpoint.c_str(), &options) != 0) {
VLOG(0) << "fl-ps > StartFlClientService failed. Try again."; VLOG(0) << "fl-ps > StartFlClientService failed. Try again.";
auto ip_port = paddle::string::Split(self_endpoint, ':'); auto ip_port = ::paddle::string::Split(self_endpoint, ':');
std::string ip = ip_port[0]; std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]); int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port); std::string int_ip_port = GetIntTypeEndpoint(ip, port);
...@@ -202,8 +206,7 @@ int32_t BrpcPsClient::InitializeFlWorker(const std::string &self_endpoint) { ...@@ -202,8 +206,7 @@ int32_t BrpcPsClient::InitializeFlWorker(const std::string &self_endpoint) {
options.protocol = "baidu_std"; options.protocol = "baidu_std";
options.timeout_ms = FLAGS_pserver_timeout_ms; options.timeout_ms = FLAGS_pserver_timeout_ms;
options.connection_type = "pooled"; options.connection_type = "pooled";
options.connect_timeout_ms = options.connect_timeout_ms = FLAGS_pserver_connect_timeout_ms;
paddle::distributed::FLAGS_pserver_connect_timeout_ms;
options.max_retry = 3; options.max_retry = 3;
// 获取 coordinator 列表,并连接 // 获取 coordinator 列表,并连接
std::string coordinator_ip_port; std::string coordinator_ip_port;
...@@ -336,11 +339,11 @@ int32_t BrpcPsClient::Initialize() { ...@@ -336,11 +339,11 @@ int32_t BrpcPsClient::Initialize() {
auto table_id = worker_param.downpour_table_param(i).table_id(); auto table_id = worker_param.downpour_table_param(i).table_id();
if (type == PS_DENSE_TABLE) { if (type == PS_DENSE_TABLE) {
_push_dense_task_queue_map[table_id] = _push_dense_task_queue_map[table_id] =
paddle::framework::MakeChannel<DenseAsyncTask *>(); ::paddle::framework::MakeChannel<DenseAsyncTask *>();
} }
if (type == PS_SPARSE_TABLE) { if (type == PS_SPARSE_TABLE) {
_push_sparse_task_queue_map[table_id] = _push_sparse_task_queue_map[table_id] =
paddle::framework::MakeChannel<SparseAsyncTask *>(); ::paddle::framework::MakeChannel<SparseAsyncTask *>();
_push_sparse_merge_count_map[table_id] = 0; _push_sparse_merge_count_map[table_id] = 0;
} }
} }
...@@ -446,7 +449,7 @@ std::future<int32_t> BrpcPsClient::PrintTableStat(uint32_t table_id) { ...@@ -446,7 +449,7 @@ std::future<int32_t> BrpcPsClient::PrintTableStat(uint32_t table_id) {
int ret = 0; int ret = 0;
uint64_t feasign_size = 0; uint64_t feasign_size = 0;
uint64_t mf_size = 0; uint64_t mf_size = 0;
paddle::framework::BinaryArchive ar; ::paddle::framework::BinaryArchive ar;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done); auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
for (size_t i = 0; i < request_call_num; ++i) { for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PRINT_TABLE_STAT) != 0) { if (closure->check_response(i, PS_PRINT_TABLE_STAT) != 0) {
......
...@@ -30,15 +30,15 @@ class RpcController; ...@@ -30,15 +30,15 @@ class RpcController;
} // namespace protobuf } // namespace protobuf
} // namespace google } // namespace google
DEFINE_int32(pserver_timeout_ms_s2s, PD_DEFINE_int32(pserver_timeout_ms_s2s,
10000, 10000,
"pserver request server timeout_ms"); "pserver request server timeout_ms");
DEFINE_int32(pserver_connect_timeout_ms_s2s, PD_DEFINE_int32(pserver_connect_timeout_ms_s2s,
10000, 10000,
"pserver connect server timeout_ms"); "pserver connect server timeout_ms");
DEFINE_string(pserver_connection_type_s2s, PD_DEFINE_string(pserver_connection_type_s2s,
"pooled", "pooled",
"pserver connection_type[pooled:single]"); "pserver connection_type[pooled:single]");
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -169,7 +169,7 @@ int32_t BrpcPsServer::ReceiveFromPServer(int msg_type, ...@@ -169,7 +169,7 @@ int32_t BrpcPsServer::ReceiveFromPServer(int msg_type,
LOG(WARNING) << "SERVER>>RESPONSE>>msg = 0 Finish S2S Response"; LOG(WARNING) << "SERVER>>RESPONSE>>msg = 0 Finish S2S Response";
return 0; return 0;
} }
paddle::framework::BinaryArchive ar; ::paddle::framework::BinaryArchive ar;
ar.SetReadBuffer(const_cast<char *>(msg.c_str()), msg.length(), nullptr); ar.SetReadBuffer(const_cast<char *>(msg.c_str()), msg.length(), nullptr);
if (ar.Cursor() == ar.Finish()) { if (ar.Cursor() == ar.Finish()) {
LOG(WARNING) << "SERVER>>RESPONSE ar = 0>> Finish S2S Response"; LOG(WARNING) << "SERVER>>RESPONSE ar = 0>> Finish S2S Response";
...@@ -598,7 +598,7 @@ int32_t BrpcPsService::PrintTableStat(Table *table, ...@@ -598,7 +598,7 @@ int32_t BrpcPsService::PrintTableStat(Table *table,
brpc::Controller *cntl) { brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
std::pair<int64_t, int64_t> ret = table->PrintTableStat(); std::pair<int64_t, int64_t> ret = table->PrintTableStat();
paddle::framework::BinaryArchive ar; ::paddle::framework::BinaryArchive ar;
ar << ret.first << ret.second; ar << ret.first << ret.second;
std::string table_info(ar.Buffer(), ar.Length()); std::string table_info(ar.Buffer(), ar.Length());
response.set_data(table_info); response.set_data(table_info);
...@@ -723,7 +723,7 @@ int32_t BrpcPsService::CacheShuffle(Table *table, ...@@ -723,7 +723,7 @@ int32_t BrpcPsService::CacheShuffle(Table *table,
table->Flush(); table->Flush();
double cache_threshold = std::stod(request.params(2)); double cache_threshold = std::stod(request.params(2));
LOG(INFO) << "cache threshold for cache shuffle: " << cache_threshold; LOG(INFO) << "cache threshold for cache shuffle: " << cache_threshold;
// auto shuffled_ins = paddle::ps::make_channel<std::pair<uint64_t, // auto shuffled_ins = ::paddle::ps::make_channel<std::pair<uint64_t,
// std::string>>(); // std::string>>();
// shuffled_ins->set_block_size(80000); // shuffled_ins->set_block_size(80000);
_server->StartS2S(); _server->StartS2S();
......
...@@ -16,11 +16,11 @@ limitations under the License. */ ...@@ -16,11 +16,11 @@ limitations under the License. */
#include <google/protobuf/text_format.h> #include <google/protobuf/text_format.h>
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h" #include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/wrapper/fleet.h" #include "paddle/fluid/distributed/ps/wrapper/fleet.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#include "paddle/utils/flags.h"
#define LEARNING_RATE_DECAY_COUNTER "@LR_DECAY_COUNTER@" #define LEARNING_RATE_DECAY_COUNTER "@LR_DECAY_COUNTER@"
#define STEP_COUNTER "@PS_STEP_COUNTER@" #define STEP_COUNTER "@PS_STEP_COUNTER@"
...@@ -42,7 +42,7 @@ Communicator::Communicator() = default; ...@@ -42,7 +42,7 @@ Communicator::Communicator() = default;
void Communicator::InitGFlag(const std::string &gflags) { void Communicator::InitGFlag(const std::string &gflags) {
VLOG(3) << "Init With Gflags:" << gflags; VLOG(3) << "Init With Gflags:" << gflags;
std::vector<std::string> flags = paddle::string::split_string(gflags); std::vector<std::string> flags = ::paddle::string::split_string(gflags);
if (flags.empty()) { if (flags.empty()) {
flags.push_back("-max_body_size=314217728"); flags.push_back("-max_body_size=314217728");
flags.push_back("-bthread_concurrency=40"); flags.push_back("-bthread_concurrency=40");
...@@ -57,7 +57,7 @@ void Communicator::InitGFlag(const std::string &gflags) { ...@@ -57,7 +57,7 @@ void Communicator::InitGFlag(const std::string &gflags) {
} }
int params_cnt = flags.size(); int params_cnt = flags.size();
char **params_ptr = &(flags_ptr[0]); char **params_ptr = &(flags_ptr[0]);
::GFLAGS_NAMESPACE::ParseCommandLineFlags(&params_cnt, &params_ptr, true); ::paddle::flags::ParseCommandLineFlags(&params_cnt, &params_ptr);
} }
std::once_flag Communicator::init_flag_; std::once_flag Communicator::init_flag_;
...@@ -66,7 +66,7 @@ std::shared_ptr<Communicator> Communicator::communicator_(nullptr); ...@@ -66,7 +66,7 @@ std::shared_ptr<Communicator> Communicator::communicator_(nullptr);
void Communicator::InitBrpcClient( void Communicator::InitBrpcClient(
const std::string &dist_desc, const std::string &dist_desc,
const std::vector<std::string> &host_sign_list) { const std::vector<std::string> &host_sign_list) {
auto fleet = paddle::distributed::FleetWrapper::GetInstance(); auto fleet = ::paddle::distributed::FleetWrapper::GetInstance();
if (_worker_ptr.get() == nullptr) { if (_worker_ptr.get() == nullptr) {
_worker_ptr = fleet->worker_ptr_; _worker_ptr = fleet->worker_ptr_;
} }
...@@ -92,7 +92,7 @@ void Communicator::RpcRecvDense(const std::vector<std::string> &varnames, ...@@ -92,7 +92,7 @@ void Communicator::RpcRecvDense(const std::vector<std::string> &varnames,
platform::RecordEvent record_event("Communicator->RpcRecvDense", platform::RecordEvent record_event("Communicator->RpcRecvDense",
platform::TracerEventType::Communication, platform::TracerEventType::Communication,
1); 1);
std::vector<paddle::distributed::Region> regions; std::vector<::paddle::distributed::Region> regions;
regions.reserve(varnames.size()); regions.reserve(varnames.size());
for (auto &t : varnames) { for (auto &t : varnames) {
Variable *var = scope->Var(t); Variable *var = scope->Var(t);
...@@ -103,7 +103,7 @@ void Communicator::RpcRecvDense(const std::vector<std::string> &varnames, ...@@ -103,7 +103,7 @@ void Communicator::RpcRecvDense(const std::vector<std::string> &varnames,
phi::DenseTensor *temp_tensor = temp_var->GetMutable<phi::DenseTensor>(); phi::DenseTensor *temp_tensor = temp_var->GetMutable<phi::DenseTensor>();
temp_tensor->Resize(tensor->dims()); temp_tensor->Resize(tensor->dims());
float *temp_data = temp_tensor->mutable_data<float>(platform::CPUPlace()); float *temp_data = temp_tensor->mutable_data<float>(platform::CPUPlace());
paddle::distributed::Region reg(temp_data, tensor->numel()); ::paddle::distributed::Region reg(temp_data, tensor->numel());
regions.emplace_back(std::move(reg)); regions.emplace_back(std::move(reg));
VLOG(1) << "Communicator::RpcRecvDense Var " << t << " table_id " VLOG(1) << "Communicator::RpcRecvDense Var " << t << " table_id "
<< table_id << " Temp_data[0] " << temp_data[0] << table_id << " Temp_data[0] " << temp_data[0]
...@@ -111,7 +111,7 @@ void Communicator::RpcRecvDense(const std::vector<std::string> &varnames, ...@@ -111,7 +111,7 @@ void Communicator::RpcRecvDense(const std::vector<std::string> &varnames,
#endif #endif
} else { } else {
float *w = tensor->mutable_data<float>(tensor->place()); float *w = tensor->mutable_data<float>(tensor->place());
paddle::distributed::Region reg(w, tensor->numel()); ::paddle::distributed::Region reg(w, tensor->numel());
regions.emplace_back(std::move(reg)); regions.emplace_back(std::move(reg));
} }
} }
...@@ -152,7 +152,7 @@ void Communicator::RpcSendDenseParam(const std::vector<std::string> &varnames, ...@@ -152,7 +152,7 @@ void Communicator::RpcSendDenseParam(const std::vector<std::string> &varnames,
platform::TracerEventType::Communication, platform::TracerEventType::Communication,
1); 1);
auto place = platform::CPUPlace(); auto place = platform::CPUPlace();
std::vector<paddle::distributed::Region> regions; std::vector<::paddle::distributed::Region> regions;
for (auto &t : varnames) { for (auto &t : varnames) {
Variable *var = scope.FindVar(t); Variable *var = scope.FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found"; CHECK(var != nullptr) << "var[" << t << "] not found";
...@@ -164,7 +164,7 @@ void Communicator::RpcSendDenseParam(const std::vector<std::string> &varnames, ...@@ -164,7 +164,7 @@ void Communicator::RpcSendDenseParam(const std::vector<std::string> &varnames,
temp_tensor->Resize(tensor->dims()); temp_tensor->Resize(tensor->dims());
float *temp_data = temp_tensor->mutable_data<float>(platform::CPUPlace()); float *temp_data = temp_tensor->mutable_data<float>(platform::CPUPlace());
framework::TensorCopy(*tensor, platform::CPUPlace(), temp_tensor); framework::TensorCopy(*tensor, platform::CPUPlace(), temp_tensor);
paddle::distributed::Region reg(temp_data, tensor->numel()); ::paddle::distributed::Region reg(temp_data, tensor->numel());
regions.emplace_back(std::move(reg)); regions.emplace_back(std::move(reg));
VLOG(1) << "rpc_send_dense_param Var " << t << " table_id " << table_id VLOG(1) << "rpc_send_dense_param Var " << t << " table_id " << table_id
<< " Temp_data[0] " << temp_data[0] << " Temp_data[-1] " << " Temp_data[0] " << temp_data[0] << " Temp_data[-1] "
...@@ -172,7 +172,7 @@ void Communicator::RpcSendDenseParam(const std::vector<std::string> &varnames, ...@@ -172,7 +172,7 @@ void Communicator::RpcSendDenseParam(const std::vector<std::string> &varnames,
#endif #endif
} else { } else {
float *w = tensor->mutable_data<float>(place); float *w = tensor->mutable_data<float>(place);
paddle::distributed::Region reg(w, tensor->numel()); ::paddle::distributed::Region reg(w, tensor->numel());
regions.emplace_back(reg); regions.emplace_back(reg);
VLOG(1) << "rpc_send_dense_param Var " << t << " table_id " << table_id VLOG(1) << "rpc_send_dense_param Var " << t << " table_id " << table_id
<< " Temp_data[0] " << w[0] << " Temp_data[-1] " << " Temp_data[0] " << w[0] << " Temp_data[-1] "
...@@ -1096,10 +1096,10 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, ...@@ -1096,10 +1096,10 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
parallel_task_nums_ += 1; parallel_task_nums_ += 1;
sparse_id_queues_.insert( sparse_id_queues_.insert(
std::pair<std::string, std::pair<std::string,
paddle::framework::Channel< ::paddle::framework::Channel<
std::shared_ptr<std::vector<int64_t>>>>( std::shared_ptr<std::vector<int64_t>>>>(
splited_var, splited_var,
paddle::framework::MakeChannel< ::paddle::framework::MakeChannel<
std::shared_ptr<std::vector<int64_t>>>(send_queue_size_))); std::shared_ptr<std::vector<int64_t>>>(send_queue_size_)));
} }
} }
...@@ -1509,7 +1509,7 @@ void GeoCommunicator::MainThread() { ...@@ -1509,7 +1509,7 @@ void GeoCommunicator::MainThread() {
void FLCommunicator::InitBrpcClient( void FLCommunicator::InitBrpcClient(
const std::string &dist_desc, const std::string &dist_desc,
const std::vector<std::string> &host_sign_list) { const std::vector<std::string> &host_sign_list) {
auto fleet = paddle::distributed::FleetWrapper::GetInstance(); auto fleet = ::paddle::distributed::FleetWrapper::GetInstance();
if (_worker_ptr.get() == nullptr) { if (_worker_ptr.get() == nullptr) {
VLOG(0) << "fl-ps > FLCommunicator::InitBrpcClient get _worker_ptr"; VLOG(0) << "fl-ps > FLCommunicator::InitBrpcClient get _worker_ptr";
_worker_ptr = _worker_ptr =
......
...@@ -29,7 +29,6 @@ limitations under the License. */ ...@@ -29,7 +29,6 @@ limitations under the License. */
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h" #include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h"
#include "paddle/fluid/distributed/ps/service/coordinator_client.h" #include "paddle/fluid/distributed/ps/service/coordinator_client.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h" #include "paddle/fluid/distributed/ps/service/ps_client.h"
...@@ -45,6 +44,7 @@ limitations under the License. */ ...@@ -45,6 +44,7 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/selected_rows_functor.h" #include "paddle/phi/kernels/funcs/selected_rows_functor.h"
#include "paddle/utils/flags.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -410,8 +410,8 @@ class Communicator { ...@@ -410,8 +410,8 @@ class Communicator {
} }
void InitGFlag(const std::string &gflags); void InitGFlag(const std::string &gflags);
paddle::distributed::PSParameter _ps_param; ::paddle::distributed::PSParameter _ps_param;
paddle::distributed::PaddlePSEnvironment _ps_env; ::paddle::distributed::PaddlePSEnvironment _ps_env;
int servers_ = 0; int servers_ = 0;
int trainers_; int trainers_;
int trainer_id_ = 0; int trainer_id_ = 0;
...@@ -661,7 +661,7 @@ class GeoCommunicator : public AsyncCommunicator { ...@@ -661,7 +661,7 @@ class GeoCommunicator : public AsyncCommunicator {
std::unordered_map< std::unordered_map<
std::string, std::string,
paddle::framework::Channel<std::shared_ptr<std::vector<int64_t>>>> ::paddle::framework::Channel<std::shared_ptr<std::vector<int64_t>>>>
sparse_id_queues_; sparse_id_queues_;
}; };
......
...@@ -28,8 +28,8 @@ static const int MAX_PORT = 65535; ...@@ -28,8 +28,8 @@ static const int MAX_PORT = 65535;
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
DEFINE_uint64(total_fl_client_size, 100, "supported total fl client size"); PD_DEFINE_uint64(total_fl_client_size, 100, "supported total fl client size");
DEFINE_uint32(coordinator_wait_all_clients_max_time, 60, "uint32: s"); PD_DEFINE_uint32(coordinator_wait_all_clients_max_time, 60, "uint32: s");
void CoordinatorService::FLService( void CoordinatorService::FLService(
::google::protobuf::RpcController* controller, ::google::protobuf::RpcController* controller,
...@@ -62,10 +62,10 @@ int32_t CoordinatorClient::Initialize( ...@@ -62,10 +62,10 @@ int32_t CoordinatorClient::Initialize(
const std::vector<std::string>& trainer_endpoints) { const std::vector<std::string>& trainer_endpoints) {
brpc::ChannelOptions options; brpc::ChannelOptions options;
options.protocol = "baidu_std"; options.protocol = "baidu_std";
options.timeout_ms = paddle::distributed::FLAGS_pserver_timeout_ms; options.timeout_ms = ::paddle::distributed::FLAGS_pserver_timeout_ms;
options.connection_type = "pooled"; options.connection_type = "pooled";
options.connect_timeout_ms = options.connect_timeout_ms =
paddle::distributed::FLAGS_pserver_connect_timeout_ms; ::paddle::distributed::FLAGS_pserver_connect_timeout_ms;
options.max_retry = 3; options.max_retry = 3;
std::string server_ip_port; std::string server_ip_port;
...@@ -109,7 +109,7 @@ int32_t CoordinatorClient::Initialize( ...@@ -109,7 +109,7 @@ int32_t CoordinatorClient::Initialize(
} }
for (size_t i = 0; i < trainer_endpoints.size(); i++) { for (size_t i = 0; i < trainer_endpoints.size(); i++) {
std::vector<std::string> addr = std::vector<std::string> addr =
paddle::string::Split(trainer_endpoints[i], ':'); ::paddle::string::Split(trainer_endpoints[i], ':');
fl_client_list[i].ip = addr[0]; fl_client_list[i].ip = addr[0];
fl_client_list[i].port = std::stol(addr[1]); fl_client_list[i].port = std::stol(addr[1]);
fl_client_list[i].rank = i; // TO CHECK fl_client_list[i].rank = i; // TO CHECK
...@@ -152,7 +152,7 @@ int32_t CoordinatorClient::StartClientService() { ...@@ -152,7 +152,7 @@ int32_t CoordinatorClient::StartClientService() {
LOG(ERROR) << "fl-ps > coordinator server endpoint not set"; LOG(ERROR) << "fl-ps > coordinator server endpoint not set";
return -1; return -1;
} }
auto addr = paddle::string::Split(_endpoint, ':'); auto addr = ::paddle::string::Split(_endpoint, ':');
std::string ip = addr[0]; std::string ip = addr[0];
std::string port = addr[1]; std::string port = addr[1];
std::string rank = addr[2]; std::string rank = addr[2];
......
...@@ -34,10 +34,10 @@ ...@@ -34,10 +34,10 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
DECLARE_int32(pserver_timeout_ms); PD_DECLARE_int32(pserver_timeout_ms);
DECLARE_int32(pserver_connect_timeout_ms); PD_DECLARE_int32(pserver_connect_timeout_ms);
DECLARE_uint64(total_fl_client_size); PD_DECLARE_uint64(total_fl_client_size);
DECLARE_uint32(coordinator_wait_all_clients_max_time); PD_DECLARE_uint32(coordinator_wait_all_clients_max_time);
using CoordinatorServiceFunc = using CoordinatorServiceFunc =
std::function<int32_t(const CoordinatorReqMessage& request, std::function<int32_t(const CoordinatorReqMessage& request,
...@@ -91,10 +91,9 @@ class CoordinatorServiceHandle { ...@@ -91,10 +91,9 @@ class CoordinatorServiceHandle {
timeline.Start(); timeline.Start();
auto f = [&]() -> bool { auto f = [&]() -> bool {
while (query_wait_time < while (query_wait_time <
paddle::distributed:: FLAGS_coordinator_wait_all_clients_max_time) { // in case that
FLAGS_coordinator_wait_all_clients_max_time) { // in case that // some
// some // clients down
// clients down
if (_is_all_clients_info_collected == true) { if (_is_all_clients_info_collected == true) {
// LOG(INFO) << "fl-ps > _is_all_clients_info_collected"; // LOG(INFO) << "fl-ps > _is_all_clients_info_collected";
return true; return true;
......
...@@ -25,8 +25,8 @@ ...@@ -25,8 +25,8 @@
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "gflags/gflags.h"
#include "paddle/phi/core/macros.h" #include "paddle/phi/core/macros.h"
#include "paddle/utils/flags.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
......
...@@ -19,8 +19,8 @@ ...@@ -19,8 +19,8 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
DEFINE_int32(heter_world_size, 100, "group size"); // group max size PD_DEFINE_int32(heter_world_size, 100, "group size"); // group max size
DEFINE_int32(switch_send_recv_timeout_s, 600, "switch_send_recv_timeout_s"); PD_DEFINE_int32(switch_send_recv_timeout_s, 600, "switch_send_recv_timeout_s");
std::shared_ptr<HeterClient> HeterClient::s_instance_ = nullptr; std::shared_ptr<HeterClient> HeterClient::s_instance_ = nullptr;
std::mutex HeterClient::mtx_; std::mutex HeterClient::mtx_;
...@@ -85,7 +85,7 @@ void HeterClient::CreateClient2XpuConnection() { ...@@ -85,7 +85,7 @@ void HeterClient::CreateClient2XpuConnection() {
xpu_channels_[i].reset(new brpc::Channel()); xpu_channels_[i].reset(new brpc::Channel());
if (xpu_channels_[i]->Init(xpu_list_[i].c_str(), "", &options) != 0) { if (xpu_channels_[i]->Init(xpu_list_[i].c_str(), "", &options) != 0) {
VLOG(0) << "HeterClient channel init fail. Try Again"; VLOG(0) << "HeterClient channel init fail. Try Again";
auto ip_port = paddle::string::Split(xpu_list_[i], ':'); auto ip_port = ::paddle::string::Split(xpu_list_[i], ':');
std::string ip = ip_port[0]; std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]); int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port); std::string int_ip_port = GetIntTypeEndpoint(ip, port);
...@@ -100,7 +100,7 @@ void HeterClient::CreateClient2XpuConnection() { ...@@ -100,7 +100,7 @@ void HeterClient::CreateClient2XpuConnection() {
if (previous_xpu_channels_[i]->Init( if (previous_xpu_channels_[i]->Init(
previous_xpu_list_[i].c_str(), "", &options) != 0) { previous_xpu_list_[i].c_str(), "", &options) != 0) {
VLOG(0) << "HeterClient channel init fail. Try Again"; VLOG(0) << "HeterClient channel init fail. Try Again";
auto ip_port = paddle::string::Split(previous_xpu_list_[i], ':'); auto ip_port = ::paddle::string::Split(previous_xpu_list_[i], ':');
std::string ip = ip_port[0]; std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]); int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port); std::string int_ip_port = GetIntTypeEndpoint(ip, port);
...@@ -181,11 +181,11 @@ void HeterClient::SendAndRecvAsync( ...@@ -181,11 +181,11 @@ void HeterClient::SendAndRecvAsync(
std::future<int32_t> HeterClient::SendCmd( std::future<int32_t> HeterClient::SendCmd(
uint32_t table_id, int cmd_id, const std::vector<std::string>& params) { uint32_t table_id, int cmd_id, const std::vector<std::string>& params) {
size_t request_call_num = xpu_channels_.size(); size_t request_call_num = xpu_channels_.size();
paddle::distributed::DownpourBrpcClosure* closure = ::paddle::distributed::DownpourBrpcClosure* closure =
new paddle::distributed::DownpourBrpcClosure( new ::paddle::distributed::DownpourBrpcClosure(
request_call_num, [request_call_num, cmd_id](void* done) { request_call_num, [request_call_num, cmd_id](void* done) {
int ret = 0; int ret = 0;
auto* closure = (paddle::distributed::DownpourBrpcClosure*)done; auto* closure = (::paddle::distributed::DownpourBrpcClosure*)done;
for (size_t i = 0; i < request_call_num; ++i) { for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, cmd_id) != 0) { if (closure->check_response(i, cmd_id) != 0) {
ret = -1; ret = -1;
......
...@@ -42,7 +42,7 @@ class Scope; ...@@ -42,7 +42,7 @@ class Scope;
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
DECLARE_int32(pserver_timeout_ms); PD_DECLARE_int32(pserver_timeout_ms);
using MultiVarMsg = ::paddle::distributed::MultiVariableMessage; using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::distributed::VariableMessage; using VarMsg = ::paddle::distributed::VariableMessage;
...@@ -116,7 +116,7 @@ class HeterClient { ...@@ -116,7 +116,7 @@ class HeterClient {
if ((*client_channels)[i]->Init(node_list[i].c_str(), "", &options) != if ((*client_channels)[i]->Init(node_list[i].c_str(), "", &options) !=
0) { 0) {
VLOG(0) << "client channel init failed! try again"; VLOG(0) << "client channel init failed! try again";
auto ip_port = paddle::string::Split(node_list[i], ':'); auto ip_port = ::paddle::string::Split(node_list[i], ':');
std::string ip = ip_port[0]; std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]); int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port); std::string int_ip_port = GetIntTypeEndpoint(ip, port);
......
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
// DEFINE_string(cert_path, "./cert.pem", "cert.pem path"); // PD_DEFINE_string(cert_path, "./cert.pem", "cert.pem path");
// DEFINE_string(key_path, "./key.pem", "key.pem path"); // PD_DEFINE_string(key_path, "./key.pem", "key.pem path");
std::shared_ptr<HeterServer> HeterServer::s_instance_ = nullptr; std::shared_ptr<HeterServer> HeterServer::s_instance_ = nullptr;
std::mutex HeterServer::mtx_; std::mutex HeterServer::mtx_;
...@@ -37,7 +37,7 @@ void HeterServer::StartHeterService(bool neeed_encrypt) { ...@@ -37,7 +37,7 @@ void HeterServer::StartHeterService(bool neeed_encrypt) {
} }
if (server_.Start(endpoint_.c_str(), &options) != 0) { if (server_.Start(endpoint_.c_str(), &options) != 0) {
VLOG(0) << "HeterServer start fail. Try again."; VLOG(0) << "HeterServer start fail. Try again.";
auto ip_port = paddle::string::Split(endpoint_, ':'); auto ip_port = ::paddle::string::Split(endpoint_, ':');
std::string ip = ip_port[0]; std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]); int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port); std::string int_ip_port = GetIntTypeEndpoint(ip, port);
...@@ -72,7 +72,7 @@ void HeterServer::StartHeterInterService(bool neeed_encrypt) { ...@@ -72,7 +72,7 @@ void HeterServer::StartHeterInterService(bool neeed_encrypt) {
} }
if (server_inter_.Start(endpoint_inter_.c_str(), &options) != 0) { if (server_inter_.Start(endpoint_inter_.c_str(), &options) != 0) {
VLOG(4) << "switch inter server start fail. Try again."; VLOG(4) << "switch inter server start fail. Try again.";
auto ip_port = paddle::string::Split(endpoint_inter_, ':'); auto ip_port = ::paddle::string::Split(endpoint_inter_, ':');
std::string ip = ip_port[0]; std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]); int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port); std::string int_ip_port = GetIntTypeEndpoint(ip, port);
......
...@@ -57,9 +57,9 @@ PHI_DECLARE_double(eager_delete_tensor_gb); ...@@ -57,9 +57,9 @@ PHI_DECLARE_double(eager_delete_tensor_gb);
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
DECLARE_int32(pserver_timeout_ms); PD_DECLARE_int32(pserver_timeout_ms);
DECLARE_int32(heter_world_size); PD_DECLARE_int32(heter_world_size);
DECLARE_int32(switch_send_recv_timeout_s); PD_DECLARE_int32(switch_send_recv_timeout_s);
using MultiVarMsg = MultiVariableMessage; using MultiVarMsg = MultiVariableMessage;
using VarMsg = VariableMessage; using VarMsg = VariableMessage;
...@@ -216,8 +216,8 @@ class SendAndRecvVariableHandler final : public ServiceHandlerBase { ...@@ -216,8 +216,8 @@ class SendAndRecvVariableHandler final : public ServiceHandlerBase {
// get microID from request // get microID from request
// deserialize variable to micro scope // deserialize variable to micro scope
// Push to heter worker's task_queue // Push to heter worker's task_queue
std::unique_ptr<paddle::framework::Scope> local_scope_ptr( std::unique_ptr<::paddle::framework::Scope> local_scope_ptr(
new paddle::framework::Scope()); new ::paddle::framework::Scope());
auto& local_scope = *(local_scope_ptr.get()); auto& local_scope = *(local_scope_ptr.get());
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::CPUPlace cpu_place; platform::CPUPlace cpu_place;
...@@ -257,7 +257,7 @@ class SendAndRecvVariableHandler final : public ServiceHandlerBase { ...@@ -257,7 +257,7 @@ class SendAndRecvVariableHandler final : public ServiceHandlerBase {
auto* minibatch_scope = &(scope_->NewScope()); auto* minibatch_scope = &(scope_->NewScope());
(*mini_scopes_)[minibatch_index] = minibatch_scope; (*mini_scopes_)[minibatch_index] = minibatch_scope;
(*micro_scopes_)[minibatch_index].reset( (*micro_scopes_)[minibatch_index].reset(
new std::vector<paddle::framework::Scope*>{}); new std::vector<::paddle::framework::Scope*>{});
for (int i = 0; i < num_microbatch_; i++) { for (int i = 0; i < num_microbatch_; i++) {
auto* micro_scope = &(minibatch_scope->NewScope()); auto* micro_scope = &(minibatch_scope->NewScope());
(*((*micro_scopes_)[minibatch_index])).push_back(micro_scope); (*((*micro_scopes_)[minibatch_index])).push_back(micro_scope);
...@@ -300,7 +300,7 @@ class SendAndRecvVariableHandler final : public ServiceHandlerBase { ...@@ -300,7 +300,7 @@ class SendAndRecvVariableHandler final : public ServiceHandlerBase {
public: public:
using shard_type = SparseTableShard<std::string, ValueInSwitch>; using shard_type = SparseTableShard<std::string, ValueInSwitch>;
std::shared_ptr<paddle::framework::Scope> local_scope_ptr; // for switch std::shared_ptr<::paddle::framework::Scope> local_scope_ptr; // for switch
std::unordered_map<uint32_t, std::unordered_map<std::string, uint32_t>> std::unordered_map<uint32_t, std::unordered_map<std::string, uint32_t>>
vars_ready_flag; vars_ready_flag;
std::unique_ptr<shard_type[]> _local_shards; std::unique_ptr<shard_type[]> _local_shards;
...@@ -344,7 +344,7 @@ class HeterService : public PsService { ...@@ -344,7 +344,7 @@ class HeterService : public PsService {
std::placeholders::_3); std::placeholders::_3);
service_handler_.local_scope_ptr = service_handler_.local_scope_ptr =
std::make_shared<paddle::framework::Scope>(); std::make_shared<::paddle::framework::Scope>();
} }
virtual ~HeterService() {} virtual ~HeterService() {}
...@@ -613,7 +613,7 @@ class HeterServer { ...@@ -613,7 +613,7 @@ class HeterServer {
void SetLocalScope() { void SetLocalScope() {
request_handler_->local_scope_ptr = request_handler_->local_scope_ptr =
std::make_shared<paddle::framework::Scope>(); std::make_shared<::paddle::framework::Scope>();
} }
void SetInterEndpoint(const std::string& endpoint) { void SetInterEndpoint(const std::string& endpoint) {
......
...@@ -37,7 +37,8 @@ REGISTER_PSCORE_CLASS(PSClient, PsGraphClient); ...@@ -37,7 +37,8 @@ REGISTER_PSCORE_CLASS(PSClient, PsGraphClient);
int32_t PSClient::Configure( // called in FleetWrapper::InitWorker int32_t PSClient::Configure( // called in FleetWrapper::InitWorker
const PSParameter &config, const PSParameter &config,
const std::map<uint64_t, std::vector<paddle::distributed::Region>> &regions, const std::map<uint64_t, std::vector<::paddle::distributed::Region>>
&regions,
PSEnvironment &env, PSEnvironment &env,
size_t client_id) { size_t client_id) {
_env = &env; _env = &env;
...@@ -88,7 +89,7 @@ PSClient *PSClientFactory::Create(const PSParameter &ps_config) { ...@@ -88,7 +89,7 @@ PSClient *PSClientFactory::Create(const PSParameter &ps_config) {
PSClient *client = NULL; PSClient *client = NULL;
#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_GPU_GRAPH) #if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_GPU_GRAPH)
auto gloo = paddle::framework::GlooWrapper::GetInstance(); auto gloo = ::paddle::framework::GlooWrapper::GetInstance();
if (client_name == "PsLocalClient" && gloo->Size() > 1) { if (client_name == "PsLocalClient" && gloo->Size() > 1) {
client = CREATE_PSCORE_CLASS(PSClient, "PsGraphClient"); client = CREATE_PSCORE_CLASS(PSClient, "PsGraphClient");
LOG(WARNING) << "change PsLocalClient to PsGraphClient"; LOG(WARNING) << "change PsLocalClient to PsGraphClient";
......
...@@ -143,13 +143,13 @@ class GraphPyServer : public GraphPyService { ...@@ -143,13 +143,13 @@ class GraphPyServer : public GraphPyService {
void start_server(bool block = true); void start_server(bool block = true);
::paddle::distributed::PSParameter GetServerProto(); ::paddle::distributed::PSParameter GetServerProto();
std::shared_ptr<paddle::distributed::GraphBrpcServer> get_ps_server() { std::shared_ptr<::paddle::distributed::GraphBrpcServer> get_ps_server() {
return pserver_ptr; return pserver_ptr;
} }
protected: protected:
int rank; int rank;
std::shared_ptr<paddle::distributed::GraphBrpcServer> pserver_ptr; std::shared_ptr<::paddle::distributed::GraphBrpcServer> pserver_ptr;
std::thread* server_thread; std::thread* server_thread;
}; };
class GraphPyClient : public GraphPyService { class GraphPyClient : public GraphPyService {
...@@ -162,14 +162,14 @@ class GraphPyClient : public GraphPyService { ...@@ -162,14 +162,14 @@ class GraphPyClient : public GraphPyService {
set_client_id(client_id); set_client_id(client_id);
GraphPyService::set_up(ips_str, shard_num, node_types, edge_types); GraphPyService::set_up(ips_str, shard_num, node_types, edge_types);
} }
std::shared_ptr<paddle::distributed::GraphBrpcClient> get_ps_client() { std::shared_ptr<::paddle::distributed::GraphBrpcClient> get_ps_client() {
return worker_ptr; return worker_ptr;
} }
void bind_local_server(int local_channel_index, void bind_local_server(int local_channel_index,
GraphPyServer& server) { // NOLINT GraphPyServer& server) { // NOLINT
worker_ptr->set_local_channel(local_channel_index); worker_ptr->set_local_channel(local_channel_index);
worker_ptr->set_local_graph_service( worker_ptr->set_local_graph_service(
(paddle::distributed::GraphBrpcService*)server.get_ps_server() (::paddle::distributed::GraphBrpcService*)server.get_ps_server()
->get_service()); ->get_service());
} }
void StopServer(); void StopServer();
...@@ -209,7 +209,7 @@ class GraphPyClient : public GraphPyService { ...@@ -209,7 +209,7 @@ class GraphPyClient : public GraphPyService {
protected: protected:
mutable std::mutex mutex_; mutable std::mutex mutex_;
int client_id; int client_id;
std::shared_ptr<paddle::distributed::GraphBrpcClient> worker_ptr; std::shared_ptr<::paddle::distributed::GraphBrpcClient> worker_ptr;
std::thread* client_thread; std::thread* client_thread;
bool stoped_ = false; bool stoped_ = false;
}; };
......
...@@ -28,9 +28,9 @@ using namespace std; // NOLINT ...@@ -28,9 +28,9 @@ using namespace std; // NOLINT
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
paddle::distributed::PSParameter load_from_prototxt( ::paddle::distributed::PSParameter load_from_prototxt(
const std::string& filename) { const std::string& filename) {
paddle::distributed::PSParameter param; ::paddle::distributed::PSParameter param;
int file_descriptor = open(filename.c_str(), O_RDONLY); int file_descriptor = open(filename.c_str(), O_RDONLY);
if (file_descriptor == -1) { if (file_descriptor == -1) {
...@@ -50,7 +50,7 @@ paddle::distributed::PSParameter load_from_prototxt( ...@@ -50,7 +50,7 @@ paddle::distributed::PSParameter load_from_prototxt(
void PSCore::InitGFlag(const std::string& gflags) { void PSCore::InitGFlag(const std::string& gflags) {
VLOG(3) << "Init With Gflags:" << gflags; VLOG(3) << "Init With Gflags:" << gflags;
std::vector<std::string> flags = paddle::string::split_string(gflags); std::vector<std::string> flags = ::paddle::string::split_string(gflags);
if (flags.empty()) { if (flags.empty()) {
flags.push_back("-max_body_size=314217728"); flags.push_back("-max_body_size=314217728");
flags.push_back("-socket_max_unwritten_bytes=2048000000"); flags.push_back("-socket_max_unwritten_bytes=2048000000");
...@@ -64,7 +64,7 @@ void PSCore::InitGFlag(const std::string& gflags) { ...@@ -64,7 +64,7 @@ void PSCore::InitGFlag(const std::string& gflags) {
} }
int params_cnt = flags.size(); int params_cnt = flags.size();
char** params_ptr = &(flags_ptr[0]); char** params_ptr = &(flags_ptr[0]);
::GFLAGS_NAMESPACE::ParseCommandLineFlags(&params_cnt, &params_ptr, true); ::paddle::flags::ParseCommandLineFlags(&params_cnt, &params_ptr);
} }
int PSCore::InitServer( int PSCore::InitServer(
...@@ -76,12 +76,12 @@ int PSCore::InitServer( ...@@ -76,12 +76,12 @@ int PSCore::InitServer(
const std::vector<framework::ProgramDesc>& server_sub_program) { const std::vector<framework::ProgramDesc>& server_sub_program) {
google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param); google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param);
InitGFlag(_ps_param.init_gflags()); InitGFlag(_ps_param.init_gflags());
_ps_env = paddle::distributed::PaddlePSEnvironment(); _ps_env = ::paddle::distributed::PaddlePSEnvironment();
_ps_env.SetPsServers(host_sign_list, node_num); _ps_env.SetPsServers(host_sign_list, node_num);
_ps_env.SetTrainers(trainers); _ps_env.SetTrainers(trainers);
int ret = 0; int ret = 0;
_server_ptr = std::shared_ptr<paddle::distributed::PSServer>( _server_ptr = std::shared_ptr<::paddle::distributed::PSServer>(
paddle::distributed::PSServerFactory::Create(_ps_param)); ::paddle::distributed::PSServerFactory::Create(_ps_param));
ret = _server_ptr->Configure(_ps_param, _ps_env, index, server_sub_program); ret = _server_ptr->Configure(_ps_param, _ps_env, index, server_sub_program);
CHECK(ret == 0) << "failed to configure server"; CHECK(ret == 0) << "failed to configure server";
return ret; return ret;
...@@ -89,13 +89,14 @@ int PSCore::InitServer( ...@@ -89,13 +89,14 @@ int PSCore::InitServer(
int PSCore::InitWorker( int PSCore::InitWorker(
const std::string& dist_desc, const std::string& dist_desc,
const std::map<uint64_t, std::vector<paddle::distributed::Region>>& regions, const std::map<uint64_t, std::vector<::paddle::distributed::Region>>&
regions,
const std::vector<std::string>* host_sign_list, const std::vector<std::string>* host_sign_list,
int node_num, int node_num,
int index) { int index) {
google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param); google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param);
InitGFlag(_ps_param.init_gflags()); InitGFlag(_ps_param.init_gflags());
_ps_env = paddle::distributed::PaddlePSEnvironment(); _ps_env = ::paddle::distributed::PaddlePSEnvironment();
_ps_env.SetPsServers(host_sign_list, node_num); _ps_env.SetPsServers(host_sign_list, node_num);
int ret = 0; int ret = 0;
VLOG(1) << "PSCore::InitWorker"; VLOG(1) << "PSCore::InitWorker";
...@@ -132,6 +133,6 @@ int PSCore::StopServer() { ...@@ -132,6 +133,6 @@ int PSCore::StopServer() {
stop_status.wait(); stop_status.wait();
return 0; return 0;
} }
paddle::distributed::PSParameter* PSCore::GetParam() { return &_ps_param; } ::paddle::distributed::PSParameter* PSCore::GetParam() { return &_ps_param; }
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -33,9 +33,9 @@ class PsRequestMessage; ...@@ -33,9 +33,9 @@ class PsRequestMessage;
class PsResponseMessage; class PsResponseMessage;
class PsService; class PsService;
using paddle::distributed::PsRequestMessage; using ::paddle::distributed::PsRequestMessage;
using paddle::distributed::PsResponseMessage; using ::paddle::distributed::PsResponseMessage;
using paddle::distributed::PsService; using ::paddle::distributed::PsService;
class PSCore { class PSCore {
public: public:
...@@ -51,7 +51,7 @@ class PSCore { ...@@ -51,7 +51,7 @@ class PSCore {
const std::vector<framework::ProgramDesc>& server_sub_program = {}); const std::vector<framework::ProgramDesc>& server_sub_program = {});
virtual int InitWorker( virtual int InitWorker(
const std::string& dist_desc, const std::string& dist_desc,
const std::map<uint64_t, std::vector<paddle::distributed::Region>>& const std::map<uint64_t, std::vector<::paddle::distributed::Region>>&
regions, regions,
const std::vector<std::string>* host_sign_list, const std::vector<std::string>* host_sign_list,
int node_num, int node_num,
...@@ -63,16 +63,16 @@ class PSCore { ...@@ -63,16 +63,16 @@ class PSCore {
virtual int CreateClient2ClientConnection(int pserver_timeout_ms, virtual int CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms, int pserver_connect_timeout_ms,
int max_retry); int max_retry);
std::shared_ptr<paddle::distributed::PSServer> std::shared_ptr<::paddle::distributed::PSServer>
_server_ptr; // pointer to server _server_ptr; // pointer to server
std::shared_ptr<paddle::distributed::PSClient> std::shared_ptr<::paddle::distributed::PSClient>
_worker_ptr; // pointer to worker _worker_ptr; // pointer to worker
virtual paddle::distributed::PSParameter* GetParam(); virtual ::paddle::distributed::PSParameter* GetParam();
private: private:
void InitGFlag(const std::string& gflags); void InitGFlag(const std::string& gflags);
paddle::distributed::PSParameter _ps_param; ::paddle::distributed::PSParameter _ps_param;
paddle::distributed::PaddlePSEnvironment _ps_env; ::paddle::distributed::PaddlePSEnvironment _ps_env;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -57,8 +57,8 @@ namespace distributed { ...@@ -57,8 +57,8 @@ namespace distributed {
class Table; class Table;
using paddle::distributed::PsRequestMessage; using ::paddle::distributed::PsRequestMessage;
using paddle::distributed::PsResponseMessage; using ::paddle::distributed::PsResponseMessage;
class PSServer { class PSServer {
public: public:
...@@ -134,7 +134,7 @@ class PSServer { ...@@ -134,7 +134,7 @@ class PSServer {
return -1; return -1;
} }
paddle::framework::Channel<std::pair<uint64_t, std::string>> _shuffled_ins; ::paddle::framework::Channel<std::pair<uint64_t, std::string>> _shuffled_ins;
protected: protected:
virtual int32_t Initialize() = 0; virtual int32_t Initialize() = 0;
......
...@@ -19,8 +19,8 @@ ...@@ -19,8 +19,8 @@
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
namespace brpc { namespace brpc {
DECLARE_uint64(max_body_size); PD_DECLARE_uint64(max_body_size);
DECLARE_int64(socket_max_unwritten_bytes); PD_DECLARE_int64(socket_max_unwritten_bytes);
} // namespace brpc } // namespace brpc
namespace paddle { namespace paddle {
......
...@@ -58,14 +58,14 @@ int32_t GraphTable::Load_to_ssd(const std::string &path, ...@@ -58,14 +58,14 @@ int32_t GraphTable::Load_to_ssd(const std::string &path,
return 0; return 0;
} }
paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea( ::paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
int gpu_id, std::vector<uint64_t> &node_ids, int slot_num) { int gpu_id, std::vector<uint64_t> &node_ids, int slot_num) {
size_t shard_num = 64; size_t shard_num = 64;
std::vector<std::vector<uint64_t>> bags(shard_num); std::vector<std::vector<uint64_t>> bags(shard_num);
std::vector<uint64_t> feature_array[shard_num]; std::vector<uint64_t> feature_array[shard_num];
std::vector<uint8_t> slot_id_array[shard_num]; std::vector<uint8_t> slot_id_array[shard_num];
std::vector<uint64_t> node_id_array[shard_num]; std::vector<uint64_t> node_id_array[shard_num];
std::vector<paddle::framework::GpuPsFeaInfo> node_fea_info_array[shard_num]; std::vector<::paddle::framework::GpuPsFeaInfo> node_fea_info_array[shard_num];
for (size_t i = 0; i < shard_num; i++) { for (size_t i = 0; i < shard_num; i++) {
auto predsize = node_ids.size() / shard_num; auto predsize = node_ids.size() / shard_num;
bags[i].reserve(predsize * 1.2); bags[i].reserve(predsize * 1.2);
...@@ -92,7 +92,7 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea( ...@@ -92,7 +92,7 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
if (bags[i].size() > 0) { if (bags[i].size() > 0) {
tasks.push_back(_cpu_worker_pool[gpu_id]->enqueue([&, i, this]() -> int { tasks.push_back(_cpu_worker_pool[gpu_id]->enqueue([&, i, this]() -> int {
uint64_t node_id; uint64_t node_id;
paddle::framework::GpuPsFeaInfo x; ::paddle::framework::GpuPsFeaInfo x;
std::vector<uint64_t> feature_ids; std::vector<uint64_t> feature_ids;
for (size_t j = 0; j < bags[i].size(); j++) { for (size_t j = 0; j < bags[i].size(); j++) {
Node *v = find_node(GraphTableType::FEATURE_TABLE, bags[i][j]); Node *v = find_node(GraphTableType::FEATURE_TABLE, bags[i][j]);
...@@ -134,7 +134,7 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea( ...@@ -134,7 +134,7 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
tasks.clear(); tasks.clear();
paddle::framework::GpuPsCommGraphFea res; ::paddle::framework::GpuPsCommGraphFea res;
uint64_t tot_len = 0; uint64_t tot_len = 0;
for (size_t i = 0; i < shard_num; i++) { for (size_t i = 0; i < shard_num; i++) {
tot_len += feature_array[i].size(); tot_len += feature_array[i].size();
...@@ -165,7 +165,7 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea( ...@@ -165,7 +165,7 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
return res; return res;
} }
paddle::framework::GpuPsCommGraph GraphTable::make_gpu_ps_graph( ::paddle::framework::GpuPsCommGraph GraphTable::make_gpu_ps_graph(
int idx, const std::vector<uint64_t> &ids) { int idx, const std::vector<uint64_t> &ids) {
std::vector<std::vector<uint64_t>> bags(task_pool_size_); std::vector<std::vector<uint64_t>> bags(task_pool_size_);
for (int i = 0; i < task_pool_size_; i++) { for (int i = 0; i < task_pool_size_; i++) {
...@@ -179,7 +179,7 @@ paddle::framework::GpuPsCommGraph GraphTable::make_gpu_ps_graph( ...@@ -179,7 +179,7 @@ paddle::framework::GpuPsCommGraph GraphTable::make_gpu_ps_graph(
std::vector<std::future<int>> tasks; std::vector<std::future<int>> tasks;
std::vector<uint64_t> node_array[task_pool_size_]; // node id list std::vector<uint64_t> node_array[task_pool_size_]; // node id list
std::vector<paddle::framework::GpuPsNodeInfo> info_array[task_pool_size_]; std::vector<::paddle::framework::GpuPsNodeInfo> info_array[task_pool_size_];
std::vector<uint64_t> edge_array[task_pool_size_]; // edge id list std::vector<uint64_t> edge_array[task_pool_size_]; // edge id list
for (size_t i = 0; i < bags.size(); i++) { for (size_t i = 0; i < bags.size(); i++) {
...@@ -215,7 +215,7 @@ paddle::framework::GpuPsCommGraph GraphTable::make_gpu_ps_graph( ...@@ -215,7 +215,7 @@ paddle::framework::GpuPsCommGraph GraphTable::make_gpu_ps_graph(
tot_len += edge_array[i].size(); tot_len += edge_array[i].size();
} }
paddle::framework::GpuPsCommGraph res; ::paddle::framework::GpuPsCommGraph res;
res.init_on_cpu(tot_len, ids.size()); res.init_on_cpu(tot_len, ids.size());
int64_t offset = 0, ind = 0; int64_t offset = 0, ind = 0;
for (int i = 0; i < task_pool_size_; i++) { for (int i = 0; i < task_pool_size_; i++) {
...@@ -516,13 +516,13 @@ void GraphTable::release_graph() { ...@@ -516,13 +516,13 @@ void GraphTable::release_graph() {
build_graph_type_keys(); build_graph_type_keys();
if (FLAGS_gpugraph_storage_mode == if (FLAGS_gpugraph_storage_mode ==
paddle::framework::GpuGraphStorageMode::WHOLE_HBM) { ::paddle::framework::GpuGraphStorageMode::WHOLE_HBM) {
build_graph_total_keys(); build_graph_total_keys();
} }
// clear graph // clear graph
if (FLAGS_gpugraph_storage_mode == paddle::framework::GpuGraphStorageMode:: if (FLAGS_gpugraph_storage_mode == ::paddle::framework::GpuGraphStorageMode::
MEM_EMB_FEATURE_AND_GPU_GRAPH || MEM_EMB_FEATURE_AND_GPU_GRAPH ||
FLAGS_gpugraph_storage_mode == paddle::framework::GpuGraphStorageMode:: FLAGS_gpugraph_storage_mode == ::paddle::framework::GpuGraphStorageMode::
SSD_EMB_AND_MEM_FEATURE_GPU_GRAPH) { SSD_EMB_AND_MEM_FEATURE_GPU_GRAPH) {
clear_edge_shard(); clear_edge_shard();
} else { } else {
...@@ -532,7 +532,7 @@ void GraphTable::release_graph() { ...@@ -532,7 +532,7 @@ void GraphTable::release_graph() {
void GraphTable::release_graph_edge() { void GraphTable::release_graph_edge() {
if (FLAGS_gpugraph_storage_mode == if (FLAGS_gpugraph_storage_mode ==
paddle::framework::GpuGraphStorageMode::WHOLE_HBM) { ::paddle::framework::GpuGraphStorageMode::WHOLE_HBM) {
build_graph_total_keys(); build_graph_total_keys();
} }
clear_edge_shard(); clear_edge_shard();
...@@ -543,10 +543,12 @@ void GraphTable::release_graph_node() { ...@@ -543,10 +543,12 @@ void GraphTable::release_graph_node() {
if (FLAGS_graph_metapath_split_opt) { if (FLAGS_graph_metapath_split_opt) {
clear_feature_shard(); clear_feature_shard();
} else { } else {
if (FLAGS_gpugraph_storage_mode != paddle::framework::GpuGraphStorageMode:: if (FLAGS_gpugraph_storage_mode !=
MEM_EMB_FEATURE_AND_GPU_GRAPH && ::paddle::framework::GpuGraphStorageMode::
FLAGS_gpugraph_storage_mode != paddle::framework::GpuGraphStorageMode:: MEM_EMB_FEATURE_AND_GPU_GRAPH &&
SSD_EMB_AND_MEM_FEATURE_GPU_GRAPH) { FLAGS_gpugraph_storage_mode !=
::paddle::framework::GpuGraphStorageMode::
SSD_EMB_AND_MEM_FEATURE_GPU_GRAPH) {
clear_feature_shard(); clear_feature_shard();
} else { } else {
merge_feature_shard(); merge_feature_shard();
...@@ -666,7 +668,7 @@ int32_t GraphTable::load_edges_to_ssd(const std::string &path, ...@@ -666,7 +668,7 @@ int32_t GraphTable::load_edges_to_ssd(const std::string &path,
idx = edge_to_id[edge_type]; idx = edge_to_id[edge_type];
} }
total_memory_cost = 0; total_memory_cost = 0;
auto paths = paddle::string::split_string<std::string>(path, ";"); auto paths = ::paddle::string::split_string<std::string>(path, ";");
int64_t count = 0; int64_t count = 0;
std::string sample_type = "random"; std::string sample_type = "random";
for (auto path : paths) { for (auto path : paths) {
...@@ -674,11 +676,12 @@ int32_t GraphTable::load_edges_to_ssd(const std::string &path, ...@@ -674,11 +676,12 @@ int32_t GraphTable::load_edges_to_ssd(const std::string &path,
std::string line; std::string line;
while (std::getline(file, line)) { while (std::getline(file, line)) {
VLOG(0) << "get a line from file " << line; VLOG(0) << "get a line from file " << line;
auto values = paddle::string::split_string<std::string>(line, "\t"); auto values = ::paddle::string::split_string<std::string>(line, "\t");
count++; count++;
if (values.size() < 2) continue; if (values.size() < 2) continue;
auto src_id = std::stoll(values[0]); auto src_id = std::stoll(values[0]);
auto dist_ids = paddle::string::split_string<std::string>(values[1], ";"); auto dist_ids =
::paddle::string::split_string<std::string>(values[1], ";");
std::vector<uint64_t> dist_data; std::vector<uint64_t> dist_data;
for (auto x : dist_ids) { for (auto x : dist_ids) {
dist_data.push_back(std::stoll(x)); dist_data.push_back(std::stoll(x));
...@@ -798,7 +801,7 @@ int CompleteGraphSampler::run_graph_sampling() { ...@@ -798,7 +801,7 @@ int CompleteGraphSampler::run_graph_sampling() {
sample_nodes.resize(gpu_num); sample_nodes.resize(gpu_num);
sample_neighbors.resize(gpu_num); sample_neighbors.resize(gpu_num);
sample_res.resize(gpu_num); sample_res.resize(gpu_num);
std::vector<std::vector<std::vector<paddle::framework::GpuPsGraphNode>>> std::vector<std::vector<std::vector<::paddle::framework::GpuPsGraphNode>>>
sample_nodes_ex(graph_table->task_pool_size_); sample_nodes_ex(graph_table->task_pool_size_);
std::vector<std::vector<std::vector<int64_t>>> sample_neighbors_ex( std::vector<std::vector<std::vector<int64_t>>> sample_neighbors_ex(
graph_table->task_pool_size_); graph_table->task_pool_size_);
...@@ -812,7 +815,7 @@ int CompleteGraphSampler::run_graph_sampling() { ...@@ -812,7 +815,7 @@ int CompleteGraphSampler::run_graph_sampling() {
graph_table->_shards_task_pool[i % graph_table->task_pool_size_] graph_table->_shards_task_pool[i % graph_table->task_pool_size_]
->enqueue([&, i, this]() -> int { ->enqueue([&, i, this]() -> int {
if (this->status == GraphSamplerStatus::terminating) return 0; if (this->status == GraphSamplerStatus::terminating) return 0;
paddle::framework::GpuPsGraphNode node; ::paddle::framework::GpuPsGraphNode node;
std::vector<Node *> &v = std::vector<Node *> &v =
this->graph_table->shards[i]->get_bucket(); this->graph_table->shards[i]->get_bucket();
size_t ind = i % this->graph_table->task_pool_size_; size_t ind = i % this->graph_table->task_pool_size_;
...@@ -962,7 +965,7 @@ int BasicBfsGraphSampler::run_graph_sampling() { ...@@ -962,7 +965,7 @@ int BasicBfsGraphSampler::run_graph_sampling() {
sample_nodes.resize(gpu_num); sample_nodes.resize(gpu_num);
sample_neighbors.resize(gpu_num); sample_neighbors.resize(gpu_num);
sample_res.resize(gpu_num); sample_res.resize(gpu_num);
std::vector<std::vector<std::vector<paddle::framework::GpuPsGraphNode>>> std::vector<std::vector<std::vector<::paddle::framework::GpuPsGraphNode>>>
sample_nodes_ex(graph_table->task_pool_size_); sample_nodes_ex(graph_table->task_pool_size_);
std::vector<std::vector<std::vector<int64_t>>> sample_neighbors_ex( std::vector<std::vector<std::vector<int64_t>>> sample_neighbors_ex(
graph_table->task_pool_size_); graph_table->task_pool_size_);
...@@ -977,7 +980,7 @@ int BasicBfsGraphSampler::run_graph_sampling() { ...@@ -977,7 +980,7 @@ int BasicBfsGraphSampler::run_graph_sampling() {
if (this->status == GraphSamplerStatus::terminating) { if (this->status == GraphSamplerStatus::terminating) {
return 0; return 0;
} }
paddle::framework::GpuPsGraphNode node; ::paddle::framework::GpuPsGraphNode node;
auto iter = sample_neighbors_map[i].begin(); auto iter = sample_neighbors_map[i].begin();
size_t ind = i; size_t ind = i;
for (; iter != sample_neighbors_map[i].end(); iter++) { for (; iter != sample_neighbors_map[i].end(); iter++) {
...@@ -1237,7 +1240,7 @@ int32_t GraphTable::Load(const std::string &path, const std::string &param) { ...@@ -1237,7 +1240,7 @@ int32_t GraphTable::Load(const std::string &path, const std::string &param) {
} }
std::string GraphTable::get_inverse_etype(std::string &etype) { std::string GraphTable::get_inverse_etype(std::string &etype) {
auto etype_split = paddle::string::split_string<std::string>(etype, "2"); auto etype_split = ::paddle::string::split_string<std::string>(etype, "2");
std::string res; std::string res;
if (etype_split.size() == 3) { if (etype_split.size() == 3) {
res = etype_split[2] + "2" + etype_split[1] + "2" + etype_split[0]; res = etype_split[2] + "2" + etype_split[1] + "2" + etype_split[0];
...@@ -1253,13 +1256,13 @@ int32_t GraphTable::parse_type_to_typepath( ...@@ -1253,13 +1256,13 @@ int32_t GraphTable::parse_type_to_typepath(
std::vector<std::string> &res_type, std::vector<std::string> &res_type,
std::unordered_map<std::string, std::string> &res_type2path) { std::unordered_map<std::string, std::string> &res_type2path) {
auto type2files_split = auto type2files_split =
paddle::string::split_string<std::string>(type2files, ","); ::paddle::string::split_string<std::string>(type2files, ",");
if (type2files_split.empty()) { if (type2files_split.empty()) {
return -1; return -1;
} }
for (auto one_type2file : type2files_split) { for (auto one_type2file : type2files_split) {
auto one_type2file_split = auto one_type2file_split =
paddle::string::split_string<std::string>(one_type2file, ":"); ::paddle::string::split_string<std::string>(one_type2file, ":");
auto type = one_type2file_split[0]; auto type = one_type2file_split[0];
auto type_dir = one_type2file_split[1]; auto type_dir = one_type2file_split[1];
res_type.push_back(type); res_type.push_back(type);
...@@ -1304,17 +1307,17 @@ int32_t GraphTable::parse_edge_and_load( ...@@ -1304,17 +1307,17 @@ int32_t GraphTable::parse_edge_and_load(
VLOG(1) << "only_load_reverse_edge is False, etype[" << etypes[i] VLOG(1) << "only_load_reverse_edge is False, etype[" << etypes[i]
<< "], file_path[" << etype_path << "]"; << "], file_path[" << etype_path << "]";
} }
auto etype_path_list = paddle::framework::localfs_list(etype_path); auto etype_path_list = ::paddle::framework::localfs_list(etype_path);
std::string etype_path_str; std::string etype_path_str;
if (part_num > 0 && if (part_num > 0 &&
part_num < static_cast<int>(etype_path_list.size())) { part_num < static_cast<int>(etype_path_list.size())) {
std::vector<std::string> sub_etype_path_list( std::vector<std::string> sub_etype_path_list(
etype_path_list.begin(), etype_path_list.begin() + part_num); etype_path_list.begin(), etype_path_list.begin() + part_num);
etype_path_str = etype_path_str =
paddle::string::join_strings(sub_etype_path_list, delim); ::paddle::string::join_strings(sub_etype_path_list, delim);
} else { } else {
etype_path_str = etype_path_str =
paddle::string::join_strings(etype_path_list, delim); ::paddle::string::join_strings(etype_path_list, delim);
} }
if (!only_load_reverse_edge) { if (!only_load_reverse_edge) {
this->load_edges(etype_path_str, false, etypes[i]); this->load_edges(etype_path_str, false, etypes[i]);
...@@ -1345,14 +1348,14 @@ int32_t GraphTable::parse_node_and_load(std::string ntype2files, ...@@ -1345,14 +1348,14 @@ int32_t GraphTable::parse_node_and_load(std::string ntype2files,
} }
std::string delim = ";"; std::string delim = ";";
std::string npath = node_to_nodedir[ntypes[0]]; std::string npath = node_to_nodedir[ntypes[0]];
auto npath_list = paddle::framework::localfs_list(npath); auto npath_list = ::paddle::framework::localfs_list(npath);
std::string npath_str; std::string npath_str;
if (part_num > 0 && part_num < static_cast<int>(npath_list.size())) { if (part_num > 0 && part_num < static_cast<int>(npath_list.size())) {
std::vector<std::string> sub_npath_list(npath_list.begin(), std::vector<std::string> sub_npath_list(npath_list.begin(),
npath_list.begin() + part_num); npath_list.begin() + part_num);
npath_str = paddle::string::join_strings(sub_npath_list, delim); npath_str = ::paddle::string::join_strings(sub_npath_list, delim);
} else { } else {
npath_str = paddle::string::join_strings(npath_list, delim); npath_str = ::paddle::string::join_strings(npath_list, delim);
} }
if (ntypes.empty()) { if (ntypes.empty()) {
...@@ -1425,17 +1428,18 @@ int32_t GraphTable::load_node_and_edge_file( ...@@ -1425,17 +1428,18 @@ int32_t GraphTable::load_node_and_edge_file(
VLOG(1) << "only_load_reverse_edge is False, etype[" << etypes[i] VLOG(1) << "only_load_reverse_edge is False, etype[" << etypes[i]
<< "], file_path[" << etype_path << "]"; << "], file_path[" << etype_path << "]";
} }
auto etype_path_list = paddle::framework::localfs_list(etype_path); auto etype_path_list =
::paddle::framework::localfs_list(etype_path);
std::string etype_path_str; std::string etype_path_str;
if (part_num > 0 && if (part_num > 0 &&
part_num < static_cast<int>(etype_path_list.size())) { part_num < static_cast<int>(etype_path_list.size())) {
std::vector<std::string> sub_etype_path_list( std::vector<std::string> sub_etype_path_list(
etype_path_list.begin(), etype_path_list.begin() + part_num); etype_path_list.begin(), etype_path_list.begin() + part_num);
etype_path_str = etype_path_str =
paddle::string::join_strings(sub_etype_path_list, delim); ::paddle::string::join_strings(sub_etype_path_list, delim);
} else { } else {
etype_path_str = etype_path_str =
paddle::string::join_strings(etype_path_list, delim); ::paddle::string::join_strings(etype_path_list, delim);
} }
if (!only_load_reverse_edge) { if (!only_load_reverse_edge) {
this->load_edges(etype_path_str, false, etypes[i]); this->load_edges(etype_path_str, false, etypes[i]);
...@@ -1448,15 +1452,15 @@ int32_t GraphTable::load_node_and_edge_file( ...@@ -1448,15 +1452,15 @@ int32_t GraphTable::load_node_and_edge_file(
} }
} else { } else {
std::string npath = node_to_nodedir[ntypes[0]]; std::string npath = node_to_nodedir[ntypes[0]];
auto npath_list = paddle::framework::localfs_list(npath); auto npath_list = ::paddle::framework::localfs_list(npath);
std::string npath_str; std::string npath_str;
if (part_num > 0 && if (part_num > 0 &&
part_num < static_cast<int>(npath_list.size())) { part_num < static_cast<int>(npath_list.size())) {
std::vector<std::string> sub_npath_list( std::vector<std::string> sub_npath_list(
npath_list.begin(), npath_list.begin() + part_num); npath_list.begin(), npath_list.begin() + part_num);
npath_str = paddle::string::join_strings(sub_npath_list, delim); npath_str = ::paddle::string::join_strings(sub_npath_list, delim);
} else { } else {
npath_str = paddle::string::join_strings(npath_list, delim); npath_str = ::paddle::string::join_strings(npath_list, delim);
} }
if (ntypes.empty()) { if (ntypes.empty()) {
...@@ -1553,14 +1557,14 @@ std::pair<uint64_t, uint64_t> GraphTable::parse_node_file( ...@@ -1553,14 +1557,14 @@ std::pair<uint64_t, uint64_t> GraphTable::parse_node_file(
uint64_t local_valid_count = 0; uint64_t local_valid_count = 0;
int num = 0; int num = 0;
std::vector<paddle::string::str_ptr> vals; std::vector<::paddle::string::str_ptr> vals;
size_t n = node_type.length(); size_t n = node_type.length();
while (std::getline(file, line)) { while (std::getline(file, line)) {
if (strncmp(line.c_str(), node_type.c_str(), n) != 0) { if (strncmp(line.c_str(), node_type.c_str(), n) != 0) {
continue; continue;
} }
vals.clear(); vals.clear();
num = paddle::string::split_string_ptr( num = ::paddle::string::split_string_ptr(
line.c_str() + n + 1, line.length() - n - 1, '\t', &vals); line.c_str() + n + 1, line.length() - n - 1, '\t', &vals);
if (num == 0) { if (num == 0) {
continue; continue;
...@@ -1603,15 +1607,15 @@ std::pair<uint64_t, uint64_t> GraphTable::parse_node_file( ...@@ -1603,15 +1607,15 @@ std::pair<uint64_t, uint64_t> GraphTable::parse_node_file(
uint64_t local_valid_count = 0; uint64_t local_valid_count = 0;
int idx = 0; int idx = 0;
auto path_split = paddle::string::split_string<std::string>(path, "/"); auto path_split = ::paddle::string::split_string<std::string>(path, "/");
auto path_name = path_split[path_split.size() - 1]; auto path_name = path_split[path_split.size() - 1];
int num = 0; int num = 0;
std::vector<paddle::string::str_ptr> vals; std::vector<::paddle::string::str_ptr> vals;
while (std::getline(file, line)) { while (std::getline(file, line)) {
vals.clear(); vals.clear();
num = paddle::string::split_string_ptr( num = ::paddle::string::split_string_ptr(
line.c_str(), line.length(), '\t', &vals); line.c_str(), line.length(), '\t', &vals);
if (vals.empty()) { if (vals.empty()) {
continue; continue;
...@@ -1654,7 +1658,7 @@ std::pair<uint64_t, uint64_t> GraphTable::parse_node_file( ...@@ -1654,7 +1658,7 @@ std::pair<uint64_t, uint64_t> GraphTable::parse_node_file(
// // TODO(danleifeng): opt load all node_types in once reading // // TODO(danleifeng): opt load all node_types in once reading
int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) { int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
auto paths = paddle::string::split_string<std::string>(path, ";"); auto paths = ::paddle::string::split_string<std::string>(path, ";");
uint64_t count = 0; uint64_t count = 0;
uint64_t valid_count = 0; uint64_t valid_count = 0;
int idx = 0; int idx = 0;
...@@ -1725,8 +1729,8 @@ std::pair<uint64_t, uint64_t> GraphTable::parse_edge_file( ...@@ -1725,8 +1729,8 @@ std::pair<uint64_t, uint64_t> GraphTable::parse_edge_file(
uint64_t local_valid_count = 0; uint64_t local_valid_count = 0;
uint64_t part_num = 0; uint64_t part_num = 0;
if (FLAGS_graph_load_in_parallel) { if (FLAGS_graph_load_in_parallel) {
auto path_split = paddle::string::split_string<std::string>(path, "/"); auto path_split = ::paddle::string::split_string<std::string>(path, "/");
auto part_name_split = paddle::string::split_string<std::string>( auto part_name_split = ::paddle::string::split_string<std::string>(
path_split[path_split.size() - 1], "-"); path_split[path_split.size() - 1], "-");
part_num = std::stoull(part_name_split[part_name_split.size() - 1]); part_num = std::stoull(part_name_split[part_name_split.size() - 1]);
} }
...@@ -1793,7 +1797,7 @@ int32_t GraphTable::load_edges(const std::string &path, ...@@ -1793,7 +1797,7 @@ int32_t GraphTable::load_edges(const std::string &path,
idx = edge_to_id[edge_type]; idx = edge_to_id[edge_type];
} }
auto paths = paddle::string::split_string<std::string>(path, ";"); auto paths = ::paddle::string::split_string<std::string>(path, ";");
uint64_t count = 0; uint64_t count = 0;
uint64_t valid_count = 0; uint64_t valid_count = 0;
...@@ -1865,7 +1869,7 @@ Node *GraphTable::find_node(GraphTableType table_type, uint64_t id) { ...@@ -1865,7 +1869,7 @@ Node *GraphTable::find_node(GraphTableType table_type, uint64_t id) {
table_type == GraphTableType::EDGE_TABLE ? edge_shards : feature_shards; table_type == GraphTableType::EDGE_TABLE ? edge_shards : feature_shards;
for (auto &search_shard : search_shards) { for (auto &search_shard : search_shards) {
PADDLE_ENFORCE_NOT_NULL(search_shard[index], PADDLE_ENFORCE_NOT_NULL(search_shard[index],
paddle::platform::errors::InvalidArgument( ::paddle::platform::errors::InvalidArgument(
"search_shard[%d] should not be null.", index)); "search_shard[%d] should not be null.", index));
node = search_shard[index]->find_node(id); node = search_shard[index]->find_node(id);
if (node != nullptr) { if (node != nullptr) {
...@@ -1885,7 +1889,7 @@ Node *GraphTable::find_node(GraphTableType table_type, int idx, uint64_t id) { ...@@ -1885,7 +1889,7 @@ Node *GraphTable::find_node(GraphTableType table_type, int idx, uint64_t id) {
? edge_shards[idx] ? edge_shards[idx]
: feature_shards[idx]; : feature_shards[idx];
PADDLE_ENFORCE_NOT_NULL(search_shards[index], PADDLE_ENFORCE_NOT_NULL(search_shards[index],
paddle::platform::errors::InvalidArgument( ::paddle::platform::errors::InvalidArgument(
"search_shard[%d] should not be null.", index)); "search_shard[%d] should not be null.", index));
Node *node = search_shards[index]->find_node(id); Node *node = search_shards[index]->find_node(id);
return node; return node;
...@@ -2164,8 +2168,8 @@ void string_vector_2_string(std::vector<std::string>::iterator strs_begin, ...@@ -2164,8 +2168,8 @@ void string_vector_2_string(std::vector<std::string>::iterator strs_begin,
} }
void string_vector_2_string( void string_vector_2_string(
std::vector<paddle::string::str_ptr>::iterator strs_begin, std::vector<::paddle::string::str_ptr>::iterator strs_begin,
std::vector<paddle::string::str_ptr>::iterator strs_end, std::vector<::paddle::string::str_ptr>::iterator strs_end,
char delim, char delim,
std::string *output) { std::string *output) {
size_t i = 0; size_t i = 0;
...@@ -2184,19 +2188,19 @@ int GraphTable::parse_feature(int idx, ...@@ -2184,19 +2188,19 @@ int GraphTable::parse_feature(int idx,
FeatureNode *node) { FeatureNode *node) {
// Return (feat_id, btyes) if name are in this->feat_name, else return (-1, // Return (feat_id, btyes) if name are in this->feat_name, else return (-1,
// "") // "")
thread_local std::vector<paddle::string::str_ptr> fields; thread_local std::vector<::paddle::string::str_ptr> fields;
fields.clear(); fields.clear();
char c = slot_feature_separator_.at(0); char c = slot_feature_separator_.at(0);
paddle::string::split_string_ptr(feat_str, len, c, &fields); ::paddle::string::split_string_ptr(feat_str, len, c, &fields);
thread_local std::vector<paddle::string::str_ptr> fea_fields; thread_local std::vector<::paddle::string::str_ptr> fea_fields;
fea_fields.clear(); fea_fields.clear();
c = feature_separator_.at(0); c = feature_separator_.at(0);
paddle::string::split_string_ptr(fields[1].ptr, ::paddle::string::split_string_ptr(fields[1].ptr,
fields[1].len, fields[1].len,
c, c,
&fea_fields, &fea_fields,
FLAGS_gpugraph_slot_feasign_max_num); FLAGS_gpugraph_slot_feasign_max_num);
std::string name = fields[0].to_string(); std::string name = fields[0].to_string();
auto it = feat_id_map[idx].find(name); auto it = feat_id_map[idx].find(name);
if (it != feat_id_map[idx].end()) { if (it != feat_id_map[idx].end()) {
...@@ -2522,14 +2526,14 @@ int32_t GraphTable::Initialize(const TableParameter &config, ...@@ -2522,14 +2526,14 @@ int32_t GraphTable::Initialize(const TableParameter &config,
} }
void GraphTable::load_node_weight(int type_id, int idx, std::string path) { void GraphTable::load_node_weight(int type_id, int idx, std::string path) {
auto paths = paddle::string::split_string<std::string>(path, ";"); auto paths = ::paddle::string::split_string<std::string>(path, ";");
int64_t count = 0; int64_t count = 0;
auto &weight_map = node_weight[type_id][idx]; auto &weight_map = node_weight[type_id][idx];
for (auto path : paths) { for (auto path : paths) {
std::ifstream file(path); std::ifstream file(path);
std::string line; std::string line;
while (std::getline(file, line)) { while (std::getline(file, line)) {
auto values = paddle::string::split_string<std::string>(line, "\t"); auto values = ::paddle::string::split_string<std::string>(line, "\t");
count++; count++;
if (values.size() < 2) continue; if (values.size() < 2) continue;
auto src_id = std::stoull(values[0]); auto src_id = std::stoull(values[0]);
...@@ -2546,7 +2550,7 @@ int32_t GraphTable::Initialize(const GraphParameter &graph) { ...@@ -2546,7 +2550,7 @@ int32_t GraphTable::Initialize(const GraphParameter &graph) {
_db = NULL; _db = NULL;
search_level = graph.search_level(); search_level = graph.search_level();
if (search_level >= 2) { if (search_level >= 2) {
_db = paddle::distributed::RocksDBHandler::GetInstance(); _db = ::paddle::distributed::RocksDBHandler::GetInstance();
_db->initialize("./temp_gpups_db", task_pool_size_); _db->initialize("./temp_gpups_db", task_pool_size_);
} }
// gpups_mode = true; // gpups_mode = true;
......
...@@ -712,9 +712,9 @@ class GraphTable : public Table { ...@@ -712,9 +712,9 @@ class GraphTable : public Table {
int &actual_size); // NOLINT int &actual_size); // NOLINT
virtual int32_t add_node_to_ssd( virtual int32_t add_node_to_ssd(
int type_id, int idx, uint64_t src_id, char *data, int len); int type_id, int idx, uint64_t src_id, char *data, int len);
virtual paddle::framework::GpuPsCommGraph make_gpu_ps_graph( virtual ::paddle::framework::GpuPsCommGraph make_gpu_ps_graph(
int idx, const std::vector<uint64_t> &ids); int idx, const std::vector<uint64_t> &ids);
virtual paddle::framework::GpuPsCommGraphFea make_gpu_ps_graph_fea( virtual ::paddle::framework::GpuPsCommGraphFea make_gpu_ps_graph_fea(
int gpu_id, std::vector<uint64_t> &node_ids, int slot_num); // NOLINT int gpu_id, std::vector<uint64_t> &node_ids, int slot_num); // NOLINT
int32_t Load_to_ssd(const std::string &path, const std::string &param); int32_t Load_to_ssd(const std::string &path, const std::string &param);
int64_t load_graph_to_memory_from_ssd(int idx, int64_t load_graph_to_memory_from_ssd(int idx,
...@@ -786,7 +786,7 @@ class GraphTable : public Table { ...@@ -786,7 +786,7 @@ class GraphTable : public Table {
std::shared_ptr<pthread_rwlock_t> rw_lock; std::shared_ptr<pthread_rwlock_t> rw_lock;
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
// paddle::framework::GpuPsGraphTable gpu_graph_table; // paddle::framework::GpuPsGraphTable gpu_graph_table;
paddle::distributed::RocksDBHandler *_db; ::paddle::distributed::RocksDBHandler *_db;
// std::shared_ptr<::ThreadPool> graph_sample_pool; // std::shared_ptr<::ThreadPool> graph_sample_pool;
// std::shared_ptr<GraphSampler> graph_sampler; // std::shared_ptr<GraphSampler> graph_sampler;
// REGISTER_GRAPH_FRIEND_CLASS(2, CompleteGraphSampler, BasicBfsGraphSampler) // REGISTER_GRAPH_FRIEND_CLASS(2, CompleteGraphSampler, BasicBfsGraphSampler)
...@@ -847,8 +847,8 @@ class BasicBfsGraphSampler : public GraphSampler { ...@@ -847,8 +847,8 @@ class BasicBfsGraphSampler : public GraphSampler {
namespace std { namespace std {
template <> template <>
struct hash<paddle::distributed::SampleKey> { struct hash<::paddle::distributed::SampleKey> {
size_t operator()(const paddle::distributed::SampleKey &s) const { size_t operator()(const ::paddle::distributed::SampleKey &s) const {
return s.idx ^ s.node_key ^ s.sample_size; return s.idx ^ s.node_key ^ s.sample_size;
} }
}; };
......
...@@ -14,10 +14,9 @@ ...@@ -14,10 +14,9 @@
#include "paddle/fluid/distributed/ps/table/ctr_accessor.h" #include "paddle/fluid/distributed/ps/table/ctr_accessor.h"
#include <gflags/gflags.h>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#include "paddle/utils/flags.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
......
...@@ -14,10 +14,9 @@ ...@@ -14,10 +14,9 @@
#include "paddle/fluid/distributed/ps/table/ctr_double_accessor.h" #include "paddle/fluid/distributed/ps/table/ctr_double_accessor.h"
#include <gflags/gflags.h>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#include "paddle/utils/flags.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
......
...@@ -14,10 +14,9 @@ ...@@ -14,10 +14,9 @@
#include "paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h" #include "paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h"
#include <gflags/gflags.h>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#include "paddle/utils/flags.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
......
...@@ -22,8 +22,8 @@ ...@@ -22,8 +22,8 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/common/utils.h" #include "paddle/fluid/distributed/common/utils.h"
#include "paddle/utils/flags.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
......
...@@ -14,11 +14,12 @@ ...@@ -14,11 +14,12 @@
#pragma once #pragma once
#include <mct/hash-map.hpp>
#include <vector> #include <vector>
#include "gflags/gflags.h" #include <mct/hash-map.hpp>
#include "paddle/fluid/distributed/common/chunk_allocator.h" #include "paddle/fluid/distributed/common/chunk_allocator.h"
#include "paddle/utils/flags.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
......
...@@ -21,9 +21,9 @@ ...@@ -21,9 +21,9 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/operators/truncated_gaussian_random_op.h" #include "paddle/fluid/operators/truncated_gaussian_random_op.h"
#include "paddle/phi/core/generator.h" #include "paddle/phi/core/generator.h"
#include "paddle/utils/flags.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -124,13 +124,13 @@ class TruncatedGaussianInitializer : public Initializer { ...@@ -124,13 +124,13 @@ class TruncatedGaussianInitializer : public Initializer {
} }
float GetValue() override { float GetValue() override {
paddle::operators::TruncatedNormal<float> truncated_normal(mean_, std_); ::paddle::operators::TruncatedNormal<float> truncated_normal(mean_, std_);
float value = truncated_normal(dist_(*random_engine_)); float value = truncated_normal(dist_(*random_engine_));
return value; return value;
} }
void GetValue(float *value, int numel) { void GetValue(float *value, int numel) {
paddle::operators::TruncatedNormal<float> truncated_normal(mean_, std_); ::paddle::operators::TruncatedNormal<float> truncated_normal(mean_, std_);
for (int x = 0; x < numel; ++x) { for (int x = 0; x < numel; ++x) {
value[x] = truncated_normal(dist_(*random_engine_)); value[x] = truncated_normal(dist_(*random_engine_));
} }
......
...@@ -36,4 +36,4 @@ ...@@ -36,4 +36,4 @@
#define DECLARE_11_FRIEND_CLASS(a, ...) \ #define DECLARE_11_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_10_FRIEND_CLASS(__VA_ARGS__) DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_10_FRIEND_CLASS(__VA_ARGS__)
#define REGISTER_GRAPH_FRIEND_CLASS(n, ...) \ #define REGISTER_GRAPH_FRIEND_CLASS(n, ...) \
DECLARE_##n##_FRIEND_CLASS(__VA_ARGS__) PD_DECLARE_##n##_FRIEND_CLASS(__VA_ARGS__)
...@@ -26,16 +26,18 @@ ...@@ -26,16 +26,18 @@
// #include "boost/lexical_cast.hpp" // #include "boost/lexical_cast.hpp"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
DEFINE_bool(pserver_print_missed_key_num_every_push, PD_DEFINE_bool(pserver_print_missed_key_num_every_push,
false, false,
"pserver_print_missed_key_num_every_push"); "pserver_print_missed_key_num_every_push");
DEFINE_bool(pserver_create_value_when_push, PD_DEFINE_bool(pserver_create_value_when_push,
true, true,
"pserver create value when push"); "pserver create value when push");
DEFINE_bool(pserver_enable_create_feasign_randomly, PD_DEFINE_bool(pserver_enable_create_feasign_randomly,
false, false,
"pserver_enable_create_feasign_randomly"); "pserver_enable_create_feasign_randomly");
DEFINE_int32(pserver_table_save_max_retry, 3, "pserver_table_save_max_retry"); PD_DEFINE_int32(pserver_table_save_max_retry,
3,
"pserver_table_save_max_retry");
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -333,7 +335,7 @@ int32_t MemorySparseTable::Save(const std::string &dirname, ...@@ -333,7 +335,7 @@ int32_t MemorySparseTable::Save(const std::string &dirname,
TopkCalculator tk(_real_local_shard_num, tk_size); TopkCalculator tk(_real_local_shard_num, tk_size);
std::string table_path = TableDir(dirname); std::string table_path = TableDir(dirname);
_afs_client.remove(paddle::string::format_string( _afs_client.remove(::paddle::string::format_string(
"%s/part-%03d-*", table_path.c_str(), _shard_idx)); "%s/part-%03d-*", table_path.c_str(), _shard_idx));
std::atomic<uint32_t> feasign_size_all{0}; std::atomic<uint32_t> feasign_size_all{0};
...@@ -350,15 +352,15 @@ int32_t MemorySparseTable::Save(const std::string &dirname, ...@@ -350,15 +352,15 @@ int32_t MemorySparseTable::Save(const std::string &dirname,
FsChannelConfig channel_config; FsChannelConfig channel_config;
if (_config.compress_in_save() && (save_param == 0 || save_param == 3)) { if (_config.compress_in_save() && (save_param == 0 || save_param == 3)) {
channel_config.path = channel_config.path =
paddle::string::format_string("%s/part-%03d-%05d.gz", ::paddle::string::format_string("%s/part-%03d-%05d.gz",
table_path.c_str(), table_path.c_str(),
_shard_idx, _shard_idx,
file_start_idx + i); file_start_idx + i);
} else { } else {
channel_config.path = paddle::string::format_string("%s/part-%03d-%05d", channel_config.path = ::paddle::string::format_string("%s/part-%03d-%05d",
table_path.c_str(), table_path.c_str(),
_shard_idx, _shard_idx,
file_start_idx + i); file_start_idx + i);
} }
channel_config.converter = _value_accesor->Converter(save_param).converter; channel_config.converter = _value_accesor->Converter(save_param).converter;
channel_config.deconverter = channel_config.deconverter =
...@@ -385,7 +387,7 @@ int32_t MemorySparseTable::Save(const std::string &dirname, ...@@ -385,7 +387,7 @@ int32_t MemorySparseTable::Save(const std::string &dirname,
if (_value_accesor->Save(it.value().data(), save_param)) { if (_value_accesor->Save(it.value().data(), save_param)) {
std::string format_value = _value_accesor->ParseToString( std::string format_value = _value_accesor->ParseToString(
it.value().data(), it.value().size()); it.value().data(), it.value().size());
if (0 != write_channel->write_line(paddle::string::format_string( if (0 != write_channel->write_line(::paddle::string::format_string(
"%lu %s", it.key(), format_value.c_str()))) { "%lu %s", it.key(), format_value.c_str()))) {
++retry_num; ++retry_num;
is_write_failed = true; is_write_failed = true;
...@@ -432,7 +434,7 @@ int32_t MemorySparseTable::SavePatch(const std::string &path, int save_param) { ...@@ -432,7 +434,7 @@ int32_t MemorySparseTable::SavePatch(const std::string &path, int save_param) {
} }
size_t file_start_idx = _m_avg_local_shard_num * _shard_idx; size_t file_start_idx = _m_avg_local_shard_num * _shard_idx;
std::string table_path = TableDir(path); std::string table_path = TableDir(path);
_afs_client.remove(paddle::string::format_string( _afs_client.remove(::paddle::string::format_string(
"%s/part-%03d-*", table_path.c_str(), _shard_idx)); "%s/part-%03d-*", table_path.c_str(), _shard_idx));
int thread_num = _m_real_local_shard_num < 20 ? _m_real_local_shard_num : 20; int thread_num = _m_real_local_shard_num < 20 ? _m_real_local_shard_num : 20;
...@@ -442,10 +444,10 @@ int32_t MemorySparseTable::SavePatch(const std::string &path, int save_param) { ...@@ -442,10 +444,10 @@ int32_t MemorySparseTable::SavePatch(const std::string &path, int save_param) {
#pragma omp parallel for schedule(dynamic) #pragma omp parallel for schedule(dynamic)
for (int i = 0; i < _m_real_local_shard_num; ++i) { for (int i = 0; i < _m_real_local_shard_num; ++i) {
FsChannelConfig channel_config; FsChannelConfig channel_config;
channel_config.path = paddle::string::format_string("%s/part-%03d-%05d", channel_config.path = ::paddle::string::format_string("%s/part-%03d-%05d",
table_path.c_str(), table_path.c_str(),
_shard_idx, _shard_idx,
file_start_idx + i); file_start_idx + i);
channel_config.converter = _value_accesor->Converter(save_param).converter; channel_config.converter = _value_accesor->Converter(save_param).converter;
channel_config.deconverter = channel_config.deconverter =
...@@ -469,8 +471,9 @@ int32_t MemorySparseTable::SavePatch(const std::string &path, int save_param) { ...@@ -469,8 +471,9 @@ int32_t MemorySparseTable::SavePatch(const std::string &path, int save_param) {
if (_value_accesor->Save(it.value().data(), save_param)) { if (_value_accesor->Save(it.value().data(), save_param)) {
std::string format_value = _value_accesor->ParseToString( std::string format_value = _value_accesor->ParseToString(
it.value().data(), it.value().size()); it.value().data(), it.value().size());
if (0 != write_channel->write_line(paddle::string::format_string( if (0 !=
"%lu %s", it.key(), format_value.c_str()))) { write_channel->write_line(::paddle::string::format_string(
"%lu %s", it.key(), format_value.c_str()))) {
++retry_num; ++retry_num;
is_write_failed = true; is_write_failed = true;
LOG(ERROR) << "MemorySparseTable save failed, retry it! path:" LOG(ERROR) << "MemorySparseTable save failed, retry it! path:"
...@@ -503,10 +506,10 @@ int32_t MemorySparseTable::SavePatch(const std::string &path, int save_param) { ...@@ -503,10 +506,10 @@ int32_t MemorySparseTable::SavePatch(const std::string &path, int save_param) {
feasign_size_all += feasign_size; feasign_size_all += feasign_size;
} }
LOG(INFO) << "MemorySparseTable save patch success, path:" LOG(INFO) << "MemorySparseTable save patch success, path:"
<< paddle::string::format_string("%s/%03d/part-%03d-", << ::paddle::string::format_string("%s/%03d/part-%03d-",
path.c_str(), path.c_str(),
_config.table_id(), _config.table_id(),
_shard_idx) _shard_idx)
<< " from " << file_start_idx << " to " << " from " << file_start_idx << " to "
<< file_start_idx + _m_real_local_shard_num - 1 << file_start_idx + _m_real_local_shard_num - 1
<< ", feasign size: " << feasign_size_all; << ", feasign size: " << feasign_size_all;
...@@ -519,7 +522,7 @@ int64_t MemorySparseTable::CacheShuffle( ...@@ -519,7 +522,7 @@ int64_t MemorySparseTable::CacheShuffle(
double cache_threshold, double cache_threshold,
std::function<std::future<int32_t>( std::function<std::future<int32_t>(
int msg_type, int to_pserver_id, std::string &msg)> send_msg_func, int msg_type, int to_pserver_id, std::string &msg)> send_msg_func,
paddle::framework::Channel<std::pair<uint64_t, std::string>> ::paddle::framework::Channel<std::pair<uint64_t, std::string>>
&shuffled_channel, &shuffled_channel,
const std::vector<Table *> &table_ptrs) { const std::vector<Table *> &table_ptrs) {
LOG(INFO) << "cache shuffle with cache threshold: " << cache_threshold; LOG(INFO) << "cache shuffle with cache threshold: " << cache_threshold;
...@@ -536,24 +539,24 @@ int64_t MemorySparseTable::CacheShuffle( ...@@ -536,24 +539,24 @@ int64_t MemorySparseTable::CacheShuffle(
int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20; int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20;
std::vector< std::vector<
paddle::framework::ChannelWriter<std::pair<uint64_t, std::string>>> ::paddle::framework::ChannelWriter<std::pair<uint64_t, std::string>>>
writers(_real_local_shard_num); writers(_real_local_shard_num);
std::vector<std::vector<std::pair<uint64_t, std::string>>> datas( std::vector<std::vector<std::pair<uint64_t, std::string>>> datas(
_real_local_shard_num); _real_local_shard_num);
int feasign_size = 0; int feasign_size = 0;
std::vector<paddle::framework::Channel<std::pair<uint64_t, std::string>>> std::vector<::paddle::framework::Channel<std::pair<uint64_t, std::string>>>
tmp_channels; tmp_channels;
for (int i = 0; i < _real_local_shard_num; ++i) { for (int i = 0; i < _real_local_shard_num; ++i) {
tmp_channels.push_back( tmp_channels.push_back(
paddle::framework::MakeChannel<std::pair<uint64_t, std::string>>()); ::paddle::framework::MakeChannel<std::pair<uint64_t, std::string>>());
} }
omp_set_num_threads(thread_num); omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic) #pragma omp parallel for schedule(dynamic)
for (int i = 0; i < _real_local_shard_num; ++i) { for (int i = 0; i < _real_local_shard_num; ++i) {
paddle::framework::ChannelWriter<std::pair<uint64_t, std::string>> &writer = ::paddle::framework::ChannelWriter<std::pair<uint64_t, std::string>>
writers[i]; &writer = writers[i];
writer.Reset(tmp_channels[i].get()); writer.Reset(tmp_channels[i].get());
for (auto table_ptr : table_ptrs) { for (auto table_ptr : table_ptrs) {
...@@ -579,15 +582,15 @@ int64_t MemorySparseTable::CacheShuffle( ...@@ -579,15 +582,15 @@ int64_t MemorySparseTable::CacheShuffle(
// shard num: " << _real_local_shard_num; // shard num: " << _real_local_shard_num;
std::vector<std::pair<uint64_t, std::string>> local_datas; std::vector<std::pair<uint64_t, std::string>> local_datas;
for (int idx_shard = 0; idx_shard < _real_local_shard_num; ++idx_shard) { for (int idx_shard = 0; idx_shard < _real_local_shard_num; ++idx_shard) {
paddle::framework::ChannelWriter<std::pair<uint64_t, std::string>> &writer = ::paddle::framework::ChannelWriter<std::pair<uint64_t, std::string>>
writers[idx_shard]; &writer = writers[idx_shard];
auto channel = writer.channel(); auto channel = writer.channel();
std::vector<std::pair<uint64_t, std::string>> &data = datas[idx_shard]; std::vector<std::pair<uint64_t, std::string>> &data = datas[idx_shard];
std::vector<paddle::framework::BinaryArchive> ars(shuffle_node_num); std::vector<::paddle::framework::BinaryArchive> ars(shuffle_node_num);
while (channel->Read(data)) { while (channel->Read(data)) {
for (auto &t : data) { for (auto &t : data) {
auto pserver_id = auto pserver_id =
paddle::distributed::local_random_engine()() % shuffle_node_num; ::paddle::distributed::local_random_engine()() % shuffle_node_num;
if (pserver_id != _shard_idx) { if (pserver_id != _shard_idx) {
ars[pserver_id] << t; ars[pserver_id] << t;
} else { } else {
...@@ -618,7 +621,7 @@ int64_t MemorySparseTable::CacheShuffle( ...@@ -618,7 +621,7 @@ int64_t MemorySparseTable::CacheShuffle(
t.wait(); t.wait();
} }
ars.clear(); ars.clear();
ars = std::vector<paddle::framework::BinaryArchive>(shuffle_node_num); ars = std::vector<::paddle::framework::BinaryArchive>(shuffle_node_num);
data = std::vector<std::pair<uint64_t, std::string>>(); data = std::vector<std::pair<uint64_t, std::string>>();
} }
} }
...@@ -629,20 +632,20 @@ int64_t MemorySparseTable::CacheShuffle( ...@@ -629,20 +632,20 @@ int64_t MemorySparseTable::CacheShuffle(
int32_t MemorySparseTable::SaveCache( int32_t MemorySparseTable::SaveCache(
const std::string &path, const std::string &path,
const std::string &param, const std::string &param,
paddle::framework::Channel<std::pair<uint64_t, std::string>> ::paddle::framework::Channel<std::pair<uint64_t, std::string>>
&shuffled_channel) { &shuffled_channel) {
if (_shard_idx >= _config.sparse_table_cache_file_num()) { if (_shard_idx >= _config.sparse_table_cache_file_num()) {
return 0; return 0;
} }
int save_param = atoi(param.c_str()); // batch_model:0 xbox:1 int save_param = atoi(param.c_str()); // batch_model:0 xbox:1
std::string table_path = paddle::string::format_string( std::string table_path = ::paddle::string::format_string(
"%s/%03d_cache/", path.c_str(), _config.table_id()); "%s/%03d_cache/", path.c_str(), _config.table_id());
_afs_client.remove(paddle::string::format_string( _afs_client.remove(::paddle::string::format_string(
"%s/part-%03d", table_path.c_str(), _shard_idx)); "%s/part-%03d", table_path.c_str(), _shard_idx));
uint32_t feasign_size = 0; uint32_t feasign_size = 0;
FsChannelConfig channel_config; FsChannelConfig channel_config;
// not compress cache model // not compress cache model
channel_config.path = paddle::string::format_string( channel_config.path = ::paddle::string::format_string(
"%s/part-%03d", table_path.c_str(), _shard_idx); "%s/part-%03d", table_path.c_str(), _shard_idx);
channel_config.converter = _value_accesor->Converter(save_param).converter; channel_config.converter = _value_accesor->Converter(save_param).converter;
channel_config.deconverter = channel_config.deconverter =
...@@ -654,7 +657,7 @@ int32_t MemorySparseTable::SaveCache( ...@@ -654,7 +657,7 @@ int32_t MemorySparseTable::SaveCache(
while (shuffled_channel->Read(data)) { while (shuffled_channel->Read(data)) {
for (auto &t : data) { for (auto &t : data) {
++feasign_size; ++feasign_size;
if (0 != write_channel->write_line(paddle::string::format_string( if (0 != write_channel->write_line(::paddle::string::format_string(
"%lu %s", t.first, t.second.c_str()))) { "%lu %s", t.first, t.second.c_str()))) {
LOG(ERROR) << "Cache Table save failed, " LOG(ERROR) << "Cache Table save failed, "
"path:" "path:"
......
...@@ -14,10 +14,9 @@ ...@@ -14,10 +14,9 @@
#include "paddle/fluid/distributed/ps/table/sparse_accessor.h" #include "paddle/fluid/distributed/ps/table/sparse_accessor.h"
#include <gflags/gflags.h>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#include "paddle/utils/flags.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -300,7 +299,7 @@ std::string SparseAccessor::ParseToString(const float* v, int param) { ...@@ -300,7 +299,7 @@ std::string SparseAccessor::ParseToString(const float* v, int param) {
int SparseAccessor::ParseFromString(const std::string& str, float* value) { int SparseAccessor::ParseFromString(const std::string& str, float* value) {
_embedx_sgd_rule->InitValue(value + sparse_feature_value.EmbedxWIndex(), _embedx_sgd_rule->InitValue(value + sparse_feature_value.EmbedxWIndex(),
value + sparse_feature_value.EmbedxG2SumIndex()); value + sparse_feature_value.EmbedxG2SumIndex());
auto ret = paddle::string::str_to_float(str.data(), value); auto ret = ::paddle::string::str_to_float(str.data(), value);
CHECK(ret >= 6) << "expect more than 6 real:" << ret; CHECK(ret >= 6) << "expect more than 6 real:" << ret;
return ret; return ret;
} }
......
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
#include "paddle/fluid/distributed/ps/table/sparse_sgd_rule.h" #include "paddle/fluid/distributed/ps/table/sparse_sgd_rule.h"
#include <gflags/gflags.h>
#include "glog/logging.h" #include "glog/logging.h"
DEFINE_bool(enable_show_scale_gradient, true, "enable show scale gradient"); #include "paddle/utils/flags.h"
PD_DEFINE_bool(enable_show_scale_gradient, true, "enable show scale gradient");
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
#pragma once #pragma once
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/ps/table/depends/rocksdb_warpper.h" #include "paddle/fluid/distributed/ps/table/depends/rocksdb_warpper.h"
#include "paddle/fluid/distributed/ps/table/memory_sparse_table.h" #include "paddle/fluid/distributed/ps/table/memory_sparse_table.h"
#include "paddle/utils/flags.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
......
...@@ -118,7 +118,7 @@ class Table { ...@@ -118,7 +118,7 @@ class Table {
virtual int32_t SaveCache( virtual int32_t SaveCache(
const std::string &path UNUSED, const std::string &path UNUSED,
const std::string &param UNUSED, const std::string &param UNUSED,
paddle::framework::Channel<std::pair<uint64_t, std::string>> ::paddle::framework::Channel<std::pair<uint64_t, std::string>>
&shuffled_channel UNUSED) { &shuffled_channel UNUSED) {
return 0; return 0;
} }
...@@ -130,7 +130,7 @@ class Table { ...@@ -130,7 +130,7 @@ class Table {
std::function<std::future<int32_t>( std::function<std::future<int32_t>(
int msg_type, int to_pserver_id, std::string &msg)> // NOLINT int msg_type, int to_pserver_id, std::string &msg)> // NOLINT
send_msg_func UNUSED, send_msg_func UNUSED,
paddle::framework::Channel<std::pair<uint64_t, std::string>> ::paddle::framework::Channel<std::pair<uint64_t, std::string>>
&shuffled_channel UNUSED, &shuffled_channel UNUSED,
const std::vector<Table *> &table_ptrs UNUSED) { const std::vector<Table *> &table_ptrs UNUSED) {
return 0; return 0;
...@@ -161,7 +161,7 @@ class Table { ...@@ -161,7 +161,7 @@ class Table {
virtual int32_t InitializeAccessor(); virtual int32_t InitializeAccessor();
virtual int32_t InitializeShard() = 0; virtual int32_t InitializeShard() = 0;
virtual std::string TableDir(const std::string &model_dir) { virtual std::string TableDir(const std::string &model_dir) {
return paddle::string::format_string( return ::paddle::string::format_string(
"%s/%03d/", model_dir.c_str(), _config.table_id()); "%s/%03d/", model_dir.c_str(), _config.table_id());
} }
......
...@@ -30,8 +30,10 @@ const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100; ...@@ -30,8 +30,10 @@ const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100;
std::shared_ptr<FleetWrapper> FleetWrapper::s_instance_ = NULL; std::shared_ptr<FleetWrapper> FleetWrapper::s_instance_ = NULL;
bool FleetWrapper::is_initialized_ = false; bool FleetWrapper::is_initialized_ = false;
std::shared_ptr<paddle::distributed::PSCore> FleetWrapper::pserver_ptr_ = NULL; std::shared_ptr<::paddle::distributed::PSCore> FleetWrapper::pserver_ptr_ =
std::shared_ptr<paddle::distributed::PSClient> FleetWrapper::worker_ptr_ = NULL; NULL;
std::shared_ptr<::paddle::distributed::PSClient> FleetWrapper::worker_ptr_ =
NULL;
int FleetWrapper::RegisterHeterCallback(HeterCallBackFunc handler) { int FleetWrapper::RegisterHeterCallback(HeterCallBackFunc handler) {
VLOG(0) << "RegisterHeterCallback support later"; VLOG(0) << "RegisterHeterCallback support later";
...@@ -76,8 +78,8 @@ void FleetWrapper::InitServer( ...@@ -76,8 +78,8 @@ void FleetWrapper::InitServer(
const std::vector<framework::ProgramDesc>& server_sub_program) { const std::vector<framework::ProgramDesc>& server_sub_program) {
if (!is_initialized_) { if (!is_initialized_) {
VLOG(3) << "Going to init server"; VLOG(3) << "Going to init server";
pserver_ptr_ = std::shared_ptr<paddle::distributed::PSCore>( pserver_ptr_ = std::shared_ptr<::paddle::distributed::PSCore>(
new paddle::distributed::PSCore()); new ::paddle::distributed::PSCore());
pserver_ptr_->InitServer(dist_desc, pserver_ptr_->InitServer(dist_desc,
&host_sign_list, &host_sign_list,
host_sign_list.size(), host_sign_list.size(),
...@@ -92,7 +94,7 @@ void FleetWrapper::InitServer( ...@@ -92,7 +94,7 @@ void FleetWrapper::InitServer(
void FleetWrapper::InitGFlag(const std::string& gflags) { void FleetWrapper::InitGFlag(const std::string& gflags) {
VLOG(3) << "Init With Gflags:" << gflags; VLOG(3) << "Init With Gflags:" << gflags;
std::vector<std::string> flags = paddle::string::split_string(gflags); std::vector<std::string> flags = ::paddle::string::split_string(gflags);
if (flags.empty()) { if (flags.empty()) {
flags.push_back("-max_body_size=314217728"); flags.push_back("-max_body_size=314217728");
flags.push_back("-bthread_concurrency=40"); flags.push_back("-bthread_concurrency=40");
...@@ -107,7 +109,7 @@ void FleetWrapper::InitGFlag(const std::string& gflags) { ...@@ -107,7 +109,7 @@ void FleetWrapper::InitGFlag(const std::string& gflags) {
} }
int params_cnt = flags.size(); int params_cnt = flags.size();
char** params_ptr = &(flags_ptr[0]); char** params_ptr = &(flags_ptr[0]);
::GFLAGS_NAMESPACE::ParseCommandLineFlags(&params_cnt, &params_ptr, true); ::paddle::flags::ParseCommandLineFlags(&params_cnt, &params_ptr);
} }
void FleetWrapper::InitWorker(const std::string& dist_desc, void FleetWrapper::InitWorker(const std::string& dist_desc,
...@@ -116,17 +118,17 @@ void FleetWrapper::InitWorker(const std::string& dist_desc, ...@@ -116,17 +118,17 @@ void FleetWrapper::InitWorker(const std::string& dist_desc,
if (!is_initialized_) { if (!is_initialized_) {
// not used, just for psclient's init // not used, just for psclient's init
// TODO(zhaocaibei123): remove this later // TODO(zhaocaibei123): remove this later
std::map<uint64_t, std::vector<paddle::distributed::Region>> std::map<uint64_t, std::vector<::paddle::distributed::Region>>
dense_pull_regions; dense_pull_regions;
if (worker_ptr_.get() == nullptr) { if (worker_ptr_.get() == nullptr) {
paddle::distributed::PSParameter ps_param; ::paddle::distributed::PSParameter ps_param;
google::protobuf::TextFormat::ParseFromString(dist_desc, &ps_param); google::protobuf::TextFormat::ParseFromString(dist_desc, &ps_param);
InitGFlag(ps_param.init_gflags()); InitGFlag(ps_param.init_gflags());
int servers = host_sign_list.size(); int servers = host_sign_list.size();
ps_env_.SetPsServers(&host_sign_list, servers); ps_env_.SetPsServers(&host_sign_list, servers);
worker_ptr_ = std::shared_ptr<paddle::distributed::PSClient>( worker_ptr_ = std::shared_ptr<::paddle::distributed::PSClient>(
paddle::distributed::PSClientFactory::Create(ps_param)); ::paddle::distributed::PSClientFactory::Create(ps_param));
worker_ptr_->Configure(ps_param, dense_pull_regions, ps_env_, index); worker_ptr_->Configure(ps_param, dense_pull_regions, ps_env_, index);
} }
dist_desc_ = dist_desc; dist_desc_ = dist_desc;
...@@ -392,7 +394,7 @@ void FleetWrapper::PullDenseVarsAsync( ...@@ -392,7 +394,7 @@ void FleetWrapper::PullDenseVarsAsync(
Variable* var = scope.FindVar(varname); Variable* var = scope.FindVar(varname);
phi::DenseTensor* tensor = var->GetMutable<phi::DenseTensor>(); phi::DenseTensor* tensor = var->GetMutable<phi::DenseTensor>();
float* w = tensor->data<float>(); float* w = tensor->data<float>();
paddle::distributed::Region reg(w, tensor->numel()); ::paddle::distributed::Region reg(w, tensor->numel());
regions[i] = std::move(reg); regions[i] = std::move(reg);
} }
...@@ -412,7 +414,7 @@ void FleetWrapper::PullDenseVarsSync( ...@@ -412,7 +414,7 @@ void FleetWrapper::PullDenseVarsSync(
phi::DenseTensor* tensor = var->GetMutable<phi::DenseTensor>(); phi::DenseTensor* tensor = var->GetMutable<phi::DenseTensor>();
if (!platform::is_gpu_place(tensor->place())) { if (!platform::is_gpu_place(tensor->place())) {
float* w = tensor->data<float>(); float* w = tensor->data<float>();
paddle::distributed::Region reg(w, tensor->numel()); ::paddle::distributed::Region reg(w, tensor->numel());
regions.emplace_back(std::move(reg)); regions.emplace_back(std::move(reg));
} }
} }
...@@ -425,14 +427,14 @@ void FleetWrapper::PushDenseParamSync( ...@@ -425,14 +427,14 @@ void FleetWrapper::PushDenseParamSync(
const uint64_t table_id, const uint64_t table_id,
const std::vector<std::string>& var_names) { const std::vector<std::string>& var_names) {
auto place = platform::CPUPlace(); auto place = platform::CPUPlace();
std::vector<paddle::distributed::Region> regions; std::vector<::paddle::distributed::Region> regions;
for (auto& t : var_names) { for (auto& t : var_names) {
Variable* var = scope.FindVar(t); Variable* var = scope.FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found"; CHECK(var != nullptr) << "var[" << t << "] not found";
phi::DenseTensor* tensor = var->GetMutable<phi::DenseTensor>(); phi::DenseTensor* tensor = var->GetMutable<phi::DenseTensor>();
if (!platform::is_gpu_place(tensor->place())) { if (!platform::is_gpu_place(tensor->place())) {
float* g = tensor->mutable_data<float>(place); float* g = tensor->mutable_data<float>(place);
paddle::distributed::Region reg(g, tensor->numel()); ::paddle::distributed::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg)); regions.emplace_back(std::move(reg));
} }
} }
...@@ -456,7 +458,7 @@ void FleetWrapper::PushDenseVarsAsync( ...@@ -456,7 +458,7 @@ void FleetWrapper::PushDenseVarsAsync(
float scale_datanorm, float scale_datanorm,
int batch_size) { int batch_size) {
auto place = platform::CPUPlace(); auto place = platform::CPUPlace();
std::vector<paddle::distributed::Region> regions; std::vector<::paddle::distributed::Region> regions;
for (auto& t : var_names) { for (auto& t : var_names) {
Variable* var = scope.FindVar(t); Variable* var = scope.FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found"; CHECK(var != nullptr) << "var[" << t << "] not found";
...@@ -479,7 +481,7 @@ void FleetWrapper::PushDenseVarsAsync( ...@@ -479,7 +481,7 @@ void FleetWrapper::PushDenseVarsAsync(
} }
} }
paddle::distributed::Region reg(g, tensor->numel()); ::paddle::distributed::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg)); regions.emplace_back(std::move(reg));
VLOG(3) << "FleetWrapper::PushDenseVarsAsync Var " << t << " talbe_id " VLOG(3) << "FleetWrapper::PushDenseVarsAsync Var " << t << " talbe_id "
<< table_id << " Temp_data[0] " << g[0] << " Temp_data[-1] " << table_id << " Temp_data[0] " << g[0] << " Temp_data[-1] "
...@@ -774,7 +776,7 @@ void FleetWrapper::ShrinkDenseTable(int table_id, ...@@ -774,7 +776,7 @@ void FleetWrapper::ShrinkDenseTable(int table_id,
std::vector<std::string> var_list, std::vector<std::string> var_list,
float decay, float decay,
int emb_dim) { int emb_dim) {
std::vector<paddle::distributed::Region> regions; std::vector<::paddle::distributed::Region> regions;
for (std::string& name : var_list) { for (std::string& name : var_list) {
if (name.find("batch_sum") != std::string::npos) { if (name.find("batch_sum") != std::string::npos) {
Variable* var = scope->FindVar(name); Variable* var = scope->FindVar(name);
...@@ -795,14 +797,14 @@ void FleetWrapper::ShrinkDenseTable(int table_id, ...@@ -795,14 +797,14 @@ void FleetWrapper::ShrinkDenseTable(int table_id,
for (int k = 0; k < tensor->numel(); k += emb_dim) { for (int k = 0; k < tensor->numel(); k += emb_dim) {
g[k] = g[k] + g_size[k] * log(decay); g[k] = g[k] + g_size[k] * log(decay);
} }
paddle::distributed::Region reg(g, tensor->numel()); ::paddle::distributed::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg)); regions.emplace_back(std::move(reg));
} else { } else {
Variable* var = scope->FindVar(name); Variable* var = scope->FindVar(name);
CHECK(var != nullptr) << "var[" << name << "] not found"; CHECK(var != nullptr) << "var[" << name << "] not found";
phi::DenseTensor* tensor = var->GetMutable<phi::DenseTensor>(); phi::DenseTensor* tensor = var->GetMutable<phi::DenseTensor>();
float* g = tensor->data<float>(); float* g = tensor->data<float>();
paddle::distributed::Region reg(g, tensor->numel()); ::paddle::distributed::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg)); regions.emplace_back(std::move(reg));
} }
} }
......
...@@ -295,7 +295,7 @@ class FleetWrapper { ...@@ -295,7 +295,7 @@ class FleetWrapper {
// FleetWrapper singleton // FleetWrapper singleton
static std::shared_ptr<FleetWrapper> GetInstance() { static std::shared_ptr<FleetWrapper> GetInstance() {
if (NULL == s_instance_) { if (NULL == s_instance_) {
s_instance_.reset(new paddle::distributed::FleetWrapper()); s_instance_.reset(new ::paddle::distributed::FleetWrapper());
} }
return s_instance_; return s_instance_;
} }
...@@ -322,13 +322,13 @@ class FleetWrapper { ...@@ -322,13 +322,13 @@ class FleetWrapper {
std::string PullFlStrategy(); std::string PullFlStrategy();
//********** //**********
static std::shared_ptr<paddle::distributed::PSCore> pserver_ptr_; static std::shared_ptr<::paddle::distributed::PSCore> pserver_ptr_;
static std::shared_ptr<paddle::distributed::PSClient> worker_ptr_; static std::shared_ptr<::paddle::distributed::PSClient> worker_ptr_;
private: private:
static std::shared_ptr<FleetWrapper> s_instance_; static std::shared_ptr<FleetWrapper> s_instance_;
std::string dist_desc_; std::string dist_desc_;
paddle::distributed::PaddlePSEnvironment ps_env_; ::paddle::distributed::PaddlePSEnvironment ps_env_;
size_t GetAbsoluteSum(size_t start, size_t GetAbsoluteSum(size_t start,
size_t end, size_t end,
size_t level, size_t level,
...@@ -336,7 +336,7 @@ class FleetWrapper { ...@@ -336,7 +336,7 @@ class FleetWrapper {
protected: protected:
static bool is_initialized_; static bool is_initialized_;
std::map<uint64_t, std::vector<paddle::distributed::Region>> regions_; std::map<uint64_t, std::vector<::paddle::distributed::Region>> regions_;
bool scale_sparse_gradient_with_batch_size_; bool scale_sparse_gradient_with_batch_size_;
int32_t sleep_seconds_before_fail_exit_; int32_t sleep_seconds_before_fail_exit_;
int client2client_request_timeout_ms_; int client2client_request_timeout_ms_;
......
...@@ -12,17 +12,7 @@ set_source_files_properties( ...@@ -12,17 +12,7 @@ set_source_files_properties(
set_source_files_properties(rpc_agent.cc PROPERTIES COMPILE_FLAGS set_source_files_properties(rpc_agent.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS}) ${DISTRIBUTE_COMPILE_FLAGS})
set(PADDLE_RPC_DEPS set(PADDLE_RPC_DEPS ${EXTERNAL_BRPC_DEPS} zlib phi pybind)
brpc
ssl
crypto
protobuf
zlib
leveldb
snappy
phi
glog
pybind)
proto_library(paddle_rpc_proto SRCS rpc.proto) proto_library(paddle_rpc_proto SRCS rpc.proto)
cc_library( cc_library(
paddle_rpc paddle_rpc
......
...@@ -228,13 +228,13 @@ cc_test( ...@@ -228,13 +228,13 @@ cc_test(
set(BRPC_DEPS "") set(BRPC_DEPS "")
if(WITH_PSCORE) if(WITH_PSCORE)
set(BRPC_DEPS brpc ssl crypto) set(BRPC_DEPS ${EXTERNAL_BRPC_DEPS})
endif() endif()
if(WITH_PSLIB) if(WITH_PSLIB)
if(WITH_PSLIB_BRPC) if(WITH_PSLIB_BRPC)
set(BRPC_DEPS pslib_brpc) set(BRPC_DEPS pslib_brpc)
elseif(NOT WITH_HETERPS) elseif(NOT WITH_HETERPS)
set(BRPC_DEPS brpc ssl crypto) set(BRPC_DEPS ${EXTERNAL_BRPC_DEPS})
endif() endif()
if(WITH_ARM_BRPC) if(WITH_ARM_BRPC)
set(BRPC_DEPS arm_brpc) set(BRPC_DEPS arm_brpc)
...@@ -833,7 +833,7 @@ if(WITH_DISTRIBUTE) ...@@ -833,7 +833,7 @@ if(WITH_DISTRIBUTE)
heter_service_proto heter_service_proto
fleet fleet
heter_server heter_server
brpc ${${EXTERNAL_BRPC_DEPS}}
phi) phi)
set(DISTRIBUTE_COMPILE_FLAGS "") set(DISTRIBUTE_COMPILE_FLAGS "")
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
......
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/fluid/framework/async_executor.h" #include "paddle/fluid/framework/async_executor.h"
#include "gflags/gflags.h"
#include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h" #include "google/protobuf/message.h"
#include "google/protobuf/text_format.h" #include "google/protobuf/text_format.h"
...@@ -32,6 +31,7 @@ limitations under the License. */ ...@@ -32,6 +31,7 @@ limitations under the License. */
#include "paddle/fluid/inference/io.h" #include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/pybind/pybind.h"
#include "paddle/utils/flags.h"
// phi // phi
#include "paddle/phi/kernels/declarations.h" #include "paddle/phi/kernels/declarations.h"
......
...@@ -16,15 +16,15 @@ ...@@ -16,15 +16,15 @@
#include <random> #include <random>
#include "gflags/gflags.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/utils/flags.h"
DECLARE_bool(use_system_allocator); PD_DECLARE_bool(use_system_allocator);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" #include "paddle/fluid/platform/device/xpu/bkcl_helper.h"
#include "xpu/bkcl.h" #include "xpu/bkcl.h"
DECLARE_bool(sync_bkcl_allreduce); PD_DECLARE_bool(sync_bkcl_allreduce);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -22,10 +22,10 @@ limitations under the License. */ ...@@ -22,10 +22,10 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h" #include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h"
#include "paddle/phi/core/flags.h" #include "paddle/phi/core/flags.h"
DECLARE_bool(convert_all_blocks); PD_DECLARE_bool(convert_all_blocks);
PHI_DECLARE_bool(use_mkldnn); PHI_DECLARE_bool(use_mkldnn);
#ifdef PADDLE_WITH_CINN #ifdef PADDLE_WITH_CINN
DECLARE_bool(use_cinn); PD_DECLARE_bool(use_cinn);
#endif #endif
namespace paddle { namespace paddle {
......
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
#include "paddle/fluid/framework/var_type_inference.h" #include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
DECLARE_bool(convert_all_blocks); PD_DECLARE_bool(convert_all_blocks);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include "paddle/phi/backends/device_memory_aligment.h" #include "paddle/phi/backends/device_memory_aligment.h"
#include "paddle/phi/core/flags.h" #include "paddle/phi/core/flags.h"
DEFINE_bool(skip_fused_all_reduce_check, false, ""); // NOLINT PD_DEFINE_bool(skip_fused_all_reduce_check, false, ""); // NOLINT
PHI_DECLARE_bool(allreduce_record_one_event); PHI_DECLARE_bool(allreduce_record_one_event);
namespace paddle { namespace paddle {
......
...@@ -31,7 +31,7 @@ limitations under the License. */ ...@@ -31,7 +31,7 @@ limitations under the License. */
#include "paddle/fluid/framework/executor_gc_helper.h" #include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/phi/core/flags.h" #include "paddle/phi/core/flags.h"
DECLARE_bool(benchmark); PD_DECLARE_bool(benchmark);
PHI_DECLARE_bool(use_mkldnn); PHI_DECLARE_bool(use_mkldnn);
namespace paddle { namespace paddle {
......
...@@ -17,7 +17,6 @@ limitations under the License. */ ...@@ -17,7 +17,6 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <utility> #include <utility>
#include "gflags/gflags.h"
#include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h" #include "google/protobuf/message.h"
#include "google/protobuf/text_format.h" #include "google/protobuf/text_format.h"
...@@ -34,6 +33,7 @@ limitations under the License. */ ...@@ -34,6 +33,7 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/timer.h" #include "paddle/fluid/platform/timer.h"
#include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/pybind/pybind.h"
#include "paddle/utils/flags.h"
// phi // phi
#include "paddle/phi/kernels/declarations.h" #include "paddle/phi/kernels/declarations.h"
......
...@@ -3,7 +3,7 @@ if(WITH_PSLIB) ...@@ -3,7 +3,7 @@ if(WITH_PSLIB)
set(BRPC_DEPS pslib_brpc) set(BRPC_DEPS pslib_brpc)
else() else()
if(NOT WITH_HETERPS) if(NOT WITH_HETERPS)
set(BRPC_DEPS brpc) set(BRPC_DEPS ${EXTERNAL_BRPC_DEPS})
endif() endif()
endif() endif()
cc_library( cc_library(
......
...@@ -16,10 +16,10 @@ ...@@ -16,10 +16,10 @@
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/cuda_device_guard.h"
#endif #endif
#include "gflags/gflags.h"
#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/platform/device/device_wrapper.h" #include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/phi/core/flags.h" #include "paddle/phi/core/flags.h"
#include "paddle/utils/flags.h"
PHI_DECLARE_double(eager_delete_tensor_gb); PHI_DECLARE_double(eager_delete_tensor_gb);
PHI_DECLARE_double(memory_fraction_of_eager_deletion); PHI_DECLARE_double(memory_fraction_of_eager_deletion);
......
...@@ -20,9 +20,9 @@ ...@@ -20,9 +20,9 @@
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include <utility> #include <utility>
#include "gflags/gflags.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/stream_callback_manager.h" #include "paddle/fluid/platform/stream_callback_manager.h"
#include "paddle/utils/flags.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -14,8 +14,6 @@ limitations under the License. */ ...@@ -14,8 +14,6 @@ limitations under the License. */
#pragma once #pragma once
#include <gflags/gflags.h>
#include <map> #include <map>
#include <memory> #include <memory>
#include <string> #include <string>
...@@ -27,8 +25,9 @@ limitations under the License. */ ...@@ -27,8 +25,9 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/utils/any.h" #include "paddle/utils/any.h"
#include "paddle/utils/flags.h"
DECLARE_bool(convert_all_blocks); PD_DECLARE_bool(convert_all_blocks);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -29,7 +29,7 @@ limitations under the License. */ ...@@ -29,7 +29,7 @@ limitations under the License. */
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#endif #endif
#include "paddle/fluid/platform/flags.h" #include "paddle/fluid/platform/flags.h"
DECLARE_bool(convert_all_blocks); PD_DECLARE_bool(convert_all_blocks);
PADDLE_DEFINE_EXPORTED_string(print_sub_graph_dir, PADDLE_DEFINE_EXPORTED_string(print_sub_graph_dir,
"", "",
"FLAGS_print_sub_graph_dir is used " "FLAGS_print_sub_graph_dir is used "
......
...@@ -14,11 +14,10 @@ limitations under the License. */ ...@@ -14,11 +14,10 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_to_program_pass.h" #include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include <gflags/gflags.h>
#include <algorithm> #include <algorithm>
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/utils/flags.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
#include "paddle/fluid/platform/os_info.h" #include "paddle/fluid/platform/os_info.h"
#include "paddle/fluid/platform/profiler/utils.h" #include "paddle/fluid/platform/profiler/utils.h"
DECLARE_bool(use_stream_safe_cuda_allocator); PD_DECLARE_bool(use_stream_safe_cuda_allocator);
PADDLE_DEFINE_EXPORTED_string(static_executor_perfstat_filepath, PADDLE_DEFINE_EXPORTED_string(static_executor_perfstat_filepath,
"", "",
"FLAGS_static_executor_perfstat_filepath " "FLAGS_static_executor_perfstat_filepath "
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include "paddle/fluid/framework/new_executor/new_executor_defs.h" #include "paddle/fluid/framework/new_executor/new_executor_defs.h"
DECLARE_bool(new_executor_sequential_run); PD_DECLARE_bool(new_executor_sequential_run);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
#include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/xpu/xpu_info.h" #include "paddle/phi/backends/xpu/xpu_info.h"
DECLARE_bool(new_executor_serial_run); PD_DECLARE_bool(new_executor_serial_run);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -20,8 +20,8 @@ ...@@ -20,8 +20,8 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/platform/flags.h" #include "paddle/fluid/platform/flags.h"
#include "paddle/utils/flags.h"
#include "paddle/fluid/framework/details/exception_holder.h" #include "paddle/fluid/framework/details/exception_holder.h"
#include "paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.h" #include "paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.h"
...@@ -38,14 +38,14 @@ ...@@ -38,14 +38,14 @@
#include "paddle/fluid/platform/device_event.h" #include "paddle/fluid/platform/device_event.h"
#include "paddle/phi/backends/device_manager.h" #include "paddle/phi/backends/device_manager.h"
DECLARE_bool(new_executor_serial_run); PD_DECLARE_bool(new_executor_serial_run);
DECLARE_bool(new_executor_static_build); PD_DECLARE_bool(new_executor_static_build);
DECLARE_bool(new_executor_use_inplace); PD_DECLARE_bool(new_executor_use_inplace);
DECLARE_bool(new_executor_use_local_scope); PD_DECLARE_bool(new_executor_use_local_scope);
PHI_DECLARE_bool(check_nan_inf); PHI_DECLARE_bool(check_nan_inf);
DECLARE_bool(benchmark); PD_DECLARE_bool(benchmark);
DECLARE_uint64(executor_log_deps_every_microseconds); PD_DECLARE_uint64(executor_log_deps_every_microseconds);
PHI_DECLARE_bool(new_executor_use_cuda_graph); PHI_DECLARE_bool(new_executor_use_cuda_graph);
PHI_DECLARE_bool(enable_new_ir_in_executor); PHI_DECLARE_bool(enable_new_ir_in_executor);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "paddle/fluid/framework/new_executor/interpreter_base_impl.h" #include "paddle/fluid/framework/new_executor/interpreter_base_impl.h"
DECLARE_bool(new_executor_use_local_scope); PD_DECLARE_bool(new_executor_use_local_scope);
namespace ir { namespace ir {
class Program; class Program;
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <unordered_set> #include <unordered_set>
#include "gflags/gflags.h" #include "paddle/utils/flags.h"
#include "paddle/fluid/framework/details/nan_inf_utils.h" #include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/details/share_tensor_buffer_functor.h" #include "paddle/fluid/framework/details/share_tensor_buffer_functor.h"
......
...@@ -16,7 +16,6 @@ limitations under the License. */ ...@@ -16,7 +16,6 @@ limitations under the License. */
#include <string> #include <string>
#include "gflags/gflags.h"
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/library_type.h" #include "paddle/fluid/framework/library_type.h"
...@@ -25,8 +24,9 @@ limitations under the License. */ ...@@ -25,8 +24,9 @@ limitations under the License. */
#include "paddle/phi/core/device_context.h" #include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/kernel_factory.h"
#include "paddle/utils/flags.h"
DECLARE_bool(use_stride_kernel); PD_DECLARE_bool(use_stride_kernel);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -17,7 +17,6 @@ limitations under the License. */ ...@@ -17,7 +17,6 @@ limitations under the License. */
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include "gflags/gflags.h"
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_transform.h" #include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/data_type_transform.h"
...@@ -43,6 +42,7 @@ limitations under the License. */ ...@@ -43,6 +42,7 @@ limitations under the License. */
#include "paddle/phi/core/kernel_context.h" #include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/ops/compat/signatures.h" #include "paddle/phi/ops/compat/signatures.h"
#include "paddle/utils/flags.h"
namespace phi { namespace phi {
class DenseTensor; class DenseTensor;
...@@ -62,9 +62,9 @@ class DenseTensor; ...@@ -62,9 +62,9 @@ class DenseTensor;
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#endif #endif
DECLARE_bool(benchmark); PD_DECLARE_bool(benchmark);
PHI_DECLARE_bool(check_nan_inf); PHI_DECLARE_bool(check_nan_inf);
DECLARE_bool(enable_unused_var_check); PD_DECLARE_bool(enable_unused_var_check);
PHI_DECLARE_bool(run_kp_kernel); PHI_DECLARE_bool(run_kp_kernel);
PHI_DECLARE_bool(enable_host_event_recorder_hook); PHI_DECLARE_bool(enable_host_event_recorder_hook);
......
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/init.h" #include "paddle/fluid/platform/init.h"
DECLARE_bool(enable_unused_var_check); PD_DECLARE_bool(enable_unused_var_check);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -24,7 +24,6 @@ limitations under the License. */ ...@@ -24,7 +24,6 @@ limitations under the License. */
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "gflags/gflags.h"
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/cinn/frontend/op_mapper_registry.h" #include "paddle/cinn/frontend/op_mapper_registry.h"
#include "paddle/cinn/frontend/op_mappers/use_op_mappers.h" #include "paddle/cinn/frontend/op_mappers/use_op_mappers.h"
...@@ -38,9 +37,10 @@ limitations under the License. */ ...@@ -38,9 +37,10 @@ limitations under the License. */
#include "paddle/fluid/operators/cinn/cinn_launch_op.h" #include "paddle/fluid/operators/cinn/cinn_launch_op.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/errors.h"
#include "paddle/utils/flags.h"
DECLARE_string(allow_cinn_ops); PD_DECLARE_string(allow_cinn_ops);
DECLARE_string(deny_cinn_ops); PD_DECLARE_string(deny_cinn_ops);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "gflags/gflags.h"
#include "paddle/cinn/auto_schedule/auto_tuner.h" #include "paddle/cinn/auto_schedule/auto_tuner.h"
#include "paddle/cinn/auto_schedule/tuning.h" #include "paddle/cinn/auto_schedule/tuning.h"
#include "paddle/cinn/common/target.h" #include "paddle/cinn/common/target.h"
...@@ -52,6 +51,7 @@ ...@@ -52,6 +51,7 @@
#include "paddle/ir/core/program.h" #include "paddle/ir/core/program.h"
#include "paddle/ir/core/value.h" #include "paddle/ir/core/value.h"
#include "paddle/phi/core/flags.h" #include "paddle/phi/core/flags.h"
#include "paddle/utils/flags.h"
PHI_DECLARE_bool(enable_pe_launch_cinn); PHI_DECLARE_bool(enable_pe_launch_cinn);
PHI_DECLARE_bool(enable_cinn_auto_tune); PHI_DECLARE_bool(enable_cinn_auto_tune);
......
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "gflags/gflags.h"
#include "glog/logging.h" #include "glog/logging.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/cinn/common/target.h" #include "paddle/cinn/common/target.h"
...@@ -38,6 +37,7 @@ ...@@ -38,6 +37,7 @@
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/flags.h" #include "paddle/phi/core/flags.h"
#include "paddle/utils/flags.h"
PHI_DECLARE_string(allow_cinn_ops); PHI_DECLARE_string(allow_cinn_ops);
PHI_DECLARE_string(deny_cinn_ops); PHI_DECLARE_string(deny_cinn_ops);
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/platform/flags.h" #include "paddle/fluid/platform/flags.h"
DECLARE_bool(benchmark); PD_DECLARE_bool(benchmark);
PADDLE_DEFINE_EXPORTED_bool( PADDLE_DEFINE_EXPORTED_bool(
eager_delete_scope, eager_delete_scope,
......
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include "gflags/gflags.h" #include "paddle/utils/flags.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -36,7 +36,7 @@ ...@@ -36,7 +36,7 @@
#include "paddle/phi/core/flags.h" #include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(check_nan_inf); PHI_DECLARE_bool(check_nan_inf);
DECLARE_bool(benchmark); PD_DECLARE_bool(benchmark);
PHI_DECLARE_bool(run_kp_kernel); PHI_DECLARE_bool(run_kp_kernel);
namespace paddle { namespace paddle {
......
...@@ -38,7 +38,7 @@ ...@@ -38,7 +38,7 @@
PHI_DECLARE_bool(use_mkldnn); PHI_DECLARE_bool(use_mkldnn);
PHI_DECLARE_string(tracer_mkldnn_ops_on); PHI_DECLARE_string(tracer_mkldnn_ops_on);
PHI_DECLARE_string(tracer_mkldnn_ops_off); PHI_DECLARE_string(tracer_mkldnn_ops_off);
DECLARE_bool(use_stride_kernel); PD_DECLARE_bool(use_stride_kernel);
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
......
...@@ -32,7 +32,7 @@ ...@@ -32,7 +32,7 @@
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
#include "paddle/phi/core/compat/arg_map_context.h" #include "paddle/phi/core/compat/arg_map_context.h"
DECLARE_bool(use_stride_kernel); PD_DECLARE_bool(use_stride_kernel);
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
......
...@@ -37,6 +37,10 @@ get_property(ir_targets GLOBAL PROPERTY IR_TARGETS) ...@@ -37,6 +37,10 @@ get_property(ir_targets GLOBAL PROPERTY IR_TARGETS)
get_property(not_infer_modules GLOBAL PROPERTY NOT_INFER_MODULES) get_property(not_infer_modules GLOBAL PROPERTY NOT_INFER_MODULES)
set(utils_modules pretty_log string_helper benchmark utf8proc) set(utils_modules pretty_log string_helper benchmark utf8proc)
if(NOT WITH_GFLAGS)
set(utils_modules ${utils_modules} paddle_flags)
endif()
add_subdirectory(api) add_subdirectory(api)
# Create static inference library if needed # Create static inference library if needed
......
...@@ -37,8 +37,8 @@ limitations under the License. */ ...@@ -37,8 +37,8 @@ limitations under the License. */
#include <vector> #include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/inference/analysis/analysis_pass.h" #include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/utils/flags.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
DEFINE_bool( // NOLINT PD_DEFINE_bool( // NOLINT
custom_model_save_cpu, custom_model_save_cpu,
false, false,
"Keep old mode for developers, the model is saved on cpu not device."); "Keep old mode for developers, the model is saved on cpu not device.");
......
...@@ -18,9 +18,9 @@ limitations under the License. */ ...@@ -18,9 +18,9 @@ limitations under the License. */
#include <fstream> #include <fstream>
#include <string> #include <string>
#include "gflags/gflags.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/utils/flags.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -30,7 +30,7 @@ extern void ReadBinaryFile(const std::string& filename, std::string* contents); ...@@ -30,7 +30,7 @@ extern void ReadBinaryFile(const std::string& filename, std::string* contents);
namespace analysis { namespace analysis {
DEFINE_string(inference_model_dir, "", "inference test model dir"); PD_DEFINE_string(inference_model_dir, "", "inference test model dir");
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
......
...@@ -1708,10 +1708,10 @@ CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>( ...@@ -1708,10 +1708,10 @@ CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
auto SetGflags = [](const AnalysisConfig &config) { auto SetGflags = [](const AnalysisConfig &config) {
auto SetGflag = [](const char *name, const char *value) { auto SetGflag = [](const char *name, const char *value) {
std::string ret = ::GFLAGS_NAMESPACE::SetCommandLineOption(name, value); bool success = paddle::flags::SetFlagValue(name, value);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ret.empty(), success,
false, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Fail to set gflag: %s, please make sure the gflag exists.", "Fail to set gflag: %s, please make sure the gflag exists.",
name)); name));
...@@ -3089,8 +3089,8 @@ std::tuple<int, int, int> GetTrtRuntimeVersion() { ...@@ -3089,8 +3089,8 @@ std::tuple<int, int, int> GetTrtRuntimeVersion() {
#endif #endif
} }
std::string UpdateDllFlag(const char *name, const char *value) { void UpdateDllFlag(const char *name, const char *value) {
return paddle::UpdateDllFlag(name, value); paddle::UpdateDllFlag(name, value);
} }
void ConvertToMixedPrecision(const std::string &model_file, void ConvertToMixedPrecision(const std::string &model_file,
......
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
#include <sstream> #include <sstream>
#include "gflags/gflags.h"
#include "paddle/fluid/framework/commit.h" #include "paddle/fluid/framework/commit.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_pass_builder.h" #include "paddle/fluid/inference/api/paddle_pass_builder.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/utils/flags.h"
namespace paddle { namespace paddle {
...@@ -134,20 +134,18 @@ std::string get_version() { ...@@ -134,20 +134,18 @@ std::string get_version() {
return ss.str(); return ss.str();
} }
std::string UpdateDllFlag(const char *name, const char *value) { void UpdateDllFlag(const char *name, const char *value) {
std::string ret; std::string ret;
LOG(WARNING) LOG(WARNING)
<< "The function \"UpdateDllFlag\" is only used to update the flag " << "The function \"UpdateDllFlag\" is only used to update the flag "
"on the Windows shared library"; "on the Windows shared library";
ret = ::GFLAGS_NAMESPACE::SetCommandLineOption(name, value); bool success = paddle::flags::SetFlagValue(name, value);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ret.empty(), success,
false, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Fail to update flag: %s, please make sure the flag exists.", name)); "Fail to update flag: %s, please make sure the flag exists.", name));
LOG(INFO) << ret;
return ret;
} }
#ifdef PADDLE_WITH_CRYPTO #ifdef PADDLE_WITH_CRYPTO
......
...@@ -26,7 +26,7 @@ limitations under the License. */ ...@@ -26,7 +26,7 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
DEFINE_bool(profile, false, "Turn on profiler for fluid"); // NOLINT PD_DEFINE_bool(profile, false, "Turn on profiler for fluid"); // NOLINT
namespace paddle { namespace paddle {
namespace { namespace {
...@@ -373,7 +373,6 @@ CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>( ...@@ -373,7 +373,6 @@ CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(
std::vector<std::string> flags; std::vector<std::string> flags;
if (config.fraction_of_gpu_memory >= 0.0f || if (config.fraction_of_gpu_memory >= 0.0f ||
config.fraction_of_gpu_memory <= 0.95f) { config.fraction_of_gpu_memory <= 0.95f) {
flags.emplace_back("dummpy");
std::string flag = "--fraction_of_gpu_memory_to_use=" + std::string flag = "--fraction_of_gpu_memory_to_use=" +
num2str<float>(config.fraction_of_gpu_memory); num2str<float>(config.fraction_of_gpu_memory);
flags.push_back(flag); flags.push_back(flag);
......
...@@ -87,7 +87,7 @@ void Main() { ...@@ -87,7 +87,7 @@ void Main() {
} // namespace paddle } // namespace paddle
int main(int argc, char** argv) { int main(int argc, char** argv) {
::GFLAGS_NAMESPACE::ParseCommandLineFlags(&argc, &argv, true); gflags::ParseCommandLineFlags(&argc, &argv, true);
paddle::demo::Main(); paddle::demo::Main();
return 0; return 0;
} }
...@@ -133,7 +133,7 @@ void MainThreads(int num_threads, bool use_gpu) { ...@@ -133,7 +133,7 @@ void MainThreads(int num_threads, bool use_gpu) {
} // namespace paddle } // namespace paddle
int main(int argc, char** argv) { int main(int argc, char** argv) {
::GFLAGS_NAMESPACE::ParseCommandLineFlags(&argc, &argv, true); gflags::ParseCommandLineFlags(&argc, &argv, true);
paddle::demo::Main(false /* use_gpu*/); paddle::demo::Main(false /* use_gpu*/);
paddle::demo::MainThreads(1, false /* use_gpu*/); paddle::demo::MainThreads(1, false /* use_gpu*/);
paddle::demo::MainThreads(4, false /* use_gpu*/); paddle::demo::MainThreads(4, false /* use_gpu*/);
......
...@@ -73,7 +73,7 @@ void Main() { ...@@ -73,7 +73,7 @@ void Main() {
} // namespace paddle } // namespace paddle
int main(int argc, char** argv) { int main(int argc, char** argv) {
::GFLAGS_NAMESPACE::ParseCommandLineFlags(&argc, &argv, true); gflags::ParseCommandLineFlags(&argc, &argv, true);
paddle::demo::Main(); paddle::demo::Main();
return 0; return 0;
} }
...@@ -28,9 +28,6 @@ DEFINE_string(data, ...@@ -28,9 +28,6 @@ DEFINE_string(data,
"path of data; each line is a record, format is " "path of data; each line is a record, format is "
"'<space split floats as data>\t<space split ints as shape'"); "'<space split floats as data>\t<space split ints as shape'");
DEFINE_bool(use_gpu, false, "Whether use gpu."); DEFINE_bool(use_gpu, false, "Whether use gpu.");
#ifdef PADDLE_WITH_SHARED_LIB
DECLARE_bool(profile);
#endif
namespace paddle { namespace paddle {
namespace demo { namespace demo {
...@@ -81,7 +78,7 @@ void Main(bool use_gpu) { ...@@ -81,7 +78,7 @@ void Main(bool use_gpu) {
} // namespace paddle } // namespace paddle
int main(int argc, char** argv) { int main(int argc, char** argv) {
::GFLAGS_NAMESPACE::ParseCommandLineFlags(&argc, &argv, true); gflags::ParseCommandLineFlags(&argc, &argv, true);
if (FLAGS_use_gpu) { if (FLAGS_use_gpu) {
paddle::demo::Main(true /*use_gpu*/); paddle::demo::Main(true /*use_gpu*/);
} else { } else {
......
...@@ -85,7 +85,7 @@ void RunAnalysis() { ...@@ -85,7 +85,7 @@ void RunAnalysis() {
} // namespace paddle } // namespace paddle
int main(int argc, char** argv) { int main(int argc, char** argv) {
::GFLAGS_NAMESPACE::ParseCommandLineFlags(&argc, &argv, true); gflags::ParseCommandLineFlags(&argc, &argv, true);
paddle::demo::RunAnalysis(); paddle::demo::RunAnalysis();
std::cout << "=========================Runs successfully====================" std::cout << "=========================Runs successfully===================="
<< std::endl; << std::endl;
......
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/backends/cpu/cpu_info.h"
#include "test/cpp/inference/api/tester_helper.h" #include "test/cpp/inference/api/tester_helper.h"
DEFINE_string(dirname, "", "dirname to tests."); PD_DEFINE_string(dirname, "", "dirname to tests.");
namespace paddle { namespace paddle {
......
...@@ -459,7 +459,7 @@ PD_INFER_DECL int PaddleDtypeSize(PaddleDType dtype); ...@@ -459,7 +459,7 @@ PD_INFER_DECL int PaddleDtypeSize(PaddleDType dtype);
PD_INFER_DECL std::string get_version(); PD_INFER_DECL std::string get_version();
PD_INFER_DECL std::string UpdateDllFlag(const char* name, const char* value); PD_INFER_DECL void UpdateDllFlag(const char* name, const char* value);
PD_INFER_DECL std::shared_ptr<framework::Cipher> MakeCipher( PD_INFER_DECL std::shared_ptr<framework::Cipher> MakeCipher(
const std::string& config_file); const std::string& config_file);
......
...@@ -235,7 +235,7 @@ PD_INFER_DECL int GetNumBytesOfDataType(DataType dtype); ...@@ -235,7 +235,7 @@ PD_INFER_DECL int GetNumBytesOfDataType(DataType dtype);
PD_INFER_DECL std::string GetVersion(); PD_INFER_DECL std::string GetVersion();
PD_INFER_DECL std::tuple<int, int, int> GetTrtCompileVersion(); PD_INFER_DECL std::tuple<int, int, int> GetTrtCompileVersion();
PD_INFER_DECL std::tuple<int, int, int> GetTrtRuntimeVersion(); PD_INFER_DECL std::tuple<int, int, int> GetTrtRuntimeVersion();
PD_INFER_DECL std::string UpdateDllFlag(const char* name, const char* value); PD_INFER_DECL void UpdateDllFlag(const char* name, const char* value);
PD_INFER_DECL void ConvertToMixedPrecision( PD_INFER_DECL void ConvertToMixedPrecision(
const std::string& model_file, const std::string& model_file,
......
...@@ -29,12 +29,12 @@ limitations under the License. */ ...@@ -29,12 +29,12 @@ limitations under the License. */
// phi // phi
#include "paddle/phi/kernels/declarations.h" #include "paddle/phi/kernels/declarations.h"
DEFINE_string(devices, // NOLINT PD_DEFINE_string(devices, // NOLINT
"", "",
"The devices to be used which is joined by comma."); "The devices to be used which is joined by comma.");
DEFINE_int32(math_num_threads, PD_DEFINE_int32(math_num_threads,
1, 1,
"Number of threads used to run math functions."); "Number of threads used to run math functions.");
namespace paddle { namespace paddle {
namespace inference { namespace inference {
......
...@@ -45,6 +45,7 @@ ...@@ -45,6 +45,7 @@
*paddle::RegisterSymbolsFor*; *paddle::RegisterSymbolsFor*;
*paddle::from_blob*; *paddle::from_blob*;
*paddle::InitPhi*; *paddle::InitPhi*;
*paddle::flags*;
/* ut needs the following symbol, we need to modify all the ut to hidden such symbols */ /* ut needs the following symbol, we need to modify all the ut to hidden such symbols */
......
...@@ -31,7 +31,7 @@ namespace nvinfer1 { ...@@ -31,7 +31,7 @@ namespace nvinfer1 {
class ITensor; class ITensor;
} // namespace nvinfer1 } // namespace nvinfer1
DECLARE_bool(profile); PD_DECLARE_bool(profile);
namespace paddle { namespace paddle {
namespace inference { namespace inference {
......
...@@ -22,7 +22,7 @@ PHI_DECLARE_double(fraction_of_gpu_memory_to_use); ...@@ -22,7 +22,7 @@ PHI_DECLARE_double(fraction_of_gpu_memory_to_use);
PHI_DECLARE_double(fraction_of_cuda_pinned_memory_to_use); PHI_DECLARE_double(fraction_of_cuda_pinned_memory_to_use);
PHI_DECLARE_uint64(initial_gpu_memory_in_mb); PHI_DECLARE_uint64(initial_gpu_memory_in_mb);
PHI_DECLARE_uint64(reallocate_gpu_memory_in_mb); PHI_DECLARE_uint64(reallocate_gpu_memory_in_mb);
DECLARE_int64(gpu_allocator_retry_time); PD_DECLARE_int64(gpu_allocator_retry_time);
#endif #endif
PHI_DECLARE_string(allocator_strategy); PHI_DECLARE_string(allocator_strategy);
......
...@@ -22,7 +22,7 @@ PHI_DECLARE_double(fraction_of_gpu_memory_to_use); ...@@ -22,7 +22,7 @@ PHI_DECLARE_double(fraction_of_gpu_memory_to_use);
PHI_DECLARE_double(fraction_of_cuda_pinned_memory_to_use); PHI_DECLARE_double(fraction_of_cuda_pinned_memory_to_use);
PHI_DECLARE_uint64(initial_gpu_memory_in_mb); PHI_DECLARE_uint64(initial_gpu_memory_in_mb);
PHI_DECLARE_uint64(reallocate_gpu_memory_in_mb); PHI_DECLARE_uint64(reallocate_gpu_memory_in_mb);
DECLARE_int64(gpu_allocator_retry_time); PD_DECLARE_int64(gpu_allocator_retry_time);
#endif #endif
PHI_DECLARE_string(allocator_strategy); PHI_DECLARE_string(allocator_strategy);
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PHI_DECLARE_double(fraction_of_gpu_memory_to_use); PHI_DECLARE_double(fraction_of_gpu_memory_to_use);
PHI_DECLARE_double(fraction_of_cuda_pinned_memory_to_use); PHI_DECLARE_double(fraction_of_cuda_pinned_memory_to_use);
DECLARE_int64(gpu_allocator_retry_time); PD_DECLARE_int64(gpu_allocator_retry_time);
#endif #endif
PHI_DECLARE_string(allocator_strategy); PHI_DECLARE_string(allocator_strategy);
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册