From 94736d6072a3fd5551b696d559a77603cafbad08 Mon Sep 17 00:00:00 2001 From: seemingwang Date: Fri, 2 Apr 2021 13:58:45 +0800 Subject: [PATCH] graph engine (#31226) * graph engine demo * upload unsaved changes * fix dependency error * fix shard_num problem * py client * remove lock and graph-type * add load direct graph * add load direct graph * add load direct graph * batch random_sample * batch_sample_k * fix num_nodes size * batch brpc * batch brpc * add test * add test * add load_nodes; change add_node function * change sample return type to pair * resolve conflict * resolved conflict * resolved conflict * separate server and client * merge pair type * fix * resolved conflict * fixed segment fault; high-level VLOG for load edges and load nodes * random_sample return 0 * rm useless loop * test:load edge * fix ret -1 * test: rm sample * rm sample * random_sample return future * random_sample return int * test fake node * fixed here * memory leak * remove test code * fix return problem * add common_graph_table * random sample node &test & change data-structure from linkedList to vector * add common_graph_table * sample with srand * add node_types * optimize nodes sample * recover test * random sample * destruct weighted sampler * GraphEdgeBlob * WeightedGraphEdgeBlob to GraphEdgeBlob * WeightedGraphEdgeBlob to GraphEdgeBlob * pybind sample nodes api * pull nodes with step * fixed pull_graph_list bug; add test for pull_graph_list by step * add graph table;name * add graph table;name * add pybind * add pybind * add FeatureNode * add FeatureNode * add FeatureNode Serialize * add FeatureNode Serialize * get_feat_node * avoid local rpc * fix get_node_feat * fix get_node_feat * remove log * get_node_feat return py:bytes * merge develop with graph_engine * fix threadpool.h head * fix * fix typo * resolve conflict * fix conflict * recover lost content * fix pybind of FeatureNode * recover cmake * recover tools * resolve conflict * resolve linking problem * code style * change test_server port * fix code problems * remove shard_num config * remove redundent threads * optimize start server * remove logs * fix code problems by reviewers' suggestions Co-authored-by: Huang Zhengjie <270018958@qq.com> Co-authored-by: Weiyue Su Co-authored-by: suweiyue Co-authored-by: luobin06 Co-authored-by: liweibin02 --- .github/ISSUE_TEMPLATE/---document-issue-.md | 2 +- .../fluid/distributed/service/CMakeLists.txt | 10 +- .../distributed/service/brpc_ps_client.cc | 2 +- .../distributed/service/brpc_ps_client.h | 27 +- .../distributed/service/graph_brpc_client.cc | 331 +++++++++++ .../distributed/service/graph_brpc_client.h | 105 ++++ .../distributed/service/graph_brpc_server.cc | 347 +++++++++++ .../distributed/service/graph_brpc_server.h | 113 ++++ .../distributed/service/graph_py_service.cc | 325 ++++++++++ .../distributed/service/graph_py_service.h | 178 ++++++ paddle/fluid/distributed/service/ps_client.cc | 5 +- paddle/fluid/distributed/service/ps_client.h | 8 +- .../fluid/distributed/service/sendrecv.proto | 6 +- paddle/fluid/distributed/service/server.cc | 3 + paddle/fluid/distributed/table/CMakeLists.txt | 10 +- .../distributed/table/common_graph_table.cc | 506 ++++++++++++++++ .../distributed/table/common_graph_table.h | 144 +++++ paddle/fluid/distributed/table/graph_edge.cc | 29 + paddle/fluid/distributed/table/graph_edge.h | 46 ++ paddle/fluid/distributed/table/graph_node.cc | 117 ++++ paddle/fluid/distributed/table/graph_node.h | 127 ++++ .../table/graph_weighted_sampler.cc | 150 +++++ .../table/graph_weighted_sampler.h | 58 ++ paddle/fluid/distributed/table/table.cc | 4 +- paddle/fluid/distributed/table/table.h | 27 + paddle/fluid/distributed/test/CMakeLists.txt | 3 + .../fluid/distributed/test/graph_node_test.cc | 556 ++++++++++++++++++ paddle/fluid/inference/api/demo_ci/clean.sh | 14 + paddle/fluid/pybind/CMakeLists.txt | 4 + paddle/fluid/pybind/fleet_py.cc | 60 ++ paddle/fluid/pybind/fleet_py.h | 6 +- paddle/fluid/pybind/pybind.cc | 5 + paddle/scripts/build_docker_images.sh | 15 + .../docker/root/.scripts/git-completion.sh | 15 + paddle/scripts/fast_install.sh | 14 + python/paddle/fluid/dataloader/fetcher.py | 7 +- .../incubate/fleet/tests/cluster_train.sh | 14 + .../test_squared_mat_sub_fuse_pass.py | 6 +- .../unittests/ir/inference/test_trt_matmul.py | 23 +- .../fluid/tests/unittests/parallel_test.sh | 15 + .../fluid/tests/unittests/test_bce_loss.py | 12 +- .../unittests/test_bce_with_logits_loss.py | 6 +- .../tests/unittests/test_c_comm_init_op.sh | 15 + .../tests/unittests/test_dist_fleet_ps10.py | 1 - .../test_flatten_contiguous_range_op.py | 3 +- .../fluid/tests/unittests/test_l1_loss.py | 12 +- .../tests/unittests/test_listen_and_serv.sh | 15 + .../fluid/tests/unittests/test_mse_loss.py | 18 +- ...ess_dataloader_iterable_dataset_dynamic.py | 1 + .../tests/unittests/test_pixel_shuffle.py | 12 +- .../fluid/tests/unittests/test_prod_op.py | 6 +- .../fluid/tests/unittests/test_selu_op.py | 9 +- .../unittests/test_sigmoid_focal_loss.py | 6 +- .../tests/unittests/test_transpose_op.py | 8 +- scripts/paddle | 169 ++++++ tools/check_api_approvals.sh | 14 + tools/check_sequence_op.sh | 14 + tools/cudaError/start.sh | 15 + tools/diff_api.py | 15 + tools/diff_unittest.py | 15 + tools/dockerfile/icode.sh | 14 + tools/document_preview.sh | 15 + tools/get_cpu_info.sh | 14 + 63 files changed, 3765 insertions(+), 81 deletions(-) create mode 100644 paddle/fluid/distributed/service/graph_brpc_client.cc create mode 100644 paddle/fluid/distributed/service/graph_brpc_client.h create mode 100644 paddle/fluid/distributed/service/graph_brpc_server.cc create mode 100644 paddle/fluid/distributed/service/graph_brpc_server.h create mode 100644 paddle/fluid/distributed/service/graph_py_service.cc create mode 100644 paddle/fluid/distributed/service/graph_py_service.h create mode 100644 paddle/fluid/distributed/table/common_graph_table.cc create mode 100644 paddle/fluid/distributed/table/common_graph_table.h create mode 100644 paddle/fluid/distributed/table/graph_edge.cc create mode 100644 paddle/fluid/distributed/table/graph_edge.h create mode 100644 paddle/fluid/distributed/table/graph_node.cc create mode 100644 paddle/fluid/distributed/table/graph_node.h create mode 100644 paddle/fluid/distributed/table/graph_weighted_sampler.cc create mode 100644 paddle/fluid/distributed/table/graph_weighted_sampler.h create mode 100644 paddle/fluid/distributed/test/graph_node_test.cc create mode 100644 scripts/paddle diff --git a/.github/ISSUE_TEMPLATE/---document-issue-.md b/.github/ISSUE_TEMPLATE/---document-issue-.md index 7c464ac584b..ffc2fcd7817 100644 --- a/.github/ISSUE_TEMPLATE/---document-issue-.md +++ b/.github/ISSUE_TEMPLATE/---document-issue-.md @@ -56,4 +56,4 @@ For example: no sample code; The sample code is not helpful; The sample code not For example:Chinese API in this doc is inconsistent with English API, including params, description, sample code, formula, etc. #### Other -For example: The doc link is broken; The doc page is missing; Dead link in docs. \ No newline at end of file +For example: The doc link is broken; The doc page is missing; Dead link in docs. diff --git a/paddle/fluid/distributed/service/CMakeLists.txt b/paddle/fluid/distributed/service/CMakeLists.txt index bb3f6f1174d..843dea9eea6 100644 --- a/paddle/fluid/distributed/service/CMakeLists.txt +++ b/paddle/fluid/distributed/service/CMakeLists.txt @@ -24,11 +24,12 @@ set_source_files_properties(heter_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUT set_source_files_properties(client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(ps_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) - +set_source_files_properties(graph_brpc_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(graph_brpc_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_library(brpc_utils SRCS brpc_utils.cc DEPS tensor device_context ${COMMON_DEPS} ${RPC_DEPS}) -cc_library(downpour_server SRCS brpc_ps_server.cc DEPS boost eigen3 table brpc_utils ${RPC_DEPS}) -cc_library(downpour_client SRCS brpc_ps_client.cc DEPS boost eigen3 table brpc_utils ${RPC_DEPS}) +cc_library(downpour_server SRCS graph_brpc_server.cc brpc_ps_server.cc DEPS boost eigen3 table brpc_utils simple_threadpool ${RPC_DEPS}) +cc_library(downpour_client SRCS graph_brpc_client.cc brpc_ps_client.cc DEPS boost eigen3 table brpc_utils simple_threadpool ${RPC_DEPS}) cc_library(client SRCS ps_client.cc DEPS downpour_client boost ${RPC_DEPS}) cc_library(server SRCS server.cc DEPS downpour_server boost ${RPC_DEPS}) @@ -38,3 +39,6 @@ cc_library(ps_service SRCS service.cc DEPS communicator client server boost ${RP cc_library(heter_server SRCS heter_server.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS}) cc_library(heter_client SRCS heter_client.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS}) + +set_source_files_properties(graph_py_service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_library(graph_py_service SRCS graph_py_service.cc DEPS ps_service) diff --git a/paddle/fluid/distributed/service/brpc_ps_client.cc b/paddle/fluid/distributed/service/brpc_ps_client.cc index 163526fe3b2..5c226e6a0dd 100644 --- a/paddle/fluid/distributed/service/brpc_ps_client.cc +++ b/paddle/fluid/distributed/service/brpc_ps_client.cc @@ -990,4 +990,4 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id, } } // namespace distributed -} // namespace paddle +} // namespace paddle \ No newline at end of file diff --git a/paddle/fluid/distributed/service/brpc_ps_client.h b/paddle/fluid/distributed/service/brpc_ps_client.h index 8f9d2653864..84a31fdbd5d 100644 --- a/paddle/fluid/distributed/service/brpc_ps_client.h +++ b/paddle/fluid/distributed/service/brpc_ps_client.h @@ -170,9 +170,22 @@ class BrpcPsClient : public PSClient { virtual int32_t recv_and_save_table(const uint64_t table_id, const std::string &path); - private: + protected: + virtual size_t get_server_nums() { return _server_channels.size(); } + inline brpc::Channel *get_sparse_channel(size_t server_id) { + return _server_channels[server_id][0].get(); + } + inline brpc::Channel *get_dense_channel(size_t server_id) { + return _server_channels[server_id][1].get(); + } + inline brpc::Channel *get_cmd_channel(size_t server_id) { + return _server_channels[server_id][2].get(); + } virtual int32_t initialize() override; + private: + // virtual int32_t initialize() override; + inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total, uint32_t shard_num) { return dense_dim_total / shard_num + 1; @@ -184,16 +197,6 @@ class BrpcPsClient : public PSClient { std::future send_save_cmd(uint32_t table_id, int cmd_id, const std::vector ¶m); - inline brpc::Channel *get_sparse_channel(size_t server_id) { - return _server_channels[server_id][0].get(); - } - inline brpc::Channel *get_dense_channel(size_t server_id) { - return _server_channels[server_id][1].get(); - } - inline brpc::Channel *get_cmd_channel(size_t server_id) { - return _server_channels[server_id][2].get(); - } - bool _running = false; bool _flushing = false; std::atomic _async_call_num; //异步请求计数 @@ -220,8 +223,6 @@ class BrpcPsClient : public PSClient { size_t num, void *done) override; - virtual size_t get_server_nums() { return _server_channels.size(); } - private: int32_t start_client_service(); diff --git a/paddle/fluid/distributed/service/graph_brpc_client.cc b/paddle/fluid/distributed/service/graph_brpc_client.cc new file mode 100644 index 00000000000..a6271cac83c --- /dev/null +++ b/paddle/fluid/distributed/service/graph_brpc_client.cc @@ -0,0 +1,331 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/service/graph_brpc_client.h" +#include +#include +#include +#include +#include +#include +#include "Eigen/Dense" +#include "paddle/fluid/distributed/service/brpc_ps_client.h" +#include "paddle/fluid/distributed/table/table.h" +#include "paddle/fluid/framework/archive.h" +#include "paddle/fluid/string/string_helper.h" +namespace paddle { +namespace distributed { + +void GraphPsService_Stub::service( + ::google::protobuf::RpcController *controller, + const ::paddle::distributed::PsRequestMessage *request, + ::paddle::distributed::PsResponseMessage *response, + ::google::protobuf::Closure *done) { + if (graph_service != NULL && local_channel == channel()) { + // VLOG(0)<<"use local"; + task_pool->enqueue([this, controller, request, response, done]() -> int { + this->graph_service->service(controller, request, response, done); + return 0; + }); + } else { + // VLOG(0)<<"use server"; + PsService_Stub::service(controller, request, response, done); + } +} + +int GraphBrpcClient::get_server_index_by_id(uint64_t id) { + int shard_num = get_shard_num(); + int shard_per_server = shard_num % server_size == 0 + ? shard_num / server_size + : shard_num / server_size + 1; + return id % shard_num / shard_per_server; +} + +std::future GraphBrpcClient::get_node_feat( + const uint32_t &table_id, const std::vector &node_ids, + const std::vector &feature_names, + std::vector> &res) { + std::vector request2server; + std::vector server2request(server_size, -1); + for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) { + int server_index = get_server_index_by_id(node_ids[query_idx]); + if (server2request[server_index] == -1) { + server2request[server_index] = request2server.size(); + request2server.push_back(server_index); + } + } + size_t request_call_num = request2server.size(); + std::vector> node_id_buckets(request_call_num); + std::vector> query_idx_buckets(request_call_num); + for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) { + int server_index = get_server_index_by_id(node_ids[query_idx]); + int request_idx = server2request[server_index]; + node_id_buckets[request_idx].push_back(node_ids[query_idx]); + query_idx_buckets[request_idx].push_back(query_idx); + } + + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, + [&, node_id_buckets, query_idx_buckets, request_call_num](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + int fail_num = 0; + for (int request_idx = 0; request_idx < request_call_num; + ++request_idx) { + if (closure->check_response(request_idx, + PS_GRAPH_SAMPLE_NEIGHBOORS) != 0) { + ++fail_num; + } else { + auto &res_io_buffer = + closure->cntl(request_idx)->response_attachment(); + butil::IOBufBytesIterator io_buffer_itr(res_io_buffer); + size_t bytes_size = io_buffer_itr.bytes_left(); + std::unique_ptr buffer_wrapper(new char[bytes_size]); + char *buffer = buffer_wrapper.get(); + io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size); + + for (size_t feat_idx = 0; feat_idx < feature_names.size(); + ++feat_idx) { + for (size_t node_idx = 0; + node_idx < query_idx_buckets.at(request_idx).size(); + ++node_idx) { + int query_idx = query_idx_buckets.at(request_idx).at(node_idx); + size_t feat_len = *(size_t *)(buffer); + buffer += sizeof(size_t); + auto feature = std::string(buffer, feat_len); + res[feat_idx][query_idx] = feature; + buffer += feat_len; + } + } + } + if (fail_num == request_call_num) { + ret = -1; + } + } + closure->set_promise_value(ret); + }); + + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + + for (int request_idx = 0; request_idx < request_call_num; ++request_idx) { + int server_index = request2server[request_idx]; + closure->request(request_idx)->set_cmd_id(PS_GRAPH_GET_NODE_FEAT); + closure->request(request_idx)->set_table_id(table_id); + closure->request(request_idx)->set_client_id(_client_id); + size_t node_num = node_id_buckets[request_idx].size(); + + closure->request(request_idx) + ->add_params((char *)node_id_buckets[request_idx].data(), + sizeof(uint64_t) * node_num); + std::string joint_feature_name = + paddle::string::join_strings(feature_names, '\t'); + closure->request(request_idx) + ->add_params(joint_feature_name.c_str(), joint_feature_name.size()); + + PsService_Stub rpc_stub(get_cmd_channel(server_index)); + closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); + rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), + closure->response(request_idx), closure); + } + + return fut; +} +// char* &buffer,int &actual_size +std::future GraphBrpcClient::batch_sample_neighboors( + uint32_t table_id, std::vector node_ids, int sample_size, + std::vector>> &res) { + std::vector request2server; + std::vector server2request(server_size, -1); + res.clear(); + for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) { + int server_index = get_server_index_by_id(node_ids[query_idx]); + if (server2request[server_index] == -1) { + server2request[server_index] = request2server.size(); + request2server.push_back(server_index); + } + res.push_back(std::vector>()); + } + size_t request_call_num = request2server.size(); + std::vector> node_id_buckets(request_call_num); + std::vector> query_idx_buckets(request_call_num); + for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) { + int server_index = get_server_index_by_id(node_ids[query_idx]); + int request_idx = server2request[server_index]; + node_id_buckets[request_idx].push_back(node_ids[query_idx]); + query_idx_buckets[request_idx].push_back(query_idx); + } + + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, + [&, node_id_buckets, query_idx_buckets, request_call_num](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + int fail_num = 0; + for (int request_idx = 0; request_idx < request_call_num; + ++request_idx) { + if (closure->check_response(request_idx, + PS_GRAPH_SAMPLE_NEIGHBOORS) != 0) { + ++fail_num; + } else { + auto &res_io_buffer = + closure->cntl(request_idx)->response_attachment(); + butil::IOBufBytesIterator io_buffer_itr(res_io_buffer); + size_t bytes_size = io_buffer_itr.bytes_left(); + std::unique_ptr buffer_wrapper(new char[bytes_size]); + char *buffer = buffer_wrapper.get(); + io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size); + + size_t node_num = *(size_t *)buffer; + int *actual_sizes = (int *)(buffer + sizeof(size_t)); + char *node_buffer = + buffer + sizeof(size_t) + sizeof(int) * node_num; + + int offset = 0; + for (size_t node_idx = 0; node_idx < node_num; ++node_idx) { + int query_idx = query_idx_buckets.at(request_idx).at(node_idx); + int actual_size = actual_sizes[node_idx]; + int start = 0; + while (start < actual_size) { + res[query_idx].push_back( + {*(uint64_t *)(node_buffer + offset + start), + *(float *)(node_buffer + offset + start + + GraphNode::id_size)}); + start += GraphNode::id_size + GraphNode::weight_size; + } + offset += actual_size; + } + } + if (fail_num == request_call_num) { + ret = -1; + } + } + closure->set_promise_value(ret); + }); + + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + + for (int request_idx = 0; request_idx < request_call_num; ++request_idx) { + int server_index = request2server[request_idx]; + closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBOORS); + closure->request(request_idx)->set_table_id(table_id); + closure->request(request_idx)->set_client_id(_client_id); + size_t node_num = node_id_buckets[request_idx].size(); + + closure->request(request_idx) + ->add_params((char *)node_id_buckets[request_idx].data(), + sizeof(uint64_t) * node_num); + closure->request(request_idx) + ->add_params((char *)&sample_size, sizeof(int)); + // PsService_Stub rpc_stub(get_cmd_channel(server_index)); + GraphPsService_Stub rpc_stub = + getServiceStub(get_cmd_channel(server_index)); + closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); + rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), + closure->response(request_idx), closure); + } + + return fut; +} +std::future GraphBrpcClient::random_sample_nodes( + uint32_t table_id, int server_index, int sample_size, + std::vector &ids) { + DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + if (closure->check_response(0, PS_GRAPH_SAMPLE_NODES) != 0) { + ret = -1; + } else { + auto &res_io_buffer = closure->cntl(0)->response_attachment(); + butil::IOBufBytesIterator io_buffer_itr(res_io_buffer); + size_t bytes_size = io_buffer_itr.bytes_left(); + char buffer[bytes_size]; + auto size = io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size); + int index = 0; + while (index < bytes_size) { + ids.push_back(*(uint64_t *)(buffer + index)); + index += GraphNode::id_size; + } + } + closure->set_promise_value(ret); + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + ; + closure->request(0)->set_cmd_id(PS_GRAPH_SAMPLE_NODES); + closure->request(0)->set_table_id(table_id); + closure->request(0)->set_client_id(_client_id); + closure->request(0)->add_params((char *)&sample_size, sizeof(int)); + ; + // PsService_Stub rpc_stub(get_cmd_channel(server_index)); + GraphPsService_Stub rpc_stub = getServiceStub(get_cmd_channel(server_index)); + closure->cntl(0)->set_log_id(butil::gettimeofday_ms()); + rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), + closure); + return fut; +} +std::future GraphBrpcClient::pull_graph_list( + uint32_t table_id, int server_index, int start, int size, int step, + std::vector &res) { + DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + if (closure->check_response(0, PS_PULL_GRAPH_LIST) != 0) { + ret = -1; + } else { + auto &res_io_buffer = closure->cntl(0)->response_attachment(); + butil::IOBufBytesIterator io_buffer_itr(res_io_buffer); + size_t bytes_size = io_buffer_itr.bytes_left(); + char buffer[bytes_size]; + io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size); + int index = 0; + while (index < bytes_size) { + FeatureNode node; + node.recover_from_buffer(buffer + index); + index += node.get_size(false); + res.push_back(node); + } + } + closure->set_promise_value(ret); + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + closure->request(0)->set_cmd_id(PS_PULL_GRAPH_LIST); + closure->request(0)->set_table_id(table_id); + closure->request(0)->set_client_id(_client_id); + closure->request(0)->add_params((char *)&start, sizeof(int)); + closure->request(0)->add_params((char *)&size, sizeof(int)); + closure->request(0)->add_params((char *)&step, sizeof(int)); + // PsService_Stub rpc_stub(get_cmd_channel(server_index)); + GraphPsService_Stub rpc_stub = getServiceStub(get_cmd_channel(server_index)); + closure->cntl(0)->set_log_id(butil::gettimeofday_ms()); + rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), + closure); + return fut; +} +int32_t GraphBrpcClient::initialize() { + // set_shard_num(_config.shard_num()); + BrpcPsClient::initialize(); + server_size = get_server_nums(); + graph_service = NULL; + local_channel = NULL; + return 0; +} +} +} diff --git a/paddle/fluid/distributed/service/graph_brpc_client.h b/paddle/fluid/distributed/service/graph_brpc_client.h new file mode 100644 index 00000000000..4e6775a4bed --- /dev/null +++ b/paddle/fluid/distributed/service/graph_brpc_client.h @@ -0,0 +1,105 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include +#include "ThreadPool.h" +#include "brpc/channel.h" +#include "brpc/controller.h" +#include "brpc/server.h" +#include "paddle/fluid/distributed/service/brpc_ps_client.h" +#include "paddle/fluid/distributed/service/graph_brpc_server.h" +#include "paddle/fluid/distributed/service/ps_client.h" +#include "paddle/fluid/distributed/table/table.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor_util.h" + +namespace paddle { +namespace distributed { + +class GraphPsService_Stub : public PsService_Stub { + public: + GraphPsService_Stub(::google::protobuf::RpcChannel* channel, + ::google::protobuf::RpcChannel* local_channel = NULL, + GraphBrpcService* service = NULL, int thread_num = 1) + : PsService_Stub(channel) { + this->local_channel = local_channel; + this->graph_service = service; + task_pool.reset(new ::ThreadPool(thread_num)); + } + virtual ~GraphPsService_Stub() {} + + // implements PsService ------------------------------------------ + GraphBrpcService* graph_service; + std::shared_ptr<::ThreadPool> task_pool; + ::google::protobuf::RpcChannel* local_channel; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(GraphPsService_Stub); + void service(::google::protobuf::RpcController* controller, + const ::paddle::distributed::PsRequestMessage* request, + ::paddle::distributed::PsResponseMessage* response, + ::google::protobuf::Closure* done); +}; +class GraphBrpcClient : public BrpcPsClient { + public: + GraphBrpcClient() {} + virtual ~GraphBrpcClient() {} + // given a batch of nodes, sample graph_neighboors for each of them + virtual std::future batch_sample_neighboors( + uint32_t table_id, std::vector node_ids, int sample_size, + std::vector>>& res); + + virtual std::future pull_graph_list(uint32_t table_id, + int server_index, int start, + int size, int step, + std::vector& res); + virtual std::future random_sample_nodes(uint32_t table_id, + int server_index, + int sample_size, + std::vector& ids); + virtual std::future get_node_feat( + const uint32_t& table_id, const std::vector& node_ids, + const std::vector& feature_names, + std::vector>& res); + virtual int32_t initialize(); + int get_shard_num() { return shard_num; } + void set_shard_num(int shard_num) { this->shard_num = shard_num; } + int get_server_index_by_id(uint64_t id); + void set_local_channel(int index) { + this->local_channel = get_cmd_channel(index); + } + void set_local_graph_service(GraphBrpcService* graph_service) { + this->graph_service = graph_service; + } + GraphPsService_Stub getServiceStub(::google::protobuf::RpcChannel* channel, + int thread_num = 1) { + return GraphPsService_Stub(channel, local_channel, graph_service, + thread_num); + } + + private: + int shard_num; + size_t server_size; + ::google::protobuf::RpcChannel* local_channel; + GraphBrpcService* graph_service; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/graph_brpc_server.cc b/paddle/fluid/distributed/service/graph_brpc_server.cc new file mode 100644 index 00000000000..4f6cc1143e9 --- /dev/null +++ b/paddle/fluid/distributed/service/graph_brpc_server.cc @@ -0,0 +1,347 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/service/graph_brpc_server.h" +#include "paddle/fluid/distributed/service/brpc_ps_server.h" + +#include // NOLINT +#include "butil/endpoint.h" +#include "iomanip" +#include "paddle/fluid/distributed/service/brpc_ps_client.h" +#include "paddle/fluid/distributed/table/table.h" +#include "paddle/fluid/framework/archive.h" +#include "paddle/fluid/platform/profiler.h" +namespace paddle { +namespace distributed { + +int32_t GraphBrpcServer::initialize() { + auto &service_config = _config.downpour_server_param().service_param(); + if (!service_config.has_service_class()) { + LOG(ERROR) << "miss service_class in ServerServiceParameter"; + return -1; + } + auto *service = + CREATE_PSCORE_CLASS(PsBaseService, service_config.service_class()); + if (service == NULL) { + LOG(ERROR) << "service is unregistered, service_name:" + << service_config.service_class(); + return -1; + } + + _service.reset(service); + if (service->configure(this) != 0 || service->initialize() != 0) { + LOG(ERROR) << "service initialize failed, service_name:" + << service_config.service_class(); + return -1; + } + if (_server.AddService(service, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) { + LOG(ERROR) << "service add to brpc failed, service:" + << service_config.service_class(); + return -1; + } + return 0; +} + +uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) { + std::unique_lock lock(mutex_); + + std::string ip_port = ip + ":" + std::to_string(port); + VLOG(3) << "server of rank " << _rank << " starts at " << ip_port; + brpc::ServerOptions options; + + int num_threads = std::thread::hardware_concurrency(); + auto trainers = _environment->get_trainers(); + options.num_threads = trainers > num_threads ? trainers : num_threads; + + if (_server.Start(ip_port.c_str(), &options) != 0) { + LOG(ERROR) << "GraphBrpcServer start failed, ip_port=" << ip_port; + return 0; + } + _environment->registe_ps_server(ip, port, _rank); + return 0; +} + +int32_t GraphBrpcServer::port() { return _server.listen_address().port; } + +int32_t GraphBrpcService::initialize() { + _is_initialize_shard_info = false; + _service_handler_map[PS_STOP_SERVER] = &GraphBrpcService::stop_server; + _service_handler_map[PS_LOAD_ONE_TABLE] = &GraphBrpcService::load_one_table; + _service_handler_map[PS_LOAD_ALL_TABLE] = &GraphBrpcService::load_all_table; + + _service_handler_map[PS_PRINT_TABLE_STAT] = + &GraphBrpcService::print_table_stat; + _service_handler_map[PS_BARRIER] = &GraphBrpcService::barrier; + _service_handler_map[PS_START_PROFILER] = &GraphBrpcService::start_profiler; + _service_handler_map[PS_STOP_PROFILER] = &GraphBrpcService::stop_profiler; + + _service_handler_map[PS_PULL_GRAPH_LIST] = &GraphBrpcService::pull_graph_list; + _service_handler_map[PS_GRAPH_SAMPLE_NEIGHBOORS] = + &GraphBrpcService::graph_random_sample_neighboors; + _service_handler_map[PS_GRAPH_SAMPLE_NODES] = + &GraphBrpcService::graph_random_sample_nodes; + _service_handler_map[PS_GRAPH_GET_NODE_FEAT] = + &GraphBrpcService::graph_get_node_feat; + + // shard初始化,server启动后才可从env获取到server_list的shard信息 + initialize_shard_info(); + + return 0; +} + +#define CHECK_TABLE_EXIST(table, request, response) \ + if (table == NULL) { \ + std::string err_msg("table not found with table_id:"); \ + err_msg.append(std::to_string(request.table_id())); \ + set_response_code(response, -1, err_msg.c_str()); \ + return -1; \ + } + +int32_t GraphBrpcService::initialize_shard_info() { + if (!_is_initialize_shard_info) { + std::lock_guard guard(_initialize_shard_mutex); + if (_is_initialize_shard_info) { + return 0; + } + size_t shard_num = _server->environment()->get_ps_servers().size(); + auto &table_map = *(_server->table()); + for (auto itr : table_map) { + itr.second->set_shard(_rank, shard_num); + } + _is_initialize_shard_info = true; + } + return 0; +} + +void GraphBrpcService::service(google::protobuf::RpcController *cntl_base, + const PsRequestMessage *request, + PsResponseMessage *response, + google::protobuf::Closure *done) { + brpc::ClosureGuard done_guard(done); + std::string log_label("ReceiveCmd-"); + if (!request->has_table_id()) { + set_response_code(*response, -1, "PsRequestMessage.tabel_id is required"); + return; + } + + response->set_err_code(0); + response->set_err_msg(""); + auto *table = _server->table(request->table_id()); + brpc::Controller *cntl = static_cast(cntl_base); + auto itr = _service_handler_map.find(request->cmd_id()); + if (itr == _service_handler_map.end()) { + std::string err_msg( + "undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:"); + err_msg.append(std::to_string(request->cmd_id())); + set_response_code(*response, -1, err_msg.c_str()); + return; + } + serviceFunc handler_func = itr->second; + int service_ret = (this->*handler_func)(table, *request, *response, cntl); + if (service_ret != 0) { + response->set_err_code(service_ret); + response->set_err_msg("server internal error"); + } +} + +int32_t GraphBrpcService::barrier(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + + if (request.params_size() < 1) { + set_response_code(response, -1, + "PsRequestMessage.params is requeired at " + "least 1 for num of sparse_key"); + return 0; + } + + auto trainer_id = request.client_id(); + auto barrier_type = request.params(0); + table->barrier(trainer_id, barrier_type); + return 0; +} + +int32_t GraphBrpcService::print_table_stat(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + std::pair ret = table->print_table_stat(); + paddle::framework::BinaryArchive ar; + ar << ret.first << ret.second; + std::string table_info(ar.Buffer(), ar.Length()); + response.set_data(table_info); + + return 0; +} + +int32_t GraphBrpcService::load_one_table(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + if (request.params_size() < 2) { + set_response_code( + response, -1, + "PsRequestMessage.datas is requeired at least 2 for path & load_param"); + return -1; + } + if (table->load(request.params(0), request.params(1)) != 0) { + set_response_code(response, -1, "table load failed"); + return -1; + } + return 0; +} + +int32_t GraphBrpcService::load_all_table(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + auto &table_map = *(_server->table()); + for (auto &itr : table_map) { + if (load_one_table(itr.second.get(), request, response, cntl) != 0) { + LOG(ERROR) << "load table[" << itr.first << "] failed"; + return -1; + } + } + return 0; +} + +int32_t GraphBrpcService::stop_server(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + GraphBrpcServer *p_server = (GraphBrpcServer *)_server; + std::thread t_stop([p_server]() { + p_server->stop(); + LOG(INFO) << "Server Stoped"; + }); + p_server->export_cv()->notify_all(); + t_stop.detach(); + return 0; +} + +int32_t GraphBrpcService::stop_profiler(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::DisableProfiler(platform::EventSortingKey::kDefault, + string::Sprintf("server_%s_profile", _rank)); + return 0; +} + +int32_t GraphBrpcService::start_profiler(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::EnableProfiler(platform::ProfilerState::kCPU); + return 0; +} + +int32_t GraphBrpcService::pull_graph_list(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + if (request.params_size() < 3) { + set_response_code(response, -1, + "pull_graph_list request requires at least 3 arguments"); + return 0; + } + int start = *(int *)(request.params(0).c_str()); + int size = *(int *)(request.params(1).c_str()); + int step = *(int *)(request.params(2).c_str()); + std::unique_ptr buffer; + int actual_size; + table->pull_graph_list(start, size, buffer, actual_size, false, step); + cntl->response_attachment().append(buffer.get(), actual_size); + return 0; +} +int32_t GraphBrpcService::graph_random_sample_neighboors( + Table *table, const PsRequestMessage &request, PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + if (request.params_size() < 2) { + set_response_code( + response, -1, + "graph_random_sample request requires at least 2 arguments"); + return 0; + } + size_t node_num = request.params(0).size() / sizeof(uint64_t); + uint64_t *node_data = (uint64_t *)(request.params(0).c_str()); + int sample_size = *(uint64_t *)(request.params(1).c_str()); + std::vector> buffers(node_num); + std::vector actual_sizes(node_num, 0); + table->random_sample_neighboors(node_data, sample_size, buffers, + actual_sizes); + + cntl->response_attachment().append(&node_num, sizeof(size_t)); + cntl->response_attachment().append(actual_sizes.data(), + sizeof(int) * node_num); + for (size_t idx = 0; idx < node_num; ++idx) { + cntl->response_attachment().append(buffers[idx].get(), actual_sizes[idx]); + } + return 0; +} +int32_t GraphBrpcService::graph_random_sample_nodes( + Table *table, const PsRequestMessage &request, PsResponseMessage &response, + brpc::Controller *cntl) { + size_t size = *(uint64_t *)(request.params(0).c_str()); + std::unique_ptr buffer; + int actual_size; + if (table->random_sample_nodes(size, buffer, actual_size) == 0) { + cntl->response_attachment().append(buffer.get(), actual_size); + } else + cntl->response_attachment().append(NULL, 0); + + return 0; +} + +int32_t GraphBrpcService::graph_get_node_feat(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + if (request.params_size() < 2) { + set_response_code( + response, -1, + "graph_get_node_feat request requires at least 2 arguments"); + return 0; + } + size_t node_num = request.params(0).size() / sizeof(uint64_t); + uint64_t *node_data = (uint64_t *)(request.params(0).c_str()); + std::vector node_ids(node_data, node_data + node_num); + + std::vector feature_names = + paddle::string::split_string(request.params(1), "\t"); + + std::vector> feature( + feature_names.size(), std::vector(node_num)); + + table->get_node_feat(node_ids, feature_names, feature); + + for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) { + for (size_t node_idx = 0; node_idx < node_num; ++node_idx) { + size_t feat_len = feature[feat_idx][node_idx].size(); + cntl->response_attachment().append(&feat_len, sizeof(size_t)); + cntl->response_attachment().append(feature[feat_idx][node_idx].data(), + feat_len); + } + } + + return 0; +} +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/graph_brpc_server.h b/paddle/fluid/distributed/service/graph_brpc_server.h new file mode 100644 index 00000000000..af63bf5d99e --- /dev/null +++ b/paddle/fluid/distributed/service/graph_brpc_server.h @@ -0,0 +1,113 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "brpc/channel.h" +#include "brpc/controller.h" +#include "brpc/server.h" + +#include +#include +#include "paddle/fluid/distributed/service/brpc_ps_server.h" +#include "paddle/fluid/distributed/service/server.h" + +namespace paddle { +namespace distributed { +class GraphBrpcServer : public PSServer { + public: + GraphBrpcServer() {} + virtual ~GraphBrpcServer() {} + PsBaseService *get_service() { return _service.get(); } + virtual uint64_t start(const std::string &ip, uint32_t port); + virtual int32_t stop() { + std::unique_lock lock(mutex_); + if (stoped_) return 0; + stoped_ = true; + // cv_.notify_all(); + _server.Stop(1000); + _server.Join(); + return 0; + } + virtual int32_t port(); + + std::condition_variable *export_cv() { return &cv_; } + + private: + virtual int32_t initialize(); + mutable std::mutex mutex_; + std::condition_variable cv_; + bool stoped_ = false; + brpc::Server _server; + std::shared_ptr _service; + std::vector> _pserver_channels; +}; + +class GraphBrpcService; + +typedef int32_t (GraphBrpcService::*serviceFunc)( + Table *table, const PsRequestMessage &request, PsResponseMessage &response, + brpc::Controller *cntl); + +class GraphBrpcService : public PsBaseService { + public: + virtual int32_t initialize() override; + + virtual void service(::google::protobuf::RpcController *controller, + const PsRequestMessage *request, + PsResponseMessage *response, + ::google::protobuf::Closure *done) override; + + protected: + std::unordered_map _service_handler_map; + int32_t initialize_shard_info(); + int32_t pull_graph_list(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t graph_random_sample_neighboors(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl); + int32_t graph_random_sample_nodes(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl); + int32_t graph_get_node_feat(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl); + int32_t barrier(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t load_one_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t load_all_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t stop_server(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t start_profiler(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t stop_profiler(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + + int32_t print_table_stat(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + + private: + bool _is_initialize_shard_info; + std::mutex _initialize_shard_mutex; + std::unordered_map _msg_handler_map; + std::vector _ori_values; + const int sample_nodes_ranges = 23; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/graph_py_service.cc b/paddle/fluid/distributed/service/graph_py_service.cc new file mode 100644 index 00000000000..61e4e0cf7bb --- /dev/null +++ b/paddle/fluid/distributed/service/graph_py_service.cc @@ -0,0 +1,325 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/service/graph_py_service.h" +#include // NOLINT +#include "butil/endpoint.h" +#include "iomanip" +#include "paddle/fluid/distributed/table/table.h" +#include "paddle/fluid/framework/archive.h" +#include "paddle/fluid/platform/profiler.h" +namespace paddle { +namespace distributed { +std::vector GraphPyService::split(std::string& str, + const char pattern) { + std::vector res; + std::stringstream input(str); + std::string temp; + while (std::getline(input, temp, pattern)) { + res.push_back(temp); + } + return res; +} + +void GraphPyService::add_table_feat_conf(std::string table_name, + std::string feat_name, + std::string feat_dtype, + int32_t feat_shape) { + if (this->table_id_map.count(table_name)) { + this->table_feat_conf_table_name.push_back(table_name); + this->table_feat_conf_feat_name.push_back(feat_name); + this->table_feat_conf_feat_dtype.push_back(feat_dtype); + this->table_feat_conf_feat_shape.push_back(feat_shape); + } +} + +void GraphPyService::set_up(std::string ips_str, int shard_num, + std::vector node_types, + std::vector edge_types) { + set_shard_num(shard_num); + set_num_node_types(node_types.size()); + + for (size_t table_id = 0; table_id < node_types.size(); table_id++) { + this->table_id_map[node_types[table_id]] = this->table_id_map.size(); + } + for (size_t table_id = 0; table_id < edge_types.size(); table_id++) { + this->table_id_map[edge_types[table_id]] = this->table_id_map.size(); + } + std::istringstream stream(ips_str); + std::string ip; + server_size = 0; + std::vector ips_list = split(ips_str, ';'); + int index = 0; + for (auto ips : ips_list) { + auto ip_and_port = split(ips, ':'); + server_list.push_back(ip_and_port[0]); + port_list.push_back(ip_and_port[1]); + uint32_t port = stoul(ip_and_port[1]); + auto ph_host = paddle::distributed::PSHost(ip_and_port[0], port, index); + host_sign_list.push_back(ph_host.serialize_to_string()); + index++; + } +} +void GraphPyClient::start_client() { + std::map> dense_regions; + dense_regions.insert( + std::pair>(0, {})); + auto regions = dense_regions[0]; + ::paddle::distributed::PSParameter worker_proto = GetWorkerProto(); + paddle::distributed::PaddlePSEnvironment _ps_env; + auto servers_ = host_sign_list.size(); + _ps_env = paddle::distributed::PaddlePSEnvironment(); + _ps_env.set_ps_servers(&host_sign_list, servers_); + worker_ptr = std::shared_ptr( + (paddle::distributed::GraphBrpcClient*) + paddle::distributed::PSClientFactory::create(worker_proto)); + worker_ptr->configure(worker_proto, dense_regions, _ps_env, client_id); + worker_ptr->set_shard_num(get_shard_num()); +} +void GraphPyServer::start_server(bool block) { + std::string ip = server_list[rank]; + uint32_t port = std::stoul(port_list[rank]); + ::paddle::distributed::PSParameter server_proto = this->GetServerProto(); + + auto _ps_env = paddle::distributed::PaddlePSEnvironment(); + _ps_env.set_ps_servers(&this->host_sign_list, + this->host_sign_list.size()); // test + pserver_ptr = std::shared_ptr( + (paddle::distributed::GraphBrpcServer*) + paddle::distributed::PSServerFactory::create(server_proto)); + VLOG(0) << "pserver-ptr created "; + std::vector empty_vec; + framework::ProgramDesc empty_prog; + empty_vec.push_back(empty_prog); + pserver_ptr->configure(server_proto, _ps_env, rank, empty_vec); + pserver_ptr->start(ip, port); + std::condition_variable* cv_ = pserver_ptr->export_cv(); + if (block) { + std::mutex mutex_; + std::unique_lock lock(mutex_); + cv_->wait(lock); + } +} +::paddle::distributed::PSParameter GraphPyServer::GetServerProto() { + // Generate server proto desc + ::paddle::distributed::PSParameter server_fleet_desc; + ::paddle::distributed::ServerParameter* server_proto = + server_fleet_desc.mutable_server_param(); + ::paddle::distributed::DownpourServerParameter* downpour_server_proto = + server_proto->mutable_downpour_server_param(); + ::paddle::distributed::ServerServiceParameter* server_service_proto = + downpour_server_proto->mutable_service_param(); + server_service_proto->set_service_class("GraphBrpcService"); + server_service_proto->set_server_class("GraphBrpcServer"); + server_service_proto->set_client_class("GraphBrpcClient"); + server_service_proto->set_start_server_port(0); + server_service_proto->set_server_thread_num(12); + + for (auto& tuple : this->table_id_map) { + VLOG(0) << " make a new table " << tuple.second; + ::paddle::distributed::TableParameter* sparse_table_proto = + downpour_server_proto->add_downpour_table_param(); + std::vector feat_name; + std::vector feat_dtype; + std::vector feat_shape; + for (size_t i = 0; i < this->table_feat_conf_table_name.size(); i++) { + if (tuple.first == table_feat_conf_table_name[i]) { + feat_name.push_back(table_feat_conf_feat_name[i]); + feat_dtype.push_back(table_feat_conf_feat_dtype[i]); + feat_shape.push_back(table_feat_conf_feat_shape[i]); + } + } + std::string table_type; + if (tuple.second < this->num_node_types) { + table_type = "node"; + } else { + table_type = "edge"; + } + + GetDownpourSparseTableProto(sparse_table_proto, tuple.second, tuple.first, + table_type, feat_name, feat_dtype, feat_shape); + } + + return server_fleet_desc; +} + +::paddle::distributed::PSParameter GraphPyClient::GetWorkerProto() { + ::paddle::distributed::PSParameter worker_fleet_desc; + ::paddle::distributed::WorkerParameter* worker_proto = + worker_fleet_desc.mutable_worker_param(); + + ::paddle::distributed::DownpourWorkerParameter* downpour_worker_proto = + worker_proto->mutable_downpour_worker_param(); + + for (auto& tuple : this->table_id_map) { + VLOG(0) << " make a new table " << tuple.second; + ::paddle::distributed::TableParameter* worker_sparse_table_proto = + downpour_worker_proto->add_downpour_table_param(); + std::vector feat_name; + std::vector feat_dtype; + std::vector feat_shape; + for (size_t i = 0; i < this->table_feat_conf_table_name.size(); i++) { + if (tuple.first == table_feat_conf_table_name[i]) { + feat_name.push_back(table_feat_conf_feat_name[i]); + feat_dtype.push_back(table_feat_conf_feat_dtype[i]); + feat_shape.push_back(table_feat_conf_feat_shape[i]); + } + } + std::string table_type; + if (tuple.second < this->num_node_types) { + table_type = "node"; + } else { + table_type = "edge"; + } + + GetDownpourSparseTableProto(worker_sparse_table_proto, tuple.second, + tuple.first, table_type, feat_name, feat_dtype, + feat_shape); + } + + ::paddle::distributed::ServerParameter* server_proto = + worker_fleet_desc.mutable_server_param(); + ::paddle::distributed::DownpourServerParameter* downpour_server_proto = + server_proto->mutable_downpour_server_param(); + ::paddle::distributed::ServerServiceParameter* server_service_proto = + downpour_server_proto->mutable_service_param(); + server_service_proto->set_service_class("GraphBrpcService"); + server_service_proto->set_server_class("GraphBrpcServer"); + server_service_proto->set_client_class("GraphBrpcClient"); + server_service_proto->set_start_server_port(0); + server_service_proto->set_server_thread_num(12); + + for (auto& tuple : this->table_id_map) { + VLOG(0) << " make a new table " << tuple.second; + ::paddle::distributed::TableParameter* sparse_table_proto = + downpour_server_proto->add_downpour_table_param(); + std::vector feat_name; + std::vector feat_dtype; + std::vector feat_shape; + for (size_t i = 0; i < this->table_feat_conf_table_name.size(); i++) { + if (tuple.first == table_feat_conf_table_name[i]) { + feat_name.push_back(table_feat_conf_feat_name[i]); + feat_dtype.push_back(table_feat_conf_feat_dtype[i]); + feat_shape.push_back(table_feat_conf_feat_shape[i]); + } + } + std::string table_type; + if (tuple.second < this->num_node_types) { + table_type = "node"; + } else { + table_type = "edge"; + } + + GetDownpourSparseTableProto(sparse_table_proto, tuple.second, tuple.first, + table_type, feat_name, feat_dtype, feat_shape); + } + + return worker_fleet_desc; +} +void GraphPyClient::load_edge_file(std::string name, std::string filepath, + bool reverse) { + // 'e' means load edge + std::string params = "e"; + if (reverse) { + // 'e<' means load edges from $2 to $1 + params += "<"; + } else { + // 'e>' means load edges from $1 to $2 + params += ">"; + } + if (this->table_id_map.count(name)) { + VLOG(0) << "loadding data with type " << name << " from " << filepath; + uint32_t table_id = this->table_id_map[name]; + auto status = + get_ps_client()->load(table_id, std::string(filepath), params); + status.wait(); + } +} + +void GraphPyClient::load_node_file(std::string name, std::string filepath) { + // 'n' means load nodes and 'node_type' follows + std::string params = "n" + name; + if (this->table_id_map.count(name)) { + uint32_t table_id = this->table_id_map[name]; + auto status = + get_ps_client()->load(table_id, std::string(filepath), params); + status.wait(); + } +} +std::vector>> +GraphPyClient::batch_sample_neighboors(std::string name, + std::vector node_ids, + int sample_size) { + std::vector>> v; + if (this->table_id_map.count(name)) { + uint32_t table_id = this->table_id_map[name]; + auto status = + worker_ptr->batch_sample_neighboors(table_id, node_ids, sample_size, v); + status.wait(); + } + return v; +} + +std::vector GraphPyClient::random_sample_nodes(std::string name, + int server_index, + int sample_size) { + std::vector v; + if (this->table_id_map.count(name)) { + uint32_t table_id = this->table_id_map[name]; + auto status = + worker_ptr->random_sample_nodes(table_id, server_index, sample_size, v); + status.wait(); + } + return v; +} + +// (name, dtype, ndarray) +std::vector> GraphPyClient::get_node_feat( + std::string node_type, std::vector node_ids, + std::vector feature_names) { + std::vector> v( + feature_names.size(), std::vector(node_ids.size())); + if (this->table_id_map.count(node_type)) { + uint32_t table_id = this->table_id_map[node_type]; + auto status = + worker_ptr->get_node_feat(table_id, node_ids, feature_names, v); + status.wait(); + } + return v; +} + +std::vector GraphPyClient::pull_graph_list(std::string name, + int server_index, + int start, int size, + int step) { + std::vector res; + if (this->table_id_map.count(name)) { + uint32_t table_id = this->table_id_map[name]; + auto status = worker_ptr->pull_graph_list(table_id, server_index, start, + size, step, res); + status.wait(); + } + return res; +} + +void GraphPyClient::stop_server() { + VLOG(0) << "going to stop server"; + std::unique_lock lock(mutex_); + if (stoped_) return; + auto status = this->worker_ptr->stop_server(); + if (status.get() == 0) stoped_ = true; +} +void GraphPyClient::finalize_worker() { this->worker_ptr->finalize_worker(); } +} +} diff --git a/paddle/fluid/distributed/service/graph_py_service.h b/paddle/fluid/distributed/service/graph_py_service.h new file mode 100644 index 00000000000..e185f23e3d2 --- /dev/null +++ b/paddle/fluid/distributed/service/graph_py_service.h @@ -0,0 +1,178 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include // NOLINT +#include +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include "google/protobuf/text_format.h" + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/framework/variable.h" + +#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/service/env.h" +#include "paddle/fluid/distributed/service/graph_brpc_client.h" +#include "paddle/fluid/distributed/service/graph_brpc_server.h" +#include "paddle/fluid/distributed/service/sendrecv.pb.h" +#include "paddle/fluid/distributed/service/service.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/string/printf.h" +namespace paddle { +namespace distributed { +class GraphPyService { + protected: + std::vector server_list, port_list, host_sign_list; + int server_size, shard_num; + int num_node_types; + std::unordered_map table_id_map; + std::vector table_feat_conf_table_name; + std::vector table_feat_conf_feat_name; + std::vector table_feat_conf_feat_dtype; + std::vector table_feat_conf_feat_shape; + + // std::thread *server_thread, *client_thread; + + // std::shared_ptr pserver_ptr; + + // std::shared_ptr worker_ptr; + + public: + // std::shared_ptr get_ps_server() { + // return pserver_ptr; + // } + // std::shared_ptr get_ps_client() { + // return worker_ptr; + // } + int get_shard_num() { return shard_num; } + void set_shard_num(int shard_num) { this->shard_num = shard_num; } + void GetDownpourSparseTableProto( + ::paddle::distributed::TableParameter* sparse_table_proto, + uint32_t table_id, std::string table_name, std::string table_type, + std::vector feat_name, std::vector feat_dtype, + std::vector feat_shape) { + sparse_table_proto->set_table_id(table_id); + sparse_table_proto->set_table_class("GraphTable"); + sparse_table_proto->set_shard_num(shard_num); + sparse_table_proto->set_type(::paddle::distributed::PS_SPARSE_TABLE); + ::paddle::distributed::TableAccessorParameter* accessor_proto = + sparse_table_proto->mutable_accessor(); + + ::paddle::distributed::CommonAccessorParameter* common_proto = + sparse_table_proto->mutable_common(); + + // Set GraphTable Parameter + common_proto->set_table_name(table_name); + common_proto->set_name(table_type); + for (size_t i = 0; i < feat_name.size(); i++) { + common_proto->add_params(feat_dtype[i]); + common_proto->add_dims(feat_shape[i]); + common_proto->add_attributes(feat_name[i]); + } + + accessor_proto->set_accessor_class("CommMergeAccessor"); + } + + void set_server_size(int server_size) { this->server_size = server_size; } + void set_num_node_types(int num_node_types) { + this->num_node_types = num_node_types; + } + int get_server_size(int server_size) { return server_size; } + std::vector split(std::string& str, const char pattern); + void set_up(std::string ips_str, int shard_num, + std::vector node_types, + std::vector edge_types); + + void add_table_feat_conf(std::string node_type, std::string feat_name, + std::string feat_dtype, int32_t feat_shape); +}; +class GraphPyServer : public GraphPyService { + public: + GraphPyServer() {} + void set_up(std::string ips_str, int shard_num, + std::vector node_types, + std::vector edge_types, int rank) { + set_rank(rank); + GraphPyService::set_up(ips_str, shard_num, node_types, edge_types); + } + int get_rank() { return rank; } + void set_rank(int rank) { this->rank = rank; } + + void start_server(bool block = true); + ::paddle::distributed::PSParameter GetServerProto(); + std::shared_ptr get_ps_server() { + return pserver_ptr; + } + + protected: + int rank; + std::shared_ptr pserver_ptr; + std::thread* server_thread; +}; +class GraphPyClient : public GraphPyService { + public: + void set_up(std::string ips_str, int shard_num, + std::vector node_types, + std::vector edge_types, int client_id) { + set_client_id(client_id); + GraphPyService::set_up(ips_str, shard_num, node_types, edge_types); + } + std::shared_ptr get_ps_client() { + return worker_ptr; + } + void bind_local_server(int local_channel_index, GraphPyServer& server) { + worker_ptr->set_local_channel(local_channel_index); + worker_ptr->set_local_graph_service( + (paddle::distributed::GraphBrpcService*)server.get_ps_server() + ->get_service()); + } + void stop_server(); + void finalize_worker(); + void load_edge_file(std::string name, std::string filepath, bool reverse); + void load_node_file(std::string name, std::string filepath); + int get_client_id() { return client_id; } + void set_client_id(int client_id) { this->client_id = client_id; } + void start_client(); + std::vector>> batch_sample_neighboors( + std::string name, std::vector node_ids, int sample_size); + std::vector random_sample_nodes(std::string name, int server_index, + int sample_size); + std::vector> get_node_feat( + std::string node_type, std::vector node_ids, + std::vector feature_names); + std::vector pull_graph_list(std::string name, int server_index, + int start, int size, int step = 1); + ::paddle::distributed::PSParameter GetWorkerProto(); + + protected: + mutable std::mutex mutex_; + int client_id; + std::shared_ptr worker_ptr; + std::thread* client_thread; + bool stoped_ = false; +}; +} +} diff --git a/paddle/fluid/distributed/service/ps_client.cc b/paddle/fluid/distributed/service/ps_client.cc index d427ecfc538..3f78908baa3 100644 --- a/paddle/fluid/distributed/service/ps_client.cc +++ b/paddle/fluid/distributed/service/ps_client.cc @@ -15,12 +15,13 @@ #include "paddle/fluid/distributed/service/ps_client.h" #include "glog/logging.h" #include "paddle/fluid/distributed/service/brpc_ps_client.h" +#include "paddle/fluid/distributed/service/graph_brpc_client.h" #include "paddle/fluid/distributed/table/table.h" namespace paddle { namespace distributed { REGISTER_PSCORE_CLASS(PSClient, BrpcPsClient); - +REGISTER_PSCORE_CLASS(PSClient, GraphBrpcClient); int32_t PSClient::configure( const PSParameter &config, const std::map> ®ions, @@ -82,4 +83,4 @@ PSClient *PSClientFactory::create(const PSParameter &ps_config) { return client; } } // namespace distributed -} // namespace paddle +} // namespace paddle \ No newline at end of file diff --git a/paddle/fluid/distributed/service/ps_client.h b/paddle/fluid/distributed/service/ps_client.h index 50f5802c63a..7b698afa726 100644 --- a/paddle/fluid/distributed/service/ps_client.h +++ b/paddle/fluid/distributed/service/ps_client.h @@ -24,16 +24,11 @@ #include "paddle/fluid/distributed/service/env.h" #include "paddle/fluid/distributed/service/sendrecv.pb.h" #include "paddle/fluid/distributed/table/accessor.h" +#include "paddle/fluid/distributed/table/graph_node.h" namespace paddle { namespace distributed { -class PSEnvironment; -class PsRequestMessage; -class PsResponseMessage; -class ValueAccessor; -struct Region; - using paddle::distributed::PsRequestMessage; using paddle::distributed::PsResponseMessage; @@ -160,6 +155,7 @@ class PSClient { promise.set_value(-1); return fut; } + // client2client消息处理,std::function ret (msg_type, from_client_id, msg) typedef std::function MsgHandlerFunc; diff --git a/paddle/fluid/distributed/service/sendrecv.proto b/paddle/fluid/distributed/service/sendrecv.proto index 6250f84c987..d908c26da98 100644 --- a/paddle/fluid/distributed/service/sendrecv.proto +++ b/paddle/fluid/distributed/service/sendrecv.proto @@ -48,6 +48,10 @@ enum PsCmdID { PS_START_PROFILER = 27; PS_STOP_PROFILER = 28; PS_PUSH_GLOBAL_STEP = 29; + PS_PULL_GRAPH_LIST = 30; + PS_GRAPH_SAMPLE_NEIGHBOORS = 31; + PS_GRAPH_SAMPLE_NODES = 32; + PS_GRAPH_GET_NODE_FEAT = 33; } message PsRequestMessage { @@ -111,4 +115,4 @@ message MultiVariableMessage { service PsService { rpc service(PsRequestMessage) returns (PsResponseMessage); rpc SendAndRecvVariable(MultiVariableMessage) returns (MultiVariableMessage); -}; \ No newline at end of file +}; diff --git a/paddle/fluid/distributed/service/server.cc b/paddle/fluid/distributed/service/server.cc index fc230a0b9c9..9324adad697 100644 --- a/paddle/fluid/distributed/service/server.cc +++ b/paddle/fluid/distributed/service/server.cc @@ -16,6 +16,7 @@ #include "glog/logging.h" #include "paddle/fluid/distributed/service/brpc_ps_server.h" +#include "paddle/fluid/distributed/service/graph_brpc_server.h" #include "paddle/fluid/distributed/table/table.h" namespace paddle { @@ -23,6 +24,8 @@ namespace distributed { REGISTER_PSCORE_CLASS(PSServer, BrpcPsServer); REGISTER_PSCORE_CLASS(PsBaseService, BrpcPsService); +REGISTER_PSCORE_CLASS(PSServer, GraphBrpcServer); +REGISTER_PSCORE_CLASS(PsBaseService, GraphBrpcService); PSServer *PSServerFactory::create(const PSParameter &ps_config) { const auto &config = ps_config.server_param(); diff --git a/paddle/fluid/distributed/table/CMakeLists.txt b/paddle/fluid/distributed/table/CMakeLists.txt index 1e98e193d54..33873abc5f7 100644 --- a/paddle/fluid/distributed/table/CMakeLists.txt +++ b/paddle/fluid/distributed/table/CMakeLists.txt @@ -1,13 +1,19 @@ set_property(GLOBAL PROPERTY TABLE_DEPS string_helper) get_property(TABLE_DEPS GLOBAL PROPERTY TABLE_DEPS) - +set_source_files_properties(graph_edge.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_library(graph_edge SRCS graph_edge.cc) +set_source_files_properties(graph_weighted_sampler.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_library(WeightedSampler SRCS graph_weighted_sampler.cc DEPS graph_edge) +set_source_files_properties(graph_node.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_library(graph_node SRCS graph_node.cc DEPS WeightedSampler) set_source_files_properties(common_dense_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(common_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(barrier_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(common_graph_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) -cc_library(common_table SRCS common_sparse_table.cc common_dense_table.cc sparse_geo_table.cc barrier_table.cc DEPS ${TABLE_DEPS} device_context string_helper simple_threadpool xxhash generator) +cc_library(common_table SRCS common_sparse_table.cc common_dense_table.cc sparse_geo_table.cc barrier_table.cc common_graph_table.cc DEPS ${TABLE_DEPS} graph_edge graph_node device_context string_helper simple_threadpool xxhash generator) set_source_files_properties(tensor_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(tensor_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) diff --git a/paddle/fluid/distributed/table/common_graph_table.cc b/paddle/fluid/distributed/table/common_graph_table.cc new file mode 100644 index 00000000000..995a39a6543 --- /dev/null +++ b/paddle/fluid/distributed/table/common_graph_table.cc @@ -0,0 +1,506 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/table/common_graph_table.h" +#include +#include +#include +#include +#include "paddle/fluid/distributed/common/utils.h" +#include "paddle/fluid/distributed/table/graph_node.h" +#include "paddle/fluid/string/printf.h" +#include "paddle/fluid/string/string_helper.h" +namespace paddle { +namespace distributed { + +std::vector GraphShard::get_batch(int start, int end, int step) { + if (start < 0) start = 0; + std::vector res; + for (int pos = start; pos < std::min(end, (int)bucket.size()); pos += step) { + res.push_back(bucket[pos]); + } + return res; +} + +size_t GraphShard::get_size() { return bucket.size(); } + +GraphNode *GraphShard::add_graph_node(uint64_t id) { + if (node_location.find(id) == node_location.end()) { + node_location[id] = bucket.size(); + bucket.push_back(new GraphNode(id)); + } + return (GraphNode *)bucket[node_location[id]]; +} + +FeatureNode *GraphShard::add_feature_node(uint64_t id) { + if (node_location.find(id) == node_location.end()) { + node_location[id] = bucket.size(); + bucket.push_back(new FeatureNode(id)); + } + return (FeatureNode *)bucket[node_location[id]]; +} + +void GraphShard::add_neighboor(uint64_t id, uint64_t dst_id, float weight) { + find_node(id)->add_edge(dst_id, weight); +} + +Node *GraphShard::find_node(uint64_t id) { + auto iter = node_location.find(id); + return iter == node_location.end() ? nullptr : bucket[iter->second]; +} + +int32_t GraphTable::load(const std::string &path, const std::string ¶m) { + bool load_edge = (param[0] == 'e'); + bool load_node = (param[0] == 'n'); + if (load_edge) { + bool reverse_edge = (param[1] == '<'); + return this->load_edges(path, reverse_edge); + } + if (load_node) { + std::string node_type = param.substr(1); + return this->load_nodes(path, node_type); + } + return 0; +} + +int32_t GraphTable::get_nodes_ids_by_ranges( + std::vector> ranges, std::vector &res) { + int start = 0, end, index = 0, total_size = 0; + res.clear(); + std::vector>> tasks; + // std::string temp = ""; + // for(int i = 0;i < shards.size();i++) + // temp+= std::to_string((int)shards[i].get_size()) + " "; + // VLOG(0)<<"range distribution "<= end) { + break; + } else { + int first = std::max(ranges[index].first, start); + int second = std::min(ranges[index].second, end); + start = second; + first -= total_size; + second -= total_size; + // VLOG(0)<<" FIND RANGE "<enqueue( + [this, first, second, i]() -> std::vector { + return shards[i].get_ids_by_range(first, second); + })); + } + } + total_size += shards[i].get_size(); + } + for (int i = 0; i < tasks.size(); i++) { + auto vec = tasks[i].get(); + for (auto &id : vec) { + res.push_back(id); + std::swap(res[rand() % res.size()], res[(int)res.size() - 1]); + } + } + return 0; +} + +int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) { + auto paths = paddle::string::split_string(path, ";"); + int64_t count = 0; + int64_t valid_count = 0; + for (auto path : paths) { + std::ifstream file(path); + std::string line; + while (std::getline(file, line)) { + count++; + auto values = paddle::string::split_string(line, "\t"); + if (values.size() < 2) continue; + auto id = std::stoull(values[1]); + + size_t shard_id = id % shard_num; + if (shard_id >= shard_end || shard_id < shard_start) { + VLOG(4) << "will not load " << id << " from " << path + << ", please check id distribution"; + continue; + } + + if (count % 1000000 == 0) { + VLOG(0) << count << " nodes are loaded from filepath"; + } + + std::string nt = values[0]; + if (nt != node_type) { + continue; + } + + size_t index = shard_id - shard_start; + + auto node = shards[index].add_feature_node(id); + + node->set_feature_size(feat_name.size()); + + for (size_t slice = 2; slice < values.size(); slice++) { + auto feat = this->parse_feature(values[slice]); + if (feat.first >= 0) { + node->set_feature(feat.first, feat.second); + } else { + VLOG(4) << "Node feature: " << values[slice] + << " not in feature_map."; + } + } + valid_count++; + } + } + + VLOG(0) << valid_count << "/" << count << " nodes in type " << node_type + << " are loaded successfully in " << path; + return 0; +} + +int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) { + auto paths = paddle::string::split_string(path, ";"); + int count = 0; + std::string sample_type = "random"; + bool is_weighted = false; + int valid_count = 0; + + for (auto path : paths) { + std::ifstream file(path); + std::string line; + while (std::getline(file, line)) { + auto values = paddle::string::split_string(line, "\t"); + count++; + if (values.size() < 2) continue; + auto src_id = std::stoull(values[0]); + auto dst_id = std::stoull(values[1]); + if (reverse_edge) { + std::swap(src_id, dst_id); + } + float weight = 1; + if (values.size() == 3) { + weight = std::stof(values[2]); + sample_type = "weighted"; + is_weighted = true; + } + + size_t src_shard_id = src_id % shard_num; + + if (src_shard_id >= shard_end || src_shard_id < shard_start) { + VLOG(4) << "will not load " << src_id << " from " << path + << ", please check id distribution"; + continue; + } + if (count % 1000000 == 0) { + VLOG(0) << count << " edges are loaded from filepath"; + } + + size_t index = src_shard_id - shard_start; + shards[index].add_graph_node(src_id)->build_edges(is_weighted); + shards[index].add_neighboor(src_id, dst_id, weight); + valid_count++; + } + } + VLOG(0) << valid_count << "/" << count << " edges are loaded successfully in " + << path; + + // Build Sampler j + + for (auto &shard : shards) { + auto bucket = shard.get_bucket(); + for (int i = 0; i < bucket.size(); i++) { + bucket[i]->build_sampler(sample_type); + } + } + return 0; +} + +Node *GraphTable::find_node(uint64_t id) { + size_t shard_id = id % shard_num; + if (shard_id >= shard_end || shard_id < shard_start) { + return nullptr; + } + size_t index = shard_id - shard_start; + Node *node = shards[index].find_node(id); + return node; +} +uint32_t GraphTable::get_thread_pool_index(uint64_t node_id) { + return node_id % shard_num % shard_num_per_table % task_pool_size_; +} +int32_t GraphTable::random_sample_nodes(int sample_size, + std::unique_ptr &buffer, + int &actual_size) { + bool need_feature = false; + int total_size = 0; + for (int i = 0; i < shards.size(); i++) { + total_size += shards[i].get_size(); + } + if (sample_size > total_size) sample_size = total_size; + int range_num = random_sample_nodes_ranges; + if (range_num > sample_size) range_num = sample_size; + if (sample_size == 0 || range_num == 0) return 0; + std::vector ranges_len, ranges_pos; + int remain = sample_size, last_pos = -1, num; + std::set separator_set; + for (int i = 0; i < range_num - 1; i++) { + while (separator_set.find(num = rand() % (sample_size - 1)) != + separator_set.end()) + ; + separator_set.insert(num); + } + for (auto p : separator_set) { + ranges_len.push_back(p - last_pos); + last_pos = p; + } + ranges_len.push_back(sample_size - 1 - last_pos); + remain = total_size - sample_size + range_num; + separator_set.clear(); + for (int i = 0; i < range_num; i++) { + while (separator_set.find(num = rand() % remain) != separator_set.end()) + ; + separator_set.insert(num); + } + int used = 0, index = 0; + last_pos = -1; + for (auto p : separator_set) { + used += p - last_pos - 1; + last_pos = p; + ranges_pos.push_back(used); + used += ranges_len[index++]; + } + std::vector> first_half, second_half; + int start_index = rand() % total_size; + for (int i = 0; i < ranges_len.size() && i < ranges_pos.size(); i++) { + if (ranges_pos[i] + ranges_len[i] - 1 + start_index < total_size) + first_half.push_back({ranges_pos[i] + start_index, + ranges_pos[i] + ranges_len[i] + start_index}); + else if (ranges_pos[i] + start_index >= total_size) { + second_half.push_back( + {ranges_pos[i] + start_index - total_size, + ranges_pos[i] + ranges_len[i] + start_index - total_size}); + } else { + first_half.push_back({ranges_pos[i] + start_index, total_size}); + second_half.push_back( + {0, ranges_pos[i] + ranges_len[i] + start_index - total_size}); + } + } + for (auto &pair : first_half) second_half.push_back(pair); + std::vector res; + get_nodes_ids_by_ranges(second_half, res); + actual_size = res.size() * sizeof(uint64_t); + buffer.reset(new char[actual_size]); + char *pointer = buffer.get(); + memcpy(pointer, res.data(), actual_size); + return 0; +} +int32_t GraphTable::random_sample_neighboors( + uint64_t *node_ids, int sample_size, + std::vector> &buffers, + std::vector &actual_sizes) { + size_t node_num = buffers.size(); + std::vector> tasks; + for (size_t idx = 0; idx < node_num; ++idx) { + uint64_t &node_id = node_ids[idx]; + std::unique_ptr &buffer = buffers[idx]; + int &actual_size = actual_sizes[idx]; + tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue( + [&]() -> int { + Node *node = find_node(node_id); + + if (node == nullptr) { + actual_size = 0; + return 0; + } + std::vector res = node->sample_k(sample_size); + actual_size = res.size() * (Node::id_size + Node::weight_size); + int offset = 0; + uint64_t id; + float weight; + char *buffer_addr = new char[actual_size]; + buffer.reset(buffer_addr); + for (int &x : res) { + id = node->get_neighbor_id(x); + weight = node->get_neighbor_weight(x); + memcpy(buffer_addr + offset, &id, Node::id_size); + offset += Node::id_size; + memcpy(buffer_addr + offset, &weight, Node::weight_size); + offset += Node::weight_size; + } + return 0; + })); + } + for (size_t idx = 0; idx < node_num; ++idx) { + tasks[idx].get(); + } + return 0; +} + +int32_t GraphTable::get_node_feat(const std::vector &node_ids, + const std::vector &feature_names, + std::vector> &res) { + size_t node_num = node_ids.size(); + std::vector> tasks; + for (size_t idx = 0; idx < node_num; ++idx) { + uint64_t node_id = node_ids[idx]; + tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue( + [&, idx, node_id]() -> int { + Node *node = find_node(node_id); + + if (node == nullptr) { + return 0; + } + for (int feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) { + const std::string &feature_name = feature_names[feat_idx]; + if (feat_id_map.find(feature_name) != feat_id_map.end()) { + // res[feat_idx][idx] = + // node->get_feature(feat_id_map[feature_name]); + auto feat = node->get_feature(feat_id_map[feature_name]); + res[feat_idx][idx] = feat; + } + } + return 0; + })); + } + for (size_t idx = 0; idx < node_num; ++idx) { + tasks[idx].get(); + } + return 0; +} + +std::pair GraphTable::parse_feature( + std::string feat_str) { + // Return (feat_id, btyes) if name are in this->feat_name, else return (-1, + // "") + auto fields = paddle::string::split_string(feat_str, " "); + if (this->feat_id_map.count(fields[0])) { + int32_t id = this->feat_id_map[fields[0]]; + std::string dtype = this->feat_dtype[id]; + int32_t shape = this->feat_shape[id]; + std::vector values(fields.begin() + 1, fields.end()); + if (dtype == "feasign") { + return std::make_pair( + int32_t(id), paddle::string::join_strings(values, ' ')); + } else if (dtype == "string") { + return std::make_pair( + int32_t(id), paddle::string::join_strings(values, ' ')); + } else if (dtype == "float32") { + return std::make_pair( + int32_t(id), FeatureNode::parse_value_to_bytes(values)); + } else if (dtype == "float64") { + return std::make_pair( + int32_t(id), FeatureNode::parse_value_to_bytes(values)); + } else if (dtype == "int32") { + return std::make_pair( + int32_t(id), FeatureNode::parse_value_to_bytes(values)); + } else if (dtype == "int64") { + return std::make_pair( + int32_t(id), FeatureNode::parse_value_to_bytes(values)); + } + } + return std::make_pair(-1, ""); +} + +int32_t GraphTable::pull_graph_list(int start, int total_size, + std::unique_ptr &buffer, + int &actual_size, bool need_feature, + int step) { + if (start < 0) start = 0; + int size = 0, cur_size; + std::vector>> tasks; + for (size_t i = 0; i < shards.size() && total_size > 0; i++) { + cur_size = shards[i].get_size(); + if (size + cur_size <= start) { + size += cur_size; + continue; + } + int count = std::min(1 + (size + cur_size - start - 1) / step, total_size); + int end = start + (count - 1) * step + 1; + tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue( + [this, i, start, end, step, size]() -> std::vector { + + return this->shards[i].get_batch(start - size, end - size, step); + })); + start += count * step; + total_size -= count; + size += cur_size; + } + for (size_t i = 0; i < tasks.size(); ++i) { + tasks[i].wait(); + } + size = 0; + std::vector> res; + for (size_t i = 0; i < tasks.size(); i++) { + res.push_back(tasks[i].get()); + for (size_t j = 0; j < res.back().size(); j++) { + size += res.back()[j]->get_size(need_feature); + } + } + char *buffer_addr = new char[size]; + buffer.reset(buffer_addr); + int index = 0; + for (size_t i = 0; i < res.size(); i++) { + for (size_t j = 0; j < res[i].size(); j++) { + res[i][j]->to_buffer(buffer_addr + index, need_feature); + index += res[i][j]->get_size(need_feature); + } + } + actual_size = size; + return 0; +} +int32_t GraphTable::initialize() { + _shards_task_pool.resize(task_pool_size_); + for (size_t i = 0; i < _shards_task_pool.size(); ++i) { + _shards_task_pool[i].reset(new ::ThreadPool(1)); + } + server_num = _shard_num; + // VLOG(0) << "in init graph table server num = " << server_num; + /* + _shard_num is actually server number here + when a server initialize its tables, it sets tables' _shard_num to server_num, + and _shard_idx to server + rank + */ + auto common = _config.common(); + + this->table_name = common.table_name(); + this->table_type = common.name(); + VLOG(0) << " init graph table type " << this->table_type << " table name " + << this->table_name; + int feat_conf_size = static_cast(common.attributes().size()); + for (int i = 0; i < feat_conf_size; i++) { + auto &f_name = common.attributes()[i]; + auto &f_shape = common.dims()[i]; + auto &f_dtype = common.params()[i]; + this->feat_name.push_back(f_name); + this->feat_shape.push_back(f_shape); + this->feat_dtype.push_back(f_dtype); + this->feat_id_map[f_name] = i; + VLOG(0) << "init graph table feat conf name:" << f_name + << " shape:" << f_shape << " dtype:" << f_dtype; + } + + shard_num = _config.shard_num(); + VLOG(0) << "in init graph table shard num = " << shard_num << " shard_idx" + << _shard_idx; + shard_num_per_table = sparse_local_shard_num(shard_num, server_num); + shard_start = _shard_idx * shard_num_per_table; + shard_end = shard_start + shard_num_per_table; + VLOG(0) << "in init graph table shard idx = " << _shard_idx << " shard_start " + << shard_start << " shard_end " << shard_end; + // shards.resize(shard_num_per_table); + shards = std::vector(shard_num_per_table, GraphShard(shard_num)); + return 0; +} +} +}; diff --git a/paddle/fluid/distributed/table/common_graph_table.h b/paddle/fluid/distributed/table/common_graph_table.h new file mode 100644 index 00000000000..de3cac134cd --- /dev/null +++ b/paddle/fluid/distributed/table/common_graph_table.h @@ -0,0 +1,144 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include +#include +#include "paddle/fluid/distributed/table/accessor.h" +#include "paddle/fluid/distributed/table/common_table.h" +#include "paddle/fluid/distributed/table/graph_node.h" +#include "paddle/fluid/framework/rw_lock.h" +#include "paddle/fluid/string/string_helper.h" +namespace paddle { +namespace distributed { +class GraphShard { + public: + // static int bucket_low_bound; + // static int gcd(int s, int t) { + // if (s % t == 0) return t; + // return gcd(t, s % t); + // } + size_t get_size(); + GraphShard() {} + GraphShard(int shard_num) { + this->shard_num = shard_num; + // bucket_size = init_bucket_size(shard_num); + // bucket.resize(bucket_size); + } + std::vector &get_bucket() { return bucket; } + std::vector get_batch(int start, int end, int step); + // int init_bucket_size(int shard_num) { + // for (int i = bucket_low_bound;; i++) { + // if (gcd(i, shard_num) == 1) return i; + // } + // return -1; + // } + std::vector get_ids_by_range(int start, int end) { + std::vector res; + for (int i = start; i < end && i < bucket.size(); i++) { + res.push_back(bucket[i]->get_id()); + } + return res; + } + GraphNode *add_graph_node(uint64_t id); + FeatureNode *add_feature_node(uint64_t id); + Node *find_node(uint64_t id); + void add_neighboor(uint64_t id, uint64_t dst_id, float weight); + // std::unordered_map::iterator> + std::unordered_map get_node_location() { + return node_location; + } + + private: + std::unordered_map node_location; + int shard_num; + std::vector bucket; +}; +class GraphTable : public SparseTable { + public: + GraphTable() {} + virtual ~GraphTable() {} + virtual int32_t pull_graph_list(int start, int size, + std::unique_ptr &buffer, + int &actual_size, bool need_feature, + int step); + + virtual int32_t random_sample_neighboors( + uint64_t *node_ids, int sample_size, + std::vector> &buffers, + std::vector &actual_sizes); + + int32_t random_sample_nodes(int sample_size, std::unique_ptr &buffers, + int &actual_sizes); + + virtual int32_t get_nodes_ids_by_ranges( + std::vector> ranges, std::vector &res); + virtual int32_t initialize(); + + int32_t load(const std::string &path, const std::string ¶m); + + int32_t load_edges(const std::string &path, bool reverse); + + int32_t load_nodes(const std::string &path, std::string node_type); + + Node *find_node(uint64_t id); + + virtual int32_t pull_sparse(float *values, const uint64_t *keys, size_t num) { + return 0; + } + virtual int32_t push_sparse(const uint64_t *keys, const float *values, + size_t num) { + return 0; + } + virtual void clear() {} + virtual int32_t flush() { return 0; } + virtual int32_t shrink(const std::string ¶m) { return 0; } + //指定保存路径 + virtual int32_t save(const std::string &path, const std::string &converter) { + return 0; + } + virtual int32_t initialize_shard() { return 0; } + virtual uint32_t get_thread_pool_index(uint64_t node_id); + virtual std::pair parse_feature(std::string feat_str); + + virtual int32_t get_node_feat(const std::vector &node_ids, + const std::vector &feature_names, + std::vector> &res); + + protected: + std::vector shards; + size_t shard_start, shard_end, server_num, shard_num_per_table, shard_num; + const int task_pool_size_ = 11; + const int random_sample_nodes_ranges = 3; + + std::vector feat_name; + std::vector feat_dtype; + std::vector feat_shape; + std::unordered_map feat_id_map; + std::string table_name; + std::string table_type; + + std::vector> _shards_task_pool; +}; +} +}; diff --git a/paddle/fluid/distributed/table/graph_edge.cc b/paddle/fluid/distributed/table/graph_edge.cc new file mode 100644 index 00000000000..cc90f4c6516 --- /dev/null +++ b/paddle/fluid/distributed/table/graph_edge.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/table/graph_edge.h" +#include +namespace paddle { +namespace distributed { + +void GraphEdgeBlob::add_edge(uint64_t id, float weight = 1) { + id_arr.push_back(id); +} + +void WeightedGraphEdgeBlob::add_edge(uint64_t id, float weight = 1) { + id_arr.push_back(id); + weight_arr.push_back(weight); +} +} +} diff --git a/paddle/fluid/distributed/table/graph_edge.h b/paddle/fluid/distributed/table/graph_edge.h new file mode 100644 index 00000000000..3dfe5a6f357 --- /dev/null +++ b/paddle/fluid/distributed/table/graph_edge.h @@ -0,0 +1,46 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +namespace paddle { +namespace distributed { + +class GraphEdgeBlob { + public: + GraphEdgeBlob() {} + virtual ~GraphEdgeBlob() {} + size_t size() { return id_arr.size(); } + virtual void add_edge(uint64_t id, float weight); + uint64_t get_id(int idx) { return id_arr[idx]; } + virtual float get_weight(int idx) { return 1; } + + protected: + std::vector id_arr; +}; + +class WeightedGraphEdgeBlob : public GraphEdgeBlob { + public: + WeightedGraphEdgeBlob() {} + virtual ~WeightedGraphEdgeBlob() {} + virtual void add_edge(uint64_t id, float weight); + virtual float get_weight(int idx) { return weight_arr[idx]; } + + protected: + std::vector weight_arr; +}; +} +} diff --git a/paddle/fluid/distributed/table/graph_node.cc b/paddle/fluid/distributed/table/graph_node.cc new file mode 100644 index 00000000000..27a2cafaf4f --- /dev/null +++ b/paddle/fluid/distributed/table/graph_node.cc @@ -0,0 +1,117 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/table/graph_node.h" +#include +namespace paddle { +namespace distributed { + +GraphNode::~GraphNode() { + if (sampler != nullptr) { + delete sampler; + sampler = nullptr; + } + if (edges != nullptr) { + delete edges; + edges = nullptr; + } +} + +int Node::weight_size = sizeof(float); +int Node::id_size = sizeof(uint64_t); +int Node::int_size = sizeof(int); + +int Node::get_size(bool need_feature) { return id_size + int_size; } + +void Node::to_buffer(char* buffer, bool need_feature) { + memcpy(buffer, &id, id_size); + buffer += id_size; + + int feat_num = 0; + memcpy(buffer, &feat_num, sizeof(int)); +} + +void Node::recover_from_buffer(char* buffer) { memcpy(&id, buffer, id_size); } + +int FeatureNode::get_size(bool need_feature) { + int size = id_size + int_size; // id, feat_num + if (need_feature) { + size += feature.size() * int_size; + for (const std::string& fea : feature) { + size += fea.size(); + } + } + return size; +} + +void GraphNode::build_edges(bool is_weighted) { + if (edges == nullptr) { + if (is_weighted == true) { + edges = new WeightedGraphEdgeBlob(); + } else { + edges = new GraphEdgeBlob(); + } + } +} +void GraphNode::build_sampler(std::string sample_type) { + if (sample_type == "random") { + sampler = new RandomSampler(); + } else if (sample_type == "weighted") { + sampler = new WeightedSampler(); + } + sampler->build(edges); +} +void FeatureNode::to_buffer(char* buffer, bool need_feature) { + memcpy(buffer, &id, id_size); + buffer += id_size; + + int feat_num = 0; + int feat_len; + if (need_feature) { + feat_num += feature.size(); + memcpy(buffer, &feat_num, sizeof(int)); + buffer += sizeof(int); + for (int i = 0; i < feat_num; ++i) { + feat_len = feature[i].size(); + memcpy(buffer, &feat_len, sizeof(int)); + buffer += sizeof(int); + memcpy(buffer, feature[i].c_str(), feature[i].size()); + buffer += feature[i].size(); + } + } else { + memcpy(buffer, &feat_num, sizeof(int)); + } +} +void FeatureNode::recover_from_buffer(char* buffer) { + int feat_num, feat_len; + memcpy(&id, buffer, id_size); + buffer += id_size; + + memcpy(&feat_num, buffer, sizeof(int)); + buffer += sizeof(int); + + feature.clear(); + for (int i = 0; i < feat_num; ++i) { + memcpy(&feat_len, buffer, sizeof(int)); + buffer += sizeof(int); + + char str[feat_len + 1]; + memcpy(str, buffer, feat_len); + buffer += feat_len; + str[feat_len] = '\0'; + feature.push_back(std::string(str)); + } +} +} +} diff --git a/paddle/fluid/distributed/table/graph_node.h b/paddle/fluid/distributed/table/graph_node.h new file mode 100644 index 00000000000..c3e8e3ce5b5 --- /dev/null +++ b/paddle/fluid/distributed/table/graph_node.h @@ -0,0 +1,127 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "paddle/fluid/distributed/table/graph_weighted_sampler.h" +namespace paddle { +namespace distributed { + +class Node { + public: + Node() {} + Node(uint64_t id) : id(id) {} + virtual ~Node() {} + static int id_size, int_size, weight_size; + uint64_t get_id() { return id; } + void set_id(uint64_t id) { this->id = id; } + + virtual void build_edges(bool is_weighted) {} + virtual void build_sampler(std::string sample_type) {} + virtual void add_edge(uint64_t id, float weight) {} + virtual std::vector sample_k(int k) { return std::vector(); } + virtual uint64_t get_neighbor_id(int idx) { return 0; } + virtual float get_neighbor_weight(int idx) { return 1.; } + + virtual int get_size(bool need_feature); + virtual void to_buffer(char *buffer, bool need_feature); + virtual void recover_from_buffer(char *buffer); + virtual std::string get_feature(int idx) { return std::string(""); } + virtual void set_feature(int idx, std::string str) {} + virtual void set_feature_size(int size) {} + virtual int get_feature_size() { return 0; } + + protected: + uint64_t id; +}; + +class GraphNode : public Node { + public: + GraphNode() : Node(), sampler(nullptr), edges(nullptr) {} + GraphNode(uint64_t id) : Node(id), sampler(nullptr), edges(nullptr) {} + virtual ~GraphNode(); + virtual void build_edges(bool is_weighted); + virtual void build_sampler(std::string sample_type); + virtual void add_edge(uint64_t id, float weight) { + edges->add_edge(id, weight); + } + virtual std::vector sample_k(int k) { return sampler->sample_k(k); } + virtual uint64_t get_neighbor_id(int idx) { return edges->get_id(idx); } + virtual float get_neighbor_weight(int idx) { return edges->get_weight(idx); } + + protected: + Sampler *sampler; + GraphEdgeBlob *edges; +}; + +class FeatureNode : public Node { + public: + FeatureNode() : Node() {} + FeatureNode(uint64_t id) : Node(id) {} + virtual ~FeatureNode() {} + virtual int get_size(bool need_feature); + virtual void to_buffer(char *buffer, bool need_feature); + virtual void recover_from_buffer(char *buffer); + virtual std::string get_feature(int idx) { + if (idx < (int)this->feature.size()) { + return this->feature[idx]; + } else { + return std::string(""); + } + } + + virtual void set_feature(int idx, std::string str) { + if (idx >= (int)this->feature.size()) { + this->feature.resize(idx + 1); + } + this->feature[idx] = str; + } + virtual void set_feature_size(int size) { this->feature.resize(size); } + virtual int get_feature_size() { return this->feature.size(); } + + template + static std::string parse_value_to_bytes(std::vector feat_str) { + T v; + size_t Tsize = sizeof(T) * feat_str.size(); + char buffer[Tsize]; + for (size_t i = 0; i < feat_str.size(); i++) { + std::stringstream ss(feat_str[i]); + ss >> v; + std::memcpy(buffer + sizeof(T) * i, (char *)&v, sizeof(T)); + } + return std::string(buffer, Tsize); + } + + template + static std::vector parse_bytes_to_array(std::string feat_str) { + T v; + std::vector out; + size_t start = 0; + const char *buffer = feat_str.data(); + while (start < feat_str.size()) { + std::memcpy((char *)&v, buffer + start, sizeof(T)); + start += sizeof(T); + out.push_back(v); + } + return out; + } + + protected: + std::vector feature; +}; +} +} diff --git a/paddle/fluid/distributed/table/graph_weighted_sampler.cc b/paddle/fluid/distributed/table/graph_weighted_sampler.cc new file mode 100644 index 00000000000..059a1d64bc3 --- /dev/null +++ b/paddle/fluid/distributed/table/graph_weighted_sampler.cc @@ -0,0 +1,150 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/table/graph_weighted_sampler.h" +#include +#include +namespace paddle { +namespace distributed { + +void RandomSampler::build(GraphEdgeBlob *edges) { this->edges = edges; } + +std::vector RandomSampler::sample_k(int k) { + int n = edges->size(); + if (k > n) { + k = n; + } + struct timespec tn; + clock_gettime(CLOCK_REALTIME, &tn); + srand(tn.tv_nsec); + std::vector sample_result; + std::unordered_map replace_map; + while (k--) { + int rand_int = rand() % n; + auto iter = replace_map.find(rand_int); + if (iter == replace_map.end()) { + sample_result.push_back(rand_int); + } else { + sample_result.push_back(iter->second); + } + + iter = replace_map.find(n - 1); + if (iter == replace_map.end()) { + replace_map[rand_int] = n - 1; + } else { + replace_map[rand_int] = iter->second; + } + --n; + } + return sample_result; +} + +WeightedSampler::WeightedSampler() { + left = nullptr; + right = nullptr; + edges = nullptr; +} + +WeightedSampler::~WeightedSampler() { + if (left != nullptr) { + delete left; + left = nullptr; + } + if (right != nullptr) { + delete right; + right = nullptr; + } +} + +void WeightedSampler::build(GraphEdgeBlob *edges) { + if (left != nullptr) { + delete left; + left = nullptr; + } + if (right != nullptr) { + delete right; + right = nullptr; + } + return build_one((WeightedGraphEdgeBlob *)edges, 0, edges->size()); +} + +void WeightedSampler::build_one(WeightedGraphEdgeBlob *edges, int start, + int end) { + count = 0; + this->edges = edges; + if (start + 1 == end) { + left = right = nullptr; + idx = start; + count = 1; + weight = edges->get_weight(idx); + + } else { + left = new WeightedSampler(); + right = new WeightedSampler(); + left->build_one(edges, start, start + (end - start) / 2); + right->build_one(edges, start + (end - start) / 2, end); + weight = left->weight + right->weight; + count = left->count + right->count; + } +} +std::vector WeightedSampler::sample_k(int k) { + if (k > count) { + k = count; + } + std::vector sample_result; + float subtract; + std::unordered_map subtract_weight_map; + std::unordered_map subtract_count_map; + struct timespec tn; + clock_gettime(CLOCK_REALTIME, &tn); + srand(tn.tv_nsec); + while (k--) { + float query_weight = rand() % 100000 / 100000.0; + query_weight *= weight - subtract_weight_map[this]; + sample_result.push_back(sample(query_weight, subtract_weight_map, + subtract_count_map, subtract)); + } + return sample_result; +} + +int WeightedSampler::sample( + float query_weight, + std::unordered_map &subtract_weight_map, + std::unordered_map &subtract_count_map, + float &subtract) { + if (left == nullptr) { + subtract_weight_map[this] = weight; + subtract = weight; + subtract_count_map[this] = 1; + return idx; + } + int left_count = left->count - subtract_count_map[left]; + int right_count = right->count - subtract_count_map[right]; + float left_subtract = subtract_weight_map[left]; + int return_idx; + if (right_count == 0 || + left_count > 0 && left->weight - left_subtract >= query_weight) { + return_idx = left->sample(query_weight, subtract_weight_map, + subtract_count_map, subtract); + } else { + return_idx = + right->sample(query_weight - (left->weight - left_subtract), + subtract_weight_map, subtract_count_map, subtract); + } + subtract_weight_map[this] += subtract; + subtract_count_map[this]++; + return return_idx; +} +} +} diff --git a/paddle/fluid/distributed/table/graph_weighted_sampler.h b/paddle/fluid/distributed/table/graph_weighted_sampler.h new file mode 100644 index 00000000000..cfc341d27c6 --- /dev/null +++ b/paddle/fluid/distributed/table/graph_weighted_sampler.h @@ -0,0 +1,58 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include "paddle/fluid/distributed/table/graph_edge.h" +namespace paddle { +namespace distributed { + +class Sampler { + public: + virtual ~Sampler() {} + virtual void build(GraphEdgeBlob *edges) = 0; + virtual std::vector sample_k(int k) = 0; +}; + +class RandomSampler : public Sampler { + public: + virtual ~RandomSampler() {} + virtual void build(GraphEdgeBlob *edges); + virtual std::vector sample_k(int k); + GraphEdgeBlob *edges; +}; + +class WeightedSampler : public Sampler { + public: + WeightedSampler(); + virtual ~WeightedSampler(); + WeightedSampler *left, *right; + float weight; + int count; + int idx; + GraphEdgeBlob *edges; + virtual void build(GraphEdgeBlob *edges); + virtual void build_one(WeightedGraphEdgeBlob *edges, int start, int end); + virtual std::vector sample_k(int k); + + private: + int sample(float query_weight, + std::unordered_map &subtract_weight_map, + std::unordered_map &subtract_count_map, + float &subtract); +}; +} +} diff --git a/paddle/fluid/distributed/table/table.cc b/paddle/fluid/distributed/table/table.cc index dfaaa6ffc12..600be954cb5 100644 --- a/paddle/fluid/distributed/table/table.cc +++ b/paddle/fluid/distributed/table/table.cc @@ -18,6 +18,7 @@ #include "paddle/fluid/distributed/common/registerer.h" #include "paddle/fluid/distributed/table/common_dense_table.h" +#include "paddle/fluid/distributed/table/common_graph_table.h" #include "paddle/fluid/distributed/table/common_sparse_table.h" #include "paddle/fluid/distributed/table/sparse_geo_table.h" #include "paddle/fluid/distributed/table/tensor_accessor.h" @@ -25,7 +26,7 @@ namespace paddle { namespace distributed { - +REGISTER_PSCORE_CLASS(Table, GraphTable); REGISTER_PSCORE_CLASS(Table, CommonDenseTable); REGISTER_PSCORE_CLASS(Table, CommonSparseTable); REGISTER_PSCORE_CLASS(Table, SparseGeoTable); @@ -75,5 +76,6 @@ int32_t Table::initialize_accessor() { _value_accesor.reset(accessor); return 0; } + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/table/table.h b/paddle/fluid/distributed/table/table.h index 65c99d2bbd4..d64e805af40 100644 --- a/paddle/fluid/distributed/table/table.h +++ b/paddle/fluid/distributed/table/table.h @@ -21,6 +21,7 @@ #include #include #include "paddle/fluid/distributed/table/accessor.h" +#include "paddle/fluid/distributed/table/graph_node.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/platform/device_context.h" @@ -86,6 +87,31 @@ class Table { return 0; } + // only for graph table + virtual int32_t pull_graph_list(int start, int total_size, + std::unique_ptr &buffer, + int &actual_size, bool need_feature, + int step = 1) { + return 0; + } + // only for graph table + virtual int32_t random_sample_neighboors( + uint64_t *node_ids, int sample_size, + std::vector> &buffers, + std::vector &actual_sizes) { + return 0; + } + + virtual int32_t random_sample_nodes(int sample_size, + std::unique_ptr &buffers, + int &actual_sizes) { + return 0; + } + virtual int32_t get_node_feat(const std::vector &node_ids, + const std::vector &feature_names, + std::vector> &res) { + return 0; + } virtual int32_t pour() { return 0; } virtual void clear() = 0; @@ -141,5 +167,6 @@ class TableManager { TableManager() {} ~TableManager() {} }; + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/test/CMakeLists.txt b/paddle/fluid/distributed/test/CMakeLists.txt index adedd049023..b756c740ac7 100644 --- a/paddle/fluid/distributed/test/CMakeLists.txt +++ b/paddle/fluid/distributed/test/CMakeLists.txt @@ -15,3 +15,6 @@ cc_test(brpc_service_sparse_sgd_test SRCS brpc_service_sparse_sgd_test.cc DEPS s set_source_files_properties(brpc_utils_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(brpc_utils_test SRCS brpc_utils_test.cc DEPS brpc_utils scope math_function ${COMMON_DEPS} ${RPC_DEPS}) + +set_source_files_properties(graph_node_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_test(graph_node_test SRCS graph_node_test.cc DEPS graph_py_service scope server client communicator ps_service boost table ps_framework_proto ${COMMON_DEPS}) diff --git a/paddle/fluid/distributed/test/graph_node_test.cc b/paddle/fluid/distributed/test/graph_node_test.cc new file mode 100644 index 00000000000..79ab2795963 --- /dev/null +++ b/paddle/fluid/distributed/test/graph_node_test.cc @@ -0,0 +1,556 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include // NOLINT +#include +#include +#include +#include // NOLINT +#include +#include +#include "google/protobuf/text_format.h" + +#include "gtest/gtest.h" +#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/service/brpc_ps_client.h" +#include "paddle/fluid/distributed/service/brpc_ps_server.h" +#include "paddle/fluid/distributed/service/env.h" +#include "paddle/fluid/distributed/service/graph_brpc_client.h" +#include "paddle/fluid/distributed/service/graph_brpc_server.h" +#include "paddle/fluid/distributed/service/graph_py_service.h" +#include "paddle/fluid/distributed/service/ps_client.h" +#include "paddle/fluid/distributed/service/sendrecv.pb.h" +#include "paddle/fluid/distributed/service/service.h" +#include "paddle/fluid/distributed/table/graph_node.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/string/printf.h" + +namespace framework = paddle::framework; +namespace platform = paddle::platform; +namespace operators = paddle::operators; +namespace math = paddle::operators::math; +namespace memory = paddle::memory; +namespace distributed = paddle::distributed; + +void testSampleNodes( + std::shared_ptr& worker_ptr_) { + std::vector ids; + auto pull_status = worker_ptr_->random_sample_nodes(0, 0, 6, ids); + std::unordered_set s; + std::unordered_set s1 = {37, 59}; + pull_status.wait(); + for (auto id : ids) s.insert(id); + ASSERT_EQ(true, s.size() == s1.size()); + for (auto id : s) { + ASSERT_EQ(true, s1.find(id) != s1.end()); + } +} + +void testFeatureNodeSerializeInt() { + std::string out = + distributed::FeatureNode::parse_value_to_bytes({"123", "345"}); + std::vector out2 = + distributed::FeatureNode::parse_bytes_to_array(out); + ASSERT_EQ(out2[0], 123); + ASSERT_EQ(out2[1], 345); +} + +void testFeatureNodeSerializeInt64() { + std::string out = + distributed::FeatureNode::parse_value_to_bytes({"123", "345"}); + std::vector out2 = + distributed::FeatureNode::parse_bytes_to_array(out); + ASSERT_EQ(out2[0], 123); + ASSERT_EQ(out2[1], 345); +} + +void testFeatureNodeSerializeFloat32() { + std::string out = distributed::FeatureNode::parse_value_to_bytes( + {"123.123", "345.123"}); + std::vector out2 = + distributed::FeatureNode::parse_bytes_to_array(out); + float eps; + std::cout << "Float " << out2[0] << " " << 123.123 << std::endl; + eps = out2[0] - 123.123; + ASSERT_LE(eps * eps, 1e-5); + eps = out2[1] - 345.123; + ASSERT_LE(eps * eps, 1e-5); +} + +void testFeatureNodeSerializeFloat64() { + std::string out = distributed::FeatureNode::parse_value_to_bytes( + {"123.123", "345.123"}); + std::vector out2 = + distributed::FeatureNode::parse_bytes_to_array(out); + float eps; + eps = out2[0] - 123.123; + std::cout << "Float64 " << out2[0] << " " << 123.123 << std::endl; + ASSERT_LE(eps * eps, 1e-5); + eps = out2[1] - 345.123; + ASSERT_LE(eps * eps, 1e-5); +} + +void testSingleSampleNeighboor( + std::shared_ptr& worker_ptr_) { + std::vector>> vs; + auto pull_status = worker_ptr_->batch_sample_neighboors( + 0, std::vector(1, 37), 4, vs); + pull_status.wait(); + + std::unordered_set s; + std::unordered_set s1 = {112, 45, 145}; + for (auto g : vs[0]) { + s.insert(g.first); + } + ASSERT_EQ(s.size(), 3); + for (auto g : s) { + ASSERT_EQ(true, s1.find(g) != s1.end()); + } + VLOG(0) << "test single done"; + s.clear(); + s1.clear(); + vs.clear(); + pull_status = worker_ptr_->batch_sample_neighboors( + 0, std::vector(1, 96), 4, vs); + pull_status.wait(); + s1 = {111, 48, 247}; + for (auto g : vs[0]) { + s.insert(g.first); + } + ASSERT_EQ(s.size(), 3); + for (auto g : s) { + ASSERT_EQ(true, s1.find(g) != s1.end()); + } +} + +void testBatchSampleNeighboor( + std::shared_ptr& worker_ptr_) { + std::vector>> vs; + std::vector v = {37, 96}; + auto pull_status = worker_ptr_->batch_sample_neighboors(0, v, 4, vs); + pull_status.wait(); + std::unordered_set s; + std::unordered_set s1 = {112, 45, 145}; + for (auto g : vs[0]) { + s.insert(g.first); + } + ASSERT_EQ(s.size(), 3); + for (auto g : s) { + ASSERT_EQ(true, s1.find(g) != s1.end()); + } + s.clear(); + s1.clear(); + s1 = {111, 48, 247}; + for (auto g : vs[1]) { + s.insert(g.first); + } + ASSERT_EQ(s.size(), 3); + for (auto g : s) { + ASSERT_EQ(true, s1.find(g) != s1.end()); + } +} + +void testGraphToBuffer(); +// std::string nodes[] = {std::string("37\taa\t45;0.34\t145;0.31\t112;0.21"), +// std::string("96\tfeature\t48;1.4\t247;0.31\t111;1.21"), +// std::string("59\ttreat\t45;0.34\t145;0.31\t112;0.21"), +// std::string("97\tfood\t48;1.4\t247;0.31\t111;1.21")}; + +std::string edges[] = { + std::string("37\t45\t0.34"), std::string("37\t145\t0.31"), + std::string("37\t112\t0.21"), std::string("96\t48\t1.4"), + std::string("96\t247\t0.31"), std::string("96\t111\t1.21"), + std::string("59\t45\t0.34"), std::string("59\t145\t0.31"), + std::string("59\t122\t0.21"), std::string("97\t48\t0.34"), + std::string("97\t247\t0.31"), std::string("97\t111\t0.21")}; +char edge_file_name[] = "edges.txt"; + +std::string nodes[] = { + std::string("user\t37\ta 0.34\tb 13 14\tc hello\td abc"), + std::string("user\t96\ta 0.31\tb 15 10\tc 96hello\td abcd"), + std::string("user\t59\ta 0.11\tb 11 14"), + std::string("user\t97\ta 0.11\tb 12 11"), + std::string("item\t45\ta 0.21"), + std::string("item\t145\ta 0.21"), + std::string("item\t112\ta 0.21"), + std::string("item\t48\ta 0.21"), + std::string("item\t247\ta 0.21"), + std::string("item\t111\ta 0.21"), + std::string("item\t46\ta 0.21"), + std::string("item\t146\ta 0.21"), + std::string("item\t122\ta 0.21"), + std::string("item\t49\ta 0.21"), + std::string("item\t248\ta 0.21"), + std::string("item\t113\ta 0.21")}; +char node_file_name[] = "nodes.txt"; + +void prepare_file(char file_name[], bool load_edge) { + std::ofstream ofile; + ofile.open(file_name); + if (load_edge) { + for (auto x : edges) { + ofile << x << std::endl; + } + } else { + for (auto x : nodes) { + ofile << x << std::endl; + } + } + ofile.close(); +} +void GetDownpourSparseTableProto( + ::paddle::distributed::TableParameter* sparse_table_proto) { + sparse_table_proto->set_table_id(0); + sparse_table_proto->set_table_class("GraphTable"); + sparse_table_proto->set_shard_num(127); + sparse_table_proto->set_type(::paddle::distributed::PS_SPARSE_TABLE); + ::paddle::distributed::TableAccessorParameter* accessor_proto = + sparse_table_proto->mutable_accessor(); + accessor_proto->set_accessor_class("CommMergeAccessor"); +} + +::paddle::distributed::PSParameter GetServerProto() { + // Generate server proto desc + ::paddle::distributed::PSParameter server_fleet_desc; + ::paddle::distributed::ServerParameter* server_proto = + server_fleet_desc.mutable_server_param(); + ::paddle::distributed::DownpourServerParameter* downpour_server_proto = + server_proto->mutable_downpour_server_param(); + ::paddle::distributed::ServerServiceParameter* server_service_proto = + downpour_server_proto->mutable_service_param(); + server_service_proto->set_service_class("GraphBrpcService"); + server_service_proto->set_server_class("GraphBrpcServer"); + server_service_proto->set_client_class("GraphBrpcClient"); + server_service_proto->set_start_server_port(0); + server_service_proto->set_server_thread_num(12); + + ::paddle::distributed::TableParameter* sparse_table_proto = + downpour_server_proto->add_downpour_table_param(); + GetDownpourSparseTableProto(sparse_table_proto); + return server_fleet_desc; +} + +::paddle::distributed::PSParameter GetWorkerProto() { + ::paddle::distributed::PSParameter worker_fleet_desc; + ::paddle::distributed::WorkerParameter* worker_proto = + worker_fleet_desc.mutable_worker_param(); + + ::paddle::distributed::DownpourWorkerParameter* downpour_worker_proto = + worker_proto->mutable_downpour_worker_param(); + + ::paddle::distributed::TableParameter* worker_sparse_table_proto = + downpour_worker_proto->add_downpour_table_param(); + GetDownpourSparseTableProto(worker_sparse_table_proto); + + ::paddle::distributed::ServerParameter* server_proto = + worker_fleet_desc.mutable_server_param(); + ::paddle::distributed::DownpourServerParameter* downpour_server_proto = + server_proto->mutable_downpour_server_param(); + ::paddle::distributed::ServerServiceParameter* server_service_proto = + downpour_server_proto->mutable_service_param(); + server_service_proto->set_service_class("GraphBrpcService"); + server_service_proto->set_server_class("GraphBrpcServer"); + server_service_proto->set_client_class("GraphBrpcClient"); + server_service_proto->set_start_server_port(0); + server_service_proto->set_server_thread_num(12); + + ::paddle::distributed::TableParameter* server_sparse_table_proto = + downpour_server_proto->add_downpour_table_param(); + GetDownpourSparseTableProto(server_sparse_table_proto); + + return worker_fleet_desc; +} + +/*-------------------------------------------------------------------------*/ + +std::string ip_ = "127.0.0.1", ip2 = "127.0.0.1"; +uint32_t port_ = 5209, port2 = 5210; + +std::vector host_sign_list_; + +std::shared_ptr pserver_ptr_, + pserver_ptr2; + +std::shared_ptr worker_ptr_; + +void RunServer() { + LOG(INFO) << "init first server"; + ::paddle::distributed::PSParameter server_proto = GetServerProto(); + + auto _ps_env = paddle::distributed::PaddlePSEnvironment(); + _ps_env.set_ps_servers(&host_sign_list_, 2); // test + pserver_ptr_ = std::shared_ptr( + (paddle::distributed::GraphBrpcServer*) + paddle::distributed::PSServerFactory::create(server_proto)); + std::vector empty_vec; + framework::ProgramDesc empty_prog; + empty_vec.push_back(empty_prog); + pserver_ptr_->configure(server_proto, _ps_env, 0, empty_vec); + LOG(INFO) << "first server, run start(ip,port)"; + pserver_ptr_->start(ip_, port_); + LOG(INFO) << "init first server Done"; +} + +void RunServer2() { + LOG(INFO) << "init second server"; + ::paddle::distributed::PSParameter server_proto2 = GetServerProto(); + + auto _ps_env2 = paddle::distributed::PaddlePSEnvironment(); + _ps_env2.set_ps_servers(&host_sign_list_, 2); // test + pserver_ptr2 = std::shared_ptr( + (paddle::distributed::GraphBrpcServer*) + paddle::distributed::PSServerFactory::create(server_proto2)); + std::vector empty_vec2; + framework::ProgramDesc empty_prog2; + empty_vec2.push_back(empty_prog2); + pserver_ptr2->configure(server_proto2, _ps_env2, 1, empty_vec2); + pserver_ptr2->start(ip2, port2); +} + +void RunClient( + std::map>& dense_regions, + int index, paddle::distributed::PsBaseService* service) { + ::paddle::distributed::PSParameter worker_proto = GetWorkerProto(); + paddle::distributed::PaddlePSEnvironment _ps_env; + auto servers_ = host_sign_list_.size(); + _ps_env = paddle::distributed::PaddlePSEnvironment(); + _ps_env.set_ps_servers(&host_sign_list_, servers_); + worker_ptr_ = std::shared_ptr( + (paddle::distributed::GraphBrpcClient*) + paddle::distributed::PSClientFactory::create(worker_proto)); + worker_ptr_->configure(worker_proto, dense_regions, _ps_env, 0); + worker_ptr_->set_shard_num(127); + worker_ptr_->set_local_channel(index); + worker_ptr_->set_local_graph_service( + (paddle::distributed::GraphBrpcService*)service); +} + +void RunBrpcPushSparse() { + setenv("http_proxy", "", 1); + setenv("https_proxy", "", 1); + prepare_file(edge_file_name, 1); + prepare_file(node_file_name, 0); + auto ph_host = paddle::distributed::PSHost(ip_, port_, 0); + host_sign_list_.push_back(ph_host.serialize_to_string()); + + // test-start + auto ph_host2 = paddle::distributed::PSHost(ip2, port2, 1); + host_sign_list_.push_back(ph_host2.serialize_to_string()); + // test-end + // Srart Server + std::thread* server_thread = new std::thread(RunServer); + std::thread* server_thread2 = new std::thread(RunServer2); + sleep(1); + + std::map> dense_regions; + dense_regions.insert( + std::pair>(0, {})); + auto regions = dense_regions[0]; + + RunClient(dense_regions, 0, pserver_ptr_->get_service()); + + /*-----------------------Test Server Init----------------------------------*/ + auto pull_status = + worker_ptr_->load(0, std::string(edge_file_name), std::string("e>")); + srand(time(0)); + pull_status.wait(); + std::vector>> vs; + testSampleNodes(worker_ptr_); + sleep(5); + testSingleSampleNeighboor(worker_ptr_); + testBatchSampleNeighboor(worker_ptr_); + pull_status = worker_ptr_->batch_sample_neighboors( + 0, std::vector(1, 10240001024), 4, vs); + pull_status.wait(); + ASSERT_EQ(0, vs[0].size()); + + std::vector nodes; + pull_status = worker_ptr_->pull_graph_list(0, 0, 0, 1, 1, nodes); + pull_status.wait(); + ASSERT_EQ(nodes.size(), 1); + ASSERT_EQ(nodes[0].get_id(), 37); + nodes.clear(); + pull_status = worker_ptr_->pull_graph_list(0, 0, 1, 4, 1, nodes); + pull_status.wait(); + ASSERT_EQ(nodes.size(), 1); + ASSERT_EQ(nodes[0].get_id(), 59); + for (auto g : nodes) { + std::cout << g.get_id() << std::endl; + } + distributed::GraphPyServer server1, server2; + distributed::GraphPyClient client1, client2; + std::string ips_str = "127.0.0.1:5211;127.0.0.1:5212"; + std::vector edge_types = {std::string("user2item")}; + std::vector node_types = {std::string("user"), + std::string("item")}; + VLOG(0) << "make 2 servers"; + server1.set_up(ips_str, 127, node_types, edge_types, 0); + server2.set_up(ips_str, 127, node_types, edge_types, 1); + + server1.add_table_feat_conf("user", "a", "float32", 1); + server1.add_table_feat_conf("user", "b", "int32", 2); + server1.add_table_feat_conf("user", "c", "string", 1); + server1.add_table_feat_conf("user", "d", "string", 1); + server1.add_table_feat_conf("item", "a", "float32", 1); + + server2.add_table_feat_conf("user", "a", "float32", 1); + server2.add_table_feat_conf("user", "b", "int32", 2); + server2.add_table_feat_conf("user", "c", "string", 1); + server2.add_table_feat_conf("user", "d", "string", 1); + server2.add_table_feat_conf("item", "a", "float32", 1); + + client1.set_up(ips_str, 127, node_types, edge_types, 0); + + client1.add_table_feat_conf("user", "a", "float32", 1); + client1.add_table_feat_conf("user", "b", "int32", 2); + client1.add_table_feat_conf("user", "c", "string", 1); + client1.add_table_feat_conf("user", "d", "string", 1); + client1.add_table_feat_conf("item", "a", "float32", 1); + + client2.set_up(ips_str, 127, node_types, edge_types, 1); + + client2.add_table_feat_conf("user", "a", "float32", 1); + client2.add_table_feat_conf("user", "b", "int32", 2); + client2.add_table_feat_conf("user", "c", "string", 1); + client2.add_table_feat_conf("user", "d", "string", 1); + client2.add_table_feat_conf("item", "a", "float32", 1); + + server1.start_server(false); + std::cout << "first server done" << std::endl; + server2.start_server(false); + std::cout << "second server done" << std::endl; + client1.start_client(); + std::cout << "first client done" << std::endl; + client2.start_client(); + std::cout << "first client done" << std::endl; + std::cout << "started" << std::endl; + VLOG(0) << "come to set local server"; + client1.bind_local_server(0, server1); + VLOG(0) << "first bound"; + client2.bind_local_server(1, server2); + VLOG(0) << "second bound"; + client1.load_node_file(std::string("user"), std::string(node_file_name)); + client1.load_node_file(std::string("item"), std::string(node_file_name)); + client1.load_edge_file(std::string("user2item"), std::string(edge_file_name), + 0); + nodes.clear(); + + nodes = client1.pull_graph_list(std::string("user"), 0, 1, 4, 1); + + ASSERT_EQ(nodes[0].get_id(), 59); + nodes.clear(); + + // Test Pull by step + + std::unordered_set count_item_nodes; + // pull by step 2 + for (int test_step = 1; test_step < 4; test_step++) { + count_item_nodes.clear(); + std::cout << "check pull graph list by step " << test_step << std::endl; + for (int server_id = 0; server_id < 2; server_id++) { + for (int start_step = 0; start_step < test_step; start_step++) { + nodes = client1.pull_graph_list(std::string("item"), server_id, + start_step, 12, test_step); + for (auto g : nodes) { + count_item_nodes.insert(g.get_id()); + } + nodes.clear(); + } + } + ASSERT_EQ(count_item_nodes.size(), 12); + } + + vs = client1.batch_sample_neighboors(std::string("user2item"), + std::vector(1, 96), 4); + ASSERT_EQ(vs[0].size(), 3); + std::vector node_ids; + node_ids.push_back(96); + node_ids.push_back(37); + vs = client1.batch_sample_neighboors(std::string("user2item"), node_ids, 4); + + ASSERT_EQ(vs.size(), 2); + std::vector nodes_ids = client2.random_sample_nodes("user", 0, 6); + ASSERT_EQ(nodes_ids.size(), 2); + ASSERT_EQ(true, (nodes_ids[0] == 59 && nodes_ids[1] == 37) || + (nodes_ids[0] == 37 && nodes_ids[1] == 59)); + + // Test get node feat + node_ids.clear(); + node_ids.push_back(37); + node_ids.push_back(96); + std::vector feature_names; + feature_names.push_back(std::string("c")); + feature_names.push_back(std::string("d")); + auto node_feat = + client1.get_node_feat(std::string("user"), node_ids, feature_names); + ASSERT_EQ(node_feat.size(), 2); + ASSERT_EQ(node_feat[0].size(), 2); + VLOG(0) << "get_node_feat: " << node_feat[0][0]; + VLOG(0) << "get_node_feat: " << node_feat[0][1]; + VLOG(0) << "get_node_feat: " << node_feat[1][0]; + VLOG(0) << "get_node_feat: " << node_feat[1][1]; + + // Test string + node_ids.clear(); + node_ids.push_back(37); + node_ids.push_back(96); + // std::vector feature_names; + feature_names.clear(); + feature_names.push_back(std::string("a")); + feature_names.push_back(std::string("b")); + node_feat = + client1.get_node_feat(std::string("user"), node_ids, feature_names); + ASSERT_EQ(node_feat.size(), 2); + ASSERT_EQ(node_feat[0].size(), 2); + VLOG(0) << "get_node_feat: " << node_feat[0][0].size(); + VLOG(0) << "get_node_feat: " << node_feat[0][1].size(); + VLOG(0) << "get_node_feat: " << node_feat[1][0].size(); + VLOG(0) << "get_node_feat: " << node_feat[1][1].size(); + + std::remove(edge_file_name); + std::remove(node_file_name); + LOG(INFO) << "Run stop_server"; + worker_ptr_->stop_server(); + LOG(INFO) << "Run finalize_worker"; + worker_ptr_->finalize_worker(); + testFeatureNodeSerializeInt(); + testFeatureNodeSerializeInt64(); + testFeatureNodeSerializeFloat32(); + testFeatureNodeSerializeFloat64(); + testGraphToBuffer(); + client1.stop_server(); +} + +void testGraphToBuffer() { + ::paddle::distributed::GraphNode s, s1; + s.set_feature_size(1); + s.set_feature(0, std::string("hhhh")); + s.set_id(65); + int size = s.get_size(true); + char str[size]; + s.to_buffer(str, true); + s1.recover_from_buffer(str); + ASSERT_EQ(s.get_id(), s1.get_id()); + VLOG(0) << s.get_feature(0); + VLOG(0) << s1.get_feature(0); +} + +TEST(RunBrpcPushSparse, Run) { RunBrpcPushSparse(); } diff --git a/paddle/fluid/inference/api/demo_ci/clean.sh b/paddle/fluid/inference/api/demo_ci/clean.sh index 0d9f3d2aa23..c265721db57 100755 --- a/paddle/fluid/inference/api/demo_ci/clean.sh +++ b/paddle/fluid/inference/api/demo_ci/clean.sh @@ -1,3 +1,17 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + set -x cd `dirname $0` rm -rf build/ data/ diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 97ebd64a07e..10c79933546 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -7,6 +7,10 @@ set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapp analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper custom_operator) +if (WITH_PSCORE) + set(PYBIND_DEPS ${PYBIND_DEPS} ps_service) + set(PYBIND_DEPS ${PYBIND_DEPS} graph_py_service) +endif() if (WITH_GPU OR WITH_ROCM) set(PYBIND_DEPS ${PYBIND_DEPS} dynload_cuda) set(PYBIND_DEPS ${PYBIND_DEPS} cuda_device_guard) diff --git a/paddle/fluid/pybind/fleet_py.cc b/paddle/fluid/pybind/fleet_py.cc index ba716fb3b55..0a2159667f3 100644 --- a/paddle/fluid/pybind/fleet_py.cc +++ b/paddle/fluid/pybind/fleet_py.cc @@ -32,6 +32,8 @@ limitations under the License. */ #include "paddle/fluid/distributed/fleet.h" #include "paddle/fluid/distributed/service/communicator.h" #include "paddle/fluid/distributed/service/env.h" +#include "paddle/fluid/distributed/service/graph_brpc_client.h" +#include "paddle/fluid/distributed/service/graph_py_service.h" #include "paddle/fluid/distributed/service/heter_client.h" namespace py = pybind11; @@ -39,6 +41,11 @@ using paddle::distributed::CommContext; using paddle::distributed::Communicator; using paddle::distributed::FleetWrapper; using paddle::distributed::HeterClient; +using paddle::distributed::GraphPyService; +using paddle::distributed::GraphNode; +using paddle::distributed::GraphPyServer; +using paddle::distributed::GraphPyClient; +using paddle::distributed::FeatureNode; namespace paddle { namespace pybind { @@ -152,5 +159,58 @@ void BindHeterClient(py::module* m) { .def("stop", &HeterClient::Stop); } +void BindGraphNode(py::module* m) { + py::class_(*m, "GraphNode") + .def(py::init<>()) + .def("get_id", &GraphNode::get_id) + .def("get_feature", &GraphNode::get_feature); +} +void BindGraphPyFeatureNode(py::module* m) { + py::class_(*m, "FeatureNode") + .def(py::init<>()) + .def("get_id", &GraphNode::get_id) + .def("get_feature", &GraphNode::get_feature); +} + +void BindGraphPyService(py::module* m) { + py::class_(*m, "GraphPyService").def(py::init<>()); +} + +void BindGraphPyServer(py::module* m) { + py::class_(*m, "GraphPyServer") + .def(py::init<>()) + .def("start_server", &GraphPyServer::start_server) + .def("set_up", &GraphPyServer::set_up) + .def("add_table_feat_conf", &GraphPyServer::add_table_feat_conf); +} +void BindGraphPyClient(py::module* m) { + py::class_(*m, "GraphPyClient") + .def(py::init<>()) + .def("load_edge_file", &GraphPyClient::load_edge_file) + .def("load_node_file", &GraphPyClient::load_node_file) + .def("set_up", &GraphPyClient::set_up) + .def("add_table_feat_conf", &GraphPyClient::add_table_feat_conf) + .def("pull_graph_list", &GraphPyClient::pull_graph_list) + .def("start_client", &GraphPyClient::start_client) + .def("batch_sample_neighboors", &GraphPyClient::batch_sample_neighboors) + .def("random_sample_nodes", &GraphPyClient::random_sample_nodes) + .def("stop_server", &GraphPyClient::stop_server) + .def("get_node_feat", + [](GraphPyClient& self, std::string node_type, + std::vector node_ids, + std::vector feature_names) { + auto feats = + self.get_node_feat(node_type, node_ids, feature_names); + std::vector> bytes_feats(feats.size()); + for (int i = 0; i < feats.size(); ++i) { + for (int j = 0; j < feats[i].size(); ++j) { + bytes_feats[i].push_back(py::bytes(feats[i][j])); + } + } + return bytes_feats; + }) + .def("bind_local_server", &GraphPyClient::bind_local_server); +} + } // end namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/fleet_py.h b/paddle/fluid/pybind/fleet_py.h index 7f471598ad2..11b430cd208 100644 --- a/paddle/fluid/pybind/fleet_py.h +++ b/paddle/fluid/pybind/fleet_py.h @@ -27,6 +27,10 @@ void BindPSHost(py::module* m); void BindCommunicatorContext(py::module* m); void BindDistCommunicator(py::module* m); void BindHeterClient(py::module* m); - +void BindGraphNode(py::module* m); +void BindGraphPyService(py::module* m); +void BindGraphPyFeatureNode(py::module* m); +void BindGraphPyServer(py::module* m); +void BindGraphPyClient(py::module* m); } // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index d8ee80c0070..29c7f00142d 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -2896,6 +2896,11 @@ All parameter, weight, gradient are variables in Paddle. BindCommunicatorContext(&m); BindDistCommunicator(&m); BindHeterClient(&m); + BindGraphPyFeatureNode(&m); + BindGraphNode(&m); + BindGraphPyService(&m); + BindGraphPyServer(&m); + BindGraphPyClient(&m); #endif } } // namespace pybind diff --git a/paddle/scripts/build_docker_images.sh b/paddle/scripts/build_docker_images.sh index a90f0885294..2b584cdca6b 100644 --- a/paddle/scripts/build_docker_images.sh +++ b/paddle/scripts/build_docker_images.sh @@ -1,4 +1,19 @@ #!/bin/sh + +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + set -xe REPO="${REPO:-paddlepaddle}" diff --git a/paddle/scripts/docker/root/.scripts/git-completion.sh b/paddle/scripts/docker/root/.scripts/git-completion.sh index bdddef5ac2f..c43e88a4acd 100755 --- a/paddle/scripts/docker/root/.scripts/git-completion.sh +++ b/paddle/scripts/docker/root/.scripts/git-completion.sh @@ -1,4 +1,19 @@ #!bash + +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # # bash/zsh completion support for core Git. # diff --git a/paddle/scripts/fast_install.sh b/paddle/scripts/fast_install.sh index 1034b1c5c10..cacec55d3bc 100644 --- a/paddle/scripts/fast_install.sh +++ b/paddle/scripts/fast_install.sh @@ -1,5 +1,19 @@ #!/bin/bash +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + ## purple to echo function purple(){ echo -e "\033[35m$1\033[0m" diff --git a/python/paddle/fluid/dataloader/fetcher.py b/python/paddle/fluid/dataloader/fetcher.py index 9382a704223..41e12fbc68e 100644 --- a/python/paddle/fluid/dataloader/fetcher.py +++ b/python/paddle/fluid/dataloader/fetcher.py @@ -27,8 +27,8 @@ class _DatasetFetcher(object): class _IterableDatasetFetcher(_DatasetFetcher): def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last): - super(_IterableDatasetFetcher, self).__init__(dataset, auto_collate_batch, - collate_fn, drop_last) + super(_IterableDatasetFetcher, self).__init__( + dataset, auto_collate_batch, collate_fn, drop_last) self.dataset_iter = iter(dataset) def fetch(self, batch_indices): @@ -53,7 +53,8 @@ class _IterableDatasetFetcher(_DatasetFetcher): class _MapDatasetFetcher(_DatasetFetcher): def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last): - super(_MapDatasetFetcher, self).__init__(dataset, auto_collate_batch, collate_fn, drop_last) + super(_MapDatasetFetcher, self).__init__(dataset, auto_collate_batch, + collate_fn, drop_last) def fetch(self, batch_indices): if self.auto_collate_batch: diff --git a/python/paddle/fluid/incubate/fleet/tests/cluster_train.sh b/python/paddle/fluid/incubate/fleet/tests/cluster_train.sh index 1df6b0618de..cac2f7234bd 100644 --- a/python/paddle/fluid/incubate/fleet/tests/cluster_train.sh +++ b/python/paddle/fluid/incubate/fleet/tests/cluster_train.sh @@ -1,5 +1,19 @@ #!/bin/bash +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # start pserver0 python fleet_deep_ctr.py \ --role pserver \ diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_squared_mat_sub_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_squared_mat_sub_fuse_pass.py index 95cff4de6f6..69a9ae3c0ad 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_squared_mat_sub_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_squared_mat_sub_fuse_pass.py @@ -40,9 +40,11 @@ class SquaredMatSubFusePassTest(InferencePassTest): matmul_ab_square = paddle.square(matmul_ab) matmul_square_ab = paddle.matmul(data_a_square, data_b_square) - scale = paddle.fluid.layers.fill_constant(shape=[1], value=0.5, dtype='float32') + scale = paddle.fluid.layers.fill_constant( + shape=[1], value=0.5, dtype='float32') - sub_val = paddle.fluid.layers.elementwise_sub(matmul_ab_square, matmul_square_ab) + sub_val = paddle.fluid.layers.elementwise_sub(matmul_ab_square, + matmul_square_ab) squared_mat_sub_out = fluid.layers.elementwise_mul(sub_val, scale) self.feeds = { diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_matmul.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_matmul.py index 94434f40434..080d1ccc905 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_matmul.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_matmul.py @@ -25,19 +25,16 @@ class TensorRTMatMulDims2Test(InferencePassTest): def setUp(self): self.set_params() with fluid.program_guard(self.main_program, self.startup_program): - data = fluid.data( - name="data", shape=[24, 24], dtype="float32") + data = fluid.data(name="data", shape=[24, 24], dtype="float32") matmul_out = fluid.layers.matmul( x=data, y=data, - transpose_x = self.transpose_x, - transpose_y = self.transpose_y, - alpha = self.alpha) + transpose_x=self.transpose_x, + transpose_y=self.transpose_y, + alpha=self.alpha) out = fluid.layers.batch_norm(matmul_out, is_test=True) - self.feeds = { - "data": np.ones([24, 24]).astype("float32"), - } + self.feeds = {"data": np.ones([24, 24]).astype("float32"), } self.enable_trt = True self.trt_parameters = TensorRTMatMulDims2Test.TensorRTParam( 1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False) @@ -65,14 +62,12 @@ class TensorRTMatMulTest(InferencePassTest): matmul_out = fluid.layers.matmul( x=data, y=data, - transpose_x = self.transpose_x, - transpose_y = self.transpose_y, - alpha = self.alpha) + transpose_x=self.transpose_x, + transpose_y=self.transpose_y, + alpha=self.alpha) out = fluid.layers.batch_norm(matmul_out, is_test=True) - self.feeds = { - "data": np.ones([1, 6, 24, 24]).astype("float32"), - } + self.feeds = {"data": np.ones([1, 6, 24, 24]).astype("float32"), } self.enable_trt = True self.trt_parameters = TensorRTMatMulTest.TensorRTParam( 1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False) diff --git a/python/paddle/fluid/tests/unittests/parallel_test.sh b/python/paddle/fluid/tests/unittests/parallel_test.sh index 9da4f035345..551b7cdb7a4 100644 --- a/python/paddle/fluid/tests/unittests/parallel_test.sh +++ b/python/paddle/fluid/tests/unittests/parallel_test.sh @@ -1,4 +1,19 @@ #!/bin/bash + +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + unset https_proxy http_proxy export FLAGS_rpc_disable_reuse_port=1 diff --git a/python/paddle/fluid/tests/unittests/test_bce_loss.py b/python/paddle/fluid/tests/unittests/test_bce_loss.py index 4b39436842b..ea1a22780f0 100644 --- a/python/paddle/fluid/tests/unittests/test_bce_loss.py +++ b/python/paddle/fluid/tests/unittests/test_bce_loss.py @@ -27,8 +27,10 @@ def test_static_layer(place, prog = paddle.static.Program() startup_prog = paddle.static.Program() with paddle.static.program_guard(prog, startup_prog): - input = paddle.fluid.data(name='input', shape=input_np.shape, dtype='float64') - label = paddle.fluid.data(name='label', shape=label_np.shape, dtype='float64') + input = paddle.fluid.data( + name='input', shape=input_np.shape, dtype='float64') + label = paddle.fluid.data( + name='label', shape=label_np.shape, dtype='float64') if weight_np is not None: weight = paddle.fluid.data( name='weight', shape=weight_np.shape, dtype='float64') @@ -58,8 +60,10 @@ def test_static_functional(place, prog = paddle.static.Program() startup_prog = paddle.static.Program() with paddle.static.program_guard(prog, startup_prog): - input = paddle.fluid.data(name='input', shape=input_np.shape, dtype='float64') - label = paddle.fluid.data(name='label', shape=label_np.shape, dtype='float64') + input = paddle.fluid.data( + name='input', shape=input_np.shape, dtype='float64') + label = paddle.fluid.data( + name='label', shape=label_np.shape, dtype='float64') if weight_np is not None: weight = paddle.fluid.data( name='weight', shape=weight_np.shape, dtype='float64') diff --git a/python/paddle/fluid/tests/unittests/test_bce_with_logits_loss.py b/python/paddle/fluid/tests/unittests/test_bce_with_logits_loss.py index a6175aa471d..153b8fd3e7f 100644 --- a/python/paddle/fluid/tests/unittests/test_bce_with_logits_loss.py +++ b/python/paddle/fluid/tests/unittests/test_bce_with_logits_loss.py @@ -48,8 +48,10 @@ def test_static(place, prog = paddle.static.Program() startup_prog = paddle.static.Program() with paddle.static.program_guard(prog, startup_prog): - logit = paddle.fluid.data(name='logit', shape=logit_np.shape, dtype='float64') - label = paddle.fluid.data(name='label', shape=label_np.shape, dtype='float64') + logit = paddle.fluid.data( + name='logit', shape=logit_np.shape, dtype='float64') + label = paddle.fluid.data( + name='label', shape=label_np.shape, dtype='float64') feed_dict = {"logit": logit_np, "label": label_np} pos_weight = None diff --git a/python/paddle/fluid/tests/unittests/test_c_comm_init_op.sh b/python/paddle/fluid/tests/unittests/test_c_comm_init_op.sh index a9d450e223f..aba95a68ab7 100644 --- a/python/paddle/fluid/tests/unittests/test_c_comm_init_op.sh +++ b/python/paddle/fluid/tests/unittests/test_c_comm_init_op.sh @@ -1,4 +1,19 @@ #!/bin/bash + +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + set -e # use default values # FIXME: random fails on Unknown command lines -c (or -m). diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_ps10.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_ps10.py index 16584ee5008..a82866a797d 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_ps10.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_ps10.py @@ -23,7 +23,6 @@ import os paddle.enable_static() - # For Net base_lr = 0.2 emb_lr = base_lr * 3 diff --git a/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py b/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py index aa85eb3df35..28803f5ac62 100644 --- a/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py +++ b/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py @@ -170,7 +170,8 @@ class TestFlatten2OpError(unittest.TestCase): x2 = np.arange(image_shape[0] * image_shape[1] * image_shape[2] * image_shape[3]).reshape(image_shape) / 100. x2 = x2.astype('float16') - x2_var = paddle.fluid.data(name='x2', shape=[3, 2, 4, 5], dtype='float16') + x2_var = paddle.fluid.data( + name='x2', shape=[3, 2, 4, 5], dtype='float16') paddle.flatten(x2_var) self.assertRaises(TypeError, test_type) diff --git a/python/paddle/fluid/tests/unittests/test_l1_loss.py b/python/paddle/fluid/tests/unittests/test_l1_loss.py index fba16959901..c35188623b4 100644 --- a/python/paddle/fluid/tests/unittests/test_l1_loss.py +++ b/python/paddle/fluid/tests/unittests/test_l1_loss.py @@ -44,8 +44,10 @@ class TestFunctionalL1Loss(unittest.TestCase): self.assertTrue(dy_result.shape, [10, 10, 5]) def run_static(self, use_gpu=False): - input = paddle.fluid.data(name='input', shape=[10, 10, 5], dtype='float32') - label = paddle.fluid.data(name='label', shape=[10, 10, 5], dtype='float32') + input = paddle.fluid.data( + name='input', shape=[10, 10, 5], dtype='float32') + label = paddle.fluid.data( + name='label', shape=[10, 10, 5], dtype='float32') result0 = paddle.nn.functional.l1_loss(input, label) result1 = paddle.nn.functional.l1_loss(input, label, reduction='sum') result2 = paddle.nn.functional.l1_loss(input, label, reduction='none') @@ -127,8 +129,10 @@ class TestClassL1Loss(unittest.TestCase): self.assertTrue(dy_result.shape, [10, 10, 5]) def run_static(self, use_gpu=False): - input = paddle.fluid.data(name='input', shape=[10, 10, 5], dtype='float32') - label = paddle.fluid.data(name='label', shape=[10, 10, 5], dtype='float32') + input = paddle.fluid.data( + name='input', shape=[10, 10, 5], dtype='float32') + label = paddle.fluid.data( + name='label', shape=[10, 10, 5], dtype='float32') l1_loss = paddle.nn.loss.L1Loss() result0 = l1_loss(input, label) l1_loss = paddle.nn.loss.L1Loss(reduction='sum') diff --git a/python/paddle/fluid/tests/unittests/test_listen_and_serv.sh b/python/paddle/fluid/tests/unittests/test_listen_and_serv.sh index bee230fba5a..d9d64e4dfa6 100644 --- a/python/paddle/fluid/tests/unittests/test_listen_and_serv.sh +++ b/python/paddle/fluid/tests/unittests/test_listen_and_serv.sh @@ -1,4 +1,19 @@ #!/bin/bash + +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + unset https_proxy http_proxy nohup python -u test_listen_and_serv_op.py > test_listen_and_serv_op.log 2>&1 & diff --git a/python/paddle/fluid/tests/unittests/test_mse_loss.py b/python/paddle/fluid/tests/unittests/test_mse_loss.py index bc5d35d3254..89eef6ca242 100644 --- a/python/paddle/fluid/tests/unittests/test_mse_loss.py +++ b/python/paddle/fluid/tests/unittests/test_mse_loss.py @@ -191,8 +191,10 @@ class TestNNFunctionalMseLoss(unittest.TestCase): place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda( ) else paddle.CPUPlace() with paddle.static.program_guard(prog, startup_prog): - input = paddle.fluid.data(name='input', shape=dim, dtype='float32') - target = paddle.fluid.data(name='target', shape=dim, dtype='float32') + input = paddle.fluid.data( + name='input', shape=dim, dtype='float32') + target = paddle.fluid.data( + name='target', shape=dim, dtype='float32') mse_loss = paddle.nn.functional.mse_loss(input, target, 'mean') exe = paddle.static.Executor(place) @@ -225,8 +227,10 @@ class TestNNFunctionalMseLoss(unittest.TestCase): place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda( ) else paddle.CPUPlace() with paddle.static.program_guard(prog, startup_prog): - input = paddle.fluid.data(name='input', shape=dim, dtype='float32') - target = paddle.fluid.data(name='target', shape=dim, dtype='float32') + input = paddle.fluid.data( + name='input', shape=dim, dtype='float32') + target = paddle.fluid.data( + name='target', shape=dim, dtype='float32') mse_loss = paddle.nn.functional.mse_loss(input, target, 'sum') exe = paddle.static.Executor(place) @@ -259,8 +263,10 @@ class TestNNFunctionalMseLoss(unittest.TestCase): place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda( ) else paddle.CPUPlace() with paddle.static.program_guard(prog, startup_prog): - input = paddle.fluid.data(name='input', shape=dim, dtype='float32') - target = paddle.fluid.data(name='target', shape=dim, dtype='float32') + input = paddle.fluid.data( + name='input', shape=dim, dtype='float32') + target = paddle.fluid.data( + name='target', shape=dim, dtype='float32') mse_loss = paddle.nn.functional.mse_loss(input, target, 'none') exe = paddle.static.Executor(place) diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_dynamic.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_dynamic.py index 0533a0d09fa..3bb3e843b1b 100644 --- a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_dynamic.py +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_dynamic.py @@ -160,5 +160,6 @@ class TestDygraphDataLoaderWithBatchedDataset(TestDygraphDataLoader): print("time cost", ret['time'], 'step_list', ret['step']) return ret + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_pixel_shuffle.py b/python/paddle/fluid/tests/unittests/test_pixel_shuffle.py index f75d6e9df54..f1a409c712f 100644 --- a/python/paddle/fluid/tests/unittests/test_pixel_shuffle.py +++ b/python/paddle/fluid/tests/unittests/test_pixel_shuffle.py @@ -97,8 +97,10 @@ class TestPixelShuffleAPI(unittest.TestCase): place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() paddle.enable_static() - x_1 = paddle.fluid.data(name="x", shape=[2, 9, 4, 4], dtype="float64") - x_2 = paddle.fluid.data(name="x2", shape=[2, 4, 4, 9], dtype="float64") + x_1 = paddle.fluid.data( + name="x", shape=[2, 9, 4, 4], dtype="float64") + x_2 = paddle.fluid.data( + name="x2", shape=[2, 4, 4, 9], dtype="float64") out_1 = F.pixel_shuffle(x_1, 3) out_2 = F.pixel_shuffle(x_2, 3, "NHWC") @@ -123,8 +125,10 @@ class TestPixelShuffleAPI(unittest.TestCase): place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() paddle.enable_static() - x_1 = paddle.fluid.data(name="x", shape=[2, 9, 4, 4], dtype="float64") - x_2 = paddle.fluid.data(name="x2", shape=[2, 4, 4, 9], dtype="float64") + x_1 = paddle.fluid.data( + name="x", shape=[2, 9, 4, 4], dtype="float64") + x_2 = paddle.fluid.data( + name="x2", shape=[2, 4, 4, 9], dtype="float64") # init instance ps_1 = paddle.nn.PixelShuffle(3) ps_2 = paddle.nn.PixelShuffle(3, "NHWC") diff --git a/python/paddle/fluid/tests/unittests/test_prod_op.py b/python/paddle/fluid/tests/unittests/test_prod_op.py index 15fd79542d6..cdfcbb4e4e7 100644 --- a/python/paddle/fluid/tests/unittests/test_prod_op.py +++ b/python/paddle/fluid/tests/unittests/test_prod_op.py @@ -55,7 +55,8 @@ class TestProdOp(unittest.TestCase): self.assertTrue(np.allclose(dy_result.numpy(), expected_result)) def run_static(self, use_gpu=False): - input = paddle.fluid.data(name='input', shape=[10, 10, 5], dtype='float32') + input = paddle.fluid.data( + name='input', shape=[10, 10, 5], dtype='float32') result0 = paddle.prod(input) result1 = paddle.prod(input, axis=1) result2 = paddle.prod(input, axis=-1) @@ -114,7 +115,8 @@ class TestProdOpError(unittest.TestCase): with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()): x = paddle.fluid.data(name='x', shape=[2, 2, 4], dtype='float32') - bool_x = paddle.fluid.data(name='bool_x', shape=[2, 2, 4], dtype='bool') + bool_x = paddle.fluid.data( + name='bool_x', shape=[2, 2, 4], dtype='bool') # The argument x shoule be a Tensor self.assertRaises(TypeError, paddle.prod, [1]) diff --git a/python/paddle/fluid/tests/unittests/test_selu_op.py b/python/paddle/fluid/tests/unittests/test_selu_op.py index 95ae1eecc66..e71adae8d9b 100644 --- a/python/paddle/fluid/tests/unittests/test_selu_op.py +++ b/python/paddle/fluid/tests/unittests/test_selu_op.py @@ -128,15 +128,18 @@ class TestSeluAPI(unittest.TestCase): # The input type must be Variable. self.assertRaises(TypeError, F.selu, 1) # The input dtype must be float16, float32, float64. - x_int32 = paddle.fluid.data(name='x_int32', shape=[12, 10], dtype='int32') + x_int32 = paddle.fluid.data( + name='x_int32', shape=[12, 10], dtype='int32') self.assertRaises(TypeError, F.selu, x_int32) # The scale must be greater than 1.0 - x_fp32 = paddle.fluid.data(name='x_fp32', shape=[12, 10], dtype='float32') + x_fp32 = paddle.fluid.data( + name='x_fp32', shape=[12, 10], dtype='float32') self.assertRaises(ValueError, F.selu, x_fp32, -1.0) # The alpha must be no less than 0 self.assertRaises(ValueError, F.selu, x_fp32, 1.6, -1.0) # support the input dtype is float16 - x_fp16 = paddle.fluid.data(name='x_fp16', shape=[12, 10], dtype='float16') + x_fp16 = paddle.fluid.data( + name='x_fp16', shape=[12, 10], dtype='float16') F.selu(x_fp16) diff --git a/python/paddle/fluid/tests/unittests/test_sigmoid_focal_loss.py b/python/paddle/fluid/tests/unittests/test_sigmoid_focal_loss.py index 85f9501e53f..2ef04d9cbfa 100644 --- a/python/paddle/fluid/tests/unittests/test_sigmoid_focal_loss.py +++ b/python/paddle/fluid/tests/unittests/test_sigmoid_focal_loss.py @@ -42,8 +42,10 @@ def test_static(place, prog = paddle.static.Program() startup_prog = paddle.static.Program() with paddle.static.program_guard(prog, startup_prog): - logit = paddle.fluid.data(name='logit', shape=logit_np.shape, dtype='float64') - label = paddle.fluid.data(name='label', shape=label_np.shape, dtype='float64') + logit = paddle.fluid.data( + name='logit', shape=logit_np.shape, dtype='float64') + label = paddle.fluid.data( + name='label', shape=label_np.shape, dtype='float64') feed_dict = {"logit": logit_np, "label": label_np} normalizer = None diff --git a/python/paddle/fluid/tests/unittests/test_transpose_op.py b/python/paddle/fluid/tests/unittests/test_transpose_op.py index f72df8cbe46..59b4afdf8b0 100644 --- a/python/paddle/fluid/tests/unittests/test_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_transpose_op.py @@ -23,6 +23,7 @@ from paddle.fluid import Program, program_guard paddle.enable_static() + class TestTransposeOp(OpTest): def setUp(self): self.init_op_type() @@ -151,6 +152,7 @@ class TestTransposeOpError(unittest.TestCase): self.assertRaises(ValueError, test_each_elem_value_check) + class TestTransposeApi(unittest.TestCase): def test_static_out(self): paddle.enable_static() @@ -161,10 +163,11 @@ class TestTransposeApi(unittest.TestCase): place = paddle.CPUPlace() exe = paddle.static.Executor(place) x_np = np.random.random([2, 3, 4]).astype("float32") - result1, result2 = exe.run(feed={"x": x_np}, fetch_list=[x_trans1, x_trans2]) + result1, result2 = exe.run(feed={"x": x_np}, + fetch_list=[x_trans1, x_trans2]) expected_result1 = np.transpose(x_np, [1, 0, 2]) expected_result2 = np.transpose(x_np, (2, 1, 0)) - + np.testing.assert_array_equal(result1, expected_result1) np.testing.assert_array_equal(result2, expected_result2) @@ -185,6 +188,7 @@ class TestTransposeApi(unittest.TestCase): # dygraph test paddle.enable_static() + class TestTAPI(unittest.TestCase): def test_out(self): with fluid.program_guard(fluid.Program()): diff --git a/scripts/paddle b/scripts/paddle new file mode 100644 index 00000000000..5f256ccf157 --- /dev/null +++ b/scripts/paddle @@ -0,0 +1,169 @@ +#!/bin/bash + +function version(){ + echo "PaddlePaddle , compiled with" + echo " with_avx: ON" + echo " with_gpu: OFF" + echo " with_mkl: ON" + echo " with_mkldnn: " + echo " with_python: ON" +} + +function ver2num() { + set -e + # convert version to number. + if [ -z "$1" ]; then # empty argument + printf "%03d%03d%03d%03d%03d" 0 + else + local VERN=$(echo $1 | sed 's#v##g' | sed 's#\.# #g' \ + | sed 's#a# 0 #g' | sed 's#b# 1 #g' | sed 's#rc# 2 #g') + if [ `echo $VERN | wc -w` -eq 3 ] ; then + printf "%03d%03d%03d%03d%03d" $VERN 999 999 + else + printf "%03d%03d%03d%03d%03d" $VERN + fi + fi + set +e +} + +function cpu_config() { + # auto set KMP_AFFINITY and OMP_DYNAMIC from Hyper Threading Status + # only when MKL enabled + if [ "ON" == "OFF" ]; then + return 0 + fi + platform="`uname -s`" + ht=0 + if [ $platform == "Linux" ]; then + ht=`lscpu |grep "per core"|awk -F':' '{print $2}'|xargs` + elif [ $platform == "Darwin" ]; then + if [ `sysctl -n hw.physicalcpu` -eq `sysctl -n hw.logicalcpu` ]; then + # HT is OFF + ht=1 + fi + else + return 0 + fi + if [ $ht -eq 1 ]; then # HT is OFF + if [ -z "$KMP_AFFINITY" ]; then + export KMP_AFFINITY="granularity=fine,compact,0,0" + fi + if [ -z "$OMP_DYNAMIC" ]; then + export OMP_DYNAMIC="FALSE" + fi + else # HT is ON + if [ -z "$KMP_AFFINITY" ]; then + export KMP_AFFINITY="granularity=fine,compact,1,0" + fi + if [ -z "$OMP_DYNAMIC" ]; then + export OMP_DYNAMIC="True" + fi + fi +} + +function threads_config() { + # auto set OMP_NUM_THREADS and MKL_NUM_THREADS + # according to trainer_count and total processors + # only when MKL enabled + # auto set OPENBLAS_NUM_THREADS when do not use MKL + platform="`uname -s`" + processors=0 + if [ $platform == "Linux" ]; then + processors=`grep "processor" /proc/cpuinfo|sort -u|wc -l` + elif [ $platform == "Darwin" ]; then + processors=`sysctl -n hw.logicalcpu` + else + return 0 + fi + trainers=`grep -Eo 'trainer_count.[0-9]+' <<< "$@" |grep -Eo '[0-9]+'|xargs` + if [ -z $trainers ]; then + trainers=1 + fi + threads=$((processors / trainers)) + if [ $threads -eq 0 ]; then + threads=1 + fi + if [ "ON" == "ON" ]; then + if [ -z "$OMP_NUM_THREADS" ]; then + export OMP_NUM_THREADS=$threads + fi + if [ -z "$MKL_NUM_THREADS" ]; then + export MKL_NUM_THREADS=$threads + fi + else + if [ -z "$OPENBLAS_NUM_THREADS" ]; then + export OPENBLAS_NUM_THREADS=$threads + fi + if [ $threads -gt 1 ] && [ -z "$OPENBLAS_MAIN_FREE" ]; then + export OPENBLAS_MAIN_FREE=1 + fi + fi + +} + +PADDLE_CONF_HOME="$HOME/.config/paddle" +mkdir -p ${PADDLE_CONF_HOME} + +if [ -z "${PADDLE_NO_STAT+x}" ]; then + SERVER_VER=`curl -m 5 -X POST --data content="{ \"version\": \"\" }"\ + -b ${PADDLE_CONF_HOME}/paddle.cookie \ + -c ${PADDLE_CONF_HOME}/paddle.cookie \ + http://api.paddlepaddle.org/version 2>/dev/null` + if [ $? -eq 0 ] && [ "$(ver2num )" -lt $(ver2num $SERVER_VER) ]; then + echo "Paddle release a new version ${SERVER_VER}, you can get the install package in http://www.paddlepaddle.org" + fi +fi + +PADDLE_BIN_PATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" + +if [ ! -z "${DEBUGGER}" ]; then + echo "Using debug command ${DEBUGGER}" +fi + +CUDNN_LIB_PATH="" + +if [ ! -z "${CUDNN_LIB_PATH}" ]; then + export LD_LIBRARY_PATH=${CUDNN_LIB_PATH}:${LD_LIBRARY_PATH} +fi + +export PYTHONPATH=${PWD}:${PYTHONPATH} + + +# Check python lib installed or not. +pip --help > /dev/null +if [ $? -ne 0 ]; then + echo "pip should be installed to run paddle." + exit 1 +fi + +if [ "OFF" == "ON" ]; then + PADDLE_NAME="paddlepaddle-gpu" +else + PADDLE_NAME="paddlepaddle" +fi + +INSTALLED_VERSION=`pip freeze 2>/dev/null | grep "^${PADDLE_NAME}==" | sed 's/.*==//g'` + +if [ -z "${INSTALLED_VERSION}" ]; then + INSTALLED_VERSION="0.0.0" # not installed +fi +cat <#RUN apt-get update \ diff --git a/tools/document_preview.sh b/tools/document_preview.sh index 10f486f8fd4..83c758d0aa8 100755 --- a/tools/document_preview.sh +++ b/tools/document_preview.sh @@ -1,4 +1,19 @@ #!/bin/bash + +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + PADDLE_ROOT=/home mkdir ${PADDLE_ROOT} cd ${PADDLE_ROOT} diff --git a/tools/get_cpu_info.sh b/tools/get_cpu_info.sh index 81eb19dc066..bce338a8619 100755 --- a/tools/get_cpu_info.sh +++ b/tools/get_cpu_info.sh @@ -1,5 +1,19 @@ #!/bin/bash +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + if [ "`uname -s`" != "Linux" ]; then echo "Current scenario only support in Linux yet!" exit 0 -- GitLab