未验证 提交 94736d60 编写于 作者: S seemingwang 提交者: GitHub

graph engine (#31226)

* graph engine demo

* upload unsaved changes

* fix dependency error

* fix shard_num problem

* py client

* remove lock and graph-type

* add load direct graph

* add load direct graph

* add load direct graph

* batch random_sample

* batch_sample_k

* fix num_nodes size

* batch brpc

* batch brpc

* add test

* add test

* add load_nodes; change add_node function

* change sample return type to pair

* resolve conflict

* resolved conflict

* resolved conflict

* separate server and client

* merge pair type

* fix

* resolved conflict

* fixed segment fault; high-level VLOG for load edges and load nodes

* random_sample return 0

* rm useless loop

* test:load edge

* fix ret -1

* test: rm sample

* rm sample

* random_sample return future

* random_sample return int

* test fake node

* fixed here

* memory leak

* remove test code

* fix return problem

* add common_graph_table

* random sample node &test & change data-structure from linkedList to vector

* add common_graph_table

* sample with srand

* add node_types

* optimize nodes sample

* recover test

* random sample

* destruct weighted sampler

* GraphEdgeBlob

* WeightedGraphEdgeBlob to GraphEdgeBlob

* WeightedGraphEdgeBlob to GraphEdgeBlob

* pybind sample nodes api

* pull nodes with step

* fixed pull_graph_list bug; add test for pull_graph_list by step

* add graph table;name

* add graph table;name

* add pybind

* add pybind

* add FeatureNode

* add FeatureNode

* add FeatureNode Serialize

* add FeatureNode Serialize

* get_feat_node

* avoid local rpc

* fix get_node_feat

* fix get_node_feat

* remove log

* get_node_feat return  py:bytes

* merge develop with graph_engine

* fix threadpool.h head

* fix

* fix typo

* resolve conflict

* fix conflict

* recover lost content

* fix pybind of FeatureNode

* recover cmake

* recover tools

* resolve conflict

* resolve linking problem

* code style

* change test_server port

* fix code problems

* remove shard_num config

* remove redundent threads

* optimize start server

* remove logs

* fix code problems by reviewers' suggestions
Co-authored-by: NHuang Zhengjie <270018958@qq.com>
Co-authored-by: NWeiyue Su <weiyue.su@gmail.com>
Co-authored-by: Nsuweiyue <suweiyue@baidu.com>
Co-authored-by: Nluobin06 <luobin06@baidu.com>
Co-authored-by: Nliweibin02 <liweibin02@baidu.com>
上级 9e06a641
......@@ -24,11 +24,12 @@ set_source_files_properties(heter_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUT
set_source_files_properties(client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(ps_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(graph_brpc_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(graph_brpc_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(brpc_utils SRCS brpc_utils.cc DEPS tensor device_context ${COMMON_DEPS} ${RPC_DEPS})
cc_library(downpour_server SRCS brpc_ps_server.cc DEPS boost eigen3 table brpc_utils ${RPC_DEPS})
cc_library(downpour_client SRCS brpc_ps_client.cc DEPS boost eigen3 table brpc_utils ${RPC_DEPS})
cc_library(downpour_server SRCS graph_brpc_server.cc brpc_ps_server.cc DEPS boost eigen3 table brpc_utils simple_threadpool ${RPC_DEPS})
cc_library(downpour_client SRCS graph_brpc_client.cc brpc_ps_client.cc DEPS boost eigen3 table brpc_utils simple_threadpool ${RPC_DEPS})
cc_library(client SRCS ps_client.cc DEPS downpour_client boost ${RPC_DEPS})
cc_library(server SRCS server.cc DEPS downpour_server boost ${RPC_DEPS})
......@@ -38,3 +39,6 @@ cc_library(ps_service SRCS service.cc DEPS communicator client server boost ${RP
cc_library(heter_server SRCS heter_server.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS})
cc_library(heter_client SRCS heter_client.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS})
set_source_files_properties(graph_py_service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(graph_py_service SRCS graph_py_service.cc DEPS ps_service)
......@@ -170,9 +170,22 @@ class BrpcPsClient : public PSClient {
virtual int32_t recv_and_save_table(const uint64_t table_id,
const std::string &path);
private:
protected:
virtual size_t get_server_nums() { return _server_channels.size(); }
inline brpc::Channel *get_sparse_channel(size_t server_id) {
return _server_channels[server_id][0].get();
}
inline brpc::Channel *get_dense_channel(size_t server_id) {
return _server_channels[server_id][1].get();
}
inline brpc::Channel *get_cmd_channel(size_t server_id) {
return _server_channels[server_id][2].get();
}
virtual int32_t initialize() override;
private:
// virtual int32_t initialize() override;
inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total,
uint32_t shard_num) {
return dense_dim_total / shard_num + 1;
......@@ -184,16 +197,6 @@ class BrpcPsClient : public PSClient {
std::future<int32_t> send_save_cmd(uint32_t table_id, int cmd_id,
const std::vector<std::string> &param);
inline brpc::Channel *get_sparse_channel(size_t server_id) {
return _server_channels[server_id][0].get();
}
inline brpc::Channel *get_dense_channel(size_t server_id) {
return _server_channels[server_id][1].get();
}
inline brpc::Channel *get_cmd_channel(size_t server_id) {
return _server_channels[server_id][2].get();
}
bool _running = false;
bool _flushing = false;
std::atomic<uint32_t> _async_call_num; //异步请求计数
......@@ -220,8 +223,6 @@ class BrpcPsClient : public PSClient {
size_t num,
void *done) override;
virtual size_t get_server_nums() { return _server_channels.size(); }
private:
int32_t start_client_service();
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/service/graph_brpc_client.h"
#include <algorithm>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "Eigen/Dense"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace distributed {
void GraphPsService_Stub::service(
::google::protobuf::RpcController *controller,
const ::paddle::distributed::PsRequestMessage *request,
::paddle::distributed::PsResponseMessage *response,
::google::protobuf::Closure *done) {
if (graph_service != NULL && local_channel == channel()) {
// VLOG(0)<<"use local";
task_pool->enqueue([this, controller, request, response, done]() -> int {
this->graph_service->service(controller, request, response, done);
return 0;
});
} else {
// VLOG(0)<<"use server";
PsService_Stub::service(controller, request, response, done);
}
}
int GraphBrpcClient::get_server_index_by_id(uint64_t id) {
int shard_num = get_shard_num();
int shard_per_server = shard_num % server_size == 0
? shard_num / server_size
: shard_num / server_size + 1;
return id % shard_num / shard_per_server;
}
std::future<int32_t> GraphBrpcClient::get_node_feat(
const uint32_t &table_id, const std::vector<uint64_t> &node_ids,
const std::vector<std::string> &feature_names,
std::vector<std::vector<std::string>> &res) {
std::vector<int> request2server;
std::vector<int> server2request(server_size, -1);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
if (server2request[server_index] == -1) {
server2request[server_index] = request2server.size();
request2server.push_back(server_index);
}
}
size_t request_call_num = request2server.size();
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
int request_idx = server2request[server_index];
node_id_buckets[request_idx].push_back(node_ids[query_idx]);
query_idx_buckets[request_idx].push_back(query_idx);
}
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num,
[&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
int fail_num = 0;
for (int request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx,
PS_GRAPH_SAMPLE_NEIGHBOORS) != 0) {
++fail_num;
} else {
auto &res_io_buffer =
closure->cntl(request_idx)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
std::unique_ptr<char[]> buffer_wrapper(new char[bytes_size]);
char *buffer = buffer_wrapper.get();
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
for (size_t feat_idx = 0; feat_idx < feature_names.size();
++feat_idx) {
for (size_t node_idx = 0;
node_idx < query_idx_buckets.at(request_idx).size();
++node_idx) {
int query_idx = query_idx_buckets.at(request_idx).at(node_idx);
size_t feat_len = *(size_t *)(buffer);
buffer += sizeof(size_t);
auto feature = std::string(buffer, feat_len);
res[feat_idx][query_idx] = feature;
buffer += feat_len;
}
}
}
if (fail_num == request_call_num) {
ret = -1;
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (int request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_GET_NODE_FEAT);
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = node_id_buckets[request_idx].size();
closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(uint64_t) * node_num);
std::string joint_feature_name =
paddle::string::join_strings(feature_names, '\t');
closure->request(request_idx)
->add_params(joint_feature_name.c_str(), joint_feature_name.size());
PsService_Stub rpc_stub(get_cmd_channel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure);
}
return fut;
}
// char* &buffer,int &actual_size
std::future<int32_t> GraphBrpcClient::batch_sample_neighboors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
std::vector<std::vector<std::pair<uint64_t, float>>> &res) {
std::vector<int> request2server;
std::vector<int> server2request(server_size, -1);
res.clear();
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
if (server2request[server_index] == -1) {
server2request[server_index] = request2server.size();
request2server.push_back(server_index);
}
res.push_back(std::vector<std::pair<uint64_t, float>>());
}
size_t request_call_num = request2server.size();
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
int request_idx = server2request[server_index];
node_id_buckets[request_idx].push_back(node_ids[query_idx]);
query_idx_buckets[request_idx].push_back(query_idx);
}
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num,
[&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
int fail_num = 0;
for (int request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx,
PS_GRAPH_SAMPLE_NEIGHBOORS) != 0) {
++fail_num;
} else {
auto &res_io_buffer =
closure->cntl(request_idx)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
std::unique_ptr<char[]> buffer_wrapper(new char[bytes_size]);
char *buffer = buffer_wrapper.get();
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
size_t node_num = *(size_t *)buffer;
int *actual_sizes = (int *)(buffer + sizeof(size_t));
char *node_buffer =
buffer + sizeof(size_t) + sizeof(int) * node_num;
int offset = 0;
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
int query_idx = query_idx_buckets.at(request_idx).at(node_idx);
int actual_size = actual_sizes[node_idx];
int start = 0;
while (start < actual_size) {
res[query_idx].push_back(
{*(uint64_t *)(node_buffer + offset + start),
*(float *)(node_buffer + offset + start +
GraphNode::id_size)});
start += GraphNode::id_size + GraphNode::weight_size;
}
offset += actual_size;
}
}
if (fail_num == request_call_num) {
ret = -1;
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (int request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBOORS);
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = node_id_buckets[request_idx].size();
closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(uint64_t) * node_num);
closure->request(request_idx)
->add_params((char *)&sample_size, sizeof(int));
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure);
}
return fut;
}
std::future<int32_t> GraphBrpcClient::random_sample_nodes(
uint32_t table_id, int server_index, int sample_size,
std::vector<uint64_t> &ids) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
if (closure->check_response(0, PS_GRAPH_SAMPLE_NODES) != 0) {
ret = -1;
} else {
auto &res_io_buffer = closure->cntl(0)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
char buffer[bytes_size];
auto size = io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
int index = 0;
while (index < bytes_size) {
ids.push_back(*(uint64_t *)(buffer + index));
index += GraphNode::id_size;
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
;
closure->request(0)->set_cmd_id(PS_GRAPH_SAMPLE_NODES);
closure->request(0)->set_table_id(table_id);
closure->request(0)->set_client_id(_client_id);
closure->request(0)->add_params((char *)&sample_size, sizeof(int));
;
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(get_cmd_channel(server_index));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0),
closure);
return fut;
}
std::future<int32_t> GraphBrpcClient::pull_graph_list(
uint32_t table_id, int server_index, int start, int size, int step,
std::vector<FeatureNode> &res) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
if (closure->check_response(0, PS_PULL_GRAPH_LIST) != 0) {
ret = -1;
} else {
auto &res_io_buffer = closure->cntl(0)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
char buffer[bytes_size];
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
int index = 0;
while (index < bytes_size) {
FeatureNode node;
node.recover_from_buffer(buffer + index);
index += node.get_size(false);
res.push_back(node);
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
closure->request(0)->set_cmd_id(PS_PULL_GRAPH_LIST);
closure->request(0)->set_table_id(table_id);
closure->request(0)->set_client_id(_client_id);
closure->request(0)->add_params((char *)&start, sizeof(int));
closure->request(0)->add_params((char *)&size, sizeof(int));
closure->request(0)->add_params((char *)&step, sizeof(int));
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(get_cmd_channel(server_index));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0),
closure);
return fut;
}
int32_t GraphBrpcClient::initialize() {
// set_shard_num(_config.shard_num());
BrpcPsClient::initialize();
server_size = get_server_nums();
graph_service = NULL;
local_channel = NULL;
return 0;
}
}
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <memory>
#include <string>
#include <vector>
#include <utility>
#include "ThreadPool.h"
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/service/graph_brpc_server.h"
#include "paddle/fluid/distributed/service/ps_client.h"
#include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
namespace paddle {
namespace distributed {
class GraphPsService_Stub : public PsService_Stub {
public:
GraphPsService_Stub(::google::protobuf::RpcChannel* channel,
::google::protobuf::RpcChannel* local_channel = NULL,
GraphBrpcService* service = NULL, int thread_num = 1)
: PsService_Stub(channel) {
this->local_channel = local_channel;
this->graph_service = service;
task_pool.reset(new ::ThreadPool(thread_num));
}
virtual ~GraphPsService_Stub() {}
// implements PsService ------------------------------------------
GraphBrpcService* graph_service;
std::shared_ptr<::ThreadPool> task_pool;
::google::protobuf::RpcChannel* local_channel;
GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(GraphPsService_Stub);
void service(::google::protobuf::RpcController* controller,
const ::paddle::distributed::PsRequestMessage* request,
::paddle::distributed::PsResponseMessage* response,
::google::protobuf::Closure* done);
};
class GraphBrpcClient : public BrpcPsClient {
public:
GraphBrpcClient() {}
virtual ~GraphBrpcClient() {}
// given a batch of nodes, sample graph_neighboors for each of them
virtual std::future<int32_t> batch_sample_neighboors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
std::vector<std::vector<std::pair<uint64_t, float>>>& res);
virtual std::future<int32_t> pull_graph_list(uint32_t table_id,
int server_index, int start,
int size, int step,
std::vector<FeatureNode>& res);
virtual std::future<int32_t> random_sample_nodes(uint32_t table_id,
int server_index,
int sample_size,
std::vector<uint64_t>& ids);
virtual std::future<int32_t> get_node_feat(
const uint32_t& table_id, const std::vector<uint64_t>& node_ids,
const std::vector<std::string>& feature_names,
std::vector<std::vector<std::string>>& res);
virtual int32_t initialize();
int get_shard_num() { return shard_num; }
void set_shard_num(int shard_num) { this->shard_num = shard_num; }
int get_server_index_by_id(uint64_t id);
void set_local_channel(int index) {
this->local_channel = get_cmd_channel(index);
}
void set_local_graph_service(GraphBrpcService* graph_service) {
this->graph_service = graph_service;
}
GraphPsService_Stub getServiceStub(::google::protobuf::RpcChannel* channel,
int thread_num = 1) {
return GraphPsService_Stub(channel, local_channel, graph_service,
thread_num);
}
private:
int shard_num;
size_t server_size;
::google::protobuf::RpcChannel* local_channel;
GraphBrpcService* graph_service;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/service/graph_brpc_server.h"
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include <thread> // NOLINT
#include "butil/endpoint.h"
#include "iomanip"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace distributed {
int32_t GraphBrpcServer::initialize() {
auto &service_config = _config.downpour_server_param().service_param();
if (!service_config.has_service_class()) {
LOG(ERROR) << "miss service_class in ServerServiceParameter";
return -1;
}
auto *service =
CREATE_PSCORE_CLASS(PsBaseService, service_config.service_class());
if (service == NULL) {
LOG(ERROR) << "service is unregistered, service_name:"
<< service_config.service_class();
return -1;
}
_service.reset(service);
if (service->configure(this) != 0 || service->initialize() != 0) {
LOG(ERROR) << "service initialize failed, service_name:"
<< service_config.service_class();
return -1;
}
if (_server.AddService(service, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) {
LOG(ERROR) << "service add to brpc failed, service:"
<< service_config.service_class();
return -1;
}
return 0;
}
uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) {
std::unique_lock<std::mutex> lock(mutex_);
std::string ip_port = ip + ":" + std::to_string(port);
VLOG(3) << "server of rank " << _rank << " starts at " << ip_port;
brpc::ServerOptions options;
int num_threads = std::thread::hardware_concurrency();
auto trainers = _environment->get_trainers();
options.num_threads = trainers > num_threads ? trainers : num_threads;
if (_server.Start(ip_port.c_str(), &options) != 0) {
LOG(ERROR) << "GraphBrpcServer start failed, ip_port=" << ip_port;
return 0;
}
_environment->registe_ps_server(ip, port, _rank);
return 0;
}
int32_t GraphBrpcServer::port() { return _server.listen_address().port; }
int32_t GraphBrpcService::initialize() {
_is_initialize_shard_info = false;
_service_handler_map[PS_STOP_SERVER] = &GraphBrpcService::stop_server;
_service_handler_map[PS_LOAD_ONE_TABLE] = &GraphBrpcService::load_one_table;
_service_handler_map[PS_LOAD_ALL_TABLE] = &GraphBrpcService::load_all_table;
_service_handler_map[PS_PRINT_TABLE_STAT] =
&GraphBrpcService::print_table_stat;
_service_handler_map[PS_BARRIER] = &GraphBrpcService::barrier;
_service_handler_map[PS_START_PROFILER] = &GraphBrpcService::start_profiler;
_service_handler_map[PS_STOP_PROFILER] = &GraphBrpcService::stop_profiler;
_service_handler_map[PS_PULL_GRAPH_LIST] = &GraphBrpcService::pull_graph_list;
_service_handler_map[PS_GRAPH_SAMPLE_NEIGHBOORS] =
&GraphBrpcService::graph_random_sample_neighboors;
_service_handler_map[PS_GRAPH_SAMPLE_NODES] =
&GraphBrpcService::graph_random_sample_nodes;
_service_handler_map[PS_GRAPH_GET_NODE_FEAT] =
&GraphBrpcService::graph_get_node_feat;
// shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info();
return 0;
}
#define CHECK_TABLE_EXIST(table, request, response) \
if (table == NULL) { \
std::string err_msg("table not found with table_id:"); \
err_msg.append(std::to_string(request.table_id())); \
set_response_code(response, -1, err_msg.c_str()); \
return -1; \
}
int32_t GraphBrpcService::initialize_shard_info() {
if (!_is_initialize_shard_info) {
std::lock_guard<std::mutex> guard(_initialize_shard_mutex);
if (_is_initialize_shard_info) {
return 0;
}
size_t shard_num = _server->environment()->get_ps_servers().size();
auto &table_map = *(_server->table());
for (auto itr : table_map) {
itr.second->set_shard(_rank, shard_num);
}
_is_initialize_shard_info = true;
}
return 0;
}
void GraphBrpcService::service(google::protobuf::RpcController *cntl_base,
const PsRequestMessage *request,
PsResponseMessage *response,
google::protobuf::Closure *done) {
brpc::ClosureGuard done_guard(done);
std::string log_label("ReceiveCmd-");
if (!request->has_table_id()) {
set_response_code(*response, -1, "PsRequestMessage.tabel_id is required");
return;
}
response->set_err_code(0);
response->set_err_msg("");
auto *table = _server->table(request->table_id());
brpc::Controller *cntl = static_cast<brpc::Controller *>(cntl_base);
auto itr = _service_handler_map.find(request->cmd_id());
if (itr == _service_handler_map.end()) {
std::string err_msg(
"undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
err_msg.append(std::to_string(request->cmd_id()));
set_response_code(*response, -1, err_msg.c_str());
return;
}
serviceFunc handler_func = itr->second;
int service_ret = (this->*handler_func)(table, *request, *response, cntl);
if (service_ret != 0) {
response->set_err_code(service_ret);
response->set_err_msg("server internal error");
}
}
int32_t GraphBrpcService::barrier(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) {
set_response_code(response, -1,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key");
return 0;
}
auto trainer_id = request.client_id();
auto barrier_type = request.params(0);
table->barrier(trainer_id, barrier_type);
return 0;
}
int32_t GraphBrpcService::print_table_stat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
std::pair<int64_t, int64_t> ret = table->print_table_stat();
paddle::framework::BinaryArchive ar;
ar << ret.first << ret.second;
std::string table_info(ar.Buffer(), ar.Length());
response.set_data(table_info);
return 0;
}
int32_t GraphBrpcService::load_one_table(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
response, -1,
"PsRequestMessage.datas is requeired at least 2 for path & load_param");
return -1;
}
if (table->load(request.params(0), request.params(1)) != 0) {
set_response_code(response, -1, "table load failed");
return -1;
}
return 0;
}
int32_t GraphBrpcService::load_all_table(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->table());
for (auto &itr : table_map) {
if (load_one_table(itr.second.get(), request, response, cntl) != 0) {
LOG(ERROR) << "load table[" << itr.first << "] failed";
return -1;
}
}
return 0;
}
int32_t GraphBrpcService::stop_server(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
GraphBrpcServer *p_server = (GraphBrpcServer *)_server;
std::thread t_stop([p_server]() {
p_server->stop();
LOG(INFO) << "Server Stoped";
});
p_server->export_cv()->notify_all();
t_stop.detach();
return 0;
}
int32_t GraphBrpcService::stop_profiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::DisableProfiler(platform::EventSortingKey::kDefault,
string::Sprintf("server_%s_profile", _rank));
return 0;
}
int32_t GraphBrpcService::start_profiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::EnableProfiler(platform::ProfilerState::kCPU);
return 0;
}
int32_t GraphBrpcService::pull_graph_list(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 3) {
set_response_code(response, -1,
"pull_graph_list request requires at least 3 arguments");
return 0;
}
int start = *(int *)(request.params(0).c_str());
int size = *(int *)(request.params(1).c_str());
int step = *(int *)(request.params(2).c_str());
std::unique_ptr<char[]> buffer;
int actual_size;
table->pull_graph_list(start, size, buffer, actual_size, false, step);
cntl->response_attachment().append(buffer.get(), actual_size);
return 0;
}
int32_t GraphBrpcService::graph_random_sample_neighboors(
Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
response, -1,
"graph_random_sample request requires at least 2 arguments");
return 0;
}
size_t node_num = request.params(0).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
int sample_size = *(uint64_t *)(request.params(1).c_str());
std::vector<std::unique_ptr<char[]>> buffers(node_num);
std::vector<int> actual_sizes(node_num, 0);
table->random_sample_neighboors(node_data, sample_size, buffers,
actual_sizes);
cntl->response_attachment().append(&node_num, sizeof(size_t));
cntl->response_attachment().append(actual_sizes.data(),
sizeof(int) * node_num);
for (size_t idx = 0; idx < node_num; ++idx) {
cntl->response_attachment().append(buffers[idx].get(), actual_sizes[idx]);
}
return 0;
}
int32_t GraphBrpcService::graph_random_sample_nodes(
Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl) {
size_t size = *(uint64_t *)(request.params(0).c_str());
std::unique_ptr<char[]> buffer;
int actual_size;
if (table->random_sample_nodes(size, buffer, actual_size) == 0) {
cntl->response_attachment().append(buffer.get(), actual_size);
} else
cntl->response_attachment().append(NULL, 0);
return 0;
}
int32_t GraphBrpcService::graph_get_node_feat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
response, -1,
"graph_get_node_feat request requires at least 2 arguments");
return 0;
}
size_t node_num = request.params(0).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num);
std::vector<std::string> feature_names =
paddle::string::split_string<std::string>(request.params(1), "\t");
std::vector<std::vector<std::string>> feature(
feature_names.size(), std::vector<std::string>(node_num));
table->get_node_feat(node_ids, feature_names, feature);
for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
size_t feat_len = feature[feat_idx][node_idx].size();
cntl->response_attachment().append(&feat_len, sizeof(size_t));
cntl->response_attachment().append(feature[feat_idx][node_idx].data(),
feat_len);
}
}
return 0;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include <memory>
#include <vector>
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include "paddle/fluid/distributed/service/server.h"
namespace paddle {
namespace distributed {
class GraphBrpcServer : public PSServer {
public:
GraphBrpcServer() {}
virtual ~GraphBrpcServer() {}
PsBaseService *get_service() { return _service.get(); }
virtual uint64_t start(const std::string &ip, uint32_t port);
virtual int32_t stop() {
std::unique_lock<std::mutex> lock(mutex_);
if (stoped_) return 0;
stoped_ = true;
// cv_.notify_all();
_server.Stop(1000);
_server.Join();
return 0;
}
virtual int32_t port();
std::condition_variable *export_cv() { return &cv_; }
private:
virtual int32_t initialize();
mutable std::mutex mutex_;
std::condition_variable cv_;
bool stoped_ = false;
brpc::Server _server;
std::shared_ptr<PsBaseService> _service;
std::vector<std::shared_ptr<brpc::Channel>> _pserver_channels;
};
class GraphBrpcService;
typedef int32_t (GraphBrpcService::*serviceFunc)(
Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl);
class GraphBrpcService : public PsBaseService {
public:
virtual int32_t initialize() override;
virtual void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done) override;
protected:
std::unordered_map<int32_t, serviceFunc> _service_handler_map;
int32_t initialize_shard_info();
int32_t pull_graph_list(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t graph_random_sample_neighboors(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t graph_random_sample_nodes(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t graph_get_node_feat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t barrier(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t load_one_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t load_all_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t stop_server(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t start_profiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t stop_profiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t print_table_stat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
private:
bool _is_initialize_shard_info;
std::mutex _initialize_shard_mutex;
std::unordered_map<int32_t, serviceHandlerFunc> _msg_handler_map;
std::vector<float> _ori_values;
const int sample_nodes_ranges = 23;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/service/graph_py_service.h"
#include <thread> // NOLINT
#include "butil/endpoint.h"
#include "iomanip"
#include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace distributed {
std::vector<std::string> GraphPyService::split(std::string& str,
const char pattern) {
std::vector<std::string> res;
std::stringstream input(str);
std::string temp;
while (std::getline(input, temp, pattern)) {
res.push_back(temp);
}
return res;
}
void GraphPyService::add_table_feat_conf(std::string table_name,
std::string feat_name,
std::string feat_dtype,
int32_t feat_shape) {
if (this->table_id_map.count(table_name)) {
this->table_feat_conf_table_name.push_back(table_name);
this->table_feat_conf_feat_name.push_back(feat_name);
this->table_feat_conf_feat_dtype.push_back(feat_dtype);
this->table_feat_conf_feat_shape.push_back(feat_shape);
}
}
void GraphPyService::set_up(std::string ips_str, int shard_num,
std::vector<std::string> node_types,
std::vector<std::string> edge_types) {
set_shard_num(shard_num);
set_num_node_types(node_types.size());
for (size_t table_id = 0; table_id < node_types.size(); table_id++) {
this->table_id_map[node_types[table_id]] = this->table_id_map.size();
}
for (size_t table_id = 0; table_id < edge_types.size(); table_id++) {
this->table_id_map[edge_types[table_id]] = this->table_id_map.size();
}
std::istringstream stream(ips_str);
std::string ip;
server_size = 0;
std::vector<std::string> ips_list = split(ips_str, ';');
int index = 0;
for (auto ips : ips_list) {
auto ip_and_port = split(ips, ':');
server_list.push_back(ip_and_port[0]);
port_list.push_back(ip_and_port[1]);
uint32_t port = stoul(ip_and_port[1]);
auto ph_host = paddle::distributed::PSHost(ip_and_port[0], port, index);
host_sign_list.push_back(ph_host.serialize_to_string());
index++;
}
}
void GraphPyClient::start_client() {
std::map<uint64_t, std::vector<paddle::distributed::Region>> dense_regions;
dense_regions.insert(
std::pair<uint64_t, std::vector<paddle::distributed::Region>>(0, {}));
auto regions = dense_regions[0];
::paddle::distributed::PSParameter worker_proto = GetWorkerProto();
paddle::distributed::PaddlePSEnvironment _ps_env;
auto servers_ = host_sign_list.size();
_ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&host_sign_list, servers_);
worker_ptr = std::shared_ptr<paddle::distributed::GraphBrpcClient>(
(paddle::distributed::GraphBrpcClient*)
paddle::distributed::PSClientFactory::create(worker_proto));
worker_ptr->configure(worker_proto, dense_regions, _ps_env, client_id);
worker_ptr->set_shard_num(get_shard_num());
}
void GraphPyServer::start_server(bool block) {
std::string ip = server_list[rank];
uint32_t port = std::stoul(port_list[rank]);
::paddle::distributed::PSParameter server_proto = this->GetServerProto();
auto _ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&this->host_sign_list,
this->host_sign_list.size()); // test
pserver_ptr = std::shared_ptr<paddle::distributed::GraphBrpcServer>(
(paddle::distributed::GraphBrpcServer*)
paddle::distributed::PSServerFactory::create(server_proto));
VLOG(0) << "pserver-ptr created ";
std::vector<framework::ProgramDesc> empty_vec;
framework::ProgramDesc empty_prog;
empty_vec.push_back(empty_prog);
pserver_ptr->configure(server_proto, _ps_env, rank, empty_vec);
pserver_ptr->start(ip, port);
std::condition_variable* cv_ = pserver_ptr->export_cv();
if (block) {
std::mutex mutex_;
std::unique_lock<std::mutex> lock(mutex_);
cv_->wait(lock);
}
}
::paddle::distributed::PSParameter GraphPyServer::GetServerProto() {
// Generate server proto desc
::paddle::distributed::PSParameter server_fleet_desc;
::paddle::distributed::ServerParameter* server_proto =
server_fleet_desc.mutable_server_param();
::paddle::distributed::DownpourServerParameter* downpour_server_proto =
server_proto->mutable_downpour_server_param();
::paddle::distributed::ServerServiceParameter* server_service_proto =
downpour_server_proto->mutable_service_param();
server_service_proto->set_service_class("GraphBrpcService");
server_service_proto->set_server_class("GraphBrpcServer");
server_service_proto->set_client_class("GraphBrpcClient");
server_service_proto->set_start_server_port(0);
server_service_proto->set_server_thread_num(12);
for (auto& tuple : this->table_id_map) {
VLOG(0) << " make a new table " << tuple.second;
::paddle::distributed::TableParameter* sparse_table_proto =
downpour_server_proto->add_downpour_table_param();
std::vector<std::string> feat_name;
std::vector<std::string> feat_dtype;
std::vector<int32_t> feat_shape;
for (size_t i = 0; i < this->table_feat_conf_table_name.size(); i++) {
if (tuple.first == table_feat_conf_table_name[i]) {
feat_name.push_back(table_feat_conf_feat_name[i]);
feat_dtype.push_back(table_feat_conf_feat_dtype[i]);
feat_shape.push_back(table_feat_conf_feat_shape[i]);
}
}
std::string table_type;
if (tuple.second < this->num_node_types) {
table_type = "node";
} else {
table_type = "edge";
}
GetDownpourSparseTableProto(sparse_table_proto, tuple.second, tuple.first,
table_type, feat_name, feat_dtype, feat_shape);
}
return server_fleet_desc;
}
::paddle::distributed::PSParameter GraphPyClient::GetWorkerProto() {
::paddle::distributed::PSParameter worker_fleet_desc;
::paddle::distributed::WorkerParameter* worker_proto =
worker_fleet_desc.mutable_worker_param();
::paddle::distributed::DownpourWorkerParameter* downpour_worker_proto =
worker_proto->mutable_downpour_worker_param();
for (auto& tuple : this->table_id_map) {
VLOG(0) << " make a new table " << tuple.second;
::paddle::distributed::TableParameter* worker_sparse_table_proto =
downpour_worker_proto->add_downpour_table_param();
std::vector<std::string> feat_name;
std::vector<std::string> feat_dtype;
std::vector<int32_t> feat_shape;
for (size_t i = 0; i < this->table_feat_conf_table_name.size(); i++) {
if (tuple.first == table_feat_conf_table_name[i]) {
feat_name.push_back(table_feat_conf_feat_name[i]);
feat_dtype.push_back(table_feat_conf_feat_dtype[i]);
feat_shape.push_back(table_feat_conf_feat_shape[i]);
}
}
std::string table_type;
if (tuple.second < this->num_node_types) {
table_type = "node";
} else {
table_type = "edge";
}
GetDownpourSparseTableProto(worker_sparse_table_proto, tuple.second,
tuple.first, table_type, feat_name, feat_dtype,
feat_shape);
}
::paddle::distributed::ServerParameter* server_proto =
worker_fleet_desc.mutable_server_param();
::paddle::distributed::DownpourServerParameter* downpour_server_proto =
server_proto->mutable_downpour_server_param();
::paddle::distributed::ServerServiceParameter* server_service_proto =
downpour_server_proto->mutable_service_param();
server_service_proto->set_service_class("GraphBrpcService");
server_service_proto->set_server_class("GraphBrpcServer");
server_service_proto->set_client_class("GraphBrpcClient");
server_service_proto->set_start_server_port(0);
server_service_proto->set_server_thread_num(12);
for (auto& tuple : this->table_id_map) {
VLOG(0) << " make a new table " << tuple.second;
::paddle::distributed::TableParameter* sparse_table_proto =
downpour_server_proto->add_downpour_table_param();
std::vector<std::string> feat_name;
std::vector<std::string> feat_dtype;
std::vector<int32_t> feat_shape;
for (size_t i = 0; i < this->table_feat_conf_table_name.size(); i++) {
if (tuple.first == table_feat_conf_table_name[i]) {
feat_name.push_back(table_feat_conf_feat_name[i]);
feat_dtype.push_back(table_feat_conf_feat_dtype[i]);
feat_shape.push_back(table_feat_conf_feat_shape[i]);
}
}
std::string table_type;
if (tuple.second < this->num_node_types) {
table_type = "node";
} else {
table_type = "edge";
}
GetDownpourSparseTableProto(sparse_table_proto, tuple.second, tuple.first,
table_type, feat_name, feat_dtype, feat_shape);
}
return worker_fleet_desc;
}
void GraphPyClient::load_edge_file(std::string name, std::string filepath,
bool reverse) {
// 'e' means load edge
std::string params = "e";
if (reverse) {
// 'e<' means load edges from $2 to $1
params += "<";
} else {
// 'e>' means load edges from $1 to $2
params += ">";
}
if (this->table_id_map.count(name)) {
VLOG(0) << "loadding data with type " << name << " from " << filepath;
uint32_t table_id = this->table_id_map[name];
auto status =
get_ps_client()->load(table_id, std::string(filepath), params);
status.wait();
}
}
void GraphPyClient::load_node_file(std::string name, std::string filepath) {
// 'n' means load nodes and 'node_type' follows
std::string params = "n" + name;
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
auto status =
get_ps_client()->load(table_id, std::string(filepath), params);
status.wait();
}
}
std::vector<std::vector<std::pair<uint64_t, float>>>
GraphPyClient::batch_sample_neighboors(std::string name,
std::vector<uint64_t> node_ids,
int sample_size) {
std::vector<std::vector<std::pair<uint64_t, float>>> v;
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
auto status =
worker_ptr->batch_sample_neighboors(table_id, node_ids, sample_size, v);
status.wait();
}
return v;
}
std::vector<uint64_t> GraphPyClient::random_sample_nodes(std::string name,
int server_index,
int sample_size) {
std::vector<uint64_t> v;
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
auto status =
worker_ptr->random_sample_nodes(table_id, server_index, sample_size, v);
status.wait();
}
return v;
}
// (name, dtype, ndarray)
std::vector<std::vector<std::string>> GraphPyClient::get_node_feat(
std::string node_type, std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names) {
std::vector<std::vector<std::string>> v(
feature_names.size(), std::vector<std::string>(node_ids.size()));
if (this->table_id_map.count(node_type)) {
uint32_t table_id = this->table_id_map[node_type];
auto status =
worker_ptr->get_node_feat(table_id, node_ids, feature_names, v);
status.wait();
}
return v;
}
std::vector<FeatureNode> GraphPyClient::pull_graph_list(std::string name,
int server_index,
int start, int size,
int step) {
std::vector<FeatureNode> res;
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
auto status = worker_ptr->pull_graph_list(table_id, server_index, start,
size, step, res);
status.wait();
}
return res;
}
void GraphPyClient::stop_server() {
VLOG(0) << "going to stop server";
std::unique_lock<std::mutex> lock(mutex_);
if (stoped_) return;
auto status = this->worker_ptr->stop_server();
if (status.get() == 0) stoped_ = true;
}
void GraphPyClient::finalize_worker() { this->worker_ptr->finalize_worker(); }
}
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <unistd.h>
#include <condition_variable> // NOLINT
#include <fstream>
#include <iomanip>
#include <iostream>
#include <sstream>
#include <string>
#include <thread> // NOLINT
#include <unordered_map>
#include <vector>
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/service/env.h"
#include "paddle/fluid/distributed/service/graph_brpc_client.h"
#include "paddle/fluid/distributed/service/graph_brpc_server.h"
#include "paddle/fluid/distributed/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/service/service.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
namespace paddle {
namespace distributed {
class GraphPyService {
protected:
std::vector<std::string> server_list, port_list, host_sign_list;
int server_size, shard_num;
int num_node_types;
std::unordered_map<std::string, uint32_t> table_id_map;
std::vector<std::string> table_feat_conf_table_name;
std::vector<std::string> table_feat_conf_feat_name;
std::vector<std::string> table_feat_conf_feat_dtype;
std::vector<int32_t> table_feat_conf_feat_shape;
// std::thread *server_thread, *client_thread;
// std::shared_ptr<paddle::distributed::PSServer> pserver_ptr;
// std::shared_ptr<paddle::distributed::PSClient> worker_ptr;
public:
// std::shared_ptr<paddle::distributed::PSServer> get_ps_server() {
// return pserver_ptr;
// }
// std::shared_ptr<paddle::distributed::PSClient> get_ps_client() {
// return worker_ptr;
// }
int get_shard_num() { return shard_num; }
void set_shard_num(int shard_num) { this->shard_num = shard_num; }
void GetDownpourSparseTableProto(
::paddle::distributed::TableParameter* sparse_table_proto,
uint32_t table_id, std::string table_name, std::string table_type,
std::vector<std::string> feat_name, std::vector<std::string> feat_dtype,
std::vector<int32_t> feat_shape) {
sparse_table_proto->set_table_id(table_id);
sparse_table_proto->set_table_class("GraphTable");
sparse_table_proto->set_shard_num(shard_num);
sparse_table_proto->set_type(::paddle::distributed::PS_SPARSE_TABLE);
::paddle::distributed::TableAccessorParameter* accessor_proto =
sparse_table_proto->mutable_accessor();
::paddle::distributed::CommonAccessorParameter* common_proto =
sparse_table_proto->mutable_common();
// Set GraphTable Parameter
common_proto->set_table_name(table_name);
common_proto->set_name(table_type);
for (size_t i = 0; i < feat_name.size(); i++) {
common_proto->add_params(feat_dtype[i]);
common_proto->add_dims(feat_shape[i]);
common_proto->add_attributes(feat_name[i]);
}
accessor_proto->set_accessor_class("CommMergeAccessor");
}
void set_server_size(int server_size) { this->server_size = server_size; }
void set_num_node_types(int num_node_types) {
this->num_node_types = num_node_types;
}
int get_server_size(int server_size) { return server_size; }
std::vector<std::string> split(std::string& str, const char pattern);
void set_up(std::string ips_str, int shard_num,
std::vector<std::string> node_types,
std::vector<std::string> edge_types);
void add_table_feat_conf(std::string node_type, std::string feat_name,
std::string feat_dtype, int32_t feat_shape);
};
class GraphPyServer : public GraphPyService {
public:
GraphPyServer() {}
void set_up(std::string ips_str, int shard_num,
std::vector<std::string> node_types,
std::vector<std::string> edge_types, int rank) {
set_rank(rank);
GraphPyService::set_up(ips_str, shard_num, node_types, edge_types);
}
int get_rank() { return rank; }
void set_rank(int rank) { this->rank = rank; }
void start_server(bool block = true);
::paddle::distributed::PSParameter GetServerProto();
std::shared_ptr<paddle::distributed::GraphBrpcServer> get_ps_server() {
return pserver_ptr;
}
protected:
int rank;
std::shared_ptr<paddle::distributed::GraphBrpcServer> pserver_ptr;
std::thread* server_thread;
};
class GraphPyClient : public GraphPyService {
public:
void set_up(std::string ips_str, int shard_num,
std::vector<std::string> node_types,
std::vector<std::string> edge_types, int client_id) {
set_client_id(client_id);
GraphPyService::set_up(ips_str, shard_num, node_types, edge_types);
}
std::shared_ptr<paddle::distributed::GraphBrpcClient> get_ps_client() {
return worker_ptr;
}
void bind_local_server(int local_channel_index, GraphPyServer& server) {
worker_ptr->set_local_channel(local_channel_index);
worker_ptr->set_local_graph_service(
(paddle::distributed::GraphBrpcService*)server.get_ps_server()
->get_service());
}
void stop_server();
void finalize_worker();
void load_edge_file(std::string name, std::string filepath, bool reverse);
void load_node_file(std::string name, std::string filepath);
int get_client_id() { return client_id; }
void set_client_id(int client_id) { this->client_id = client_id; }
void start_client();
std::vector<std::vector<std::pair<uint64_t, float>>> batch_sample_neighboors(
std::string name, std::vector<uint64_t> node_ids, int sample_size);
std::vector<uint64_t> random_sample_nodes(std::string name, int server_index,
int sample_size);
std::vector<std::vector<std::string>> get_node_feat(
std::string node_type, std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names);
std::vector<FeatureNode> pull_graph_list(std::string name, int server_index,
int start, int size, int step = 1);
::paddle::distributed::PSParameter GetWorkerProto();
protected:
mutable std::mutex mutex_;
int client_id;
std::shared_ptr<paddle::distributed::GraphBrpcClient> worker_ptr;
std::thread* client_thread;
bool stoped_ = false;
};
}
}
......@@ -15,12 +15,13 @@
#include "paddle/fluid/distributed/service/ps_client.h"
#include "glog/logging.h"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/service/graph_brpc_client.h"
#include "paddle/fluid/distributed/table/table.h"
namespace paddle {
namespace distributed {
REGISTER_PSCORE_CLASS(PSClient, BrpcPsClient);
REGISTER_PSCORE_CLASS(PSClient, GraphBrpcClient);
int32_t PSClient::configure(
const PSParameter &config,
const std::map<uint64_t, std::vector<paddle::distributed::Region>> &regions,
......
......@@ -24,16 +24,11 @@
#include "paddle/fluid/distributed/service/env.h"
#include "paddle/fluid/distributed/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/table/accessor.h"
#include "paddle/fluid/distributed/table/graph_node.h"
namespace paddle {
namespace distributed {
class PSEnvironment;
class PsRequestMessage;
class PsResponseMessage;
class ValueAccessor;
struct Region;
using paddle::distributed::PsRequestMessage;
using paddle::distributed::PsResponseMessage;
......@@ -160,6 +155,7 @@ class PSClient {
promise.set_value(-1);
return fut;
}
// client2client消息处理,std::function<int32_t (int, int, const std::string&)
// -> ret (msg_type, from_client_id, msg)
typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc;
......
......@@ -48,6 +48,10 @@ enum PsCmdID {
PS_START_PROFILER = 27;
PS_STOP_PROFILER = 28;
PS_PUSH_GLOBAL_STEP = 29;
PS_PULL_GRAPH_LIST = 30;
PS_GRAPH_SAMPLE_NEIGHBOORS = 31;
PS_GRAPH_SAMPLE_NODES = 32;
PS_GRAPH_GET_NODE_FEAT = 33;
}
message PsRequestMessage {
......
......@@ -16,6 +16,7 @@
#include "glog/logging.h"
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include "paddle/fluid/distributed/service/graph_brpc_server.h"
#include "paddle/fluid/distributed/table/table.h"
namespace paddle {
......@@ -23,6 +24,8 @@ namespace distributed {
REGISTER_PSCORE_CLASS(PSServer, BrpcPsServer);
REGISTER_PSCORE_CLASS(PsBaseService, BrpcPsService);
REGISTER_PSCORE_CLASS(PSServer, GraphBrpcServer);
REGISTER_PSCORE_CLASS(PsBaseService, GraphBrpcService);
PSServer *PSServerFactory::create(const PSParameter &ps_config) {
const auto &config = ps_config.server_param();
......
set_property(GLOBAL PROPERTY TABLE_DEPS string_helper)
get_property(TABLE_DEPS GLOBAL PROPERTY TABLE_DEPS)
set_source_files_properties(graph_edge.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(graph_edge SRCS graph_edge.cc)
set_source_files_properties(graph_weighted_sampler.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(WeightedSampler SRCS graph_weighted_sampler.cc DEPS graph_edge)
set_source_files_properties(graph_node.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(graph_node SRCS graph_node.cc DEPS WeightedSampler)
set_source_files_properties(common_dense_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(common_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(barrier_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(common_graph_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(common_table SRCS common_sparse_table.cc common_dense_table.cc sparse_geo_table.cc barrier_table.cc DEPS ${TABLE_DEPS} device_context string_helper simple_threadpool xxhash generator)
cc_library(common_table SRCS common_sparse_table.cc common_dense_table.cc sparse_geo_table.cc barrier_table.cc common_graph_table.cc DEPS ${TABLE_DEPS} graph_edge graph_node device_context string_helper simple_threadpool xxhash generator)
set_source_files_properties(tensor_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(tensor_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/table/common_graph_table.h"
#include <time.h>
#include <algorithm>
#include <set>
#include <sstream>
#include "paddle/fluid/distributed/common/utils.h"
#include "paddle/fluid/distributed/table/graph_node.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace distributed {
std::vector<Node *> GraphShard::get_batch(int start, int end, int step) {
if (start < 0) start = 0;
std::vector<Node *> res;
for (int pos = start; pos < std::min(end, (int)bucket.size()); pos += step) {
res.push_back(bucket[pos]);
}
return res;
}
size_t GraphShard::get_size() { return bucket.size(); }
GraphNode *GraphShard::add_graph_node(uint64_t id) {
if (node_location.find(id) == node_location.end()) {
node_location[id] = bucket.size();
bucket.push_back(new GraphNode(id));
}
return (GraphNode *)bucket[node_location[id]];
}
FeatureNode *GraphShard::add_feature_node(uint64_t id) {
if (node_location.find(id) == node_location.end()) {
node_location[id] = bucket.size();
bucket.push_back(new FeatureNode(id));
}
return (FeatureNode *)bucket[node_location[id]];
}
void GraphShard::add_neighboor(uint64_t id, uint64_t dst_id, float weight) {
find_node(id)->add_edge(dst_id, weight);
}
Node *GraphShard::find_node(uint64_t id) {
auto iter = node_location.find(id);
return iter == node_location.end() ? nullptr : bucket[iter->second];
}
int32_t GraphTable::load(const std::string &path, const std::string &param) {
bool load_edge = (param[0] == 'e');
bool load_node = (param[0] == 'n');
if (load_edge) {
bool reverse_edge = (param[1] == '<');
return this->load_edges(path, reverse_edge);
}
if (load_node) {
std::string node_type = param.substr(1);
return this->load_nodes(path, node_type);
}
return 0;
}
int32_t GraphTable::get_nodes_ids_by_ranges(
std::vector<std::pair<int, int>> ranges, std::vector<uint64_t> &res) {
int start = 0, end, index = 0, total_size = 0;
res.clear();
std::vector<std::future<std::vector<uint64_t>>> tasks;
// std::string temp = "";
// for(int i = 0;i < shards.size();i++)
// temp+= std::to_string((int)shards[i].get_size()) + " ";
// VLOG(0)<<"range distribution "<<temp;
for (int i = 0; i < shards.size() && index < ranges.size(); i++) {
end = total_size + shards[i].get_size();
start = total_size;
while (start < end && index < ranges.size()) {
if (ranges[index].second <= start)
index++;
else if (ranges[index].first >= end) {
break;
} else {
int first = std::max(ranges[index].first, start);
int second = std::min(ranges[index].second, end);
start = second;
first -= total_size;
second -= total_size;
// VLOG(0)<<" FIND RANGE "<<i<<" "<<first<<" "<<second;
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[this, first, second, i]() -> std::vector<uint64_t> {
return shards[i].get_ids_by_range(first, second);
}));
}
}
total_size += shards[i].get_size();
}
for (int i = 0; i < tasks.size(); i++) {
auto vec = tasks[i].get();
for (auto &id : vec) {
res.push_back(id);
std::swap(res[rand() % res.size()], res[(int)res.size() - 1]);
}
}
return 0;
}
int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
auto paths = paddle::string::split_string<std::string>(path, ";");
int64_t count = 0;
int64_t valid_count = 0;
for (auto path : paths) {
std::ifstream file(path);
std::string line;
while (std::getline(file, line)) {
count++;
auto values = paddle::string::split_string<std::string>(line, "\t");
if (values.size() < 2) continue;
auto id = std::stoull(values[1]);
size_t shard_id = id % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) {
VLOG(4) << "will not load " << id << " from " << path
<< ", please check id distribution";
continue;
}
if (count % 1000000 == 0) {
VLOG(0) << count << " nodes are loaded from filepath";
}
std::string nt = values[0];
if (nt != node_type) {
continue;
}
size_t index = shard_id - shard_start;
auto node = shards[index].add_feature_node(id);
node->set_feature_size(feat_name.size());
for (size_t slice = 2; slice < values.size(); slice++) {
auto feat = this->parse_feature(values[slice]);
if (feat.first >= 0) {
node->set_feature(feat.first, feat.second);
} else {
VLOG(4) << "Node feature: " << values[slice]
<< " not in feature_map.";
}
}
valid_count++;
}
}
VLOG(0) << valid_count << "/" << count << " nodes in type " << node_type
<< " are loaded successfully in " << path;
return 0;
}
int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
auto paths = paddle::string::split_string<std::string>(path, ";");
int count = 0;
std::string sample_type = "random";
bool is_weighted = false;
int valid_count = 0;
for (auto path : paths) {
std::ifstream file(path);
std::string line;
while (std::getline(file, line)) {
auto values = paddle::string::split_string<std::string>(line, "\t");
count++;
if (values.size() < 2) continue;
auto src_id = std::stoull(values[0]);
auto dst_id = std::stoull(values[1]);
if (reverse_edge) {
std::swap(src_id, dst_id);
}
float weight = 1;
if (values.size() == 3) {
weight = std::stof(values[2]);
sample_type = "weighted";
is_weighted = true;
}
size_t src_shard_id = src_id % shard_num;
if (src_shard_id >= shard_end || src_shard_id < shard_start) {
VLOG(4) << "will not load " << src_id << " from " << path
<< ", please check id distribution";
continue;
}
if (count % 1000000 == 0) {
VLOG(0) << count << " edges are loaded from filepath";
}
size_t index = src_shard_id - shard_start;
shards[index].add_graph_node(src_id)->build_edges(is_weighted);
shards[index].add_neighboor(src_id, dst_id, weight);
valid_count++;
}
}
VLOG(0) << valid_count << "/" << count << " edges are loaded successfully in "
<< path;
// Build Sampler j
for (auto &shard : shards) {
auto bucket = shard.get_bucket();
for (int i = 0; i < bucket.size(); i++) {
bucket[i]->build_sampler(sample_type);
}
}
return 0;
}
Node *GraphTable::find_node(uint64_t id) {
size_t shard_id = id % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) {
return nullptr;
}
size_t index = shard_id - shard_start;
Node *node = shards[index].find_node(id);
return node;
}
uint32_t GraphTable::get_thread_pool_index(uint64_t node_id) {
return node_id % shard_num % shard_num_per_table % task_pool_size_;
}
int32_t GraphTable::random_sample_nodes(int sample_size,
std::unique_ptr<char[]> &buffer,
int &actual_size) {
bool need_feature = false;
int total_size = 0;
for (int i = 0; i < shards.size(); i++) {
total_size += shards[i].get_size();
}
if (sample_size > total_size) sample_size = total_size;
int range_num = random_sample_nodes_ranges;
if (range_num > sample_size) range_num = sample_size;
if (sample_size == 0 || range_num == 0) return 0;
std::vector<int> ranges_len, ranges_pos;
int remain = sample_size, last_pos = -1, num;
std::set<int> separator_set;
for (int i = 0; i < range_num - 1; i++) {
while (separator_set.find(num = rand() % (sample_size - 1)) !=
separator_set.end())
;
separator_set.insert(num);
}
for (auto p : separator_set) {
ranges_len.push_back(p - last_pos);
last_pos = p;
}
ranges_len.push_back(sample_size - 1 - last_pos);
remain = total_size - sample_size + range_num;
separator_set.clear();
for (int i = 0; i < range_num; i++) {
while (separator_set.find(num = rand() % remain) != separator_set.end())
;
separator_set.insert(num);
}
int used = 0, index = 0;
last_pos = -1;
for (auto p : separator_set) {
used += p - last_pos - 1;
last_pos = p;
ranges_pos.push_back(used);
used += ranges_len[index++];
}
std::vector<std::pair<int, int>> first_half, second_half;
int start_index = rand() % total_size;
for (int i = 0; i < ranges_len.size() && i < ranges_pos.size(); i++) {
if (ranges_pos[i] + ranges_len[i] - 1 + start_index < total_size)
first_half.push_back({ranges_pos[i] + start_index,
ranges_pos[i] + ranges_len[i] + start_index});
else if (ranges_pos[i] + start_index >= total_size) {
second_half.push_back(
{ranges_pos[i] + start_index - total_size,
ranges_pos[i] + ranges_len[i] + start_index - total_size});
} else {
first_half.push_back({ranges_pos[i] + start_index, total_size});
second_half.push_back(
{0, ranges_pos[i] + ranges_len[i] + start_index - total_size});
}
}
for (auto &pair : first_half) second_half.push_back(pair);
std::vector<uint64_t> res;
get_nodes_ids_by_ranges(second_half, res);
actual_size = res.size() * sizeof(uint64_t);
buffer.reset(new char[actual_size]);
char *pointer = buffer.get();
memcpy(pointer, res.data(), actual_size);
return 0;
}
int32_t GraphTable::random_sample_neighboors(
uint64_t *node_ids, int sample_size,
std::vector<std::unique_ptr<char[]>> &buffers,
std::vector<int> &actual_sizes) {
size_t node_num = buffers.size();
std::vector<std::future<int>> tasks;
for (size_t idx = 0; idx < node_num; ++idx) {
uint64_t &node_id = node_ids[idx];
std::unique_ptr<char[]> &buffer = buffers[idx];
int &actual_size = actual_sizes[idx];
tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue(
[&]() -> int {
Node *node = find_node(node_id);
if (node == nullptr) {
actual_size = 0;
return 0;
}
std::vector<int> res = node->sample_k(sample_size);
actual_size = res.size() * (Node::id_size + Node::weight_size);
int offset = 0;
uint64_t id;
float weight;
char *buffer_addr = new char[actual_size];
buffer.reset(buffer_addr);
for (int &x : res) {
id = node->get_neighbor_id(x);
weight = node->get_neighbor_weight(x);
memcpy(buffer_addr + offset, &id, Node::id_size);
offset += Node::id_size;
memcpy(buffer_addr + offset, &weight, Node::weight_size);
offset += Node::weight_size;
}
return 0;
}));
}
for (size_t idx = 0; idx < node_num; ++idx) {
tasks[idx].get();
}
return 0;
}
int32_t GraphTable::get_node_feat(const std::vector<uint64_t> &node_ids,
const std::vector<std::string> &feature_names,
std::vector<std::vector<std::string>> &res) {
size_t node_num = node_ids.size();
std::vector<std::future<int>> tasks;
for (size_t idx = 0; idx < node_num; ++idx) {
uint64_t node_id = node_ids[idx];
tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue(
[&, idx, node_id]() -> int {
Node *node = find_node(node_id);
if (node == nullptr) {
return 0;
}
for (int feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
const std::string &feature_name = feature_names[feat_idx];
if (feat_id_map.find(feature_name) != feat_id_map.end()) {
// res[feat_idx][idx] =
// node->get_feature(feat_id_map[feature_name]);
auto feat = node->get_feature(feat_id_map[feature_name]);
res[feat_idx][idx] = feat;
}
}
return 0;
}));
}
for (size_t idx = 0; idx < node_num; ++idx) {
tasks[idx].get();
}
return 0;
}
std::pair<int32_t, std::string> GraphTable::parse_feature(
std::string feat_str) {
// Return (feat_id, btyes) if name are in this->feat_name, else return (-1,
// "")
auto fields = paddle::string::split_string<std::string>(feat_str, " ");
if (this->feat_id_map.count(fields[0])) {
int32_t id = this->feat_id_map[fields[0]];
std::string dtype = this->feat_dtype[id];
int32_t shape = this->feat_shape[id];
std::vector<std::string> values(fields.begin() + 1, fields.end());
if (dtype == "feasign") {
return std::make_pair<int32_t, std::string>(
int32_t(id), paddle::string::join_strings(values, ' '));
} else if (dtype == "string") {
return std::make_pair<int32_t, std::string>(
int32_t(id), paddle::string::join_strings(values, ' '));
} else if (dtype == "float32") {
return std::make_pair<int32_t, std::string>(
int32_t(id), FeatureNode::parse_value_to_bytes<float>(values));
} else if (dtype == "float64") {
return std::make_pair<int32_t, std::string>(
int32_t(id), FeatureNode::parse_value_to_bytes<double>(values));
} else if (dtype == "int32") {
return std::make_pair<int32_t, std::string>(
int32_t(id), FeatureNode::parse_value_to_bytes<int32_t>(values));
} else if (dtype == "int64") {
return std::make_pair<int32_t, std::string>(
int32_t(id), FeatureNode::parse_value_to_bytes<int64_t>(values));
}
}
return std::make_pair<int32_t, std::string>(-1, "");
}
int32_t GraphTable::pull_graph_list(int start, int total_size,
std::unique_ptr<char[]> &buffer,
int &actual_size, bool need_feature,
int step) {
if (start < 0) start = 0;
int size = 0, cur_size;
std::vector<std::future<std::vector<Node *>>> tasks;
for (size_t i = 0; i < shards.size() && total_size > 0; i++) {
cur_size = shards[i].get_size();
if (size + cur_size <= start) {
size += cur_size;
continue;
}
int count = std::min(1 + (size + cur_size - start - 1) / step, total_size);
int end = start + (count - 1) * step + 1;
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[this, i, start, end, step, size]() -> std::vector<Node *> {
return this->shards[i].get_batch(start - size, end - size, step);
}));
start += count * step;
total_size -= count;
size += cur_size;
}
for (size_t i = 0; i < tasks.size(); ++i) {
tasks[i].wait();
}
size = 0;
std::vector<std::vector<Node *>> res;
for (size_t i = 0; i < tasks.size(); i++) {
res.push_back(tasks[i].get());
for (size_t j = 0; j < res.back().size(); j++) {
size += res.back()[j]->get_size(need_feature);
}
}
char *buffer_addr = new char[size];
buffer.reset(buffer_addr);
int index = 0;
for (size_t i = 0; i < res.size(); i++) {
for (size_t j = 0; j < res[i].size(); j++) {
res[i][j]->to_buffer(buffer_addr + index, need_feature);
index += res[i][j]->get_size(need_feature);
}
}
actual_size = size;
return 0;
}
int32_t GraphTable::initialize() {
_shards_task_pool.resize(task_pool_size_);
for (size_t i = 0; i < _shards_task_pool.size(); ++i) {
_shards_task_pool[i].reset(new ::ThreadPool(1));
}
server_num = _shard_num;
// VLOG(0) << "in init graph table server num = " << server_num;
/*
_shard_num is actually server number here
when a server initialize its tables, it sets tables' _shard_num to server_num,
and _shard_idx to server
rank
*/
auto common = _config.common();
this->table_name = common.table_name();
this->table_type = common.name();
VLOG(0) << " init graph table type " << this->table_type << " table name "
<< this->table_name;
int feat_conf_size = static_cast<int>(common.attributes().size());
for (int i = 0; i < feat_conf_size; i++) {
auto &f_name = common.attributes()[i];
auto &f_shape = common.dims()[i];
auto &f_dtype = common.params()[i];
this->feat_name.push_back(f_name);
this->feat_shape.push_back(f_shape);
this->feat_dtype.push_back(f_dtype);
this->feat_id_map[f_name] = i;
VLOG(0) << "init graph table feat conf name:" << f_name
<< " shape:" << f_shape << " dtype:" << f_dtype;
}
shard_num = _config.shard_num();
VLOG(0) << "in init graph table shard num = " << shard_num << " shard_idx"
<< _shard_idx;
shard_num_per_table = sparse_local_shard_num(shard_num, server_num);
shard_start = _shard_idx * shard_num_per_table;
shard_end = shard_start + shard_num_per_table;
VLOG(0) << "in init graph table shard idx = " << _shard_idx << " shard_start "
<< shard_start << " shard_end " << shard_end;
// shards.resize(shard_num_per_table);
shards = std::vector<GraphShard>(shard_num_per_table, GraphShard(shard_num));
return 0;
}
}
};
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <assert.h>
#include <pthread.h>
#include <list>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/distributed/table/accessor.h"
#include "paddle/fluid/distributed/table/common_table.h"
#include "paddle/fluid/distributed/table/graph_node.h"
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace distributed {
class GraphShard {
public:
// static int bucket_low_bound;
// static int gcd(int s, int t) {
// if (s % t == 0) return t;
// return gcd(t, s % t);
// }
size_t get_size();
GraphShard() {}
GraphShard(int shard_num) {
this->shard_num = shard_num;
// bucket_size = init_bucket_size(shard_num);
// bucket.resize(bucket_size);
}
std::vector<Node *> &get_bucket() { return bucket; }
std::vector<Node *> get_batch(int start, int end, int step);
// int init_bucket_size(int shard_num) {
// for (int i = bucket_low_bound;; i++) {
// if (gcd(i, shard_num) == 1) return i;
// }
// return -1;
// }
std::vector<uint64_t> get_ids_by_range(int start, int end) {
std::vector<uint64_t> res;
for (int i = start; i < end && i < bucket.size(); i++) {
res.push_back(bucket[i]->get_id());
}
return res;
}
GraphNode *add_graph_node(uint64_t id);
FeatureNode *add_feature_node(uint64_t id);
Node *find_node(uint64_t id);
void add_neighboor(uint64_t id, uint64_t dst_id, float weight);
// std::unordered_map<uint64_t, std::list<GraphNode *>::iterator>
std::unordered_map<uint64_t, int> get_node_location() {
return node_location;
}
private:
std::unordered_map<uint64_t, int> node_location;
int shard_num;
std::vector<Node *> bucket;
};
class GraphTable : public SparseTable {
public:
GraphTable() {}
virtual ~GraphTable() {}
virtual int32_t pull_graph_list(int start, int size,
std::unique_ptr<char[]> &buffer,
int &actual_size, bool need_feature,
int step);
virtual int32_t random_sample_neighboors(
uint64_t *node_ids, int sample_size,
std::vector<std::unique_ptr<char[]>> &buffers,
std::vector<int> &actual_sizes);
int32_t random_sample_nodes(int sample_size, std::unique_ptr<char[]> &buffers,
int &actual_sizes);
virtual int32_t get_nodes_ids_by_ranges(
std::vector<std::pair<int, int>> ranges, std::vector<uint64_t> &res);
virtual int32_t initialize();
int32_t load(const std::string &path, const std::string &param);
int32_t load_edges(const std::string &path, bool reverse);
int32_t load_nodes(const std::string &path, std::string node_type);
Node *find_node(uint64_t id);
virtual int32_t pull_sparse(float *values, const uint64_t *keys, size_t num) {
return 0;
}
virtual int32_t push_sparse(const uint64_t *keys, const float *values,
size_t num) {
return 0;
}
virtual void clear() {}
virtual int32_t flush() { return 0; }
virtual int32_t shrink(const std::string &param) { return 0; }
//指定保存路径
virtual int32_t save(const std::string &path, const std::string &converter) {
return 0;
}
virtual int32_t initialize_shard() { return 0; }
virtual uint32_t get_thread_pool_index(uint64_t node_id);
virtual std::pair<int32_t, std::string> parse_feature(std::string feat_str);
virtual int32_t get_node_feat(const std::vector<uint64_t> &node_ids,
const std::vector<std::string> &feature_names,
std::vector<std::vector<std::string>> &res);
protected:
std::vector<GraphShard> shards;
size_t shard_start, shard_end, server_num, shard_num_per_table, shard_num;
const int task_pool_size_ = 11;
const int random_sample_nodes_ranges = 3;
std::vector<std::string> feat_name;
std::vector<std::string> feat_dtype;
std::vector<int32_t> feat_shape;
std::unordered_map<std::string, int32_t> feat_id_map;
std::string table_name;
std::string table_type;
std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
};
}
};
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/table/graph_edge.h"
#include <cstring>
namespace paddle {
namespace distributed {
void GraphEdgeBlob::add_edge(uint64_t id, float weight = 1) {
id_arr.push_back(id);
}
void WeightedGraphEdgeBlob::add_edge(uint64_t id, float weight = 1) {
id_arr.push_back(id);
weight_arr.push_back(weight);
}
}
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstddef>
#include <cstdint>
#include <vector>
namespace paddle {
namespace distributed {
class GraphEdgeBlob {
public:
GraphEdgeBlob() {}
virtual ~GraphEdgeBlob() {}
size_t size() { return id_arr.size(); }
virtual void add_edge(uint64_t id, float weight);
uint64_t get_id(int idx) { return id_arr[idx]; }
virtual float get_weight(int idx) { return 1; }
protected:
std::vector<uint64_t> id_arr;
};
class WeightedGraphEdgeBlob : public GraphEdgeBlob {
public:
WeightedGraphEdgeBlob() {}
virtual ~WeightedGraphEdgeBlob() {}
virtual void add_edge(uint64_t id, float weight);
virtual float get_weight(int idx) { return weight_arr[idx]; }
protected:
std::vector<float> weight_arr;
};
}
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/table/graph_node.h"
#include <cstring>
namespace paddle {
namespace distributed {
GraphNode::~GraphNode() {
if (sampler != nullptr) {
delete sampler;
sampler = nullptr;
}
if (edges != nullptr) {
delete edges;
edges = nullptr;
}
}
int Node::weight_size = sizeof(float);
int Node::id_size = sizeof(uint64_t);
int Node::int_size = sizeof(int);
int Node::get_size(bool need_feature) { return id_size + int_size; }
void Node::to_buffer(char* buffer, bool need_feature) {
memcpy(buffer, &id, id_size);
buffer += id_size;
int feat_num = 0;
memcpy(buffer, &feat_num, sizeof(int));
}
void Node::recover_from_buffer(char* buffer) { memcpy(&id, buffer, id_size); }
int FeatureNode::get_size(bool need_feature) {
int size = id_size + int_size; // id, feat_num
if (need_feature) {
size += feature.size() * int_size;
for (const std::string& fea : feature) {
size += fea.size();
}
}
return size;
}
void GraphNode::build_edges(bool is_weighted) {
if (edges == nullptr) {
if (is_weighted == true) {
edges = new WeightedGraphEdgeBlob();
} else {
edges = new GraphEdgeBlob();
}
}
}
void GraphNode::build_sampler(std::string sample_type) {
if (sample_type == "random") {
sampler = new RandomSampler();
} else if (sample_type == "weighted") {
sampler = new WeightedSampler();
}
sampler->build(edges);
}
void FeatureNode::to_buffer(char* buffer, bool need_feature) {
memcpy(buffer, &id, id_size);
buffer += id_size;
int feat_num = 0;
int feat_len;
if (need_feature) {
feat_num += feature.size();
memcpy(buffer, &feat_num, sizeof(int));
buffer += sizeof(int);
for (int i = 0; i < feat_num; ++i) {
feat_len = feature[i].size();
memcpy(buffer, &feat_len, sizeof(int));
buffer += sizeof(int);
memcpy(buffer, feature[i].c_str(), feature[i].size());
buffer += feature[i].size();
}
} else {
memcpy(buffer, &feat_num, sizeof(int));
}
}
void FeatureNode::recover_from_buffer(char* buffer) {
int feat_num, feat_len;
memcpy(&id, buffer, id_size);
buffer += id_size;
memcpy(&feat_num, buffer, sizeof(int));
buffer += sizeof(int);
feature.clear();
for (int i = 0; i < feat_num; ++i) {
memcpy(&feat_len, buffer, sizeof(int));
buffer += sizeof(int);
char str[feat_len + 1];
memcpy(str, buffer, feat_len);
buffer += feat_len;
str[feat_len] = '\0';
feature.push_back(std::string(str));
}
}
}
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstring>
#include <iostream>
#include <sstream>
#include <vector>
#include "paddle/fluid/distributed/table/graph_weighted_sampler.h"
namespace paddle {
namespace distributed {
class Node {
public:
Node() {}
Node(uint64_t id) : id(id) {}
virtual ~Node() {}
static int id_size, int_size, weight_size;
uint64_t get_id() { return id; }
void set_id(uint64_t id) { this->id = id; }
virtual void build_edges(bool is_weighted) {}
virtual void build_sampler(std::string sample_type) {}
virtual void add_edge(uint64_t id, float weight) {}
virtual std::vector<int> sample_k(int k) { return std::vector<int>(); }
virtual uint64_t get_neighbor_id(int idx) { return 0; }
virtual float get_neighbor_weight(int idx) { return 1.; }
virtual int get_size(bool need_feature);
virtual void to_buffer(char *buffer, bool need_feature);
virtual void recover_from_buffer(char *buffer);
virtual std::string get_feature(int idx) { return std::string(""); }
virtual void set_feature(int idx, std::string str) {}
virtual void set_feature_size(int size) {}
virtual int get_feature_size() { return 0; }
protected:
uint64_t id;
};
class GraphNode : public Node {
public:
GraphNode() : Node(), sampler(nullptr), edges(nullptr) {}
GraphNode(uint64_t id) : Node(id), sampler(nullptr), edges(nullptr) {}
virtual ~GraphNode();
virtual void build_edges(bool is_weighted);
virtual void build_sampler(std::string sample_type);
virtual void add_edge(uint64_t id, float weight) {
edges->add_edge(id, weight);
}
virtual std::vector<int> sample_k(int k) { return sampler->sample_k(k); }
virtual uint64_t get_neighbor_id(int idx) { return edges->get_id(idx); }
virtual float get_neighbor_weight(int idx) { return edges->get_weight(idx); }
protected:
Sampler *sampler;
GraphEdgeBlob *edges;
};
class FeatureNode : public Node {
public:
FeatureNode() : Node() {}
FeatureNode(uint64_t id) : Node(id) {}
virtual ~FeatureNode() {}
virtual int get_size(bool need_feature);
virtual void to_buffer(char *buffer, bool need_feature);
virtual void recover_from_buffer(char *buffer);
virtual std::string get_feature(int idx) {
if (idx < (int)this->feature.size()) {
return this->feature[idx];
} else {
return std::string("");
}
}
virtual void set_feature(int idx, std::string str) {
if (idx >= (int)this->feature.size()) {
this->feature.resize(idx + 1);
}
this->feature[idx] = str;
}
virtual void set_feature_size(int size) { this->feature.resize(size); }
virtual int get_feature_size() { return this->feature.size(); }
template <typename T>
static std::string parse_value_to_bytes(std::vector<std::string> feat_str) {
T v;
size_t Tsize = sizeof(T) * feat_str.size();
char buffer[Tsize];
for (size_t i = 0; i < feat_str.size(); i++) {
std::stringstream ss(feat_str[i]);
ss >> v;
std::memcpy(buffer + sizeof(T) * i, (char *)&v, sizeof(T));
}
return std::string(buffer, Tsize);
}
template <typename T>
static std::vector<T> parse_bytes_to_array(std::string feat_str) {
T v;
std::vector<T> out;
size_t start = 0;
const char *buffer = feat_str.data();
while (start < feat_str.size()) {
std::memcpy((char *)&v, buffer + start, sizeof(T));
start += sizeof(T);
out.push_back(v);
}
return out;
}
protected:
std::vector<std::string> feature;
};
}
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/table/graph_weighted_sampler.h"
#include <iostream>
#include <unordered_map>
namespace paddle {
namespace distributed {
void RandomSampler::build(GraphEdgeBlob *edges) { this->edges = edges; }
std::vector<int> RandomSampler::sample_k(int k) {
int n = edges->size();
if (k > n) {
k = n;
}
struct timespec tn;
clock_gettime(CLOCK_REALTIME, &tn);
srand(tn.tv_nsec);
std::vector<int> sample_result;
std::unordered_map<int, int> replace_map;
while (k--) {
int rand_int = rand() % n;
auto iter = replace_map.find(rand_int);
if (iter == replace_map.end()) {
sample_result.push_back(rand_int);
} else {
sample_result.push_back(iter->second);
}
iter = replace_map.find(n - 1);
if (iter == replace_map.end()) {
replace_map[rand_int] = n - 1;
} else {
replace_map[rand_int] = iter->second;
}
--n;
}
return sample_result;
}
WeightedSampler::WeightedSampler() {
left = nullptr;
right = nullptr;
edges = nullptr;
}
WeightedSampler::~WeightedSampler() {
if (left != nullptr) {
delete left;
left = nullptr;
}
if (right != nullptr) {
delete right;
right = nullptr;
}
}
void WeightedSampler::build(GraphEdgeBlob *edges) {
if (left != nullptr) {
delete left;
left = nullptr;
}
if (right != nullptr) {
delete right;
right = nullptr;
}
return build_one((WeightedGraphEdgeBlob *)edges, 0, edges->size());
}
void WeightedSampler::build_one(WeightedGraphEdgeBlob *edges, int start,
int end) {
count = 0;
this->edges = edges;
if (start + 1 == end) {
left = right = nullptr;
idx = start;
count = 1;
weight = edges->get_weight(idx);
} else {
left = new WeightedSampler();
right = new WeightedSampler();
left->build_one(edges, start, start + (end - start) / 2);
right->build_one(edges, start + (end - start) / 2, end);
weight = left->weight + right->weight;
count = left->count + right->count;
}
}
std::vector<int> WeightedSampler::sample_k(int k) {
if (k > count) {
k = count;
}
std::vector<int> sample_result;
float subtract;
std::unordered_map<WeightedSampler *, float> subtract_weight_map;
std::unordered_map<WeightedSampler *, int> subtract_count_map;
struct timespec tn;
clock_gettime(CLOCK_REALTIME, &tn);
srand(tn.tv_nsec);
while (k--) {
float query_weight = rand() % 100000 / 100000.0;
query_weight *= weight - subtract_weight_map[this];
sample_result.push_back(sample(query_weight, subtract_weight_map,
subtract_count_map, subtract));
}
return sample_result;
}
int WeightedSampler::sample(
float query_weight,
std::unordered_map<WeightedSampler *, float> &subtract_weight_map,
std::unordered_map<WeightedSampler *, int> &subtract_count_map,
float &subtract) {
if (left == nullptr) {
subtract_weight_map[this] = weight;
subtract = weight;
subtract_count_map[this] = 1;
return idx;
}
int left_count = left->count - subtract_count_map[left];
int right_count = right->count - subtract_count_map[right];
float left_subtract = subtract_weight_map[left];
int return_idx;
if (right_count == 0 ||
left_count > 0 && left->weight - left_subtract >= query_weight) {
return_idx = left->sample(query_weight, subtract_weight_map,
subtract_count_map, subtract);
} else {
return_idx =
right->sample(query_weight - (left->weight - left_subtract),
subtract_weight_map, subtract_count_map, subtract);
}
subtract_weight_map[this] += subtract;
subtract_count_map[this]++;
return return_idx;
}
}
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ctime>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/table/graph_edge.h"
namespace paddle {
namespace distributed {
class Sampler {
public:
virtual ~Sampler() {}
virtual void build(GraphEdgeBlob *edges) = 0;
virtual std::vector<int> sample_k(int k) = 0;
};
class RandomSampler : public Sampler {
public:
virtual ~RandomSampler() {}
virtual void build(GraphEdgeBlob *edges);
virtual std::vector<int> sample_k(int k);
GraphEdgeBlob *edges;
};
class WeightedSampler : public Sampler {
public:
WeightedSampler();
virtual ~WeightedSampler();
WeightedSampler *left, *right;
float weight;
int count;
int idx;
GraphEdgeBlob *edges;
virtual void build(GraphEdgeBlob *edges);
virtual void build_one(WeightedGraphEdgeBlob *edges, int start, int end);
virtual std::vector<int> sample_k(int k);
private:
int sample(float query_weight,
std::unordered_map<WeightedSampler *, float> &subtract_weight_map,
std::unordered_map<WeightedSampler *, int> &subtract_count_map,
float &subtract);
};
}
}
......@@ -18,6 +18,7 @@
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/table/common_dense_table.h"
#include "paddle/fluid/distributed/table/common_graph_table.h"
#include "paddle/fluid/distributed/table/common_sparse_table.h"
#include "paddle/fluid/distributed/table/sparse_geo_table.h"
#include "paddle/fluid/distributed/table/tensor_accessor.h"
......@@ -25,7 +26,7 @@
namespace paddle {
namespace distributed {
REGISTER_PSCORE_CLASS(Table, GraphTable);
REGISTER_PSCORE_CLASS(Table, CommonDenseTable);
REGISTER_PSCORE_CLASS(Table, CommonSparseTable);
REGISTER_PSCORE_CLASS(Table, SparseGeoTable);
......@@ -75,5 +76,6 @@ int32_t Table::initialize_accessor() {
_value_accesor.reset(accessor);
return 0;
}
} // namespace distributed
} // namespace paddle
......@@ -21,6 +21,7 @@
#include <string>
#include <utility>
#include "paddle/fluid/distributed/table/accessor.h"
#include "paddle/fluid/distributed/table/graph_node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"
......@@ -86,6 +87,31 @@ class Table {
return 0;
}
// only for graph table
virtual int32_t pull_graph_list(int start, int total_size,
std::unique_ptr<char[]> &buffer,
int &actual_size, bool need_feature,
int step = 1) {
return 0;
}
// only for graph table
virtual int32_t random_sample_neighboors(
uint64_t *node_ids, int sample_size,
std::vector<std::unique_ptr<char[]>> &buffers,
std::vector<int> &actual_sizes) {
return 0;
}
virtual int32_t random_sample_nodes(int sample_size,
std::unique_ptr<char[]> &buffers,
int &actual_sizes) {
return 0;
}
virtual int32_t get_node_feat(const std::vector<uint64_t> &node_ids,
const std::vector<std::string> &feature_names,
std::vector<std::vector<std::string>> &res) {
return 0;
}
virtual int32_t pour() { return 0; }
virtual void clear() = 0;
......@@ -141,5 +167,6 @@ class TableManager {
TableManager() {}
~TableManager() {}
};
} // namespace distributed
} // namespace paddle
......@@ -15,3 +15,6 @@ cc_test(brpc_service_sparse_sgd_test SRCS brpc_service_sparse_sgd_test.cc DEPS s
set_source_files_properties(brpc_utils_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(brpc_utils_test SRCS brpc_utils_test.cc DEPS brpc_utils scope math_function ${COMMON_DEPS} ${RPC_DEPS})
set_source_files_properties(graph_node_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(graph_node_test SRCS graph_node_test.cc DEPS graph_py_service scope server client communicator ps_service boost table ps_framework_proto ${COMMON_DEPS})
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -x
cd `dirname $0`
rm -rf build/ data/
......
......@@ -7,6 +7,10 @@ set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapp
analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context
gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper custom_operator)
if (WITH_PSCORE)
set(PYBIND_DEPS ${PYBIND_DEPS} ps_service)
set(PYBIND_DEPS ${PYBIND_DEPS} graph_py_service)
endif()
if (WITH_GPU OR WITH_ROCM)
set(PYBIND_DEPS ${PYBIND_DEPS} dynload_cuda)
set(PYBIND_DEPS ${PYBIND_DEPS} cuda_device_guard)
......
......@@ -32,6 +32,8 @@ limitations under the License. */
#include "paddle/fluid/distributed/fleet.h"
#include "paddle/fluid/distributed/service/communicator.h"
#include "paddle/fluid/distributed/service/env.h"
#include "paddle/fluid/distributed/service/graph_brpc_client.h"
#include "paddle/fluid/distributed/service/graph_py_service.h"
#include "paddle/fluid/distributed/service/heter_client.h"
namespace py = pybind11;
......@@ -39,6 +41,11 @@ using paddle::distributed::CommContext;
using paddle::distributed::Communicator;
using paddle::distributed::FleetWrapper;
using paddle::distributed::HeterClient;
using paddle::distributed::GraphPyService;
using paddle::distributed::GraphNode;
using paddle::distributed::GraphPyServer;
using paddle::distributed::GraphPyClient;
using paddle::distributed::FeatureNode;
namespace paddle {
namespace pybind {
......@@ -152,5 +159,58 @@ void BindHeterClient(py::module* m) {
.def("stop", &HeterClient::Stop);
}
void BindGraphNode(py::module* m) {
py::class_<GraphNode>(*m, "GraphNode")
.def(py::init<>())
.def("get_id", &GraphNode::get_id)
.def("get_feature", &GraphNode::get_feature);
}
void BindGraphPyFeatureNode(py::module* m) {
py::class_<FeatureNode>(*m, "FeatureNode")
.def(py::init<>())
.def("get_id", &GraphNode::get_id)
.def("get_feature", &GraphNode::get_feature);
}
void BindGraphPyService(py::module* m) {
py::class_<GraphPyService>(*m, "GraphPyService").def(py::init<>());
}
void BindGraphPyServer(py::module* m) {
py::class_<GraphPyServer>(*m, "GraphPyServer")
.def(py::init<>())
.def("start_server", &GraphPyServer::start_server)
.def("set_up", &GraphPyServer::set_up)
.def("add_table_feat_conf", &GraphPyServer::add_table_feat_conf);
}
void BindGraphPyClient(py::module* m) {
py::class_<GraphPyClient>(*m, "GraphPyClient")
.def(py::init<>())
.def("load_edge_file", &GraphPyClient::load_edge_file)
.def("load_node_file", &GraphPyClient::load_node_file)
.def("set_up", &GraphPyClient::set_up)
.def("add_table_feat_conf", &GraphPyClient::add_table_feat_conf)
.def("pull_graph_list", &GraphPyClient::pull_graph_list)
.def("start_client", &GraphPyClient::start_client)
.def("batch_sample_neighboors", &GraphPyClient::batch_sample_neighboors)
.def("random_sample_nodes", &GraphPyClient::random_sample_nodes)
.def("stop_server", &GraphPyClient::stop_server)
.def("get_node_feat",
[](GraphPyClient& self, std::string node_type,
std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names) {
auto feats =
self.get_node_feat(node_type, node_ids, feature_names);
std::vector<std::vector<py::bytes>> bytes_feats(feats.size());
for (int i = 0; i < feats.size(); ++i) {
for (int j = 0; j < feats[i].size(); ++j) {
bytes_feats[i].push_back(py::bytes(feats[i][j]));
}
}
return bytes_feats;
})
.def("bind_local_server", &GraphPyClient::bind_local_server);
}
} // end namespace pybind
} // namespace paddle
......@@ -27,6 +27,10 @@ void BindPSHost(py::module* m);
void BindCommunicatorContext(py::module* m);
void BindDistCommunicator(py::module* m);
void BindHeterClient(py::module* m);
void BindGraphNode(py::module* m);
void BindGraphPyService(py::module* m);
void BindGraphPyFeatureNode(py::module* m);
void BindGraphPyServer(py::module* m);
void BindGraphPyClient(py::module* m);
} // namespace pybind
} // namespace paddle
......@@ -2896,6 +2896,11 @@ All parameter, weight, gradient are variables in Paddle.
BindCommunicatorContext(&m);
BindDistCommunicator(&m);
BindHeterClient(&m);
BindGraphPyFeatureNode(&m);
BindGraphNode(&m);
BindGraphPyService(&m);
BindGraphPyServer(&m);
BindGraphPyClient(&m);
#endif
}
} // namespace pybind
......
#!/bin/sh
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -xe
REPO="${REPO:-paddlepaddle}"
......
#!bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# bash/zsh completion support for core Git.
#
......
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
## purple to echo
function purple(){
echo -e "\033[35m$1\033[0m"
......
......@@ -27,8 +27,8 @@ class _DatasetFetcher(object):
class _IterableDatasetFetcher(_DatasetFetcher):
def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last):
super(_IterableDatasetFetcher, self).__init__(dataset, auto_collate_batch,
collate_fn, drop_last)
super(_IterableDatasetFetcher, self).__init__(
dataset, auto_collate_batch, collate_fn, drop_last)
self.dataset_iter = iter(dataset)
def fetch(self, batch_indices):
......@@ -53,7 +53,8 @@ class _IterableDatasetFetcher(_DatasetFetcher):
class _MapDatasetFetcher(_DatasetFetcher):
def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collate_batch, collate_fn, drop_last)
super(_MapDatasetFetcher, self).__init__(dataset, auto_collate_batch,
collate_fn, drop_last)
def fetch(self, batch_indices):
if self.auto_collate_batch:
......
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# start pserver0
python fleet_deep_ctr.py \
--role pserver \
......
......@@ -40,9 +40,11 @@ class SquaredMatSubFusePassTest(InferencePassTest):
matmul_ab_square = paddle.square(matmul_ab)
matmul_square_ab = paddle.matmul(data_a_square, data_b_square)
scale = paddle.fluid.layers.fill_constant(shape=[1], value=0.5, dtype='float32')
scale = paddle.fluid.layers.fill_constant(
shape=[1], value=0.5, dtype='float32')
sub_val = paddle.fluid.layers.elementwise_sub(matmul_ab_square, matmul_square_ab)
sub_val = paddle.fluid.layers.elementwise_sub(matmul_ab_square,
matmul_square_ab)
squared_mat_sub_out = fluid.layers.elementwise_mul(sub_val, scale)
self.feeds = {
......
......@@ -25,19 +25,16 @@ class TensorRTMatMulDims2Test(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[24, 24], dtype="float32")
data = fluid.data(name="data", shape=[24, 24], dtype="float32")
matmul_out = fluid.layers.matmul(
x=data,
y=data,
transpose_x = self.transpose_x,
transpose_y = self.transpose_y,
alpha = self.alpha)
transpose_x=self.transpose_x,
transpose_y=self.transpose_y,
alpha=self.alpha)
out = fluid.layers.batch_norm(matmul_out, is_test=True)
self.feeds = {
"data": np.ones([24, 24]).astype("float32"),
}
self.feeds = {"data": np.ones([24, 24]).astype("float32"), }
self.enable_trt = True
self.trt_parameters = TensorRTMatMulDims2Test.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False)
......@@ -65,14 +62,12 @@ class TensorRTMatMulTest(InferencePassTest):
matmul_out = fluid.layers.matmul(
x=data,
y=data,
transpose_x = self.transpose_x,
transpose_y = self.transpose_y,
alpha = self.alpha)
transpose_x=self.transpose_x,
transpose_y=self.transpose_y,
alpha=self.alpha)
out = fluid.layers.batch_norm(matmul_out, is_test=True)
self.feeds = {
"data": np.ones([1, 6, 24, 24]).astype("float32"),
}
self.feeds = {"data": np.ones([1, 6, 24, 24]).astype("float32"), }
self.enable_trt = True
self.trt_parameters = TensorRTMatMulTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False)
......
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
unset https_proxy http_proxy
export FLAGS_rpc_disable_reuse_port=1
......
......@@ -27,8 +27,10 @@ def test_static_layer(place,
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(prog, startup_prog):
input = paddle.fluid.data(name='input', shape=input_np.shape, dtype='float64')
label = paddle.fluid.data(name='label', shape=label_np.shape, dtype='float64')
input = paddle.fluid.data(
name='input', shape=input_np.shape, dtype='float64')
label = paddle.fluid.data(
name='label', shape=label_np.shape, dtype='float64')
if weight_np is not None:
weight = paddle.fluid.data(
name='weight', shape=weight_np.shape, dtype='float64')
......@@ -58,8 +60,10 @@ def test_static_functional(place,
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(prog, startup_prog):
input = paddle.fluid.data(name='input', shape=input_np.shape, dtype='float64')
label = paddle.fluid.data(name='label', shape=label_np.shape, dtype='float64')
input = paddle.fluid.data(
name='input', shape=input_np.shape, dtype='float64')
label = paddle.fluid.data(
name='label', shape=label_np.shape, dtype='float64')
if weight_np is not None:
weight = paddle.fluid.data(
name='weight', shape=weight_np.shape, dtype='float64')
......
......@@ -48,8 +48,10 @@ def test_static(place,
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(prog, startup_prog):
logit = paddle.fluid.data(name='logit', shape=logit_np.shape, dtype='float64')
label = paddle.fluid.data(name='label', shape=label_np.shape, dtype='float64')
logit = paddle.fluid.data(
name='logit', shape=logit_np.shape, dtype='float64')
label = paddle.fluid.data(
name='label', shape=label_np.shape, dtype='float64')
feed_dict = {"logit": logit_np, "label": label_np}
pos_weight = None
......
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
# use default values
# FIXME: random fails on Unknown command lines -c (or -m).
......
......@@ -23,7 +23,6 @@ import os
paddle.enable_static()
# For Net
base_lr = 0.2
emb_lr = base_lr * 3
......
......@@ -170,7 +170,8 @@ class TestFlatten2OpError(unittest.TestCase):
x2 = np.arange(image_shape[0] * image_shape[1] * image_shape[2] *
image_shape[3]).reshape(image_shape) / 100.
x2 = x2.astype('float16')
x2_var = paddle.fluid.data(name='x2', shape=[3, 2, 4, 5], dtype='float16')
x2_var = paddle.fluid.data(
name='x2', shape=[3, 2, 4, 5], dtype='float16')
paddle.flatten(x2_var)
self.assertRaises(TypeError, test_type)
......
......@@ -44,8 +44,10 @@ class TestFunctionalL1Loss(unittest.TestCase):
self.assertTrue(dy_result.shape, [10, 10, 5])
def run_static(self, use_gpu=False):
input = paddle.fluid.data(name='input', shape=[10, 10, 5], dtype='float32')
label = paddle.fluid.data(name='label', shape=[10, 10, 5], dtype='float32')
input = paddle.fluid.data(
name='input', shape=[10, 10, 5], dtype='float32')
label = paddle.fluid.data(
name='label', shape=[10, 10, 5], dtype='float32')
result0 = paddle.nn.functional.l1_loss(input, label)
result1 = paddle.nn.functional.l1_loss(input, label, reduction='sum')
result2 = paddle.nn.functional.l1_loss(input, label, reduction='none')
......@@ -127,8 +129,10 @@ class TestClassL1Loss(unittest.TestCase):
self.assertTrue(dy_result.shape, [10, 10, 5])
def run_static(self, use_gpu=False):
input = paddle.fluid.data(name='input', shape=[10, 10, 5], dtype='float32')
label = paddle.fluid.data(name='label', shape=[10, 10, 5], dtype='float32')
input = paddle.fluid.data(
name='input', shape=[10, 10, 5], dtype='float32')
label = paddle.fluid.data(
name='label', shape=[10, 10, 5], dtype='float32')
l1_loss = paddle.nn.loss.L1Loss()
result0 = l1_loss(input, label)
l1_loss = paddle.nn.loss.L1Loss(reduction='sum')
......
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
unset https_proxy http_proxy
nohup python -u test_listen_and_serv_op.py > test_listen_and_serv_op.log 2>&1 &
......
......@@ -191,8 +191,10 @@ class TestNNFunctionalMseLoss(unittest.TestCase):
place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda(
) else paddle.CPUPlace()
with paddle.static.program_guard(prog, startup_prog):
input = paddle.fluid.data(name='input', shape=dim, dtype='float32')
target = paddle.fluid.data(name='target', shape=dim, dtype='float32')
input = paddle.fluid.data(
name='input', shape=dim, dtype='float32')
target = paddle.fluid.data(
name='target', shape=dim, dtype='float32')
mse_loss = paddle.nn.functional.mse_loss(input, target, 'mean')
exe = paddle.static.Executor(place)
......@@ -225,8 +227,10 @@ class TestNNFunctionalMseLoss(unittest.TestCase):
place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda(
) else paddle.CPUPlace()
with paddle.static.program_guard(prog, startup_prog):
input = paddle.fluid.data(name='input', shape=dim, dtype='float32')
target = paddle.fluid.data(name='target', shape=dim, dtype='float32')
input = paddle.fluid.data(
name='input', shape=dim, dtype='float32')
target = paddle.fluid.data(
name='target', shape=dim, dtype='float32')
mse_loss = paddle.nn.functional.mse_loss(input, target, 'sum')
exe = paddle.static.Executor(place)
......@@ -259,8 +263,10 @@ class TestNNFunctionalMseLoss(unittest.TestCase):
place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda(
) else paddle.CPUPlace()
with paddle.static.program_guard(prog, startup_prog):
input = paddle.fluid.data(name='input', shape=dim, dtype='float32')
target = paddle.fluid.data(name='target', shape=dim, dtype='float32')
input = paddle.fluid.data(
name='input', shape=dim, dtype='float32')
target = paddle.fluid.data(
name='target', shape=dim, dtype='float32')
mse_loss = paddle.nn.functional.mse_loss(input, target, 'none')
exe = paddle.static.Executor(place)
......
......@@ -160,5 +160,6 @@ class TestDygraphDataLoaderWithBatchedDataset(TestDygraphDataLoader):
print("time cost", ret['time'], 'step_list', ret['step'])
return ret
if __name__ == '__main__':
unittest.main()
......@@ -97,8 +97,10 @@ class TestPixelShuffleAPI(unittest.TestCase):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.enable_static()
x_1 = paddle.fluid.data(name="x", shape=[2, 9, 4, 4], dtype="float64")
x_2 = paddle.fluid.data(name="x2", shape=[2, 4, 4, 9], dtype="float64")
x_1 = paddle.fluid.data(
name="x", shape=[2, 9, 4, 4], dtype="float64")
x_2 = paddle.fluid.data(
name="x2", shape=[2, 4, 4, 9], dtype="float64")
out_1 = F.pixel_shuffle(x_1, 3)
out_2 = F.pixel_shuffle(x_2, 3, "NHWC")
......@@ -123,8 +125,10 @@ class TestPixelShuffleAPI(unittest.TestCase):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.enable_static()
x_1 = paddle.fluid.data(name="x", shape=[2, 9, 4, 4], dtype="float64")
x_2 = paddle.fluid.data(name="x2", shape=[2, 4, 4, 9], dtype="float64")
x_1 = paddle.fluid.data(
name="x", shape=[2, 9, 4, 4], dtype="float64")
x_2 = paddle.fluid.data(
name="x2", shape=[2, 4, 4, 9], dtype="float64")
# init instance
ps_1 = paddle.nn.PixelShuffle(3)
ps_2 = paddle.nn.PixelShuffle(3, "NHWC")
......
......@@ -55,7 +55,8 @@ class TestProdOp(unittest.TestCase):
self.assertTrue(np.allclose(dy_result.numpy(), expected_result))
def run_static(self, use_gpu=False):
input = paddle.fluid.data(name='input', shape=[10, 10, 5], dtype='float32')
input = paddle.fluid.data(
name='input', shape=[10, 10, 5], dtype='float32')
result0 = paddle.prod(input)
result1 = paddle.prod(input, axis=1)
result2 = paddle.prod(input, axis=-1)
......@@ -114,7 +115,8 @@ class TestProdOpError(unittest.TestCase):
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
x = paddle.fluid.data(name='x', shape=[2, 2, 4], dtype='float32')
bool_x = paddle.fluid.data(name='bool_x', shape=[2, 2, 4], dtype='bool')
bool_x = paddle.fluid.data(
name='bool_x', shape=[2, 2, 4], dtype='bool')
# The argument x shoule be a Tensor
self.assertRaises(TypeError, paddle.prod, [1])
......
......@@ -128,15 +128,18 @@ class TestSeluAPI(unittest.TestCase):
# The input type must be Variable.
self.assertRaises(TypeError, F.selu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
x_int32 = paddle.fluid.data(
name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, F.selu, x_int32)
# The scale must be greater than 1.0
x_fp32 = paddle.fluid.data(name='x_fp32', shape=[12, 10], dtype='float32')
x_fp32 = paddle.fluid.data(
name='x_fp32', shape=[12, 10], dtype='float32')
self.assertRaises(ValueError, F.selu, x_fp32, -1.0)
# The alpha must be no less than 0
self.assertRaises(ValueError, F.selu, x_fp32, 1.6, -1.0)
# support the input dtype is float16
x_fp16 = paddle.fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
x_fp16 = paddle.fluid.data(
name='x_fp16', shape=[12, 10], dtype='float16')
F.selu(x_fp16)
......
......@@ -42,8 +42,10 @@ def test_static(place,
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(prog, startup_prog):
logit = paddle.fluid.data(name='logit', shape=logit_np.shape, dtype='float64')
label = paddle.fluid.data(name='label', shape=label_np.shape, dtype='float64')
logit = paddle.fluid.data(
name='logit', shape=logit_np.shape, dtype='float64')
label = paddle.fluid.data(
name='label', shape=label_np.shape, dtype='float64')
feed_dict = {"logit": logit_np, "label": label_np}
normalizer = None
......
......@@ -23,6 +23,7 @@ from paddle.fluid import Program, program_guard
paddle.enable_static()
class TestTransposeOp(OpTest):
def setUp(self):
self.init_op_type()
......@@ -151,6 +152,7 @@ class TestTransposeOpError(unittest.TestCase):
self.assertRaises(ValueError, test_each_elem_value_check)
class TestTransposeApi(unittest.TestCase):
def test_static_out(self):
paddle.enable_static()
......@@ -161,7 +163,8 @@ class TestTransposeApi(unittest.TestCase):
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
x_np = np.random.random([2, 3, 4]).astype("float32")
result1, result2 = exe.run(feed={"x": x_np}, fetch_list=[x_trans1, x_trans2])
result1, result2 = exe.run(feed={"x": x_np},
fetch_list=[x_trans1, x_trans2])
expected_result1 = np.transpose(x_np, [1, 0, 2])
expected_result2 = np.transpose(x_np, (2, 1, 0))
......@@ -185,6 +188,7 @@ class TestTransposeApi(unittest.TestCase):
# dygraph test
paddle.enable_static()
class TestTAPI(unittest.TestCase):
def test_out(self):
with fluid.program_guard(fluid.Program()):
......
#!/bin/bash
function version(){
echo "PaddlePaddle , compiled with"
echo " with_avx: ON"
echo " with_gpu: OFF"
echo " with_mkl: ON"
echo " with_mkldnn: "
echo " with_python: ON"
}
function ver2num() {
set -e
# convert version to number.
if [ -z "$1" ]; then # empty argument
printf "%03d%03d%03d%03d%03d" 0
else
local VERN=$(echo $1 | sed 's#v##g' | sed 's#\.# #g' \
| sed 's#a# 0 #g' | sed 's#b# 1 #g' | sed 's#rc# 2 #g')
if [ `echo $VERN | wc -w` -eq 3 ] ; then
printf "%03d%03d%03d%03d%03d" $VERN 999 999
else
printf "%03d%03d%03d%03d%03d" $VERN
fi
fi
set +e
}
function cpu_config() {
# auto set KMP_AFFINITY and OMP_DYNAMIC from Hyper Threading Status
# only when MKL enabled
if [ "ON" == "OFF" ]; then
return 0
fi
platform="`uname -s`"
ht=0
if [ $platform == "Linux" ]; then
ht=`lscpu |grep "per core"|awk -F':' '{print $2}'|xargs`
elif [ $platform == "Darwin" ]; then
if [ `sysctl -n hw.physicalcpu` -eq `sysctl -n hw.logicalcpu` ]; then
# HT is OFF
ht=1
fi
else
return 0
fi
if [ $ht -eq 1 ]; then # HT is OFF
if [ -z "$KMP_AFFINITY" ]; then
export KMP_AFFINITY="granularity=fine,compact,0,0"
fi
if [ -z "$OMP_DYNAMIC" ]; then
export OMP_DYNAMIC="FALSE"
fi
else # HT is ON
if [ -z "$KMP_AFFINITY" ]; then
export KMP_AFFINITY="granularity=fine,compact,1,0"
fi
if [ -z "$OMP_DYNAMIC" ]; then
export OMP_DYNAMIC="True"
fi
fi
}
function threads_config() {
# auto set OMP_NUM_THREADS and MKL_NUM_THREADS
# according to trainer_count and total processors
# only when MKL enabled
# auto set OPENBLAS_NUM_THREADS when do not use MKL
platform="`uname -s`"
processors=0
if [ $platform == "Linux" ]; then
processors=`grep "processor" /proc/cpuinfo|sort -u|wc -l`
elif [ $platform == "Darwin" ]; then
processors=`sysctl -n hw.logicalcpu`
else
return 0
fi
trainers=`grep -Eo 'trainer_count.[0-9]+' <<< "$@" |grep -Eo '[0-9]+'|xargs`
if [ -z $trainers ]; then
trainers=1
fi
threads=$((processors / trainers))
if [ $threads -eq 0 ]; then
threads=1
fi
if [ "ON" == "ON" ]; then
if [ -z "$OMP_NUM_THREADS" ]; then
export OMP_NUM_THREADS=$threads
fi
if [ -z "$MKL_NUM_THREADS" ]; then
export MKL_NUM_THREADS=$threads
fi
else
if [ -z "$OPENBLAS_NUM_THREADS" ]; then
export OPENBLAS_NUM_THREADS=$threads
fi
if [ $threads -gt 1 ] && [ -z "$OPENBLAS_MAIN_FREE" ]; then
export OPENBLAS_MAIN_FREE=1
fi
fi
}
PADDLE_CONF_HOME="$HOME/.config/paddle"
mkdir -p ${PADDLE_CONF_HOME}
if [ -z "${PADDLE_NO_STAT+x}" ]; then
SERVER_VER=`curl -m 5 -X POST --data content="{ \"version\": \"\" }"\
-b ${PADDLE_CONF_HOME}/paddle.cookie \
-c ${PADDLE_CONF_HOME}/paddle.cookie \
http://api.paddlepaddle.org/version 2>/dev/null`
if [ $? -eq 0 ] && [ "$(ver2num )" -lt $(ver2num $SERVER_VER) ]; then
echo "Paddle release a new version ${SERVER_VER}, you can get the install package in http://www.paddlepaddle.org"
fi
fi
PADDLE_BIN_PATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
if [ ! -z "${DEBUGGER}" ]; then
echo "Using debug command ${DEBUGGER}"
fi
CUDNN_LIB_PATH=""
if [ ! -z "${CUDNN_LIB_PATH}" ]; then
export LD_LIBRARY_PATH=${CUDNN_LIB_PATH}:${LD_LIBRARY_PATH}
fi
export PYTHONPATH=${PWD}:${PYTHONPATH}
# Check python lib installed or not.
pip --help > /dev/null
if [ $? -ne 0 ]; then
echo "pip should be installed to run paddle."
exit 1
fi
if [ "OFF" == "ON" ]; then
PADDLE_NAME="paddlepaddle-gpu"
else
PADDLE_NAME="paddlepaddle"
fi
INSTALLED_VERSION=`pip freeze 2>/dev/null | grep "^${PADDLE_NAME}==" | sed 's/.*==//g'`
if [ -z "${INSTALLED_VERSION}" ]; then
INSTALLED_VERSION="0.0.0" # not installed
fi
cat <<EOF | python -
from distutils.version import LooseVersion
import sys
if LooseVersion("${INSTALLED_VERSION}") < LooseVersion(""):
sys.exit(1)
else:
sys.exit(0)
EOF
cpu_config
# echo $KMP_AFFINITY $OMP_DYNAMIC
case "$1" in
"version")
version
;;
*)
version
;;
esac
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if [ -z ${BRANCH} ]; then
BRANCH="develop"
fi
......
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
PADDLE_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}")/../" && pwd )"
function check_sequnece_op_unitests(){
......
#!/usr/bin/env bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -ex
SYSTEM=`uname -s`
rm -f protoc-3.11.3-linux-x86_64.*
......
#!/usr/bin/env python
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import difflib
import sys
......
#!/usr/bin/env python
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import difflib
import sys
......
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
function install_gcc(){
sed -i 's#<install_gcc>#RUN apt-get update \
......
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
PADDLE_ROOT=/home
mkdir ${PADDLE_ROOT}
cd ${PADDLE_ROOT}
......
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if [ "`uname -s`" != "Linux" ]; then
echo "Current scenario only support in Linux yet!"
exit 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册