提交 54731749 编写于 作者: P phlrain

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add_some_yaml_config

...@@ -52,12 +52,12 @@ tools/__pycache__ ...@@ -52,12 +52,12 @@ tools/__pycache__
# This file is automatically generated. # This file is automatically generated.
# TODO(zhiqiang) Move this file to build directory. # TODO(zhiqiang) Move this file to build directory.
paddle/infrt/dialect/pd_ops.td paddle/infrt/dialect/pd/ir/pd_ops.td
paddle/infrt/dialect/phi/ir/phi_cpu_kernels.td paddle/infrt/dialect/phi/ir/phi_cpu_kernels.td
paddle/infrt/dialect/phi/ir/phi_gpu_kernels.td paddle/infrt/dialect/phi/ir/phi_gpu_kernels.td
tools/infrt/kernels.json tools/infrt/kernels.json
tools/infrt/kernel_signature.json tools/infrt/kernel_signature.json
paddle/infrt/dialect/pd_ops_info.h paddle/infrt/dialect/pd/common/pd_ops_info.h
.lit_test_times.txt .lit_test_times.txt
paddle/infrt/tests/dialect/Output paddle/infrt/tests/dialect/Output
paddle/infrt/tests/lit.cfg.py paddle/infrt/tests/lit.cfg.py
......
...@@ -61,6 +61,7 @@ set(PADDLE2ONNX_OPTIONAL_ARGS ...@@ -61,6 +61,7 @@ set(PADDLE2ONNX_OPTIONAL_ARGS
-DONNX_CUSTOM_PROTOC_PATH=${PROTOC_BIN_PATH} -DONNX_CUSTOM_PROTOC_PATH=${PROTOC_BIN_PATH}
-DWITH_STATIC=OFF -DWITH_STATIC=OFF
-DCMAKE_INSTALL_PREFIX=${PADDLE2ONNX_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX=${PADDLE2ONNX_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR=${PADDLE2ONNX_INSTALL_DIR}/${LIBDIR}
-DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
${EXTERNAL_OPTIONAL_ARGS} ${EXTERNAL_OPTIONAL_ARGS}
......
...@@ -4,7 +4,7 @@ if(WITH_PYTHON) ...@@ -4,7 +4,7 @@ if(WITH_PYTHON)
endif() endif()
proto_library(interceptor_message_proto SRCS interceptor_message.proto) proto_library(interceptor_message_proto SRCS interceptor_message.proto)
if(WITH_DISTRIBUTE AND WITH_PSCORE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL)) if(WITH_DISTRIBUTE AND WITH_PSCORE)
set(BRPC_DEPS brpc ssl crypto protobuf zlib leveldb snappy gflags glog) set(BRPC_DEPS brpc ssl crypto protobuf zlib leveldb snappy gflags glog)
else() else()
set(BRPC_DEPS "") set(BRPC_DEPS "")
......
...@@ -67,8 +67,7 @@ bool MessageBus::IsInit() const { return is_init_; } ...@@ -67,8 +67,7 @@ bool MessageBus::IsInit() const { return is_init_; }
MessageBus::~MessageBus() { MessageBus::~MessageBus() {
VLOG(3) << "Message bus releases resource."; VLOG(3) << "Message bus releases resource.";
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
server_.Stop(1000); server_.Stop(1000);
server_.Join(); server_.Join();
#endif #endif
...@@ -87,8 +86,7 @@ bool MessageBus::Send(int64_t dst_rank, ...@@ -87,8 +86,7 @@ bool MessageBus::Send(int64_t dst_rank,
IsInit(), true, IsInit(), true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Using message bus since it has not been initialized.")); "Using message bus since it has not been initialized."));
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
int retry_time = 0; // message bus will retry sending for 10 times int retry_time = 0; // message bus will retry sending for 10 times
while (retry_time < 10) { while (retry_time < 10) {
++retry_time; ++retry_time;
...@@ -173,8 +171,7 @@ void MessageBus::ListenPort() { ...@@ -173,8 +171,7 @@ void MessageBus::ListenPort() {
LOG(INFO) << "No need listen to port since training on single card."; LOG(INFO) << "No need listen to port since training on single card.";
return; return;
} }
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
// function keep listen the port and handle the message // function keep listen the port and handle the message
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
server_.AddService(&message_service_, brpc::SERVER_DOESNT_OWN_SERVICE), 0, server_.AddService(&message_service_, brpc::SERVER_DOESNT_OWN_SERVICE), 0,
...@@ -203,8 +200,7 @@ void MessageBus::ListenPort() { ...@@ -203,8 +200,7 @@ void MessageBus::ListenPort() {
#endif #endif
} }
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
bool MessageBus::SendInterRank(int64_t dst_rank, bool MessageBus::SendInterRank(int64_t dst_rank,
const InterceptorMessage& interceptor_message) { const InterceptorMessage& interceptor_message) {
const auto& dst_addr = GetAddr(dst_rank); const auto& dst_addr = GetAddr(dst_rank);
......
...@@ -20,8 +20,7 @@ ...@@ -20,8 +20,7 @@
#include <thread> #include <thread>
#include <unordered_map> #include <unordered_map>
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
#include "brpc/channel.h" #include "brpc/channel.h"
#include "brpc/server.h" #include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/message_service.h" #include "paddle/fluid/distributed/fleet_executor/message_service.h"
...@@ -64,8 +63,7 @@ class MessageBus final { ...@@ -64,8 +63,7 @@ class MessageBus final {
const std::string& GetAddr(int64_t rank) const; const std::string& GetAddr(int64_t rank) const;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
// send the message inter rank (dst is different rank with src) // send the message inter rank (dst is different rank with src)
bool SendInterRank(int64_t dst_rank, bool SendInterRank(int64_t dst_rank,
const InterceptorMessage& interceptor_message); const InterceptorMessage& interceptor_message);
...@@ -81,8 +79,7 @@ class MessageBus final { ...@@ -81,8 +79,7 @@ class MessageBus final {
// the ip needs to be listened // the ip needs to be listened
std::string addr_; std::string addr_;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
MessageServiceImpl message_service_; MessageServiceImpl message_service_;
// brpc server // brpc server
brpc::Server server_; brpc::Server server_;
......
...@@ -11,8 +11,7 @@ ...@@ -11,8 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/distributed/fleet_executor/message_service.h" #include "paddle/fluid/distributed/fleet_executor/message_service.h"
#include "brpc/server.h" #include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/global.h" #include "paddle/fluid/distributed/fleet_executor/global.h"
......
...@@ -11,8 +11,7 @@ ...@@ -11,8 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
#pragma once #pragma once
#include "brpc/server.h" #include "brpc/server.h"
......
...@@ -115,6 +115,7 @@ message TableParameter { ...@@ -115,6 +115,7 @@ message TableParameter {
optional CommonAccessorParameter common = 6; optional CommonAccessorParameter common = 6;
optional TableType type = 7; optional TableType type = 7;
optional bool compress_in_save = 8 [ default = false ]; optional bool compress_in_save = 8 [ default = false ];
optional GraphParameter graph_parameter = 9;
} }
message TableAccessorParameter { message TableAccessorParameter {
...@@ -211,3 +212,25 @@ message SparseAdamSGDParameter { // SparseAdamSGDRule ...@@ -211,3 +212,25 @@ message SparseAdamSGDParameter { // SparseAdamSGDRule
optional double ada_epsilon = 5 [ default = 1e-08 ]; optional double ada_epsilon = 5 [ default = 1e-08 ];
repeated float weight_bounds = 6; repeated float weight_bounds = 6;
} }
message GraphParameter {
optional int32 task_pool_size = 1 [ default = 24 ];
optional bool gpups_mode = 2 [ default = false ];
optional string gpups_graph_sample_class = 3
[ default = "CompleteGraphSampler" ];
optional string gpups_graph_sample_args = 4 [ default = "" ];
optional bool use_cache = 5 [ default = true ];
optional float cache_ratio = 6 [ default = 0.3 ];
optional int32 cache_ttl = 7 [ default = 5 ];
optional GraphFeature graph_feature = 8;
optional string table_name = 9 [ default = "" ];
optional string table_type = 10 [ default = "" ];
optional int32 gpups_mode_shard_num = 11 [ default = 127 ];
optional int32 gpu_num = 12 [ default = 1 ];
}
message GraphFeature {
repeated string name = 1;
repeated string dtype = 2;
repeated int32 shape = 3;
}
\ No newline at end of file
...@@ -44,7 +44,7 @@ void GraphPsService_Stub::service( ...@@ -44,7 +44,7 @@ void GraphPsService_Stub::service(
} }
} }
int GraphBrpcClient::get_server_index_by_id(uint64_t id) { int GraphBrpcClient::get_server_index_by_id(int64_t id) {
int shard_num = get_shard_num(); int shard_num = get_shard_num();
int shard_per_server = shard_num % server_size == 0 int shard_per_server = shard_num % server_size == 0
? shard_num / server_size ? shard_num / server_size
...@@ -53,7 +53,7 @@ int GraphBrpcClient::get_server_index_by_id(uint64_t id) { ...@@ -53,7 +53,7 @@ int GraphBrpcClient::get_server_index_by_id(uint64_t id) {
} }
std::future<int32_t> GraphBrpcClient::get_node_feat( std::future<int32_t> GraphBrpcClient::get_node_feat(
const uint32_t &table_id, const std::vector<uint64_t> &node_ids, const uint32_t &table_id, const std::vector<int64_t> &node_ids,
const std::vector<std::string> &feature_names, const std::vector<std::string> &feature_names,
std::vector<std::vector<std::string>> &res) { std::vector<std::vector<std::string>> &res) {
std::vector<int> request2server; std::vector<int> request2server;
...@@ -66,7 +66,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat( ...@@ -66,7 +66,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat(
} }
} }
size_t request_call_num = request2server.size(); size_t request_call_num = request2server.size();
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num); std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num); std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) { for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]); int server_index = get_server_index_by_id(node_ids[query_idx]);
...@@ -129,7 +129,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat( ...@@ -129,7 +129,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat(
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(), ->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(uint64_t) * node_num); sizeof(int64_t) * node_num);
std::string joint_feature_name = std::string joint_feature_name =
paddle::string::join_strings(feature_names, '\t'); paddle::string::join_strings(feature_names, '\t');
closure->request(request_idx) closure->request(request_idx)
...@@ -179,9 +179,9 @@ std::future<int32_t> GraphBrpcClient::clear_nodes(uint32_t table_id) { ...@@ -179,9 +179,9 @@ std::future<int32_t> GraphBrpcClient::clear_nodes(uint32_t table_id) {
return fut; return fut;
} }
std::future<int32_t> GraphBrpcClient::add_graph_node( std::future<int32_t> GraphBrpcClient::add_graph_node(
uint32_t table_id, std::vector<uint64_t> &node_id_list, uint32_t table_id, std::vector<int64_t> &node_id_list,
std::vector<bool> &is_weighted_list) { std::vector<bool> &is_weighted_list) {
std::vector<std::vector<uint64_t>> request_bucket; std::vector<std::vector<int64_t>> request_bucket;
std::vector<std::vector<bool>> is_weighted_bucket; std::vector<std::vector<bool>> is_weighted_bucket;
bool add_weight = is_weighted_list.size() > 0; bool add_weight = is_weighted_list.size() > 0;
std::vector<int> server_index_arr; std::vector<int> server_index_arr;
...@@ -191,7 +191,7 @@ std::future<int32_t> GraphBrpcClient::add_graph_node( ...@@ -191,7 +191,7 @@ std::future<int32_t> GraphBrpcClient::add_graph_node(
if (index_mapping[server_index] == -1) { if (index_mapping[server_index] == -1) {
index_mapping[server_index] = request_bucket.size(); index_mapping[server_index] = request_bucket.size();
server_index_arr.push_back(server_index); server_index_arr.push_back(server_index);
request_bucket.push_back(std::vector<uint64_t>()); request_bucket.push_back(std::vector<int64_t>());
if (add_weight) is_weighted_bucket.push_back(std::vector<bool>()); if (add_weight) is_weighted_bucket.push_back(std::vector<bool>());
} }
request_bucket[index_mapping[server_index]].push_back( request_bucket[index_mapping[server_index]].push_back(
...@@ -229,7 +229,7 @@ std::future<int32_t> GraphBrpcClient::add_graph_node( ...@@ -229,7 +229,7 @@ std::future<int32_t> GraphBrpcClient::add_graph_node(
size_t node_num = request_bucket[request_idx].size(); size_t node_num = request_bucket[request_idx].size();
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)request_bucket[request_idx].data(), ->add_params((char *)request_bucket[request_idx].data(),
sizeof(uint64_t) * node_num); sizeof(int64_t) * node_num);
if (add_weight) { if (add_weight) {
bool weighted[is_weighted_bucket[request_idx].size() + 1]; bool weighted[is_weighted_bucket[request_idx].size() + 1];
for (size_t j = 0; j < is_weighted_bucket[request_idx].size(); j++) for (size_t j = 0; j < is_weighted_bucket[request_idx].size(); j++)
...@@ -248,8 +248,8 @@ std::future<int32_t> GraphBrpcClient::add_graph_node( ...@@ -248,8 +248,8 @@ std::future<int32_t> GraphBrpcClient::add_graph_node(
return fut; return fut;
} }
std::future<int32_t> GraphBrpcClient::remove_graph_node( std::future<int32_t> GraphBrpcClient::remove_graph_node(
uint32_t table_id, std::vector<uint64_t> &node_id_list) { uint32_t table_id, std::vector<int64_t> &node_id_list) {
std::vector<std::vector<uint64_t>> request_bucket; std::vector<std::vector<int64_t>> request_bucket;
std::vector<int> server_index_arr; std::vector<int> server_index_arr;
std::vector<int> index_mapping(server_size, -1); std::vector<int> index_mapping(server_size, -1);
for (size_t query_idx = 0; query_idx < node_id_list.size(); ++query_idx) { for (size_t query_idx = 0; query_idx < node_id_list.size(); ++query_idx) {
...@@ -257,7 +257,7 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node( ...@@ -257,7 +257,7 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
if (index_mapping[server_index] == -1) { if (index_mapping[server_index] == -1) {
index_mapping[server_index] = request_bucket.size(); index_mapping[server_index] = request_bucket.size();
server_index_arr.push_back(server_index); server_index_arr.push_back(server_index);
request_bucket.push_back(std::vector<uint64_t>()); request_bucket.push_back(std::vector<int64_t>());
} }
request_bucket[index_mapping[server_index]].push_back( request_bucket[index_mapping[server_index]].push_back(
node_id_list[query_idx]); node_id_list[query_idx]);
...@@ -291,7 +291,7 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node( ...@@ -291,7 +291,7 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)request_bucket[request_idx].data(), ->add_params((char *)request_bucket[request_idx].data(),
sizeof(uint64_t) * node_num); sizeof(int64_t) * node_num);
// PsService_Stub rpc_stub(get_cmd_channel(server_index)); // PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub = GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index)); getServiceStub(get_cmd_channel(server_index));
...@@ -303,9 +303,9 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node( ...@@ -303,9 +303,9 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
} }
// char* &buffer,int &actual_size // char* &buffer,int &actual_size
std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size, uint32_t table_id, std::vector<int64_t> node_ids, int sample_size,
// std::vector<std::vector<std::pair<uint64_t, float>>> &res, // std::vector<std::vector<std::pair<int64_t, float>>> &res,
std::vector<std::vector<uint64_t>> &res, std::vector<std::vector<int64_t>> &res,
std::vector<std::vector<float>> &res_weight, bool need_weight, std::vector<std::vector<float>> &res_weight, bool need_weight,
int server_index) { int server_index) {
if (server_index != -1) { if (server_index != -1) {
...@@ -337,7 +337,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -337,7 +337,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
int start = 0; int start = 0;
while (start < actual_size) { while (start < actual_size) {
res[node_idx].emplace_back( res[node_idx].emplace_back(
*(uint64_t *)(node_buffer + offset + start)); *(int64_t *)(node_buffer + offset + start));
start += GraphNode::id_size; start += GraphNode::id_size;
if (need_weight) { if (need_weight) {
res_weight[node_idx].emplace_back( res_weight[node_idx].emplace_back(
...@@ -358,7 +358,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -358,7 +358,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
closure->request(0)->set_table_id(table_id); closure->request(0)->set_table_id(table_id);
closure->request(0)->set_client_id(_client_id); closure->request(0)->set_client_id(_client_id);
closure->request(0)->add_params((char *)node_ids.data(), closure->request(0)->add_params((char *)node_ids.data(),
sizeof(uint64_t) * node_ids.size()); sizeof(int64_t) * node_ids.size());
closure->request(0)->add_params((char *)&sample_size, sizeof(int)); closure->request(0)->add_params((char *)&sample_size, sizeof(int));
closure->request(0)->add_params((char *)&need_weight, sizeof(bool)); closure->request(0)->add_params((char *)&need_weight, sizeof(bool));
; ;
...@@ -380,14 +380,14 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -380,14 +380,14 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
server2request[server_index] = request2server.size(); server2request[server_index] = request2server.size();
request2server.push_back(server_index); request2server.push_back(server_index);
} }
// res.push_back(std::vector<std::pair<uint64_t, float>>()); // res.push_back(std::vector<std::pair<int64_t, float>>());
res.push_back({}); res.push_back({});
if (need_weight) { if (need_weight) {
res_weight.push_back({}); res_weight.push_back({});
} }
} }
size_t request_call_num = request2server.size(); size_t request_call_num = request2server.size();
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num); std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num); std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) { for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]); int server_index = get_server_index_by_id(node_ids[query_idx]);
...@@ -428,7 +428,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -428,7 +428,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
int start = 0; int start = 0;
while (start < actual_size) { while (start < actual_size) {
res[query_idx].emplace_back( res[query_idx].emplace_back(
*(uint64_t *)(node_buffer + offset + start)); *(int64_t *)(node_buffer + offset + start));
start += GraphNode::id_size; start += GraphNode::id_size;
if (need_weight) { if (need_weight) {
res_weight[query_idx].emplace_back( res_weight[query_idx].emplace_back(
...@@ -459,7 +459,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -459,7 +459,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(), ->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(uint64_t) * node_num); sizeof(int64_t) * node_num);
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)&sample_size, sizeof(int)); ->add_params((char *)&sample_size, sizeof(int));
closure->request(request_idx) closure->request(request_idx)
...@@ -476,7 +476,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -476,7 +476,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
} }
std::future<int32_t> GraphBrpcClient::random_sample_nodes( std::future<int32_t> GraphBrpcClient::random_sample_nodes(
uint32_t table_id, int server_index, int sample_size, uint32_t table_id, int server_index, int sample_size,
std::vector<uint64_t> &ids) { std::vector<int64_t> &ids) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) { DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
int ret = 0; int ret = 0;
auto *closure = (DownpourBrpcClosure *)done; auto *closure = (DownpourBrpcClosure *)done;
...@@ -490,7 +490,7 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes( ...@@ -490,7 +490,7 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes(
auto size = io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size); auto size = io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
int index = 0; int index = 0;
while (index < bytes_size) { while (index < bytes_size) {
ids.push_back(*(uint64_t *)(buffer + index)); ids.push_back(*(int64_t *)(buffer + index));
index += GraphNode::id_size; index += GraphNode::id_size;
} }
delete[] buffer; delete[] buffer;
...@@ -633,7 +633,7 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list( ...@@ -633,7 +633,7 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
} }
std::future<int32_t> GraphBrpcClient::set_node_feat( std::future<int32_t> GraphBrpcClient::set_node_feat(
const uint32_t &table_id, const std::vector<uint64_t> &node_ids, const uint32_t &table_id, const std::vector<int64_t> &node_ids,
const std::vector<std::string> &feature_names, const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &features) { const std::vector<std::vector<std::string>> &features) {
std::vector<int> request2server; std::vector<int> request2server;
...@@ -646,7 +646,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat( ...@@ -646,7 +646,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
} }
} }
size_t request_call_num = request2server.size(); size_t request_call_num = request2server.size();
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num); std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num); std::vector<std::vector<int>> query_idx_buckets(request_call_num);
std::vector<std::vector<std::vector<std::string>>> features_idx_buckets( std::vector<std::vector<std::vector<std::string>>> features_idx_buckets(
request_call_num); request_call_num);
...@@ -696,7 +696,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat( ...@@ -696,7 +696,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(), ->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(uint64_t) * node_num); sizeof(int64_t) * node_num);
std::string joint_feature_name = std::string joint_feature_name =
paddle::string::join_strings(feature_names, '\t'); paddle::string::join_strings(feature_names, '\t');
closure->request(request_idx) closure->request(request_idx)
......
...@@ -63,8 +63,8 @@ class GraphBrpcClient : public BrpcPsClient { ...@@ -63,8 +63,8 @@ class GraphBrpcClient : public BrpcPsClient {
virtual ~GraphBrpcClient() {} virtual ~GraphBrpcClient() {}
// given a batch of nodes, sample graph_neighbors for each of them // given a batch of nodes, sample graph_neighbors for each of them
virtual std::future<int32_t> batch_sample_neighbors( virtual std::future<int32_t> batch_sample_neighbors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size, uint32_t table_id, std::vector<int64_t> node_ids, int sample_size,
std::vector<std::vector<uint64_t>>& res, std::vector<std::vector<int64_t>>& res,
std::vector<std::vector<float>>& res_weight, bool need_weight, std::vector<std::vector<float>>& res_weight, bool need_weight,
int server_index = -1); int server_index = -1);
...@@ -75,20 +75,20 @@ class GraphBrpcClient : public BrpcPsClient { ...@@ -75,20 +75,20 @@ class GraphBrpcClient : public BrpcPsClient {
virtual std::future<int32_t> random_sample_nodes(uint32_t table_id, virtual std::future<int32_t> random_sample_nodes(uint32_t table_id,
int server_index, int server_index,
int sample_size, int sample_size,
std::vector<uint64_t>& ids); std::vector<int64_t>& ids);
virtual std::future<int32_t> get_node_feat( virtual std::future<int32_t> get_node_feat(
const uint32_t& table_id, const std::vector<uint64_t>& node_ids, const uint32_t& table_id, const std::vector<int64_t>& node_ids,
const std::vector<std::string>& feature_names, const std::vector<std::string>& feature_names,
std::vector<std::vector<std::string>>& res); std::vector<std::vector<std::string>>& res);
virtual std::future<int32_t> set_node_feat( virtual std::future<int32_t> set_node_feat(
const uint32_t& table_id, const std::vector<uint64_t>& node_ids, const uint32_t& table_id, const std::vector<int64_t>& node_ids,
const std::vector<std::string>& feature_names, const std::vector<std::string>& feature_names,
const std::vector<std::vector<std::string>>& features); const std::vector<std::vector<std::string>>& features);
virtual std::future<int32_t> clear_nodes(uint32_t table_id); virtual std::future<int32_t> clear_nodes(uint32_t table_id);
virtual std::future<int32_t> add_graph_node( virtual std::future<int32_t> add_graph_node(
uint32_t table_id, std::vector<uint64_t>& node_id_list, uint32_t table_id, std::vector<int64_t>& node_id_list,
std::vector<bool>& is_weighted_list); std::vector<bool>& is_weighted_list);
virtual std::future<int32_t> use_neighbors_sample_cache(uint32_t table_id, virtual std::future<int32_t> use_neighbors_sample_cache(uint32_t table_id,
size_t size_limit, size_t size_limit,
...@@ -96,11 +96,11 @@ class GraphBrpcClient : public BrpcPsClient { ...@@ -96,11 +96,11 @@ class GraphBrpcClient : public BrpcPsClient {
virtual std::future<int32_t> load_graph_split_config(uint32_t table_id, virtual std::future<int32_t> load_graph_split_config(uint32_t table_id,
std::string path); std::string path);
virtual std::future<int32_t> remove_graph_node( virtual std::future<int32_t> remove_graph_node(
uint32_t table_id, std::vector<uint64_t>& node_id_list); uint32_t table_id, std::vector<int64_t>& node_id_list);
virtual int32_t initialize(); virtual int32_t initialize();
int get_shard_num() { return shard_num; } int get_shard_num() { return shard_num; }
void set_shard_num(int shard_num) { this->shard_num = shard_num; } void set_shard_num(int shard_num) { this->shard_num = shard_num; }
int get_server_index_by_id(uint64_t id); int get_server_index_by_id(int64_t id);
void set_local_channel(int index) { void set_local_channel(int index) {
this->local_channel = get_cmd_channel(index); this->local_channel = get_cmd_channel(index);
} }
......
...@@ -140,9 +140,9 @@ int32_t GraphBrpcService::add_graph_node(Table *table, ...@@ -140,9 +140,9 @@ int32_t GraphBrpcService::add_graph_node(Table *table,
return 0; return 0;
} }
size_t node_num = request.params(0).size() / sizeof(uint64_t); size_t node_num = request.params(0).size() / sizeof(int64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str()); int64_t *node_data = (int64_t *)(request.params(0).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num); std::vector<int64_t> node_ids(node_data, node_data + node_num);
std::vector<bool> is_weighted_list; std::vector<bool> is_weighted_list;
if (request.params_size() == 2) { if (request.params_size() == 2) {
size_t weight_list_size = request.params(1).size() / sizeof(bool); size_t weight_list_size = request.params(1).size() / sizeof(bool);
...@@ -165,9 +165,9 @@ int32_t GraphBrpcService::remove_graph_node(Table *table, ...@@ -165,9 +165,9 @@ int32_t GraphBrpcService::remove_graph_node(Table *table,
"graph_get_node_feat request requires at least 1 argument"); "graph_get_node_feat request requires at least 1 argument");
return 0; return 0;
} }
size_t node_num = request.params(0).size() / sizeof(uint64_t); size_t node_num = request.params(0).size() / sizeof(int64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str()); int64_t *node_data = (int64_t *)(request.params(0).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num); std::vector<int64_t> node_ids(node_data, node_data + node_num);
((GraphTable *)table)->remove_graph_node(node_ids); ((GraphTable *)table)->remove_graph_node(node_ids);
return 0; return 0;
...@@ -386,9 +386,9 @@ int32_t GraphBrpcService::graph_random_sample_neighbors( ...@@ -386,9 +386,9 @@ int32_t GraphBrpcService::graph_random_sample_neighbors(
"graph_random_sample_neighbors request requires at least 3 arguments"); "graph_random_sample_neighbors request requires at least 3 arguments");
return 0; return 0;
} }
size_t node_num = request.params(0).size() / sizeof(uint64_t); size_t node_num = request.params(0).size() / sizeof(int64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str()); int64_t *node_data = (int64_t *)(request.params(0).c_str());
int sample_size = *(uint64_t *)(request.params(1).c_str()); int sample_size = *(int64_t *)(request.params(1).c_str());
bool need_weight = *(bool *)(request.params(2).c_str()); bool need_weight = *(bool *)(request.params(2).c_str());
std::vector<std::shared_ptr<char>> buffers(node_num); std::vector<std::shared_ptr<char>> buffers(node_num);
std::vector<int> actual_sizes(node_num, 0); std::vector<int> actual_sizes(node_num, 0);
...@@ -407,7 +407,7 @@ int32_t GraphBrpcService::graph_random_sample_neighbors( ...@@ -407,7 +407,7 @@ int32_t GraphBrpcService::graph_random_sample_neighbors(
int32_t GraphBrpcService::graph_random_sample_nodes( int32_t GraphBrpcService::graph_random_sample_nodes(
Table *table, const PsRequestMessage &request, PsResponseMessage &response, Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
size_t size = *(uint64_t *)(request.params(0).c_str()); size_t size = *(int64_t *)(request.params(0).c_str());
std::unique_ptr<char[]> buffer; std::unique_ptr<char[]> buffer;
int actual_size; int actual_size;
if (((GraphTable *)table)->random_sample_nodes(size, buffer, actual_size) == if (((GraphTable *)table)->random_sample_nodes(size, buffer, actual_size) ==
...@@ -430,9 +430,9 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table, ...@@ -430,9 +430,9 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,
"graph_get_node_feat request requires at least 2 arguments"); "graph_get_node_feat request requires at least 2 arguments");
return 0; return 0;
} }
size_t node_num = request.params(0).size() / sizeof(uint64_t); size_t node_num = request.params(0).size() / sizeof(int64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str()); int64_t *node_data = (int64_t *)(request.params(0).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num); std::vector<int64_t> node_ids(node_data, node_data + node_num);
std::vector<std::string> feature_names = std::vector<std::string> feature_names =
paddle::string::split_string<std::string>(request.params(1), "\t"); paddle::string::split_string<std::string>(request.params(1), "\t");
...@@ -464,16 +464,16 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( ...@@ -464,16 +464,16 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
"at least 3 arguments"); "at least 3 arguments");
return 0; return 0;
} }
size_t node_num = request.params(0).size() / sizeof(uint64_t), size_t node_num = request.params(0).size() / sizeof(int64_t),
size_of_size_t = sizeof(size_t); size_of_size_t = sizeof(size_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str()); int64_t *node_data = (int64_t *)(request.params(0).c_str());
int sample_size = *(uint64_t *)(request.params(1).c_str()); int sample_size = *(int64_t *)(request.params(1).c_str());
bool need_weight = *(uint64_t *)(request.params(2).c_str()); bool need_weight = *(int64_t *)(request.params(2).c_str());
// std::vector<uint64_t> res = ((GraphTable // std::vector<int64_t> res = ((GraphTable
// *)table).filter_out_non_exist_nodes(node_data, sample_size); // *)table).filter_out_non_exist_nodes(node_data, sample_size);
std::vector<int> request2server; std::vector<int> request2server;
std::vector<int> server2request(server_size, -1); std::vector<int> server2request(server_size, -1);
std::vector<uint64_t> local_id; std::vector<int64_t> local_id;
std::vector<int> local_query_idx; std::vector<int> local_query_idx;
size_t rank = get_rank(); size_t rank = get_rank();
for (int query_idx = 0; query_idx < node_num; ++query_idx) { for (int query_idx = 0; query_idx < node_num; ++query_idx) {
...@@ -496,7 +496,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( ...@@ -496,7 +496,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
std::vector<std::shared_ptr<char>> local_buffers; std::vector<std::shared_ptr<char>> local_buffers;
std::vector<int> local_actual_sizes; std::vector<int> local_actual_sizes;
std::vector<size_t> seq; std::vector<size_t> seq;
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num); std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num); std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (int query_idx = 0; query_idx < node_num; ++query_idx) { for (int query_idx = 0; query_idx < node_num; ++query_idx) {
int server_index = int server_index =
...@@ -583,7 +583,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( ...@@ -583,7 +583,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(), ->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(uint64_t) * node_num); sizeof(int64_t) * node_num);
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)&sample_size, sizeof(int)); ->add_params((char *)&sample_size, sizeof(int));
closure->request(request_idx) closure->request(request_idx)
...@@ -618,9 +618,9 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table, ...@@ -618,9 +618,9 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table,
"graph_set_node_feat request requires at least 3 arguments"); "graph_set_node_feat request requires at least 3 arguments");
return 0; return 0;
} }
size_t node_num = request.params(0).size() / sizeof(uint64_t); size_t node_num = request.params(0).size() / sizeof(int64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str()); int64_t *node_data = (int64_t *)(request.params(0).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num); std::vector<int64_t> node_ids(node_data, node_data + node_num);
std::vector<std::string> feature_names = std::vector<std::string> feature_names =
paddle::string::split_string<std::string>(request.params(1), "\t"); paddle::string::split_string<std::string>(request.params(1), "\t");
......
...@@ -44,9 +44,9 @@ void GraphPyService::add_table_feat_conf(std::string table_name, ...@@ -44,9 +44,9 @@ void GraphPyService::add_table_feat_conf(std::string table_name,
} }
} }
void add_graph_node(std::vector<uint64_t> node_ids, void add_graph_node(std::vector<int64_t> node_ids,
std::vector<bool> weight_list) {} std::vector<bool> weight_list) {}
void remove_graph_node(std::vector<uint64_t> node_ids) {} void remove_graph_node(std::vector<int64_t> node_ids) {}
void GraphPyService::set_up(std::string ips_str, int shard_num, void GraphPyService::set_up(std::string ips_str, int shard_num,
std::vector<std::string> node_types, std::vector<std::string> node_types,
std::vector<std::string> edge_types) { std::vector<std::string> edge_types) {
...@@ -260,7 +260,7 @@ void GraphPyClient::clear_nodes(std::string name) { ...@@ -260,7 +260,7 @@ void GraphPyClient::clear_nodes(std::string name) {
} }
void GraphPyClient::add_graph_node(std::string name, void GraphPyClient::add_graph_node(std::string name,
std::vector<uint64_t>& node_ids, std::vector<int64_t>& node_ids,
std::vector<bool>& weight_list) { std::vector<bool>& weight_list) {
if (this->table_id_map.count(name)) { if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name]; uint32_t table_id = this->table_id_map[name];
...@@ -271,7 +271,7 @@ void GraphPyClient::add_graph_node(std::string name, ...@@ -271,7 +271,7 @@ void GraphPyClient::add_graph_node(std::string name,
} }
void GraphPyClient::remove_graph_node(std::string name, void GraphPyClient::remove_graph_node(std::string name,
std::vector<uint64_t>& node_ids) { std::vector<int64_t>& node_ids) {
if (this->table_id_map.count(name)) { if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name]; uint32_t table_id = this->table_id_map[name];
auto status = get_ps_client()->remove_graph_node(table_id, node_ids); auto status = get_ps_client()->remove_graph_node(table_id, node_ids);
...@@ -290,13 +290,12 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) { ...@@ -290,13 +290,12 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) {
} }
} }
std::pair<std::vector<std::vector<uint64_t>>, std::vector<float>> std::pair<std::vector<std::vector<int64_t>>, std::vector<float>>
GraphPyClient::batch_sample_neighbors(std::string name, GraphPyClient::batch_sample_neighbors(std::string name,
std::vector<uint64_t> node_ids, std::vector<int64_t> node_ids,
int sample_size, bool return_weight, int sample_size, bool return_weight,
bool return_edges) { bool return_edges) {
// std::vector<std::vector<std::pair<uint64_t, float>>> v; std::vector<std::vector<int64_t>> v;
std::vector<std::vector<uint64_t>> v;
std::vector<std::vector<float>> v1; std::vector<std::vector<float>> v1;
if (this->table_id_map.count(name)) { if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name]; uint32_t table_id = this->table_id_map[name];
...@@ -309,7 +308,7 @@ GraphPyClient::batch_sample_neighbors(std::string name, ...@@ -309,7 +308,7 @@ GraphPyClient::batch_sample_neighbors(std::string name,
// res.first[1]: slice index // res.first[1]: slice index
// res.first[2]: src nodes // res.first[2]: src nodes
// res.second: edges weight // res.second: edges weight
std::pair<std::vector<std::vector<uint64_t>>, std::vector<float>> res; std::pair<std::vector<std::vector<int64_t>>, std::vector<float>> res;
res.first.push_back({}); res.first.push_back({});
res.first.push_back({}); res.first.push_back({});
if (return_edges) res.first.push_back({}); if (return_edges) res.first.push_back({});
...@@ -342,10 +341,10 @@ void GraphPyClient::use_neighbors_sample_cache(std::string name, ...@@ -342,10 +341,10 @@ void GraphPyClient::use_neighbors_sample_cache(std::string name,
status.wait(); status.wait();
} }
} }
std::vector<uint64_t> GraphPyClient::random_sample_nodes(std::string name, std::vector<int64_t> GraphPyClient::random_sample_nodes(std::string name,
int server_index, int server_index,
int sample_size) { int sample_size) {
std::vector<uint64_t> v; std::vector<int64_t> v;
if (this->table_id_map.count(name)) { if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name]; uint32_t table_id = this->table_id_map[name];
auto status = auto status =
...@@ -357,7 +356,7 @@ std::vector<uint64_t> GraphPyClient::random_sample_nodes(std::string name, ...@@ -357,7 +356,7 @@ std::vector<uint64_t> GraphPyClient::random_sample_nodes(std::string name,
// (name, dtype, ndarray) // (name, dtype, ndarray)
std::vector<std::vector<std::string>> GraphPyClient::get_node_feat( std::vector<std::vector<std::string>> GraphPyClient::get_node_feat(
std::string node_type, std::vector<uint64_t> node_ids, std::string node_type, std::vector<int64_t> node_ids,
std::vector<std::string> feature_names) { std::vector<std::string> feature_names) {
std::vector<std::vector<std::string>> v( std::vector<std::vector<std::string>> v(
feature_names.size(), std::vector<std::string>(node_ids.size())); feature_names.size(), std::vector<std::string>(node_ids.size()));
...@@ -371,7 +370,7 @@ std::vector<std::vector<std::string>> GraphPyClient::get_node_feat( ...@@ -371,7 +370,7 @@ std::vector<std::vector<std::string>> GraphPyClient::get_node_feat(
} }
void GraphPyClient::set_node_feat( void GraphPyClient::set_node_feat(
std::string node_type, std::vector<uint64_t> node_ids, std::string node_type, std::vector<int64_t> node_ids,
std::vector<std::string> feature_names, std::vector<std::string> feature_names,
const std::vector<std::vector<std::string>> features) { const std::vector<std::vector<std::string>> features) {
if (this->table_id_map.count(node_type)) { if (this->table_id_map.count(node_type)) {
......
...@@ -70,18 +70,34 @@ class GraphPyService { ...@@ -70,18 +70,34 @@ class GraphPyService {
::paddle::distributed::TableAccessorParameter* accessor_proto = ::paddle::distributed::TableAccessorParameter* accessor_proto =
sparse_table_proto->mutable_accessor(); sparse_table_proto->mutable_accessor();
::paddle::distributed::CommonAccessorParameter* common_proto = // ::paddle::distributed::CommonAccessorParameter* common_proto =
sparse_table_proto->mutable_common(); // sparse_table_proto->mutable_common();
::paddle::distributed::GraphParameter* graph_proto =
sparse_table_proto->mutable_graph_parameter();
::paddle::distributed::GraphFeature* graph_feature =
graph_proto->mutable_graph_feature();
graph_proto->set_task_pool_size(24);
graph_proto->set_table_name(table_name);
graph_proto->set_table_type(table_type);
graph_proto->set_use_cache(false);
// Set GraphTable Parameter // Set GraphTable Parameter
common_proto->set_table_name(table_name); // common_proto->set_table_name(table_name);
common_proto->set_name(table_type); // common_proto->set_name(table_type);
// for (size_t i = 0; i < feat_name.size(); i++) {
// common_proto->add_params(feat_dtype[i]);
// common_proto->add_dims(feat_shape[i]);
// common_proto->add_attributes(feat_name[i]);
// }
for (size_t i = 0; i < feat_name.size(); i++) { for (size_t i = 0; i < feat_name.size(); i++) {
common_proto->add_params(feat_dtype[i]); graph_feature->add_dtype(feat_dtype[i]);
common_proto->add_dims(feat_shape[i]); graph_feature->add_shape(feat_shape[i]);
common_proto->add_attributes(feat_name[i]); graph_feature->add_name(feat_name[i]);
} }
accessor_proto->set_accessor_class("CommMergeAccessor"); accessor_proto->set_accessor_class("CommMergeAccessor");
} }
...@@ -143,24 +159,24 @@ class GraphPyClient : public GraphPyService { ...@@ -143,24 +159,24 @@ class GraphPyClient : public GraphPyService {
void load_edge_file(std::string name, std::string filepath, bool reverse); void load_edge_file(std::string name, std::string filepath, bool reverse);
void load_node_file(std::string name, std::string filepath); void load_node_file(std::string name, std::string filepath);
void clear_nodes(std::string name); void clear_nodes(std::string name);
void add_graph_node(std::string name, std::vector<uint64_t>& node_ids, void add_graph_node(std::string name, std::vector<int64_t>& node_ids,
std::vector<bool>& weight_list); std::vector<bool>& weight_list);
void remove_graph_node(std::string name, std::vector<uint64_t>& node_ids); void remove_graph_node(std::string name, std::vector<int64_t>& node_ids);
int get_client_id() { return client_id; } int get_client_id() { return client_id; }
void set_client_id(int client_id) { this->client_id = client_id; } void set_client_id(int client_id) { this->client_id = client_id; }
void start_client(); void start_client();
std::pair<std::vector<std::vector<uint64_t>>, std::vector<float>> std::pair<std::vector<std::vector<int64_t>>, std::vector<float>>
batch_sample_neighbors(std::string name, std::vector<uint64_t> node_ids, batch_sample_neighbors(std::string name, std::vector<int64_t> node_ids,
int sample_size, bool return_weight, int sample_size, bool return_weight,
bool return_edges); bool return_edges);
std::vector<uint64_t> random_sample_nodes(std::string name, int server_index, std::vector<int64_t> random_sample_nodes(std::string name, int server_index,
int sample_size); int sample_size);
std::vector<std::vector<std::string>> get_node_feat( std::vector<std::vector<std::string>> get_node_feat(
std::string node_type, std::vector<uint64_t> node_ids, std::string node_type, std::vector<int64_t> node_ids,
std::vector<std::string> feature_names); std::vector<std::string> feature_names);
void use_neighbors_sample_cache(std::string name, size_t total_size_limit, void use_neighbors_sample_cache(std::string name, size_t total_size_limit,
size_t ttl); size_t ttl);
void set_node_feat(std::string node_type, std::vector<uint64_t> node_ids, void set_node_feat(std::string node_type, std::vector<int64_t> node_ids,
std::vector<std::string> feature_names, std::vector<std::string> feature_names,
const std::vector<std::vector<std::string>> features); const std::vector<std::vector<std::string>> features);
std::vector<FeatureNode> pull_graph_list(std::string name, int server_index, std::vector<FeatureNode> pull_graph_list(std::string name, int server_index,
......
...@@ -53,7 +53,6 @@ cc_library(memory_sparse_table SRCS memory_sparse_table.cc DEPS ps_framework_pro ...@@ -53,7 +53,6 @@ cc_library(memory_sparse_table SRCS memory_sparse_table.cc DEPS ps_framework_pro
set_source_files_properties(memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(memory_sparse_geo_table SRCS memory_sparse_geo_table.cc DEPS ps_framework_proto ${TABLE_DEPS} common_table) cc_library(memory_sparse_geo_table SRCS memory_sparse_geo_table.cc DEPS ps_framework_proto ${TABLE_DEPS} common_table)
cc_library(table SRCS table.cc DEPS memory_sparse_table memory_sparse_geo_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost) cc_library(table SRCS table.cc DEPS memory_sparse_table memory_sparse_geo_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost)
target_link_libraries(table -fopenmp) target_link_libraries(table -fopenmp)
...@@ -38,10 +38,14 @@ ...@@ -38,10 +38,14 @@
#include <vector> #include <vector>
#include "paddle/fluid/distributed/ps/table/accessor.h" #include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/common_table.h" #include "paddle/fluid/distributed/ps/table/common_table.h"
#include "paddle/fluid/distributed/ps/table/graph/class_macro.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h" #include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/core/utils/rw_lock.h" #include "paddle/phi/core/utils/rw_lock.h"
#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
#endif
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
class GraphShard { class GraphShard {
...@@ -51,37 +55,37 @@ class GraphShard { ...@@ -51,37 +55,37 @@ class GraphShard {
~GraphShard(); ~GraphShard();
std::vector<Node *> &get_bucket() { return bucket; } std::vector<Node *> &get_bucket() { return bucket; }
std::vector<Node *> get_batch(int start, int end, int step); std::vector<Node *> get_batch(int start, int end, int step);
std::vector<uint64_t> get_ids_by_range(int start, int end) { std::vector<int64_t> get_ids_by_range(int start, int end) {
std::vector<uint64_t> res; std::vector<int64_t> res;
for (int i = start; i < end && i < (int)bucket.size(); i++) { for (int i = start; i < end && i < (int)bucket.size(); i++) {
res.push_back(bucket[i]->get_id()); res.push_back(bucket[i]->get_id());
} }
return res; return res;
} }
GraphNode *add_graph_node(uint64_t id); GraphNode *add_graph_node(int64_t id);
GraphNode *add_graph_node(Node *node); GraphNode *add_graph_node(Node *node);
FeatureNode *add_feature_node(uint64_t id); FeatureNode *add_feature_node(int64_t id);
Node *find_node(uint64_t id); Node *find_node(int64_t id);
void delete_node(uint64_t id); void delete_node(int64_t id);
void clear(); void clear();
void add_neighbor(uint64_t id, uint64_t dst_id, float weight); void add_neighbor(int64_t id, int64_t dst_id, float weight);
std::unordered_map<uint64_t, int> &get_node_location() { std::unordered_map<int64_t, int> &get_node_location() {
return node_location; return node_location;
} }
private: private:
std::unordered_map<uint64_t, int> node_location; std::unordered_map<int64_t, int> node_location;
std::vector<Node *> bucket; std::vector<Node *> bucket;
}; };
enum LRUResponse { ok = 0, blocked = 1, err = 2 }; enum LRUResponse { ok = 0, blocked = 1, err = 2 };
struct SampleKey { struct SampleKey {
uint64_t node_key; int64_t node_key;
size_t sample_size; size_t sample_size;
bool is_weighted; bool is_weighted;
SampleKey(uint64_t _node_key, size_t _sample_size, bool _is_weighted) SampleKey(int64_t _node_key, size_t _sample_size, bool _is_weighted)
: node_key(_node_key), : node_key(_node_key),
sample_size(_sample_size), sample_size(_sample_size),
is_weighted(_is_weighted) {} is_weighted(_is_weighted) {}
...@@ -300,7 +304,7 @@ class ScaledLRU { ...@@ -300,7 +304,7 @@ class ScaledLRU {
node_size += lru_pool[i].node_size - lru_pool[i].remove_count; node_size += lru_pool[i].node_size - lru_pool[i].remove_count;
} }
if (node_size <= size_t(1.1 * size_limit) + 1) return 0; if ((size_t)node_size <= size_t(1.1 * size_limit) + 1) return 0;
if (pthread_rwlock_wrlock(&rwlock) == 0) { if (pthread_rwlock_wrlock(&rwlock) == 0) {
// VLOG(0)<"in shrink\n"; // VLOG(0)<"in shrink\n";
global_count = 0; global_count = 0;
...@@ -308,9 +312,9 @@ class ScaledLRU { ...@@ -308,9 +312,9 @@ class ScaledLRU {
global_count += lru_pool[i].node_size - lru_pool[i].remove_count; global_count += lru_pool[i].node_size - lru_pool[i].remove_count;
} }
// VLOG(0)<<"global_count "<<global_count<<"\n"; // VLOG(0)<<"global_count "<<global_count<<"\n";
if (global_count > size_limit) { if ((size_t)global_count > size_limit) {
size_t remove = global_count - size_limit; size_t remove = global_count - size_limit;
for (int i = 0; i < lru_pool.size(); i++) { for (size_t i = 0; i < lru_pool.size(); i++) {
lru_pool[i].total_diff = 0; lru_pool[i].total_diff = 0;
lru_pool[i].remove_count += lru_pool[i].remove_count +=
1.0 * (lru_pool[i].node_size - lru_pool[i].remove_count) / 1.0 * (lru_pool[i].node_size - lru_pool[i].remove_count) /
...@@ -352,9 +356,69 @@ class ScaledLRU { ...@@ -352,9 +356,69 @@ class ScaledLRU {
friend class RandomSampleLRU<K, V>; friend class RandomSampleLRU<K, V>;
}; };
#ifdef PADDLE_WITH_HETERPS
enum GraphSamplerStatus { waiting = 0, running = 1, terminating = 2 };
class GraphTable;
class GraphSampler {
public:
GraphSampler() {
status = GraphSamplerStatus::waiting;
thread_pool.reset(new ::ThreadPool(1));
callback = [](std::vector<paddle::framework::GpuPsCommGraph> &res) {
return;
};
}
virtual int run_graph_sampling() = 0;
virtual int start_graph_sampling() {
if (status != GraphSamplerStatus::waiting) {
return -1;
}
std::promise<int> prom;
std::future<int> fut = prom.get_future();
graph_sample_task_over = thread_pool->enqueue([&prom, this]() {
prom.set_value(0);
status = GraphSamplerStatus::running;
return run_graph_sampling();
});
return fut.get();
}
virtual void init(size_t gpu_num, GraphTable *graph_table,
std::vector<std::string> args) = 0;
virtual void set_graph_sample_callback(
std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
callback) {
this->callback = callback;
}
virtual int end_graph_sampling() {
if (status == GraphSamplerStatus::running) {
status = GraphSamplerStatus::terminating;
return graph_sample_task_over.get();
}
return -1;
}
virtual GraphSamplerStatus get_graph_sampler_status() { return status; }
protected:
std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
callback;
std::shared_ptr<::ThreadPool> thread_pool;
GraphSamplerStatus status;
std::future<int> graph_sample_task_over;
std::vector<paddle::framework::GpuPsCommGraph> sample_res;
};
#endif
class GraphTable : public SparseTable { class GraphTable : public SparseTable {
public: public:
GraphTable() { use_cache = false; } GraphTable() {
use_cache = false;
shard_num = 0;
#ifdef PADDLE_WITH_HETERPS
gpups_mode = false;
#endif
rw_lock.reset(new pthread_rwlock_t());
}
virtual ~GraphTable(); virtual ~GraphTable();
virtual int32_t pull_graph_list(int start, int size, virtual int32_t pull_graph_list(int start, int size,
std::unique_ptr<char[]> &buffer, std::unique_ptr<char[]> &buffer,
...@@ -362,7 +426,7 @@ class GraphTable : public SparseTable { ...@@ -362,7 +426,7 @@ class GraphTable : public SparseTable {
int step); int step);
virtual int32_t random_sample_neighbors( virtual int32_t random_sample_neighbors(
uint64_t *node_ids, int sample_size, int64_t *node_ids, int sample_size,
std::vector<std::shared_ptr<char>> &buffers, std::vector<std::shared_ptr<char>> &buffers,
std::vector<int> &actual_sizes, bool need_weight); std::vector<int> &actual_sizes, bool need_weight);
...@@ -370,9 +434,11 @@ class GraphTable : public SparseTable { ...@@ -370,9 +434,11 @@ class GraphTable : public SparseTable {
int &actual_sizes); int &actual_sizes);
virtual int32_t get_nodes_ids_by_ranges( virtual int32_t get_nodes_ids_by_ranges(
std::vector<std::pair<int, int>> ranges, std::vector<uint64_t> &res); std::vector<std::pair<int, int>> ranges, std::vector<int64_t> &res);
virtual int32_t initialize(); virtual int32_t initialize() { return 0; }
virtual int32_t initialize(const TableParameter &config,
const FsClientParameter &fs_config);
virtual int32_t initialize(const GraphParameter &config);
int32_t load(const std::string &path, const std::string &param); int32_t load(const std::string &path, const std::string &param);
int32_t load_graph_split_config(const std::string &path); int32_t load_graph_split_config(const std::string &path);
...@@ -380,13 +446,13 @@ class GraphTable : public SparseTable { ...@@ -380,13 +446,13 @@ class GraphTable : public SparseTable {
int32_t load_nodes(const std::string &path, std::string node_type); int32_t load_nodes(const std::string &path, std::string node_type);
int32_t add_graph_node(std::vector<uint64_t> &id_list, int32_t add_graph_node(std::vector<int64_t> &id_list,
std::vector<bool> &is_weight_list); std::vector<bool> &is_weight_list);
int32_t remove_graph_node(std::vector<uint64_t> &id_list); int32_t remove_graph_node(std::vector<int64_t> &id_list);
int32_t get_server_index_by_id(uint64_t id); int32_t get_server_index_by_id(int64_t id);
Node *find_node(uint64_t id); Node *find_node(int64_t id);
virtual int32_t pull_sparse(float *values, virtual int32_t pull_sparse(float *values,
const PullSparseValue &pull_value) { const PullSparseValue &pull_value) {
...@@ -407,16 +473,27 @@ class GraphTable : public SparseTable { ...@@ -407,16 +473,27 @@ class GraphTable : public SparseTable {
return 0; return 0;
} }
virtual int32_t initialize_shard() { return 0; } virtual int32_t initialize_shard() { return 0; }
virtual uint32_t get_thread_pool_index_by_shard_index(uint64_t shard_index); virtual int32_t set_shard(size_t shard_idx, size_t server_num) {
virtual uint32_t get_thread_pool_index(uint64_t node_id); _shard_idx = shard_idx;
/*
_shard_num is not used in graph_table, this following operation is for the
purpose of
being compatible with base class table.
*/
_shard_num = server_num;
this->server_num = server_num;
return 0;
}
virtual uint32_t get_thread_pool_index_by_shard_index(int64_t shard_index);
virtual uint32_t get_thread_pool_index(int64_t node_id);
virtual std::pair<int32_t, std::string> parse_feature(std::string feat_str); virtual std::pair<int32_t, std::string> parse_feature(std::string feat_str);
virtual int32_t get_node_feat(const std::vector<uint64_t> &node_ids, virtual int32_t get_node_feat(const std::vector<int64_t> &node_ids,
const std::vector<std::string> &feature_names, const std::vector<std::string> &feature_names,
std::vector<std::vector<std::string>> &res); std::vector<std::vector<std::string>> &res);
virtual int32_t set_node_feat( virtual int32_t set_node_feat(
const std::vector<uint64_t> &node_ids, const std::vector<int64_t> &node_ids,
const std::vector<std::string> &feature_names, const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &res); const std::vector<std::vector<std::string>> &res);
...@@ -433,11 +510,25 @@ class GraphTable : public SparseTable { ...@@ -433,11 +510,25 @@ class GraphTable : public SparseTable {
} }
return 0; return 0;
} }
#ifdef PADDLE_WITH_HETERPS
virtual int32_t start_graph_sampling() {
return this->graph_sampler->start_graph_sampling();
}
virtual int32_t end_graph_sampling() {
return this->graph_sampler->end_graph_sampling();
}
virtual int32_t set_graph_sample_callback(
std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
callback) {
graph_sampler->set_graph_sample_callback(callback);
return 0;
}
// virtual GraphSampler *get_graph_sampler() { return graph_sampler.get(); }
#endif
protected: protected:
std::vector<GraphShard *> shards, extra_shards; std::vector<GraphShard *> shards, extra_shards;
size_t shard_start, shard_end, server_num, shard_num_per_server, shard_num; size_t shard_start, shard_end, server_num, shard_num_per_server, shard_num;
const int task_pool_size_ = 24; int task_pool_size_ = 24;
const int random_sample_nodes_ranges = 3; const int random_sample_nodes_ranges = 3;
std::vector<std::string> feat_name; std::vector<std::string> feat_name;
...@@ -450,11 +541,61 @@ class GraphTable : public SparseTable { ...@@ -450,11 +541,61 @@ class GraphTable : public SparseTable {
std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool; std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
std::vector<std::shared_ptr<std::mt19937_64>> _shards_task_rng_pool; std::vector<std::shared_ptr<std::mt19937_64>> _shards_task_rng_pool;
std::shared_ptr<ScaledLRU<SampleKey, SampleResult>> scaled_lru; std::shared_ptr<ScaledLRU<SampleKey, SampleResult>> scaled_lru;
std::unordered_set<uint64_t> extra_nodes; std::unordered_set<int64_t> extra_nodes;
std::unordered_map<uint64_t, size_t> extra_nodes_to_thread_index; std::unordered_map<int64_t, size_t> extra_nodes_to_thread_index;
bool use_cache, use_duplicate_nodes; bool use_cache, use_duplicate_nodes;
mutable std::mutex mutex_; mutable std::mutex mutex_;
std::shared_ptr<pthread_rwlock_t> rw_lock;
#ifdef PADDLE_WITH_HETERPS
// paddle::framework::GpuPsGraphTable gpu_graph_table;
bool gpups_mode;
// std::shared_ptr<::ThreadPool> graph_sample_pool;
std::shared_ptr<GraphSampler> graph_sampler;
REGISTER_GRAPH_FRIEND_CLASS(2, CompleteGraphSampler, BasicBfsGraphSampler)
#endif
};
#ifdef PADDLE_WITH_HETERPS
REGISTER_PSCORE_REGISTERER(GraphSampler);
class CompleteGraphSampler : public GraphSampler {
public:
CompleteGraphSampler() {}
~CompleteGraphSampler() {}
// virtual pthread_rwlock_t *export_rw_lock();
virtual int run_graph_sampling();
virtual void init(size_t gpu_num, GraphTable *graph_table,
std::vector<std::string> args_);
protected:
GraphTable *graph_table;
std::vector<std::vector<paddle::framework::GpuPsGraphNode>> sample_nodes;
std::vector<std::vector<int64_t>> sample_neighbors;
// std::vector<GpuPsCommGraph> sample_res;
// std::shared_ptr<std::mt19937_64> random;
int gpu_num;
};
class BasicBfsGraphSampler : public GraphSampler {
public:
BasicBfsGraphSampler() {}
~BasicBfsGraphSampler() {}
// virtual pthread_rwlock_t *export_rw_lock();
virtual int run_graph_sampling();
virtual void init(size_t gpu_num, GraphTable *graph_table,
std::vector<std::string> args_);
protected:
GraphTable *graph_table;
// std::vector<std::vector<GpuPsGraphNode>> sample_nodes;
std::vector<std::vector<paddle::framework::GpuPsGraphNode>> sample_nodes;
std::vector<std::vector<int64_t>> sample_neighbors;
size_t gpu_num;
int node_num_for_each_shard, edge_num_for_each_node;
int rounds, interval;
std::vector<std::unordered_map<int64_t, std::vector<int64_t>>>
sample_neighbors_map;
}; };
#endif
} // namespace distributed } // namespace distributed
}; // namespace paddle }; // namespace paddle
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/operators/truncated_gaussian_random_op.h" #include "paddle/fluid/operators/truncated_gaussian_random_op.h"
namespace paddle { namespace paddle {
...@@ -117,13 +118,9 @@ class TruncatedGaussianInitializer : public Initializer { ...@@ -117,13 +118,9 @@ class TruncatedGaussianInitializer : public Initializer {
seed_ = static_cast<unsigned int>(std::stoi(attrs[1])); seed_ = static_cast<unsigned int>(std::stoi(attrs[1]));
mean_ = std::stof(attrs[2]); mean_ = std::stof(attrs[2]);
std_ = std::stof(attrs[3]); std_ = std::stof(attrs[3]);
auto normal_cdf = [](float x) {
return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0; std::uniform_real_distribution<float> dist_(
}; std::numeric_limits<float>::min(), 1.0);
float a_normal_cdf = normal_cdf((-2.0 - mean_) / std_);
float b_normal_cdf = normal_cdf((2.0 - mean_) / std_);
std::uniform_real_distribution<float> dist_(2.0 * a_normal_cdf - 1.0,
2.0 * b_normal_cdf - 1.0);
random_engine_ = framework::GetCPURandomEngine(seed_); random_engine_ = framework::GetCPURandomEngine(seed_);
} }
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#define DECLARE_GRAPH_FRIEND_CLASS(a) friend class a;
#define DECLARE_1_FRIEND_CLASS(a, ...) DECLARE_GRAPH_FRIEND_CLASS(a)
#define DECLARE_2_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_1_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_3_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_2_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_4_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_3_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_5_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_4_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_6_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_5_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_7_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_6_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_8_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_7_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_9_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_8_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_10_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_9_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_11_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_10_FRIEND_CLASS(__VA_ARGS__)
#define REGISTER_GRAPH_FRIEND_CLASS(n, ...) \
DECLARE_##n##_FRIEND_CLASS(__VA_ARGS__)
...@@ -17,11 +17,11 @@ ...@@ -17,11 +17,11 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
void GraphEdgeBlob::add_edge(uint64_t id, float weight = 1) { void GraphEdgeBlob::add_edge(int64_t id, float weight = 1) {
id_arr.push_back(id); id_arr.push_back(id);
} }
void WeightedGraphEdgeBlob::add_edge(uint64_t id, float weight = 1) { void WeightedGraphEdgeBlob::add_edge(int64_t id, float weight = 1) {
id_arr.push_back(id); id_arr.push_back(id);
weight_arr.push_back(weight); weight_arr.push_back(weight);
} }
......
...@@ -24,19 +24,20 @@ class GraphEdgeBlob { ...@@ -24,19 +24,20 @@ class GraphEdgeBlob {
GraphEdgeBlob() {} GraphEdgeBlob() {}
virtual ~GraphEdgeBlob() {} virtual ~GraphEdgeBlob() {}
size_t size() { return id_arr.size(); } size_t size() { return id_arr.size(); }
virtual void add_edge(uint64_t id, float weight); virtual void add_edge(int64_t id, float weight);
uint64_t get_id(int idx) { return id_arr[idx]; } int64_t get_id(int idx) { return id_arr[idx]; }
virtual float get_weight(int idx) { return 1; } virtual float get_weight(int idx) { return 1; }
std::vector<int64_t>& export_id_array() { return id_arr; }
protected: protected:
std::vector<uint64_t> id_arr; std::vector<int64_t> id_arr;
}; };
class WeightedGraphEdgeBlob : public GraphEdgeBlob { class WeightedGraphEdgeBlob : public GraphEdgeBlob {
public: public:
WeightedGraphEdgeBlob() {} WeightedGraphEdgeBlob() {}
virtual ~WeightedGraphEdgeBlob() {} virtual ~WeightedGraphEdgeBlob() {}
virtual void add_edge(uint64_t id, float weight); virtual void add_edge(int64_t id, float weight);
virtual float get_weight(int idx) { return weight_arr[idx]; } virtual float get_weight(int idx) { return weight_arr[idx]; }
protected: protected:
......
...@@ -48,6 +48,7 @@ class Node { ...@@ -48,6 +48,7 @@ class Node {
virtual void set_feature(int idx, std::string str) {} virtual void set_feature(int idx, std::string str) {}
virtual void set_feature_size(int size) {} virtual void set_feature_size(int size) {}
virtual int get_feature_size() { return 0; } virtual int get_feature_size() { return 0; }
virtual size_t get_neighbor_size() { return 0; }
protected: protected:
uint64_t id; uint64_t id;
...@@ -70,6 +71,7 @@ class GraphNode : public Node { ...@@ -70,6 +71,7 @@ class GraphNode : public Node {
} }
virtual uint64_t get_neighbor_id(int idx) { return edges->get_id(idx); } virtual uint64_t get_neighbor_id(int idx) { return edges->get_id(idx); }
virtual float get_neighbor_weight(int idx) { return edges->get_weight(idx); } virtual float get_neighbor_weight(int idx) { return edges->get_weight(idx); }
virtual size_t get_neighbor_size() { return edges->size(); }
protected: protected:
Sampler *sampler; Sampler *sampler;
......
...@@ -37,6 +37,8 @@ REGISTER_PSCORE_CLASS(Table, CommonDenseTable); ...@@ -37,6 +37,8 @@ REGISTER_PSCORE_CLASS(Table, CommonDenseTable);
REGISTER_PSCORE_CLASS(Table, CommonSparseTable); REGISTER_PSCORE_CLASS(Table, CommonSparseTable);
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
REGISTER_PSCORE_CLASS(Table, SSDSparseTable); REGISTER_PSCORE_CLASS(Table, SSDSparseTable);
REGISTER_PSCORE_CLASS(GraphSampler, CompleteGraphSampler);
REGISTER_PSCORE_CLASS(GraphSampler, BasicBfsGraphSampler);
#endif #endif
REGISTER_PSCORE_CLASS(Table, SparseGeoTable); REGISTER_PSCORE_CLASS(Table, SparseGeoTable);
REGISTER_PSCORE_CLASS(Table, BarrierTable); REGISTER_PSCORE_CLASS(Table, BarrierTable);
......
...@@ -24,6 +24,9 @@ cc_test(graph_node_test SRCS graph_node_test.cc DEPS graph_py_service scope serv ...@@ -24,6 +24,9 @@ cc_test(graph_node_test SRCS graph_node_test.cc DEPS graph_py_service scope serv
set_source_files_properties(graph_node_split_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(graph_node_split_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(graph_node_split_test SRCS graph_node_split_test.cc DEPS graph_py_service scope server client communicator ps_service boost table ps_framework_proto ${COMMON_DEPS}) cc_test(graph_node_split_test SRCS graph_node_split_test.cc DEPS graph_py_service scope server client communicator ps_service boost table ps_framework_proto ${COMMON_DEPS})
set_source_files_properties(graph_table_sample_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(graph_table_sample_test SRCS graph_table_sample_test.cc DEPS scope server communicator ps_service boost table ps_framework_proto ${COMMON_DEPS})
set_source_files_properties(feature_value_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(feature_value_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(feature_value_test SRCS feature_value_test.cc DEPS ${COMMON_DEPS} boost table) cc_test(feature_value_test SRCS feature_value_test.cc DEPS ${COMMON_DEPS} boost table)
......
...@@ -236,7 +236,7 @@ void RunGraphSplit() { ...@@ -236,7 +236,7 @@ void RunGraphSplit() {
sleep(2); sleep(2);
std::map<uint64_t, std::vector<paddle::distributed::Region>> dense_regions; std::map<uint64_t, std::vector<paddle::distributed::Region>> dense_regions;
dense_regions.insert( dense_regions.insert(
std::pair<uint64_t, std::vector<paddle::distributed::Region>>(0, {})); std::pair<int64_t, std::vector<paddle::distributed::Region>>(0, {}));
auto regions = dense_regions[0]; auto regions = dense_regions[0];
RunClient(dense_regions, 0, pserver_ptr_->get_service()); RunClient(dense_regions, 0, pserver_ptr_->get_service());
...@@ -250,16 +250,16 @@ void RunGraphSplit() { ...@@ -250,16 +250,16 @@ void RunGraphSplit() {
worker_ptr_->load(0, std::string(edge_file_name), std::string("e>")); worker_ptr_->load(0, std::string(edge_file_name), std::string("e>"));
srand(time(0)); srand(time(0));
pull_status.wait(); pull_status.wait();
std::vector<std::vector<uint64_t>> _vs; std::vector<std::vector<int64_t>> _vs;
std::vector<std::vector<float>> vs; std::vector<std::vector<float>> vs;
pull_status = worker_ptr_->batch_sample_neighbors( pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 10240001024), 4, _vs, vs, true); 0, std::vector<int64_t>(1, 10240001024), 4, _vs, vs, true);
pull_status.wait(); pull_status.wait();
ASSERT_EQ(0, _vs[0].size()); ASSERT_EQ(0, _vs[0].size());
_vs.clear(); _vs.clear();
vs.clear(); vs.clear();
pull_status = worker_ptr_->batch_sample_neighbors( pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 97), 4, _vs, vs, true); 0, std::vector<int64_t>(1, 97), 4, _vs, vs, true);
pull_status.wait(); pull_status.wait();
ASSERT_EQ(3, _vs[0].size()); ASSERT_EQ(3, _vs[0].size());
std::remove(edge_file_name); std::remove(edge_file_name);
......
...@@ -48,10 +48,10 @@ namespace distributed = paddle::distributed; ...@@ -48,10 +48,10 @@ namespace distributed = paddle::distributed;
void testSampleNodes( void testSampleNodes(
std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) { std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) {
std::vector<uint64_t> ids; std::vector<int64_t> ids;
auto pull_status = worker_ptr_->random_sample_nodes(0, 0, 6, ids); auto pull_status = worker_ptr_->random_sample_nodes(0, 0, 6, ids);
std::unordered_set<uint64_t> s; std::unordered_set<int64_t> s;
std::unordered_set<uint64_t> s1 = {37, 59}; std::unordered_set<int64_t> s1 = {37, 59};
pull_status.wait(); pull_status.wait();
for (auto id : ids) s.insert(id); for (auto id : ids) s.insert(id);
ASSERT_EQ(true, s.size() == s1.size()); ASSERT_EQ(true, s.size() == s1.size());
...@@ -106,14 +106,14 @@ void testFeatureNodeSerializeFloat64() { ...@@ -106,14 +106,14 @@ void testFeatureNodeSerializeFloat64() {
void testSingleSampleNeighboor( void testSingleSampleNeighboor(
std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) { std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) {
std::vector<std::vector<uint64_t>> vs; std::vector<std::vector<int64_t>> vs;
std::vector<std::vector<float>> vs1; std::vector<std::vector<float>> vs1;
auto pull_status = worker_ptr_->batch_sample_neighbors( auto pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 37), 4, vs, vs1, true); 0, std::vector<int64_t>(1, 37), 4, vs, vs1, true);
pull_status.wait(); pull_status.wait();
std::unordered_set<uint64_t> s; std::unordered_set<int64_t> s;
std::unordered_set<uint64_t> s1 = {112, 45, 145}; std::unordered_set<int64_t> s1 = {112, 45, 145};
for (auto g : vs[0]) { for (auto g : vs[0]) {
s.insert(g); s.insert(g);
} }
...@@ -126,7 +126,7 @@ void testSingleSampleNeighboor( ...@@ -126,7 +126,7 @@ void testSingleSampleNeighboor(
vs.clear(); vs.clear();
vs1.clear(); vs1.clear();
pull_status = worker_ptr_->batch_sample_neighbors( pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 96), 4, vs, vs1, true); 0, std::vector<int64_t>(1, 96), 4, vs, vs1, true);
pull_status.wait(); pull_status.wait();
s1 = {111, 48, 247}; s1 = {111, 48, 247};
for (auto g : vs[0]) { for (auto g : vs[0]) {
...@@ -147,30 +147,30 @@ void testAddNode( ...@@ -147,30 +147,30 @@ void testAddNode(
std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) { std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) {
worker_ptr_->clear_nodes(0); worker_ptr_->clear_nodes(0);
int total_num = 270000; int total_num = 270000;
uint64_t id; int64_t id;
std::unordered_set<uint64_t> id_set; std::unordered_set<int64_t> id_set;
for (int i = 0; i < total_num; i++) { for (int i = 0; i < total_num; i++) {
while (id_set.find(id = rand()) != id_set.end()) while (id_set.find(id = rand()) != id_set.end())
; ;
id_set.insert(id); id_set.insert(id);
} }
std::vector<uint64_t> id_list(id_set.begin(), id_set.end()); std::vector<int64_t> id_list(id_set.begin(), id_set.end());
std::vector<bool> weight_list; std::vector<bool> weight_list;
auto status = worker_ptr_->add_graph_node(0, id_list, weight_list); auto status = worker_ptr_->add_graph_node(0, id_list, weight_list);
status.wait(); status.wait();
std::vector<uint64_t> ids[2]; std::vector<int64_t> ids[2];
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
auto sample_status = auto sample_status =
worker_ptr_->random_sample_nodes(0, i, total_num, ids[i]); worker_ptr_->random_sample_nodes(0, i, total_num, ids[i]);
sample_status.wait(); sample_status.wait();
} }
std::unordered_set<uint64_t> id_set_check(ids[0].begin(), ids[0].end()); std::unordered_set<int64_t> id_set_check(ids[0].begin(), ids[0].end());
for (auto x : ids[1]) id_set_check.insert(x); for (auto x : ids[1]) id_set_check.insert(x);
ASSERT_EQ(id_set.size(), id_set_check.size()); ASSERT_EQ(id_set.size(), id_set_check.size());
for (auto x : id_set) { for (auto x : id_set) {
ASSERT_EQ(id_set_check.find(x) != id_set_check.end(), true); ASSERT_EQ(id_set_check.find(x) != id_set_check.end(), true);
} }
std::vector<uint64_t> remove_ids; std::vector<int64_t> remove_ids;
for (auto p : id_set_check) { for (auto p : id_set_check) {
if (remove_ids.size() == 0) if (remove_ids.size() == 0)
remove_ids.push_back(p); remove_ids.push_back(p);
...@@ -187,7 +187,7 @@ void testAddNode( ...@@ -187,7 +187,7 @@ void testAddNode(
worker_ptr_->random_sample_nodes(0, i, total_num, ids[i]); worker_ptr_->random_sample_nodes(0, i, total_num, ids[i]);
sample_status.wait(); sample_status.wait();
} }
std::unordered_set<uint64_t> id_set_check1(ids[0].begin(), ids[0].end()); std::unordered_set<int64_t> id_set_check1(ids[0].begin(), ids[0].end());
for (auto x : ids[1]) id_set_check1.insert(x); for (auto x : ids[1]) id_set_check1.insert(x);
ASSERT_EQ(id_set_check1.size(), id_set_check.size()); ASSERT_EQ(id_set_check1.size(), id_set_check.size());
for (auto x : id_set_check1) { for (auto x : id_set_check1) {
...@@ -196,14 +196,14 @@ void testAddNode( ...@@ -196,14 +196,14 @@ void testAddNode(
} }
void testBatchSampleNeighboor( void testBatchSampleNeighboor(
std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) { std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) {
std::vector<std::vector<uint64_t>> vs; std::vector<std::vector<int64_t>> vs;
std::vector<std::vector<float>> vs1; std::vector<std::vector<float>> vs1;
std::vector<std::uint64_t> v = {37, 96}; std::vector<std::int64_t> v = {37, 96};
auto pull_status = auto pull_status =
worker_ptr_->batch_sample_neighbors(0, v, 4, vs, vs1, false); worker_ptr_->batch_sample_neighbors(0, v, 4, vs, vs1, false);
pull_status.wait(); pull_status.wait();
std::unordered_set<uint64_t> s; std::unordered_set<int64_t> s;
std::unordered_set<uint64_t> s1 = {112, 45, 145}; std::unordered_set<int64_t> s1 = {112, 45, 145};
for (auto g : vs[0]) { for (auto g : vs[0]) {
s.insert(g); s.insert(g);
} }
...@@ -417,7 +417,7 @@ void RunBrpcPushSparse() { ...@@ -417,7 +417,7 @@ void RunBrpcPushSparse() {
std::map<uint64_t, std::vector<paddle::distributed::Region>> dense_regions; std::map<uint64_t, std::vector<paddle::distributed::Region>> dense_regions;
dense_regions.insert( dense_regions.insert(
std::pair<uint64_t, std::vector<paddle::distributed::Region>>(0, {})); std::pair<int64_t, std::vector<paddle::distributed::Region>>(0, {}));
auto regions = dense_regions[0]; auto regions = dense_regions[0];
RunClient(dense_regions, 0, pserver_ptr_->get_service()); RunClient(dense_regions, 0, pserver_ptr_->get_service());
...@@ -427,14 +427,14 @@ void RunBrpcPushSparse() { ...@@ -427,14 +427,14 @@ void RunBrpcPushSparse() {
worker_ptr_->load(0, std::string(edge_file_name), std::string("e>")); worker_ptr_->load(0, std::string(edge_file_name), std::string("e>"));
srand(time(0)); srand(time(0));
pull_status.wait(); pull_status.wait();
std::vector<std::vector<uint64_t>> _vs; std::vector<std::vector<int64_t>> _vs;
std::vector<std::vector<float>> vs; std::vector<std::vector<float>> vs;
testSampleNodes(worker_ptr_); testSampleNodes(worker_ptr_);
sleep(5); sleep(5);
testSingleSampleNeighboor(worker_ptr_); testSingleSampleNeighboor(worker_ptr_);
testBatchSampleNeighboor(worker_ptr_); testBatchSampleNeighboor(worker_ptr_);
pull_status = worker_ptr_->batch_sample_neighbors( pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 10240001024), 4, _vs, vs, true); 0, std::vector<int64_t>(1, 10240001024), 4, _vs, vs, true);
pull_status.wait(); pull_status.wait();
ASSERT_EQ(0, _vs[0].size()); ASSERT_EQ(0, _vs[0].size());
paddle::distributed::GraphTable* g = paddle::distributed::GraphTable* g =
...@@ -445,14 +445,14 @@ void RunBrpcPushSparse() { ...@@ -445,14 +445,14 @@ void RunBrpcPushSparse() {
while (round--) { while (round--) {
vs.clear(); vs.clear();
pull_status = worker_ptr_->batch_sample_neighbors( pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 37), 1, _vs, vs, false); 0, std::vector<int64_t>(1, 37), 1, _vs, vs, false);
pull_status.wait(); pull_status.wait();
for (int i = 0; i < ttl; i++) { for (int i = 0; i < ttl; i++) {
std::vector<std::vector<uint64_t>> vs1; std::vector<std::vector<int64_t>> vs1;
std::vector<std::vector<float>> vs2; std::vector<std::vector<float>> vs2;
pull_status = worker_ptr_->batch_sample_neighbors( pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 37), 1, vs1, vs2, false); 0, std::vector<int64_t>(1, 37), 1, vs1, vs2, false);
pull_status.wait(); pull_status.wait();
ASSERT_EQ(_vs[0].size(), vs1[0].size()); ASSERT_EQ(_vs[0].size(), vs1[0].size());
...@@ -540,7 +540,7 @@ void RunBrpcPushSparse() { ...@@ -540,7 +540,7 @@ void RunBrpcPushSparse() {
// Test Pull by step // Test Pull by step
std::unordered_set<uint64_t> count_item_nodes; std::unordered_set<int64_t> count_item_nodes;
// pull by step 2 // pull by step 2
for (int test_step = 1; test_step < 4; test_step++) { for (int test_step = 1; test_step < 4; test_step++) {
count_item_nodes.clear(); count_item_nodes.clear();
...@@ -558,18 +558,18 @@ void RunBrpcPushSparse() { ...@@ -558,18 +558,18 @@ void RunBrpcPushSparse() {
ASSERT_EQ(count_item_nodes.size(), 12); ASSERT_EQ(count_item_nodes.size(), 12);
} }
std::pair<std::vector<std::vector<uint64_t>>, std::vector<float>> res; std::pair<std::vector<std::vector<int64_t>>, std::vector<float>> res;
res = client1.batch_sample_neighbors( res = client1.batch_sample_neighbors(
std::string("user2item"), std::vector<uint64_t>(1, 96), 4, true, false); std::string("user2item"), std::vector<int64_t>(1, 96), 4, true, false);
ASSERT_EQ(res.first[0].size(), 3); ASSERT_EQ(res.first[0].size(), 3);
std::vector<uint64_t> node_ids; std::vector<int64_t> node_ids;
node_ids.push_back(96); node_ids.push_back(96);
node_ids.push_back(37); node_ids.push_back(37);
res = client1.batch_sample_neighbors(std::string("user2item"), node_ids, 4, res = client1.batch_sample_neighbors(std::string("user2item"), node_ids, 4,
true, false); true, false);
ASSERT_EQ(res.first[1].size(), 1); ASSERT_EQ(res.first[1].size(), 1);
std::vector<uint64_t> nodes_ids = client2.random_sample_nodes("user", 0, 6); std::vector<int64_t> nodes_ids = client2.random_sample_nodes("user", 0, 6);
ASSERT_EQ(nodes_ids.size(), 2); ASSERT_EQ(nodes_ids.size(), 2);
ASSERT_EQ(true, (nodes_ids[0] == 59 && nodes_ids[1] == 37) || ASSERT_EQ(true, (nodes_ids[0] == 59 && nodes_ids[1] == 37) ||
(nodes_ids[0] == 37 && nodes_ids[1] == 59)); (nodes_ids[0] == 37 && nodes_ids[1] == 59));
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <unistd.h>
#include <condition_variable> // NOLINT
#include <fstream>
#include <iomanip>
#include <string>
#include <thread> // NOLINT
#include <unordered_set>
#include <vector>
#include "google/protobuf/text_format.h"
#include <chrono>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/ps/service/env.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/common_graph_table.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace operators = paddle::operators;
namespace memory = paddle::memory;
namespace distributed = paddle::distributed;
std::vector<std::string> edges = {
std::string("37\t45\t0.34"), std::string("37\t145\t0.31"),
std::string("37\t112\t0.21"), std::string("96\t48\t1.4"),
std::string("96\t247\t0.31"), std::string("96\t111\t1.21"),
std::string("59\t45\t0.34"), std::string("59\t145\t0.31"),
std::string("59\t122\t0.21"), std::string("97\t48\t0.34"),
std::string("97\t247\t0.31"), std::string("97\t111\t0.21")};
// odd id:96 48 122 112
char edge_file_name[] = "edges.txt";
std::vector<std::string> nodes = {
std::string("user\t37\ta 0.34\tb 13 14\tc hello\td abc"),
std::string("user\t96\ta 0.31\tb 15 10\tc 96hello\td abcd"),
std::string("user\t59\ta 0.11\tb 11 14"),
std::string("user\t97\ta 0.11\tb 12 11"),
std::string("item\t45\ta 0.21"),
std::string("item\t145\ta 0.21"),
std::string("item\t112\ta 0.21"),
std::string("item\t48\ta 0.21"),
std::string("item\t247\ta 0.21"),
std::string("item\t111\ta 0.21"),
std::string("item\t46\ta 0.21"),
std::string("item\t146\ta 0.21"),
std::string("item\t122\ta 0.21"),
std::string("item\t49\ta 0.21"),
std::string("item\t248\ta 0.21"),
std::string("item\t113\ta 0.21")};
char node_file_name[] = "nodes.txt";
void prepare_file(char file_name[], std::vector<std::string> data) {
std::ofstream ofile;
ofile.open(file_name);
for (auto x : data) {
ofile << x << std::endl;
}
ofile.close();
}
void testGraphSample() {
#ifdef PADDLE_WITH_HETERPS
::paddle::distributed::GraphParameter table_proto;
table_proto.set_gpups_mode(true);
table_proto.set_gpups_mode_shard_num(127);
table_proto.set_gpu_num(2);
distributed::GraphTable graph_table, graph_table1;
graph_table.initialize(table_proto);
prepare_file(edge_file_name, edges);
graph_table.load(std::string(edge_file_name), std::string("e>"));
std::vector<paddle::framework::GpuPsCommGraph> res;
std::promise<int> prom;
std::future<int> fut = prom.get_future();
graph_table.set_graph_sample_callback(
[&res, &prom](std::vector<paddle::framework::GpuPsCommGraph> &res0) {
res = res0;
prom.set_value(0);
});
graph_table.start_graph_sampling();
fut.get();
graph_table.end_graph_sampling();
ASSERT_EQ(2, res.size());
// 37 59 97
for (int i = 0; i < (int)res[1].node_size; i++) {
std::cout << res[1].node_list[i].node_id << std::endl;
}
ASSERT_EQ(3, res[1].node_size);
::paddle::distributed::GraphParameter table_proto1;
table_proto1.set_gpups_mode(true);
table_proto1.set_gpups_mode_shard_num(127);
table_proto1.set_gpu_num(2);
table_proto1.set_gpups_graph_sample_class("BasicBfsGraphSampler");
table_proto1.set_gpups_graph_sample_args("5,5,1,1");
graph_table1.initialize(table_proto1);
graph_table1.load(std::string(edge_file_name), std::string("e>"));
std::vector<paddle::framework::GpuPsCommGraph> res1;
std::promise<int> prom1;
std::future<int> fut1 = prom1.get_future();
graph_table1.set_graph_sample_callback(
[&res1, &prom1](std::vector<paddle::framework::GpuPsCommGraph> &res0) {
res1 = res0;
prom1.set_value(0);
});
graph_table1.start_graph_sampling();
fut1.get();
graph_table1.end_graph_sampling();
// distributed::BasicBfsGraphSampler *sampler1 =
// (distributed::BasicBfsGraphSampler *)graph_table1.get_graph_sampler();
// sampler1->start_graph_sampling();
// std::this_thread::sleep_for (std::chrono::seconds(1));
// std::vector<paddle::framework::GpuPsCommGraph> res1;// =
// sampler1->fetch_sample_res();
ASSERT_EQ(2, res1.size());
// odd id:96 48 122 112
for (int i = 0; i < (int)res1[0].node_size; i++) {
std::cout << res1[0].node_list[i].node_id << std::endl;
}
ASSERT_EQ(4, res1[0].node_size);
#endif
}
TEST(testGraphSample, Run) { testGraphSample(); }
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/errors.h"
#include "glog/logging.h" #include "glog/logging.h"
DECLARE_bool(retain_grad_for_all_tensor);
namespace egr { namespace egr {
static void CopyOrAddTensor(paddle::experimental::Tensor* tensor, static void CopyOrAddTensor(paddle::experimental::Tensor* tensor,
...@@ -39,8 +39,8 @@ static void CopyOrAddTensor(paddle::experimental::Tensor* tensor, ...@@ -39,8 +39,8 @@ static void CopyOrAddTensor(paddle::experimental::Tensor* tensor,
} }
std::vector<std::vector<paddle::experimental::Tensor>> GradNodeAccumulation:: std::vector<std::vector<paddle::experimental::Tensor>> GradNodeAccumulation::
operator()( operator()(const std::vector<std::vector<paddle::experimental::Tensor>>& grads,
const std::vector<std::vector<paddle::experimental::Tensor>>& grads) { bool create_graph) {
VLOG(3) << "Running Eager Backward Node: GradNodeAccumulation"; VLOG(3) << "Running Eager Backward Node: GradNodeAccumulation";
PADDLE_ENFORCE(grads.size() == 1, PADDLE_ENFORCE(grads.size() == 1,
paddle::platform::errors::Fatal( paddle::platform::errors::Fatal(
...@@ -62,7 +62,7 @@ operator()( ...@@ -62,7 +62,7 @@ operator()(
grad_out = grads[0][0]; grad_out = grads[0][0];
} }
if (!weak_grad_.expired()) { if (!weak_grad_.expired() && FLAGS_retain_grad_for_all_tensor) {
auto grad = weak_grad_.lock(); auto grad = weak_grad_.lock();
CopyOrAddTensor(grad.get(), grad_out); CopyOrAddTensor(grad.get(), grad_out);
} }
......
...@@ -35,8 +35,15 @@ class GradNodeAccumulation : public GradNodeBase { ...@@ -35,8 +35,15 @@ class GradNodeAccumulation : public GradNodeBase {
// Functor: perform backward computations // Functor: perform backward computations
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()( virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
const std::vector<std::vector<paddle::experimental::Tensor>>& grads) const std::vector<std::vector<paddle::experimental::Tensor>>& grads,
override; bool create_graph = false) override;
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
}
std::string name() { return "GradNodeAccumulation"; } std::string name() { return "GradNodeAccumulation"; }
......
...@@ -145,8 +145,8 @@ void GradNodeScale::SetTensorWrappers_X( ...@@ -145,8 +145,8 @@ void GradNodeScale::SetTensorWrappers_X(
void GradNodeScale::SetAttributes_scale(float scale) { scale_ = scale; } void GradNodeScale::SetAttributes_scale(float scale) { scale_ = scale; }
std::vector<std::vector<paddle::experimental::Tensor>> GradNodeScale:: std::vector<std::vector<paddle::experimental::Tensor>> GradNodeScale::
operator()( operator()(const std::vector<std::vector<paddle::experimental::Tensor>>& grads,
const std::vector<std::vector<paddle::experimental::Tensor>>& grads) { bool create_graph) {
// 1. Check Output Size // 1. Check Output Size
PADDLE_ENFORCE( PADDLE_ENFORCE(
((grads.size() == 1) && (grads[0].size() == 1)), ((grads.size() == 1) && (grads[0].size() == 1)),
......
...@@ -39,8 +39,15 @@ class GradNodeScale : public GradNodeBase { ...@@ -39,8 +39,15 @@ class GradNodeScale : public GradNodeBase {
// Functor: perform backward computations // Functor: perform backward computations
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()( virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
const std::vector<std::vector<paddle::experimental::Tensor>>& grads) const std::vector<std::vector<paddle::experimental::Tensor>>& grads,
override; bool create_graph = false) override;
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
}
void SetTensorWrappers_X( void SetTensorWrappers_X(
const std::vector<paddle::experimental::Tensor>& tensors); const std::vector<paddle::experimental::Tensor>& tensors);
......
...@@ -2074,7 +2074,8 @@ static std::string GenerateGradNodeCCContents( ...@@ -2074,7 +2074,8 @@ static std::string GenerateGradNodeCCContents(
const char* GRAD_FUNCTION_TEMPLATE = const char* GRAD_FUNCTION_TEMPLATE =
"std::vector<std::vector<paddle::experimental::Tensor>> " "std::vector<std::vector<paddle::experimental::Tensor>> "
"GradNode%s::operator()(const " "GradNode%s::operator()(const "
"std::vector<std::vector<paddle::experimental::Tensor>>& grads) {\n%s\n}"; "std::vector<std::vector<paddle::experimental::Tensor>>& grads, "
"bool create_graph) {\n%s\n}";
std::string grad_function_str = paddle::string::Sprintf( std::string grad_function_str = paddle::string::Sprintf(
GRAD_FUNCTION_TEMPLATE, fwd_op_type, generated_grad_function_body); GRAD_FUNCTION_TEMPLATE, fwd_op_type, generated_grad_function_body);
...@@ -2109,18 +2110,28 @@ static std::string GenerateGradNodeHeaderContents( ...@@ -2109,18 +2110,28 @@ static std::string GenerateGradNodeHeaderContents(
"\n" "\n"
" virtual std::vector<std::vector<paddle::experimental::Tensor>> " " virtual std::vector<std::vector<paddle::experimental::Tensor>> "
"operator()(const " "operator()(const "
"std::vector<std::vector<paddle::experimental::Tensor>>& grads) " "std::vector<std::vector<paddle::experimental::Tensor>>& grads, const "
"bool create_graph = false) "
"override;\n" "override;\n"
"\n" "\n"
" void ClearTensorWrappers() override { \n"
"%s\n"
" is_tensor_wrappers_cleared = true;\n"
" }\n"
" std::string name() override { return \" GradNode%s \"; } \n " " std::string name() override { return \" GradNode%s \"; } \n "
"\n" "\n"
" // SetX, SetY, ...\n" " // SetX, SetY, ...\n"
"%s\n" "%s\n"
" // SetAttrMap\n" " // SetAttrMap\n"
"%s\n" "%s\n"
" bool IsTensorWrappersCleared() override { \n"
" return is_tensor_wrappers_cleared;\n"
" }\n"
" private:\n" " private:\n"
" // TensorWrappers\n" " // TensorWrappers\n"
"%s\n" "%s\n"
" bool is_tensor_wrappers_cleared = false;\n"
"\n"
" // Attribute Map\n" " // Attribute Map\n"
"%s\n" "%s\n"
"};"; "};";
...@@ -2154,6 +2165,7 @@ static std::string GenerateGradNodeHeaderContents( ...@@ -2154,6 +2165,7 @@ static std::string GenerateGradNodeHeaderContents(
std::string set_tensor_wrappers_str = ""; std::string set_tensor_wrappers_str = "";
std::string tensor_wrapper_members_str = ""; std::string tensor_wrapper_members_str = "";
std::string clear_tensor_wrappers_str = "";
for (const auto& iter : op_base_infos) { for (const auto& iter : op_base_infos) {
const std::map<std::string, std::string>& grad_ins_fwd_slotname_map = const std::map<std::string, std::string>& grad_ins_fwd_slotname_map =
iter.GetGradInsFwdSlotnameMap(); iter.GetGradInsFwdSlotnameMap();
...@@ -2185,6 +2197,13 @@ static std::string GenerateGradNodeHeaderContents( ...@@ -2185,6 +2197,13 @@ static std::string GenerateGradNodeHeaderContents(
SET_TENSOR_WRAPPER_BODY_TEMPLATE, tensor_wrapper_name, SET_TENSOR_WRAPPER_BODY_TEMPLATE, tensor_wrapper_name,
struct_tensor_wrapper_name); struct_tensor_wrapper_name);
const char* CLEAR_TENSOR_WRAPPER_TEMPLATE =
"for (auto tw: %s) {\n"
" tw.clear();\n"
" }\n";
clear_tensor_wrappers_str += paddle::string::Sprintf(
CLEAR_TENSOR_WRAPPER_TEMPLATE, struct_tensor_wrapper_name);
} else { } else {
const char* ATTR_TENSOR_WRAPPER_ARG_TEMPLATE = const char* ATTR_TENSOR_WRAPPER_ARG_TEMPLATE =
"const paddle::experimental::Tensor& %s"; "const paddle::experimental::Tensor& %s";
...@@ -2197,10 +2216,14 @@ static std::string GenerateGradNodeHeaderContents( ...@@ -2197,10 +2216,14 @@ static std::string GenerateGradNodeHeaderContents(
TENSOR_WRAPPER_MEMBER_TEMPLATE, struct_tensor_wrapper_name); TENSOR_WRAPPER_MEMBER_TEMPLATE, struct_tensor_wrapper_name);
const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE = const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE =
"%s = egr::TensorWrapper(%s, %s /*full_reserved*/);"; "%s = egr::TensorWrapper(%s, %s /*full_reserved*/);\n";
tensor_wrapper_body_str = paddle::string::Sprintf( tensor_wrapper_body_str = paddle::string::Sprintf(
SET_TENSOR_WRAPPER_BODY_TEMPLATE, struct_tensor_wrapper_name, SET_TENSOR_WRAPPER_BODY_TEMPLATE, struct_tensor_wrapper_name,
tensor_wrapper_name, full_reserved_str); tensor_wrapper_name, full_reserved_str);
const char* CLEAR_TENSOR_WRAPPER_TEMPLATE = " %s.clear();\n";
clear_tensor_wrappers_str += paddle::string::Sprintf(
CLEAR_TENSOR_WRAPPER_TEMPLATE, struct_tensor_wrapper_name);
} }
std::string full_reserved_signature_str = "bool full_reserved"; std::string full_reserved_signature_str = "bool full_reserved";
const char* SET_TENSOR_WRAPPER_TEMPLATE = const char* SET_TENSOR_WRAPPER_TEMPLATE =
...@@ -2215,8 +2238,8 @@ static std::string GenerateGradNodeHeaderContents( ...@@ -2215,8 +2238,8 @@ static std::string GenerateGradNodeHeaderContents(
std::string grad_node_str = paddle::string::Sprintf( std::string grad_node_str = paddle::string::Sprintf(
GRAD_NODE_TEMPLATE, op_type, op_type, op_type, op_type, op_type, op_type, GRAD_NODE_TEMPLATE, op_type, op_type, op_type, op_type, op_type, op_type,
op_type, op_type, set_tensor_wrappers_str, set_attr_map_str, op_type, clear_tensor_wrappers_str, op_type, set_tensor_wrappers_str,
tensor_wrapper_members_str, attr_members_str); set_attr_map_str, tensor_wrapper_members_str, attr_members_str);
return grad_node_str; return grad_node_str;
} }
......
...@@ -213,7 +213,8 @@ def ParseYamlArgs(string): ...@@ -213,7 +213,8 @@ def ParseYamlArgs(string):
default_value = m.group(3).split("=")[1].strip() if len( default_value = m.group(3).split("=")[1].strip() if len(
m.group(3).split("=")) > 1 else None m.group(3).split("=")) > 1 else None
assert arg_type in yaml_types_mapping.keys(), arg_type assert arg_type in yaml_types_mapping.keys(
), f"The argument type {arg_type} in yaml config is not supported in yaml_types_mapping."
arg_type = yaml_types_mapping[arg_type] arg_type = yaml_types_mapping[arg_type]
arg_name = RemoveSpecialSymbolsInName(arg_name) arg_name = RemoveSpecialSymbolsInName(arg_name)
...@@ -248,7 +249,8 @@ def ParseYamlReturns(string): ...@@ -248,7 +249,8 @@ def ParseYamlReturns(string):
else: else:
ret_type = ret.strip() ret_type = ret.strip()
assert ret_type in yaml_types_mapping.keys(), ret_type assert ret_type in yaml_types_mapping.keys(
), f"The return type {ret_type} in yaml config is not supported in yaml_types_mapping."
ret_type = yaml_types_mapping[ret_type] ret_type = yaml_types_mapping[ret_type]
assert "Tensor" in ret_type assert "Tensor" in ret_type
...@@ -477,6 +479,7 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map, ...@@ -477,6 +479,7 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map,
# SetTensorWrapper Methods & TensorWrapper Members # SetTensorWrapper Methods & TensorWrapper Members
set_tensor_wrapper_methods_str = "" set_tensor_wrapper_methods_str = ""
tensor_wrapper_members_str = "" tensor_wrapper_members_str = ""
clear_tensor_wrapper_str = ""
for tname, (ttype, is_fwd_input, _) in backward_fwd_input_map.items(): for tname, (ttype, is_fwd_input, _) in backward_fwd_input_map.items():
if tname in no_need_buffer_set: if tname in no_need_buffer_set:
no_need_buffer = "true" no_need_buffer = "true"
...@@ -498,6 +501,13 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map, ...@@ -498,6 +501,13 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map,
""" """
tensor_wrapper_members_str += PLAIN_TENSOR_MEMBER_TEMPLATE.format( tensor_wrapper_members_str += PLAIN_TENSOR_MEMBER_TEMPLATE.format(
tensor_wrapper_name) tensor_wrapper_name)
CLEAR_TENSOR_WRAPPERS_TEMPLATE = """
{}.clear();
"""
clear_tensor_wrapper_str += CLEAR_TENSOR_WRAPPERS_TEMPLATE.format(
tensor_wrapper_name)
else: else:
assert IsVectorTensorType(ttype) assert IsVectorTensorType(ttype)
SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = """ SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = """
...@@ -515,6 +525,15 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map, ...@@ -515,6 +525,15 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map,
""" """
tensor_wrapper_members_str += VECTOR_TENSOR_MEMBER_TEMPLATE.format( tensor_wrapper_members_str += VECTOR_TENSOR_MEMBER_TEMPLATE.format(
tensor_wrapper_name) tensor_wrapper_name)
CLEAR_TENSOR_WRAPPERS_TEMPLATE = """
for (auto tw: {}) {
tw.clear();
};
"""
clear_tensor_wrapper_str += CLEAR_TENSOR_WRAPPERS_TEMPLATE.format(
tensor_wrapper_name)
# End: SetTensorWrapper Methods & TensorWrapper Members # End: SetTensorWrapper Methods & TensorWrapper Members
# SetAttributes & Attribute Members # SetAttributes & Attribute Members
...@@ -523,7 +542,7 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map, ...@@ -523,7 +542,7 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map,
for aname, atype, default_val, _ in backward_attrs_list: for aname, atype, default_val, _ in backward_attrs_list:
saved_attr_name = GetSavedName(aname) saved_attr_name = GetSavedName(aname)
SET_ATTR_METHOD_TEMPLATE = """ SET_ATTR_METHOD_TEMPLATE = """
void SetAttribute{}({} {}) {{ void SetAttribute{}({} {}) {{
{} = {}; {} = {};
}} }}
""" """
...@@ -554,25 +573,37 @@ class {} : public egr::GradNodeBase {{ ...@@ -554,25 +573,37 @@ class {} : public egr::GradNodeBase {{
~{}() override = default; ~{}() override = default;
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()( virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
const std::vector<std::vector<paddle::experimental::Tensor>>& grads) override; const std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph = false) override;
std::string name() override {{ return \" {} \"; }} std::string name() override {{ return \" {} \"; }}
void ClearTensorWrappers() override {{
{}
is_tensor_wrappers_cleared = true;
}}
// SetTensorWrapperX, SetTensorWrapperY, ... // SetTensorWrapperX, SetTensorWrapperY, ...
{} {}
// SetAttributes // SetAttributes
{} {}
bool IsTensorWrappersCleared() override {{
return is_tensor_wrappers_cleared;
}}
private: private:
// TensorWrappers // TensorWrappers
{} {}
bool is_tensor_wrappers_cleared = false;
// Attributes // Attributes
{} {}
}}; }};
""" """
node_declaration_str = NODE_DECLARATION_TEMPLATE.format( node_declaration_str = NODE_DECLARATION_TEMPLATE.format(
grad_node_name, grad_node_name, grad_node_name, grad_node_name, grad_node_name, grad_node_name, grad_node_name, grad_node_name,
grad_node_name, set_tensor_wrapper_methods_str, grad_node_name, clear_tensor_wrapper_str,
set_attribute_methods_str, tensor_wrapper_members_str, set_tensor_wrapper_methods_str, set_attribute_methods_str,
attribute_members_str) tensor_wrapper_members_str, attribute_members_str)
return node_declaration_str return node_declaration_str
...@@ -636,7 +667,7 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map, ...@@ -636,7 +667,7 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
grad_api_namespace = f"paddle::experimental" grad_api_namespace = f"paddle::experimental"
FUNCTION_TEMPLATE = """ FUNCTION_TEMPLATE = """
std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(const std::vector<std::vector<paddle::experimental::Tensor>>& grads) {{ std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(const std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph) {{
// Call grad_api function // Call grad_api function
auto grad_api_returns = {}::{}({}); auto grad_api_returns = {}::{}({});
{} {}
......
...@@ -25,7 +25,7 @@ atype_to_parsing_function = { ...@@ -25,7 +25,7 @@ atype_to_parsing_function = {
"std::string": "CastPyArg2String", "std::string": "CastPyArg2String",
"int64_t": "CastPyArg2Long", "int64_t": "CastPyArg2Long",
"float": "CastPyArg2Float", "float": "CastPyArg2Float",
"string": "CastPyArg2String", "std::string": "CastPyArg2String",
"std::vector<bool>": "CastPyArg2Booleans", "std::vector<bool>": "CastPyArg2Booleans",
"std::vector<int>": "CastPyArg2Ints", "std::vector<int>": "CastPyArg2Ints",
"std::vector<long>": "CastPyArg2Longs", "std::vector<long>": "CastPyArg2Longs",
......
...@@ -39,12 +39,21 @@ std::unordered_map<GradNodeBase*, int> getInDegreeMap( ...@@ -39,12 +39,21 @@ std::unordered_map<GradNodeBase*, int> getInDegreeMap(
// Copy nodes // Copy nodes
std::queue<GradNodeBase*> queue = init_queue; std::queue<GradNodeBase*> queue = init_queue;
std::unordered_set<GradNodeBase*> visited; std::unordered_set<GradNodeBase*> visited;
size_t potential_startup_ops_cnt = queue.size();
size_t cnt = 0;
// Visit each node exactly once in any order // Visit each node exactly once in any order
while (!queue.empty()) { while (!queue.empty()) {
GradNodeBase* node = queue.front(); GradNodeBase* node = queue.front();
queue.pop(); queue.pop();
if (cnt < potential_startup_ops_cnt) {
if (!node_in_degree_map.count(node)) {
node_in_degree_map[node] = 0;
}
cnt += 1;
}
if (visited.count(node)) { if (visited.count(node)) {
continue; continue;
} }
...@@ -76,23 +85,248 @@ std::unordered_map<GradNodeBase*, int> getInDegreeMap( ...@@ -76,23 +85,248 @@ std::unordered_map<GradNodeBase*, int> getInDegreeMap(
return node_in_degree_map; return node_in_degree_map;
} }
void RunBackward(const std::vector<paddle::experimental::Tensor>& tensors, // Remove some nodes those doesn't need to be
const std::vector<paddle::experimental::Tensor>& grad_tensors, // stored in potential_stop_nodes、potential_startup_nodes
bool retain_graph) { void UpdateGraphInfo(
paddle::platform::RecordEvent backward_record_event( std::unordered_map<GradNodeBase*, AutogradMeta*>*
"backward", paddle::platform::TracerEventType::Operator, 1); target_nodes_inputmeta_map,
std::unordered_map<GradNodeBase*, std::unordered_set<GradNodeBase*>>*
depending_nodes,
std::unordered_set<GradNodeBase*>* potential_stop_nodes,
std::unordered_set<GradNodeBase*>* potential_startup_nodes) {
// Updated potential_sotp_nodes by depending_nodes,
// make sure the path from root to target_node is ok
std::unordered_set<GradNodeBase*> _startup_ops;
VLOG(6) << "Running in UpdateGraphInfo";
std::queue<GradNodeBase*> queue;
for (auto& target_nodes_inputmeta_pair : *target_nodes_inputmeta_map) {
queue.emplace(target_nodes_inputmeta_pair.first);
}
while (!queue.empty()) {
auto* target_node = queue.front();
queue.pop();
if (!(*depending_nodes)[target_node].empty()) {
auto precedding_nodes = (*depending_nodes)[target_node];
for (auto pre_nodes : precedding_nodes) {
queue.emplace(pre_nodes);
if (potential_stop_nodes->find(pre_nodes) !=
potential_stop_nodes->end()) {
potential_stop_nodes->erase(pre_nodes);
}
}
} else { // startup_ops have no precedding nodes
VLOG(6) << "Emplace _startup_ops";
_startup_ops.emplace(target_node);
}
}
// Purify potential_startup_nodes again, remove some
// potential startup_nodes that unreach to input target nodes
if (!_startup_ops.empty()) {
std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
for (auto node : *potential_startup_nodes) {
if (_startup_ops.count(node) == 0) {
VLOG(6) << "Set up potential_startup_nodes_to_be_erased";
potential_startup_nodes_to_be_erased.emplace(node);
}
}
if (!potential_startup_nodes_to_be_erased.empty()) {
for (auto node : potential_startup_nodes_to_be_erased) {
VLOG(6) << "Erase nodes in potential_startup_nodes_to_be_erased";
potential_startup_nodes->erase(node);
}
}
}
}
// Get Graph Info Betweent input target gradnode and outputs,
// record depending_nodes、 potential_stop_nodes、potential_startup_nodes
void GetGraphInfoBetweenTargets(
const std::queue<GradNodeBase*>& init_queue,
std::unordered_map<GradNodeBase*, AutogradMeta*>*
input_target_nodes_inputmeta_map,
std::unordered_map</*child node*/ GradNodeBase*,
/*father nodes*/ std::unordered_set<GradNodeBase*>>*
depending_nodes,
std::unordered_set<GradNodeBase*>* potential_stop_nodes,
std::unordered_set<GradNodeBase*>* potential_startup_nodes) {
if (input_target_nodes_inputmeta_map->empty()) return;
VLOG(6) << "Runing In GetGraphInfoBetweenTargets";
// Calculate in_degree for each node
std::unordered_map<GradNodeBase*, int> node_in_degree_map;
// Copy nodes
std::queue<GradNodeBase*> queue = init_queue;
std::unordered_set<GradNodeBase*> visited;
// Visit each node exactly once in any order
while (!queue.empty()) {
GradNodeBase* node = queue.front();
queue.pop();
if (visited.count(node)) {
continue;
}
visited.insert(node);
// Check node is target_nodes or not, if node is not target_node,
// all the next_node will be marked in potential_stop_nodes
bool is_potential_stop_nodes =
input_target_nodes_inputmeta_map->count(node);
// Find and append next nodes
const std::vector<std::vector<Edge>>& edges = node->GetEdges();
for (const auto& edge_list : edges) {
for (const Edge& edge : edge_list) {
GradNodeBase* next_node = edge.GetMutableGradNode().get();
// Next node could be nullptr if it is leaf tensor with no
// AccumulationNode attached
// Or it could also originated from dispensable inputs
if (!next_node) continue;
// if node not in input_target_nodes,
// all the next_nodes of current node will be inserted to
// potential_stop_node
if (is_potential_stop_nodes) {
potential_stop_nodes->emplace(next_node);
}
// Update in_degree
if (!node_in_degree_map.count(next_node))
node_in_degree_map[next_node] = 0;
node_in_degree_map[next_node]++;
// Record depending relationship
(*depending_nodes)[next_node].emplace(node);
queue.push(next_node);
}
}
}
// Update Graph Info, remove some stop_node in potential_stop_nodes
UpdateGraphInfo(input_target_nodes_inputmeta_map, depending_nodes,
potential_stop_nodes, potential_startup_nodes);
}
void GetTargetNodesInfo(const std::vector<paddle::experimental::Tensor>& inputs,
std::unordered_map<GradNodeBase*, AutogradMeta*>*
target_nodes_inputmeta_map) {
VLOG(6) << "Running in GetTargetNodesInfo";
if (!inputs.empty()) {
VLOG(6) << "Inputs are not empty";
size_t num_inputs = inputs.size();
for (size_t i = 0; i < num_inputs; i++) {
AutogradMeta* auto_grad_meta =
EagerUtils::unsafe_autograd_meta(inputs[i]);
auto target_node = auto_grad_meta->GetMutableGradNode().get();
PADDLE_ENFORCE_NOT_NULL(target_node,
paddle::platform::errors::Fatal(
"There is no grad op for input:%d or it's"
"stop_gradient=True",
i));
(*target_nodes_inputmeta_map)[target_node] = auto_grad_meta;
}
}
}
std::vector<paddle::experimental::Tensor> GetResults(
const std::vector<paddle::experimental::Tensor>& inputs,
std::unordered_map<GradNodeBase*, paddle::experimental::Tensor>*
results_map,
bool allow_unused, bool create_graph) {
VLOG(6) << "Running in GetResults";
if (inputs.empty()) return {};
std::vector<paddle::experimental::Tensor> results;
results.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
auto& input = inputs[i];
AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(input);
auto target_node = auto_grad_meta->GetMutableGradNode().get();
auto iter = results_map->find(target_node);
if (iter != results_map->end()) {
// set StopGradient = !create_graph
AutogradMeta* tensor_auto_grad_meta =
EagerUtils::autograd_meta(&(iter->second));
tensor_auto_grad_meta->SetStopGradient(!create_graph);
results.emplace_back(iter->second);
} else {
PADDLE_ENFORCE_EQ(allow_unused, true,
paddle::platform::errors::InvalidArgument(
"The %d-th input does not appear in the backward "
"graph. Please check the input variable or set "
"allow_unused=True to get None result.",
i));
results.emplace_back();
}
}
return results;
}
// Enforce GradNode has TensorWrappers as Input
void EnforceGradNodeHasInput(GradNodeBase* node) {
VLOG(6) << "Running in EnforceGradNodeHasInput";
PADDLE_ENFORCE_NE(
node->IsTensorWrappersCleared(), true,
paddle::platform::errors::Fatal(
"The TensorWrappers of %s do not exist. This may be because:\n"
"You calculate backward twice for the same subgraph without "
"setting retain_graph=True. Please set retain_graph=True in the "
"first backward/grad call.\n",
node->name()));
}
// Purify potential_startup_nodes, remove nodes those are the same as
// input_target_nodes
void PurifyPotentialStartUpNodes(
std::unordered_set<GradNodeBase*>* potential_startup_nodes,
std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>*
input_target_nodes_inputmeta_map) {
VLOG(6) << "Running in PurifyPotentialStartUpNodes";
if (input_target_nodes_inputmeta_map->empty()) return;
std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
for (auto startup_op : *potential_startup_nodes) {
auto iter = input_target_nodes_inputmeta_map->find(startup_op);
if (iter != input_target_nodes_inputmeta_map->end()) {
potential_startup_nodes_to_be_erased.emplace(iter->first);
}
}
if (!potential_startup_nodes_to_be_erased.empty()) {
for (auto nodes : potential_startup_nodes_to_be_erased) {
potential_startup_nodes->erase(nodes);
}
}
}
std::vector<paddle::experimental::Tensor> RunBackward(
const std::vector<paddle::experimental::Tensor>& tensors, // output
const std::vector<paddle::experimental::Tensor>& grad_tensors,
bool retain_graph, bool create_graph = false,
const std::vector<paddle::experimental::Tensor>& inputs = {},
bool allow_unused = false,
const std::vector<paddle::experimental::Tensor>& no_grad_vars = {}) {
VLOG(6) << "Start Backward"; VLOG(6) << "Start Backward";
// *Gradient Hook should happen at node-level // *Gradient Hook should happen at node-level
// *Inplace version check should perform at node-level // *Inplace version check should perform at node-level
// *Cross-batch accumulation happens at forward pass // *Cross-batch accumulation happens at forward pass
std::unordered_map<GradNodeBase*, AutogradMeta*>
no_grad_var_nodes_inputmeta_map;
// Get no_grad_vars's GradNodes and InputMeta Info
GetTargetNodesInfo(no_grad_vars, &no_grad_var_nodes_inputmeta_map);
/* --- Initialization --- */ /* --- Initialization --- */
// 1. Init queue with starting nodes // 1. Init queue with starting nodes
// 2. Prepare initial input buffers // 2. Prepare initial input buffers
std::queue<GradNodeBase*> queue; std::queue<GradNodeBase*> queue;
std::unordered_map<GradNodeBase*, std::unique_ptr<GradTensorHolder>> std::unordered_map<GradNodeBase*, std::unique_ptr<GradTensorHolder>>
node_input_buffers_dict; node_input_buffers_dict;
std::unordered_set<GradNodeBase*> potential_startup_nodes;
for (size_t i = 0; i < tensors.size(); i++) { for (size_t i = 0; i < tensors.size(); i++) {
const paddle::experimental::Tensor& tensor = tensors[i]; const paddle::experimental::Tensor& tensor = tensors[i];
...@@ -132,8 +366,17 @@ void RunBackward(const std::vector<paddle::experimental::Tensor>& tensors, ...@@ -132,8 +366,17 @@ void RunBackward(const std::vector<paddle::experimental::Tensor>& tensors,
"size = 0 or same size as tensors")); "size = 0 or same size as tensors"));
// Feed given tensor if it's provided // Feed given tensor if it's provided
VLOG(6) << "Fill grad input tensor " << i << "with give grad tensor"; VLOG(6) << "Fill grad input tensor " << i << "with give grad tensor";
node_input_buffers_dict[grad_node]->add(
input_info.first, input_info.second, grad_tensors[i]); if (grad_tensors[i].is_initialized()) {
// Deep copy
paddle::experimental::Tensor tmp_tensor;
tmp_tensor.copy_(grad_tensors[i], grad_tensors[i].inner_place(), true);
node_input_buffers_dict[grad_node]->add(input_info.first,
input_info.second, tmp_tensor);
} else {
node_input_buffers_dict[grad_node]->add(
input_info.first, input_info.second, grad_tensors[i]);
}
} else { } else {
VLOG(6) << "Fill grad input tensor " << i << " with 1.0"; VLOG(6) << "Fill grad input tensor " << i << " with 1.0";
...@@ -146,8 +389,9 @@ void RunBackward(const std::vector<paddle::experimental::Tensor>& tensors, ...@@ -146,8 +389,9 @@ void RunBackward(const std::vector<paddle::experimental::Tensor>& tensors,
input_info.first, input_info.second, tensor, true /*fill_one=true*/); input_info.first, input_info.second, tensor, true /*fill_one=true*/);
} }
// Prepare queue // Prepare queue, potential startup_nodes
queue.push(grad_node); queue.push(grad_node);
potential_startup_nodes.emplace(grad_node);
} }
VLOG(6) << "Update In degree Map for backward"; VLOG(6) << "Update In degree Map for backward";
...@@ -155,25 +399,74 @@ void RunBackward(const std::vector<paddle::experimental::Tensor>& tensors, ...@@ -155,25 +399,74 @@ void RunBackward(const std::vector<paddle::experimental::Tensor>& tensors,
std::unordered_map<GradNodeBase*, int> node_in_degree_map = std::unordered_map<GradNodeBase*, int> node_in_degree_map =
getInDegreeMap(queue); getInDegreeMap(queue);
// Get input's GradNodes and InputMeta Info
std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>
input_target_nodes_inputmeta_map;
GetTargetNodesInfo(inputs, &input_target_nodes_inputmeta_map);
// Purify potential_startup_ops, remove those nodes that are the same as
// input_target_nodes
PurifyPotentialStartUpNodes(&potential_startup_nodes,
&input_target_nodes_inputmeta_map);
// Get Graph Info Betweent input target gradnode and outputs
// Record the depending_nodes and potential_stop_nodes
std::unordered_map<GradNodeBase* /* child node */,
std::unordered_set<GradNodeBase*> /* father node */>
depending_nodes;
std::unordered_set<GradNodeBase*> potential_stop_nodes;
// std::unordered_set<GradNodeBase*> startup_ops;
GetGraphInfoBetweenTargets(queue, &input_target_nodes_inputmeta_map,
&depending_nodes, &potential_stop_nodes,
&potential_startup_nodes);
// ready_queue store all startup nodes
std::queue<GradNodeBase*> ready_queue;
// startup op's indegree should be 0
for (auto node : potential_startup_nodes) {
if (node_in_degree_map[node] == 0) {
ready_queue.emplace(node);
}
}
VLOG(1) << " startup_ops' size is :" << ready_queue.size();
std::unordered_map<GradNodeBase*, paddle::experimental::Tensor> results_map;
// read_queue is empty only when 1.input equals to output. 2.input can not
// reach to output.
if (ready_queue.size() == 0) {
for (auto input_target_node : input_target_nodes_inputmeta_map) {
// out rank_info of forward op
auto rank_info = input_target_node.second->OutRankInfo();
if (node_input_buffers_dict[input_target_node.first]) {
auto& target_result =
node_input_buffers_dict[input_target_node.first]
->Buffers()[rank_info.first][rank_info.second];
// save the target result
results_map[input_target_node.first] = target_result;
}
}
}
/* --- Topological Visit --- */ /* --- Topological Visit --- */
// 1. Pop queue // 1. Pop queue
// 2. Run node // 2. Run node
// |- Check and capture target result
// |- node(grads) // |- node(grads)
// |- Prepare for next node // |- Prepare for next node
// 3. Update queue // 3. Update queue
VLOG(6) << "Run Backward"; VLOG(6) << "Run Backward";
while (!queue.empty()) { while (!ready_queue.empty()) {
GradNodeBase* node = queue.front(); GradNodeBase* node = ready_queue.front();
VLOG(6) << "Running GradNode:" << node->name();
ready_queue.pop();
paddle::platform::RecordEvent node_record_event( paddle::platform::RecordEvent node_record_event(
std::string(typeid(*node).name()) + " grad_node", std::string(typeid(*node).name()) + " grad_node",
paddle::platform::TracerEventType::Operator, 1); paddle::platform::TracerEventType::Operator, 1);
if (queue.size() > 1 && node_in_degree_map[node] != 0) {
queue.pop();
continue;
}
queue.pop();
// Run node: This is where Hook happens // Run node: This is where Hook happens
PADDLE_ENFORCE( PADDLE_ENFORCE(
node_input_buffers_dict.count(node), node_input_buffers_dict.count(node),
...@@ -184,10 +477,45 @@ void RunBackward(const std::vector<paddle::experimental::Tensor>& tensors, ...@@ -184,10 +477,45 @@ void RunBackward(const std::vector<paddle::experimental::Tensor>& tensors,
std::unique_ptr<GradTensorHolder> node_input_buffer = std::unique_ptr<GradTensorHolder> node_input_buffer =
std::move(node_input_buffers_dict[node]); std::move(node_input_buffers_dict[node]);
// get target grad_var from node_input_buffer by inputmeta
if (input_target_nodes_inputmeta_map.find(node) !=
input_target_nodes_inputmeta_map.end()) {
VLOG(6) << "Get target result by by inputmeta";
// out rank_info of forward op
auto rank_info = input_target_nodes_inputmeta_map[node]->OutRankInfo();
// rank_info is a pair, first means slot_id, second means rank.
auto& target_result =
node_input_buffer->Buffers()[rank_info.first][rank_info.second];
// save the target result
results_map[node] = target_result;
}
// no_grad_vars
if (no_grad_var_nodes_inputmeta_map.find(node) !=
no_grad_var_nodes_inputmeta_map.end()) {
VLOG(6) << "Change the input buffer[slot][rank] by Zeros";
auto rank_info = no_grad_var_nodes_inputmeta_map[node]->OutRankInfo();
node_input_buffer->SetBufferSlotRankZeros(rank_info.first,
rank_info.second);
}
VLOG(6) << "Running GradNode:" << node->name();
// check input
EnforceGradNodeHasInput(node);
VLOG(6) << "Run Backward Kernel with GradTensorHolder"; VLOG(6) << "Run Backward Kernel with GradTensorHolder";
// Run Pre Backward Node and get outputs // Run Pre Backward Node and get outputs
std::vector<std::vector<paddle::experimental::Tensor>> grad_output_tensors = std::vector<std::vector<paddle::experimental::Tensor>> grad_output_tensors =
(*node)(node_input_buffer->Buffers()); (*node)(node_input_buffer->Buffers(), create_graph);
// retain_grad or not
if (!retain_graph) {
VLOG(6)
<< "retain_graph is false, need to clear the TensorWrapper of nodes.";
node->ClearTensorWrappers();
}
// TODO(jiabin): Should we erase it or find a more efficient way. // TODO(jiabin): Should we erase it or find a more efficient way.
node_input_buffers_dict.erase(node); node_input_buffers_dict.erase(node);
...@@ -252,18 +580,44 @@ void RunBackward(const std::vector<paddle::experimental::Tensor>& tensors, ...@@ -252,18 +580,44 @@ void RunBackward(const std::vector<paddle::experimental::Tensor>& tensors,
// Update queue // Update queue
node_in_degree_map[next_node]--; node_in_degree_map[next_node]--;
PADDLE_ENFORCE( PADDLE_ENFORCE(
node_in_degree_map[next_node] >= 0, node_in_degree_map[next_node] >= 0,
paddle::platform::errors::Fatal( paddle::platform::errors::Fatal(
"Detected in-degree value smaller than zero. For Node: %s" "Detected in-degree value smaller than zero. For Node: %s"
"Node's in-degree cannot be negative", "Node's in-degree cannot be negative",
next_node->name())); next_node->name()));
if (node_in_degree_map[next_node] == 0) {
queue.emplace(std::move(next_node)); bool is_potential_stop_node = potential_stop_nodes.count(next_node);
if (node_in_degree_map[next_node] == 0 && !is_potential_stop_node) {
ready_queue.emplace(std::move(next_node));
} }
} }
} }
} }
return GetResults(inputs, &results_map, allow_unused, create_graph);
} }
void Backward(
const std::vector<paddle::experimental::Tensor>& tensors, // output
const std::vector<paddle::experimental::Tensor>& grad_tensors,
bool retain_graph) {
VLOG(6) << "Run in Backward";
paddle::platform::RecordEvent backward_record_event(
"backward", paddle::platform::TracerEventType::Operator, 1);
RunBackward(tensors, grad_tensors, retain_graph);
}
std::vector<paddle::experimental::Tensor> Grad(
const std::vector<paddle::experimental::Tensor>& tensors, // output
const std::vector<paddle::experimental::Tensor>& inputs,
const std::vector<paddle::experimental::Tensor>& grad_tensors,
bool retain_graph, bool create_graph, bool only_inputs, bool allow_unused,
const std::vector<paddle::experimental::Tensor>& no_grad_vars) {
VLOG(6) << "Run in Grad";
return RunBackward(tensors, grad_tensors, retain_graph, create_graph, inputs,
allow_unused, no_grad_vars);
}
} // namespace egr } // namespace egr
...@@ -19,12 +19,20 @@ ...@@ -19,12 +19,20 @@
namespace egr { namespace egr {
// run_backward(): // Backward():
// tensors corresponds to those lived in the backward graph // tensors corresponds to those lived in the backward graph
// each grad_tensors[i] keeps the value for its corresponding tensors[i] // each grad_tensors[i] keeps the value for its corresponding tensors[i]
void RunBackward(const std::vector<paddle::experimental::Tensor> &tensors, void Backward(const std::vector<paddle::experimental::Tensor>& tensors,
const std::vector<paddle::experimental::Tensor> &grad_tensors, const std::vector<paddle::experimental::Tensor>& grad_tensors,
bool retain_graph = false); bool retain_graph = false);
std::vector<paddle::experimental::Tensor> Grad(
const std::vector<paddle::experimental::Tensor>& tensors,
const std::vector<paddle::experimental::Tensor>& inputs,
const std::vector<paddle::experimental::Tensor>& grad_tensors = {},
bool retain_graph = false, bool create_graph = false,
bool only_inputs = false, bool allow_unused = false,
const std::vector<paddle::experimental::Tensor>& no_grad_vars = {});
// Reserved for gradient() // Reserved for gradient()
......
...@@ -20,8 +20,8 @@ ...@@ -20,8 +20,8 @@
namespace egr { namespace egr {
std::vector<std::vector<paddle::experimental::Tensor>> RunCustomOpNode:: std::vector<std::vector<paddle::experimental::Tensor>> RunCustomOpNode::
operator()( operator()(const std::vector<std::vector<paddle::experimental::Tensor>>& grads,
const std::vector<std::vector<paddle::experimental::Tensor>>& grads) { bool create_graph) {
paddle::CustomOpKernelContext ctx; paddle::CustomOpKernelContext ctx;
auto grad_inputs_name = paddle::framework::OpMetaInfoHelper::GetInputs( auto grad_inputs_name = paddle::framework::OpMetaInfoHelper::GetInputs(
egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]); egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]);
......
...@@ -37,8 +37,8 @@ class RunCustomOpNode : public GradNodeBase { ...@@ -37,8 +37,8 @@ class RunCustomOpNode : public GradNodeBase {
// Functor: perform backward computations // Functor: perform backward computations
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()( virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
const std::vector<std::vector<paddle::experimental::Tensor>>& grads) const std::vector<std::vector<paddle::experimental::Tensor>>& grads,
override; bool create_graph) override;
std::string name() { std::string name() {
return paddle::string::Sprintf("RunCustomOpNode: %s_grad", op_type_); return paddle::string::Sprintf("RunCustomOpNode: %s_grad", op_type_);
...@@ -62,6 +62,12 @@ class RunCustomOpNode : public GradNodeBase { ...@@ -62,6 +62,12 @@ class RunCustomOpNode : public GradNodeBase {
return res; return res;
} }
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
}
void SetAttrs(const std::vector<paddle::any>& attr) { attrs_ = attr; } void SetAttrs(const std::vector<paddle::any>& attr) { attrs_ = attr; }
public: public:
......
...@@ -95,8 +95,12 @@ class GradNodeBase { ...@@ -95,8 +95,12 @@ class GradNodeBase {
* is better choice to fit this format. * is better choice to fit this format.
* **/ * **/
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()( virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
const std::vector<std::vector<paddle::experimental::Tensor>>& grads) = 0; const std::vector<std::vector<paddle::experimental::Tensor>>& grads,
bool create_graph = false) = 0;
virtual void ClearTensorWrappers() = 0;
virtual bool IsTensorWrappersCleared() = 0;
/** /**
* AddEdges is designed to set input tensors' backward Node as current * AddEdges is designed to set input tensors' backward Node as current
* node's Edges. * node's Edges.
......
...@@ -21,6 +21,11 @@ ...@@ -21,6 +21,11 @@
namespace egr { namespace egr {
void GradTensorHolder::SetBufferSlotRankZeros(size_t slot_id, size_t rank) {
buffer_[slot_id][rank] =
paddle::experimental::zeros_like(buffer_[slot_id][rank]);
}
void GradTensorHolder::add(size_t slot_id, size_t rank, void GradTensorHolder::add(size_t slot_id, size_t rank,
const paddle::experimental::Tensor& t, const paddle::experimental::Tensor& t,
bool fill_one) { bool fill_one) {
......
...@@ -56,6 +56,8 @@ class GradTensorHolder { ...@@ -56,6 +56,8 @@ class GradTensorHolder {
return buffer_; return buffer_;
} }
void SetBufferSlotRankZeros(size_t slot_id, size_t rank);
private: private:
std::vector<std::vector<paddle::experimental::Tensor>> buffer_; std::vector<std::vector<paddle::experimental::Tensor>> buffer_;
}; };
......
...@@ -98,6 +98,8 @@ class TensorWrapper { ...@@ -98,6 +98,8 @@ class TensorWrapper {
} }
} }
void clear() { intermidiate_tensor_.reset(); }
private: private:
bool full_reserved_ = false; bool full_reserved_ = false;
std::pair<size_t, size_t> out_rank_info_; std::pair<size_t, size_t> out_rank_info_;
......
...@@ -17,6 +17,14 @@ ...@@ -17,6 +17,14 @@
#include "paddle/fluid/eager/eager_tensor.h" #include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/core/kernel_registry.h"
PD_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(copy_sr, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_DECLARE_KERNEL(copy, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(copy_sr, GPU, ALL_LAYOUT);
#endif
namespace eager_test { namespace eager_test {
using AbstractAutogradMeta = paddle::experimental::AbstractAutogradMeta; using AbstractAutogradMeta = paddle::experimental::AbstractAutogradMeta;
...@@ -151,5 +159,50 @@ TEST(EagerVariable, Constructor) { ...@@ -151,5 +159,50 @@ TEST(EagerVariable, Constructor) {
CHECK_EQ(dt3_tmp_ptr[1], 10.0f); CHECK_EQ(dt3_tmp_ptr[1], 10.0f);
t4.reset(); t4.reset();
CHECK(t4.defined() == false); CHECK(t4.defined() == false);
VLOG(6) << "Check Tensor Copy_";
std::vector<int64_t> rows = {1, 2};
std::vector<int64_t> dims = {2};
paddle::experimental::Tensor t7(std::make_shared<phi::SelectedRows>(rows, 2));
std::dynamic_pointer_cast<phi::SelectedRows>(t7.impl())
->mutable_value()
->Resize(phi::make_ddim(dims));
auto* dt7_tmp_ptr = std::dynamic_pointer_cast<phi::SelectedRows>(t7.impl())
->mutable_value()
->mutable_data<float>(paddle::platform::CPUPlace());
dt7_tmp_ptr[0] = 6.0f;
dt7_tmp_ptr[1] = 11.0f;
paddle::experimental::Tensor t8;
paddle::experimental::Tensor t5;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
paddle::experimental::Tensor t6;
paddle::experimental::Tensor t9;
VLOG(6) << "Check Tensor Copy_ Selected Rows";
t8.copy_(t7, paddle::platform::CUDAPlace(0), true);
t9.copy_(t8, paddle::platform::CPUPlace(), true);
auto* dt9_tmp_ptr = std::dynamic_pointer_cast<phi::SelectedRows>(t9.impl())
->value()
.data<float>();
CHECK_EQ(dt9_tmp_ptr[0], 6.0f);
CHECK_EQ(dt9_tmp_ptr[1], 11.0f);
CHECK_EQ(std::dynamic_pointer_cast<phi::SelectedRows>(t9.impl())->height(),
2);
VLOG(6) << "Check Tensor Copy_ Dense Tensor";
t5.copy_(t3, paddle::platform::CUDAPlace(0), true);
t6.copy_(t5, paddle::platform::CPUPlace(), true);
auto* dt6_tmp_ptr =
std::dynamic_pointer_cast<phi::DenseTensor>(t6.impl())->data<float>();
CHECK_EQ(dt6_tmp_ptr[0], 5.0f);
CHECK_EQ(dt6_tmp_ptr[1], 10.0f);
#else
t5.copy_(t3, paddle::platform::CPUPlace(), true);
auto* dt5_tmp_ptr =
std::dynamic_pointer_cast<phi::DenseTensor>(t5.impl())->data<float>();
CHECK_EQ(dt5_tmp_ptr[0], 5.0f);
CHECK_EQ(dt5_tmp_ptr[1], 10.0f);
#endif
VLOG(6) << "Finish"; VLOG(6) << "Finish";
} }
...@@ -32,8 +32,8 @@ class GradTestNode : public egr::GradNodeBase { ...@@ -32,8 +32,8 @@ class GradTestNode : public egr::GradNodeBase {
GradTestNode() : GradNodeBase() { val_ = 1.0; } GradTestNode() : GradNodeBase() { val_ = 1.0; }
std::string name() override { return "GradTestNode"; } std::string name() override { return "GradTestNode"; }
std::vector<std::vector<paddle::experimental::Tensor>> operator()( std::vector<std::vector<paddle::experimental::Tensor>> operator()(
const std::vector<std::vector<paddle::experimental::Tensor>>& grads) const std::vector<std::vector<paddle::experimental::Tensor>>& grads,
override { bool create_graph = false) override {
val_ = std::dynamic_pointer_cast<phi::DenseTensor>(grads[0][0].impl()) val_ = std::dynamic_pointer_cast<phi::DenseTensor>(grads[0][0].impl())
->data<float>()[0]; ->data<float>()[0];
phi::DenseTensorMeta meta = phi::DenseTensorMeta meta =
...@@ -49,6 +49,11 @@ class GradTestNode : public egr::GradNodeBase { ...@@ -49,6 +49,11 @@ class GradTestNode : public egr::GradNodeBase {
std::vector<std::vector<paddle::experimental::Tensor>> res = {{et1}}; std::vector<std::vector<paddle::experimental::Tensor>> res = {{et1}};
return res; return res;
} }
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
}
float val_; float val_;
}; };
} // namespace eager_test } // namespace eager_test
...@@ -58,7 +58,7 @@ void benchmark_eager_scale(const paddle::experimental::Tensor& tensor, ...@@ -58,7 +58,7 @@ void benchmark_eager_scale(const paddle::experimental::Tensor& tensor,
} }
std::vector<paddle::experimental::Tensor> target_tensors = {input_tensor}; std::vector<paddle::experimental::Tensor> target_tensors = {input_tensor};
RunBackward(target_tensors, {}); Backward(target_tensors, {});
if (accuracy_check) { if (accuracy_check) {
// Examine Forward Grad (w.r.t max_num_runs = 10) // Examine Forward Grad (w.r.t max_num_runs = 10)
...@@ -80,7 +80,7 @@ void benchmark_eager_matmul(const paddle::experimental::Tensor& X, ...@@ -80,7 +80,7 @@ void benchmark_eager_matmul(const paddle::experimental::Tensor& X,
} }
std::vector<paddle::experimental::Tensor> target_tensors = {input_tensor0}; std::vector<paddle::experimental::Tensor> target_tensors = {input_tensor0};
RunBackward(target_tensors, {}); Backward(target_tensors, {});
if (accuracy_check) { if (accuracy_check) {
// Examine Forward Grad (w.r.t max_num_runs = 2) // Examine Forward Grad (w.r.t max_num_runs = 2)
...@@ -106,7 +106,7 @@ void benchmark_eager_intermediate_matmul(const paddle::experimental::Tensor& X, ...@@ -106,7 +106,7 @@ void benchmark_eager_intermediate_matmul(const paddle::experimental::Tensor& X,
} }
std::vector<paddle::experimental::Tensor> target_tensors = {input_tensor0}; std::vector<paddle::experimental::Tensor> target_tensors = {input_tensor0};
RunBackward(target_tensors, {}); Backward(target_tensors, {});
if (accuracy_check) { if (accuracy_check) {
// Examine Forward Grad (w.r.t max_num_runs = 2) // Examine Forward Grad (w.r.t max_num_runs = 2)
...@@ -137,7 +137,7 @@ void benchmark_eager_intermediate_mlp( ...@@ -137,7 +137,7 @@ void benchmark_eager_intermediate_mlp(
reduce_sum_dygraph_function(input0, {{"reduce_all", true}}); reduce_sum_dygraph_function(input0, {{"reduce_all", true}});
std::vector<paddle::experimental::Tensor> target_tensors = {Out}; std::vector<paddle::experimental::Tensor> target_tensors = {Out};
RunBackward(target_tensors, {}); Backward(target_tensors, {});
if (accuracy_check) { if (accuracy_check) {
std::unordered_map<std::string, float> result = std::unordered_map<std::string, float> result =
......
...@@ -5,6 +5,7 @@ cc_test(test_egr_task_backward SRCS backward_test.cc DEPS ${eager_deps} ${fluid_ ...@@ -5,6 +5,7 @@ cc_test(test_egr_task_backward SRCS backward_test.cc DEPS ${eager_deps} ${fluid_
cc_test(test_egr_task_hook SRCS hook_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node) cc_test(test_egr_task_hook SRCS hook_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node)
cc_test(test_egr_task_cross_batch SRCS cross_batch_accumulation_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node) cc_test(test_egr_task_cross_batch SRCS cross_batch_accumulation_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node)
cc_test(test_egr_task_fwd_bwd_joint SRCS fwd_bwd_joint_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node) cc_test(test_egr_task_fwd_bwd_joint SRCS fwd_bwd_joint_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node)
cc_test(test_egr_task_grad SRCS grad_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node)
if(NOT ((NOT WITH_PYTHON) AND ON_INFER)) if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
cc_test(test_egr_task_hook_intermidiate SRCS hook_test_intermidiate.cc DEPS ${eager_deps} ${fluid_deps} ${generated_deps} dygraph_node) cc_test(test_egr_task_hook_intermidiate SRCS hook_test_intermidiate.cc DEPS ${eager_deps} ${fluid_deps} ${generated_deps} dygraph_node)
......
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT);
namespace egr { namespace egr {
...@@ -79,7 +80,7 @@ TEST(Backward, SingleNodeEmptyGrad) { ...@@ -79,7 +80,7 @@ TEST(Backward, SingleNodeEmptyGrad) {
} }
std::vector<paddle::experimental::Tensor> outs = {target_tensor}; std::vector<paddle::experimental::Tensor> outs = {target_tensor};
// Run Backward // Run Backward
RunBackward(outs, {}); Backward(outs, {});
// Check Output Value // Check Output Value
eager_test::CompareGradTensorWithValue<float>(leaf_tensor, 5.0); eager_test::CompareGradTensorWithValue<float>(leaf_tensor, 5.0);
...@@ -138,7 +139,7 @@ TEST(Backward, SingleNodeCustomGrad) { ...@@ -138,7 +139,7 @@ TEST(Backward, SingleNodeCustomGrad) {
} }
// Run Backward // Run Backward
RunBackward(target_tensors, grad_tensors); Backward(target_tensors, grad_tensors);
// Check Output Value // Check Output Value
eager_test::CompareGradTensorWithValue<float>(leaf_tensor, 50.0); eager_test::CompareGradTensorWithValue<float>(leaf_tensor, 50.0);
...@@ -211,7 +212,7 @@ TEST(Backward, LinearNodes) { ...@@ -211,7 +212,7 @@ TEST(Backward, LinearNodes) {
} }
// Use Empty Grad Tensor // Use Empty Grad Tensor
RunBackward(target_tensors, {}); Backward(target_tensors, {});
// Check Output Value // Check Output Value
eager_test::CompareGradTensorWithValue<float>(leaf_tensor, 50.0); eager_test::CompareGradTensorWithValue<float>(leaf_tensor, 50.0);
...@@ -315,7 +316,7 @@ TEST(Backward, WithAccumulation) { ...@@ -315,7 +316,7 @@ TEST(Backward, WithAccumulation) {
node2_ptr->AddEdges(&res2, 0); node2_ptr->AddEdges(&res2, 0);
} }
RunBackward(target_tensors, grad_tensors); Backward(target_tensors, grad_tensors);
eager_test::CompareGradTensorWithValue<float>(leaf_tensor, 2500.0); eager_test::CompareGradTensorWithValue<float>(leaf_tensor, 2500.0);
} }
......
...@@ -71,12 +71,12 @@ TEST(CrossBatchAccumulation, SingleScaleNode) { ...@@ -71,12 +71,12 @@ TEST(CrossBatchAccumulation, SingleScaleNode) {
std::vector<egr::AutogradMeta*> res = {meta}; std::vector<egr::AutogradMeta*> res = {meta};
scale_node_ptr->AddEdges(&res, 0); scale_node_ptr->AddEdges(&res, 0);
RunBackward(target_tensors, {}); Backward(target_tensors, {});
eager_test::CompareGradTensorWithValue<float>(target_tensor, 1.0); eager_test::CompareGradTensorWithValue<float>(target_tensor, 1.0);
eager_test::CompareGradTensorWithValue<float>(leaf_tensor, 5.0); eager_test::CompareGradTensorWithValue<float>(leaf_tensor, 5.0);
RunBackward(target_tensors, {}); Backward(target_tensors, {});
eager_test::CompareGradTensorWithValue<float>(target_tensor, 1.0); eager_test::CompareGradTensorWithValue<float>(target_tensor, 1.0);
eager_test::CompareGradTensorWithValue<float>(leaf_tensor, 10.0); eager_test::CompareGradTensorWithValue<float>(leaf_tensor, 10.0);
......
...@@ -86,7 +86,7 @@ TEST(FwdBwdJoint, SingleNode) { ...@@ -86,7 +86,7 @@ TEST(FwdBwdJoint, SingleNode) {
std::vector<paddle::experimental::Tensor> outs = {out}; std::vector<paddle::experimental::Tensor> outs = {out};
// 4. Run Backward // 4. Run Backward
RunBackward(outs, {}); Backward(outs, {});
VLOG(7) << "Target Grad is: " VLOG(7) << "Target Grad is: "
<< std::static_pointer_cast<phi::DenseTensor>( << std::static_pointer_cast<phi::DenseTensor>(
...@@ -137,7 +137,7 @@ TEST(FwdBwdJoint, LinearNodes) { ...@@ -137,7 +137,7 @@ TEST(FwdBwdJoint, LinearNodes) {
std::vector<paddle::experimental::Tensor> outs = {out1}; std::vector<paddle::experimental::Tensor> outs = {out1};
// 4. Run Backward // 4. Run Backward
RunBackward(outs, {}); Backward(outs, {});
// Examine Backward Grad // Examine Backward Grad
eager_test::CompareGradTensorWithValue<float>(tensor, 10.0); eager_test::CompareGradTensorWithValue<float>(tensor, 10.0);
...@@ -203,7 +203,7 @@ TEST(FwdBwdJoint, BranchedNodes) { ...@@ -203,7 +203,7 @@ TEST(FwdBwdJoint, BranchedNodes) {
// 4. Run Backward // 4. Run Backward
std::vector<paddle::experimental::Tensor> outs = {out1, out2}; std::vector<paddle::experimental::Tensor> outs = {out1, out2};
RunBackward(outs, {}); Backward(outs, {});
// Examine Backward Grad // Examine Backward Grad
eager_test::CompareGradTensorWithValue<float>(tensor, 30.0); eager_test::CompareGradTensorWithValue<float>(tensor, 30.0);
...@@ -260,7 +260,7 @@ TEST(FwdBwdJoint, GradientHook) { ...@@ -260,7 +260,7 @@ TEST(FwdBwdJoint, GradientHook) {
// 4. Run Backward // 4. Run Backward
std::vector<paddle::experimental::Tensor> outs = {out1, out2}; std::vector<paddle::experimental::Tensor> outs = {out1, out2};
RunBackward(outs, {}); Backward(outs, {});
// Examine Backward Grad // Examine Backward Grad
// leaf grad // leaf grad
...@@ -318,13 +318,13 @@ TEST(FwdBwdJoint, CrossBatchAccumulation) { ...@@ -318,13 +318,13 @@ TEST(FwdBwdJoint, CrossBatchAccumulation) {
// 4. Run Backward // 4. Run Backward
std::vector<paddle::experimental::Tensor> outs = {out1, out2}; std::vector<paddle::experimental::Tensor> outs = {out1, out2};
RunBackward(outs, {}); Backward(outs, {});
// Examine Backward Grad // Examine Backward Grad
eager_test::CompareGradTensorWithValue<float>(tensor, 30.0); eager_test::CompareGradTensorWithValue<float>(tensor, 30.0);
// Cross Batch Accumulation // Cross Batch Accumulation
RunBackward(outs, {}); Backward(outs, {});
// Examine Backward Grad // Examine Backward Grad
eager_test::CompareGradTensorWithValue<float>(tensor, 60.0); eager_test::CompareGradTensorWithValue<float>(tensor, 60.0);
...@@ -356,7 +356,7 @@ TEST(FwdBwdJoint, SingleNodeCUDA) { ...@@ -356,7 +356,7 @@ TEST(FwdBwdJoint, SingleNodeCUDA) {
std::vector<paddle::experimental::Tensor> outs = {out}; std::vector<paddle::experimental::Tensor> outs = {out};
// 4. Run Backward // 4. Run Backward
RunBackward(outs, {}); Backward(outs, {});
// Examine Backward Grad // Examine Backward Grad
eager_test::CompareGradTensorWithValue<float>(tensor, 2.0); eager_test::CompareGradTensorWithValue<float>(tensor, 2.0);
...@@ -412,7 +412,7 @@ TEST(FwdBwdJoint, BranchedNodesCUDA) { ...@@ -412,7 +412,7 @@ TEST(FwdBwdJoint, BranchedNodesCUDA) {
// TODO(jiabin): fix this with add functor // TODO(jiabin): fix this with add functor
// 4. Run Backward // 4. Run Backward
std::vector<paddle::experimental::Tensor> outs = {out1, out2}; std::vector<paddle::experimental::Tensor> outs = {out1, out2};
RunBackward(outs, {}); Backward(outs, {});
// Examine Backward Grad // Examine Backward Grad
eager_test::CompareGradTensorWithValue<float>(tensor, 30.0); eager_test::CompareGradTensorWithValue<float>(tensor, 30.0);
......
...@@ -57,7 +57,7 @@ TEST(Generated, Sigmoid) { ...@@ -57,7 +57,7 @@ TEST(Generated, Sigmoid) {
std::vector<paddle::experimental::Tensor> target_tensors = {output_tensor}; std::vector<paddle::experimental::Tensor> target_tensors = {output_tensor};
VLOG(6) << "Runing Backward"; VLOG(6) << "Runing Backward";
RunBackward(target_tensors, {}); Backward(target_tensors, {});
VLOG(6) << "Finish Backward"; VLOG(6) << "Finish Backward";
eager_test::CompareGradTensorWithValue<float>(tensor, 0.25); eager_test::CompareGradTensorWithValue<float>(tensor, 0.25);
...@@ -89,7 +89,7 @@ TEST(Generated, Matmul_v2) { ...@@ -89,7 +89,7 @@ TEST(Generated, Matmul_v2) {
eager_test::CompareTensorWithValue<float>(output_tensor, 96); eager_test::CompareTensorWithValue<float>(output_tensor, 96);
std::vector<paddle::experimental::Tensor> target_tensors = {output_tensor}; std::vector<paddle::experimental::Tensor> target_tensors = {output_tensor};
RunBackward(target_tensors, {}); Backward(target_tensors, {});
eager_test::CompareGradTensorWithValue<float>(X, 2.0 * 20); eager_test::CompareGradTensorWithValue<float>(X, 2.0 * 20);
eager_test::CompareGradTensorWithValue<float>(Y, 3.0 * 4); eager_test::CompareGradTensorWithValue<float>(Y, 3.0 * 4);
...@@ -120,7 +120,7 @@ TEST(Generated, ElementwiseAdd) { ...@@ -120,7 +120,7 @@ TEST(Generated, ElementwiseAdd) {
eager_test::CompareTensorWithValue<float>(output_tensor, 5); eager_test::CompareTensorWithValue<float>(output_tensor, 5);
std::vector<paddle::experimental::Tensor> target_tensors = {output_tensor}; std::vector<paddle::experimental::Tensor> target_tensors = {output_tensor};
RunBackward(target_tensors, {}); Backward(target_tensors, {});
eager_test::CompareGradTensorWithValue<float>(X, 1.0); eager_test::CompareGradTensorWithValue<float>(X, 1.0);
eager_test::CompareGradTensorWithValue<float>(Y, 1.0); eager_test::CompareGradTensorWithValue<float>(Y, 1.0);
...@@ -128,6 +128,6 @@ TEST(Generated, ElementwiseAdd) { ...@@ -128,6 +128,6 @@ TEST(Generated, ElementwiseAdd) {
} // namespace egr } // namespace egr
USE_OP(sigmoid); USE_OP_ITSELF(sigmoid);
USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP_ITSELF(matmul_v2); USE_OP_ITSELF(matmul_v2);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <sstream>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/eager/accumulation/accumulation_node.h"
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.h"
#include "paddle/fluid/eager/api/utils/tensor_utils.h"
#include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/fluid/eager/backward.h"
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/tests/test_utils.h"
#include "paddle/fluid/eager/api/all.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT);
namespace egr {
TEST(Grad, SingleNodeEmptyGrad) {
// Prepare Device Contexts
eager_test::InitEnv(paddle::platform::CPUPlace());
// Prepare Inputs
paddle::framework::DDim ddim = phi::make_ddim({4, 16, 16, 32});
// Create Target Tensor (output)
paddle::experimental::Tensor output_tensor =
egr_utils_api::CreateTensorWithValue(
ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32,
phi::DataLayout::NCHW, 1.0 /*value*/, false /*is_leaf*/);
// Create input tensor
const paddle::experimental::Tensor leaf_tensor =
egr_utils_api::CreateTensorWithValue(
ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32,
phi::DataLayout::NCHW, 1.0 /*value*/, true /*is_leaf*/);
{
// Create Scale Node
auto node0_ptr = std::make_shared<GradNodeScale>(1, 1);
node0_ptr->SetAttributes_scale(5.0 /*scale*/);
// Set grad in/out meta
node0_ptr->SetDefaultGradInOutMeta();
// Output_tensor set GradNode、OutRank、StopGradient propertis
AutogradMeta* auto_grad_meta = EagerUtils::autograd_meta(&output_tensor);
auto_grad_meta->SetGradNode(
std::dynamic_pointer_cast<GradNodeBase>(node0_ptr));
auto_grad_meta->SetSingleOutRankWithSlot(0, 0);
auto_grad_meta->SetStopGradient(false);
// Get autograd_meta from input tensor
AutogradMeta* auto_grad_meta1 =
EagerUtils::unsafe_autograd_meta(leaf_tensor);
// Connect Tensor and AccumulationNode via AutoGradMeta
auto acc_node_ptr =
std::make_shared<egr::GradNodeAccumulation>(auto_grad_meta1);
// input tensor set GradNode、OutRank、StopGradient propertis
auto_grad_meta1->SetGradNode(
std::dynamic_pointer_cast<GradNodeBase>(acc_node_ptr));
auto_grad_meta1->SetSingleOutRankWithSlot(0, 0);
auto_grad_meta1->SetStopGradient(false);
// grad_node Add Edges
std::vector<egr::AutogradMeta*> res = {auto_grad_meta1};
node0_ptr->AddEdges(&res, 0);
}
std::vector<paddle::experimental::Tensor> outs = {output_tensor};
// Run Grad
auto result = Grad(outs, {leaf_tensor}, {});
// Check Output Value
eager_test::CompareTensorWithValue<float>(result[0], 5.0);
}
TEST(Grad, SingleNodeCustomGrad) {
// Prepare Device Contexts
eager_test::InitEnv(paddle::platform::CPUPlace());
// Prepare Inputs
std::vector<paddle::experimental::Tensor> target_tensors;
paddle::framework::DDim ddim = phi::make_ddim({4, 16, 16, 32});
// Create Target Tensor
paddle::experimental::Tensor tensor = egr_utils_api::CreateTensorWithValue(
ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32,
phi::DataLayout::NCHW, 1.0 /*value*/, false /*is_leaf*/);
target_tensors.emplace_back(std::move(tensor));
std::vector<paddle::experimental::Tensor> grad_tensors;
// Create Grad Tensor
paddle::experimental::Tensor grad_tensor =
egr_utils_api::CreateTensorWithValue(
ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32,
phi::DataLayout::NCHW, 10.0 /*value*/, false /*is_leaf*/);
grad_tensors.emplace_back(std::move(grad_tensor));
paddle::experimental::Tensor leaf_tensor =
egr_utils_api::CreateTensorWithValue(
ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32,
phi::DataLayout::NCHW, 1.0 /*value*/, true /*is_leaf*/);
{
// Create Scale Node
auto node0_ptr = std::make_shared<GradNodeScale>(1, 1);
node0_ptr->SetAttributes_scale(5.0 /*scale*/);
// Set grad in/out meta
node0_ptr->SetDefaultGradInOutMeta();
// Connect Tensor and Node via AutoGradMeta
AutogradMeta* auto_grad_meta =
EagerUtils::autograd_meta(&(target_tensors[0]));
auto_grad_meta->SetGradNode(
std::dynamic_pointer_cast<GradNodeBase>(node0_ptr));
auto_grad_meta->SetSingleOutRankWithSlot(0, 0);
auto_grad_meta->SetStopGradient(false);
AutogradMeta* auto_grad_meta1 = EagerUtils::autograd_meta(&leaf_tensor);
// Connect Tensor and AccumulationNode via AutoGradMeta
auto acc_node_ptr =
std::make_shared<egr::GradNodeAccumulation>(auto_grad_meta1);
auto_grad_meta1->SetGradNode(
std::dynamic_pointer_cast<GradNodeBase>(acc_node_ptr));
auto_grad_meta1->SetSingleOutRankWithSlot(0, 0);
auto_grad_meta1->SetStopGradient(false);
std::vector<egr::AutogradMeta*> res = {auto_grad_meta1};
node0_ptr->AddEdges(&res, 0);
}
auto result = Grad(target_tensors, {leaf_tensor}, grad_tensors);
// Check Output Value
eager_test::CompareTensorWithValue<float>(result[0], 50.0);
}
/*
Node1
|
Node0
|
{ } // empty grad tensor
*/
TEST(Grad, LinearNodes) {
// Prepare Device Contexts
eager_test::InitEnv(paddle::platform::CPUPlace());
// Prepare Target Tensor
std::vector<paddle::experimental::Tensor> target_tensors;
paddle::framework::DDim ddim = phi::make_ddim({4, 16, 16, 32});
// Create Target Tensor
paddle::experimental::Tensor tensor = egr_utils_api::CreateTensorWithValue(
ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32,
phi::DataLayout::NCHW, 1.0 /*value*/, false /*is_leaf*/);
target_tensors.emplace_back(std::move(tensor));
paddle::experimental::Tensor leaf_tensor =
egr_utils_api::CreateTensorWithValue(
ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32,
phi::DataLayout::NCHW, 1.0 /*value*/, true /*is_leaf*/);
{
// Create Node0
auto node0_ptr = std::make_shared<GradNodeScale>(1, 1);
node0_ptr->SetAttributes_scale(5.0 /*scale*/);
// Set grad in/out meta for node0
node0_ptr->SetDefaultGradInOutMeta();
// Create Node1
auto node1_ptr = std::make_shared<GradNodeScale>(1, 1);
node1_ptr->SetAttributes_scale(10.0 /*scale*/);
// Set grad in/out meta for node1
node1_ptr->SetDefaultGradInOutMeta();
// Connect Input Tensor and Node0 via AutoGradMeta
AutogradMeta* auto_grad_meta =
EagerUtils::autograd_meta(&(target_tensors[0]));
auto_grad_meta->SetGradNode(
std::dynamic_pointer_cast<GradNodeBase>(node0_ptr));
auto_grad_meta->SetSingleOutRankWithSlot(0, 0);
auto_grad_meta->SetStopGradient(false);
// Connect Node0 -> Node1 via Edge
auto meta0 = egr::AutogradMeta();
meta0.SetStopGradient(false);
meta0.SetSingleOutRankWithSlot(0, 0);
meta0.SetGradNode(node1_ptr);
std::vector<egr::AutogradMeta*> res0 = {&meta0};
node0_ptr->AddEdges(&res0, 0);
AutogradMeta* auto_grad_meta1 = EagerUtils::autograd_meta(&leaf_tensor);
// Connect Tensor and AccumulationNode via AutoGradMeta
auto acc_node_ptr =
std::make_shared<egr::GradNodeAccumulation>(auto_grad_meta1);
auto_grad_meta1->SetGradNode(
std::dynamic_pointer_cast<GradNodeBase>(acc_node_ptr));
auto_grad_meta1->SetSingleOutRankWithSlot(0, 0);
auto_grad_meta1->SetStopGradient(false);
std::vector<egr::AutogradMeta*> res1 = {auto_grad_meta1};
node1_ptr->AddEdges(&res1, 0);
}
// Use Empty Grad Tensor
auto result = Grad(target_tensors, {leaf_tensor}, {});
// Check Output Value
eager_test::CompareTensorWithValue<float>(result[0], 50.0);
}
/*
Node2
| |
Node0 Node1
| |
in0 in1
*/
TEST(Grad, WithAccumulation) {
// Prepare Device Contexts
eager_test::InitEnv(paddle::platform::CPUPlace());
// Prepare Inputs
paddle::framework::DDim ddim = phi::make_ddim({4, 16, 16, 32});
// Create Target Tensor
std::vector<paddle::experimental::Tensor> target_tensors;
paddle::experimental::Tensor tensor0 = egr_utils_api::CreateTensorWithValue(
ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32,
phi::DataLayout::NCHW, 1.0 /*value*/, false /*is_leaf*/);
paddle::experimental::Tensor tensor1 = egr_utils_api::CreateTensorWithValue(
ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32,
phi::DataLayout::NCHW, 1.0 /*value*/, false /*is_leaf*/);
target_tensors.emplace_back(std::move(tensor0));
target_tensors.emplace_back(std::move(tensor1));
// Create Grad Tensor
std::vector<paddle::experimental::Tensor> grad_tensors;
paddle::experimental::Tensor grad_tensor0 =
egr_utils_api::CreateTensorWithValue(
ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32,
phi::DataLayout::NCHW, 5.0 /*value*/, false /*is_leaf*/);
paddle::experimental::Tensor grad_tensor1 =
egr_utils_api::CreateTensorWithValue(
ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32,
phi::DataLayout::NCHW, 10.0 /*value*/, false /*is_leaf*/);
grad_tensors.emplace_back(std::move(grad_tensor0));
grad_tensors.emplace_back(std::move(grad_tensor1));
paddle::experimental::Tensor leaf_tensor;
{
// Create Node0
auto node0_ptr = std::make_shared<GradNodeScale>(1, 1);
node0_ptr->SetAttributes_scale(5.0 /*scale*/);
node0_ptr->SetDefaultGradInOutMeta();
// Create Node1
auto node1_ptr = std::make_shared<GradNodeScale>(1, 1);
node1_ptr->SetAttributes_scale(10.0 /*scale*/);
node1_ptr->SetDefaultGradInOutMeta();
// Create Node2
auto node2_ptr = std::make_shared<GradNodeScale>(1, 1);
node2_ptr->SetAttributes_scale(20.0 /*scale*/);
node2_ptr->SetDefaultGradInOutMeta();
// Connect Inp0 and Node0 via AutoGradMeta
AutogradMeta* auto_grad_meta0 =
EagerUtils::autograd_meta(&(target_tensors[0]));
auto_grad_meta0->SetGradNode(
std::dynamic_pointer_cast<GradNodeBase>(node0_ptr));
auto_grad_meta0->SetSingleOutRankWithSlot(0, 0);
auto_grad_meta0->SetStopGradient(false);
// Connect Inp1 and Node1 via AutoGradMeta
AutogradMeta* auto_grad_meta1 =
EagerUtils::autograd_meta(&(target_tensors[1]));
auto_grad_meta1->SetGradNode(
std::dynamic_pointer_cast<GradNodeBase>(node1_ptr));
auto_grad_meta1->SetSingleOutRankWithSlot(0, 0);
auto_grad_meta1->SetStopGradient(false);
// Connect Node0 -> Node2 via Edge
auto meta0 = egr::AutogradMeta();
meta0.SetStopGradient(false);
meta0.SetSingleOutRankWithSlot(0, 0);
meta0.SetGradNode(node2_ptr);
std::vector<egr::AutogradMeta*> res0 = {&meta0};
node0_ptr->AddEdges(&res0, 0);
// Connect Node1 -> Node2 via Edge
auto meta1 = egr::AutogradMeta();
meta1.SetStopGradient(false);
meta1.SetSingleOutRankWithSlot(0, 0);
meta1.SetGradNode(node2_ptr);
std::vector<egr::AutogradMeta*> res1 = {&meta1};
node1_ptr->AddEdges(&res1, 0);
AutogradMeta* auto_grad_meta2 = EagerUtils::autograd_meta(&leaf_tensor);
// Connect Tensor and AccumulationNode via AutoGradMeta
auto acc_node_ptr =
std::make_shared<egr::GradNodeAccumulation>(auto_grad_meta2);
auto_grad_meta2->SetGradNode(
std::dynamic_pointer_cast<GradNodeBase>(acc_node_ptr));
auto_grad_meta2->SetSingleOutRankWithSlot(0, 0);
auto_grad_meta2->SetStopGradient(false);
std::vector<egr::AutogradMeta*> res2 = {auto_grad_meta2};
node2_ptr->AddEdges(&res2, 0);
}
auto result = Grad(target_tensors, {leaf_tensor}, grad_tensors);
eager_test::CompareTensorWithValue<float>(result[0], 2500.0);
}
} // namespace egr
...@@ -132,7 +132,7 @@ TEST(RetainGrad, HookBeforeRetainGrad) { ...@@ -132,7 +132,7 @@ TEST(RetainGrad, HookBeforeRetainGrad) {
leaf_tensor); // result: 4.0*5.0 + 3.0 = 23.0 leaf_tensor); // result: 4.0*5.0 + 3.0 = 23.0
} }
RunBackward(target_tensors, {}); Backward(target_tensors, {});
eager_test::CompareGradTensorWithValue<float>(target_tensor, 4.0); eager_test::CompareGradTensorWithValue<float>(target_tensor, 4.0);
eager_test::CompareGradTensorWithValue<float>(leaf_tensor, 23.0); eager_test::CompareGradTensorWithValue<float>(leaf_tensor, 23.0);
...@@ -199,7 +199,7 @@ TEST(RetainGrad, HookAfterRetainGrad) { ...@@ -199,7 +199,7 @@ TEST(RetainGrad, HookAfterRetainGrad) {
leaf_tensor, std::make_shared<egr::CppTensorHook>(hook_function)); leaf_tensor, std::make_shared<egr::CppTensorHook>(hook_function));
} }
RunBackward(target_tensors, {}); Backward(target_tensors, {});
eager_test::CompareGradTensorWithValue<float>(target_tensor, 1.0); eager_test::CompareGradTensorWithValue<float>(target_tensor, 1.0);
eager_test::CompareGradTensorWithValue<float>(leaf_tensor, 23.0); eager_test::CompareGradTensorWithValue<float>(leaf_tensor, 23.0);
} }
......
...@@ -108,7 +108,7 @@ void test_sigmoid(bool is_remove_gradient_hook) { ...@@ -108,7 +108,7 @@ void test_sigmoid(bool is_remove_gradient_hook) {
} }
VLOG(6) << "Runing Backward"; VLOG(6) << "Runing Backward";
RunBackward(target_tensors, {}); Backward(target_tensors, {});
VLOG(6) << "Finish Backward"; VLOG(6) << "Finish Backward";
eager_test::CompareGradTensorWithValue<float>( eager_test::CompareGradTensorWithValue<float>(
...@@ -166,7 +166,7 @@ void test_elementwiseAdd(bool is_remove_gradient_hook) { ...@@ -166,7 +166,7 @@ void test_elementwiseAdd(bool is_remove_gradient_hook) {
grad_node_tmp->RemoveGradientHook(hook_id); grad_node_tmp->RemoveGradientHook(hook_id);
} }
RunBackward(target_tensors, {}); Backward(target_tensors, {});
eager_test::CompareGradTensorWithValue<float>(X, 1.0); eager_test::CompareGradTensorWithValue<float>(X, 1.0);
eager_test::CompareGradTensorWithValue<float>( eager_test::CompareGradTensorWithValue<float>(
...@@ -224,7 +224,7 @@ void test_matmul(bool is_remove_gradient_hook) { ...@@ -224,7 +224,7 @@ void test_matmul(bool is_remove_gradient_hook) {
grad_node_tmp->RemoveGradientHook(hook_id); grad_node_tmp->RemoveGradientHook(hook_id);
} }
RunBackward(target_tensors, {}); Backward(target_tensors, {});
eager_test::CompareGradTensorWithValue<float>(X, 2.0 * 20); eager_test::CompareGradTensorWithValue<float>(X, 2.0 * 20);
eager_test::CompareGradTensorWithValue<float>( eager_test::CompareGradTensorWithValue<float>(
...@@ -255,6 +255,6 @@ TEST(Hook_intermidiate, Matmul_v2) { ...@@ -255,6 +255,6 @@ TEST(Hook_intermidiate, Matmul_v2) {
} }
} // namespace egr } // namespace egr
USE_OP(sigmoid); USE_OP_ITSELF(sigmoid);
USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP_ITSELF(matmul_v2); USE_OP_ITSELF(matmul_v2);
...@@ -370,8 +370,8 @@ class GradNodeRunProgram : public egr::GradNodeBase { ...@@ -370,8 +370,8 @@ class GradNodeRunProgram : public egr::GradNodeBase {
~GradNodeRunProgram() override = default; ~GradNodeRunProgram() override = default;
// Functor: perform backward computations // Functor: perform backward computations
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()( virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
const std::vector<std::vector<paddle::experimental::Tensor>> &grads) const std::vector<std::vector<paddle::experimental::Tensor>> &grads,
override { bool create_graph) override {
VLOG(3) << "Running Eager Backward Node: GradNodeRunProgram"; VLOG(3) << "Running Eager Backward Node: GradNodeRunProgram";
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
grads.size(), 1, grads.size(), 1,
...@@ -415,6 +415,12 @@ class GradNodeRunProgram : public egr::GradNodeBase { ...@@ -415,6 +415,12 @@ class GradNodeRunProgram : public egr::GradNodeBase {
// return {x_grad, details::DereferenceTensors(params_grad_ptr)}; // return {x_grad, details::DereferenceTensors(params_grad_ptr)};
} }
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
}
// SetAttrMap // SetAttrMap
void SetAttrMap(const paddle::framework::AttributeMap &attrs) { void SetAttrMap(const paddle::framework::AttributeMap &attrs) {
attrs_ = attrs; attrs_ = attrs;
......
...@@ -10,8 +10,9 @@ IF(WITH_GPU) ...@@ -10,8 +10,9 @@ IF(WITH_GPU)
nv_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h mem_pool.h DEPS ${HETERPS_DEPS}) nv_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h mem_pool.h DEPS ${HETERPS_DEPS})
nv_test(test_heter_comm SRCS feature_value.h DEPS heter_comm) nv_test(test_heter_comm SRCS feature_value.h DEPS heter_comm)
nv_library(heter_ps SRCS heter_ps.cu DEPS heter_comm) nv_library(heter_ps SRCS heter_ps.cu DEPS heter_comm)
nv_library(graph_gpu_ps SRCS graph_gpu_ps_table.h DEPS heter_comm) nv_library(graph_gpu_ps SRCS graph_gpu_ps_table.h DEPS heter_comm table)
nv_test(test_graph_comm SRCS test_graph.cu DEPS graph_gpu_ps) nv_test(test_graph_comm SRCS test_graph.cu DEPS graph_gpu_ps)
nv_test(test_cpu_graph_sample SRCS test_cpu_graph_sample.cu DEPS graph_gpu_ps)
ENDIF() ENDIF()
IF(WITH_ROCM) IF(WITH_ROCM)
hip_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h DEPS cub device_context) hip_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h DEPS cub device_context)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_HETERPS
namespace paddle {
namespace framework {
struct GpuPsGraphNode {
int64_t node_id;
int neighbor_size, neighbor_offset;
// this node's neighbor is stored on [neighbor_offset,neighbor_offset +
// neighbor_size) of int64_t *neighbor_list;
};
struct GpuPsCommGraph {
int64_t *neighbor_list;
GpuPsGraphNode *node_list;
int neighbor_size, node_size;
// the size of neighbor array and graph_node_list array
GpuPsCommGraph()
: neighbor_list(NULL), node_list(NULL), neighbor_size(0), node_size(0) {}
GpuPsCommGraph(int64_t *neighbor_list_, GpuPsGraphNode *node_list_,
int neighbor_size_, int node_size_)
: neighbor_list(neighbor_list_),
node_list(node_list_),
neighbor_size(neighbor_size_),
node_size(node_size_) {}
};
/*
suppose we have a graph like this
0----3-----5----7
\ |\ |\
17 8 9 1 2
we save the nodes in arbitrary order,
in this example,the order is
[0,5,1,2,7,3,8,9,17]
let us name this array u_id;
we record each node's neighbors:
0:3,17
5:3,7
1:7
2:7
7:1,2,5
3:0,5,8,9
8:3
9:3
17:0
by concatenating each node's neighbor_list in the order we save the node id.
we get [3,17,3,7,7,7,1,2,5,0,5,8,9,3,3,0]
this is the neighbor_list of GpuPsCommGraph
given this neighbor_list and the order to save node id,
we know,
node 0's neighbors are in the range [0,1] of neighbor_list
node 5's neighbors are in the range [2,3] of neighbor_list
node 1's neighbors are in the range [4,4] of neighbor_list
node 2:[5,5]
node 7:[6,6]
node 3:[9,12]
node 8:[13,13]
node 9:[14,14]
node 17:[15,15]
...
by the above information,
we generate a node_list:GpuPsGraphNode *graph_node_list in GpuPsCommGraph
of size 9,
where node_list[i].id = u_id[i]
then we have:
node_list[0]-> node_id:0, neighbor_size:2, neighbor_offset:0
node_list[1]-> node_id:5, neighbor_size:2, neighbor_offset:2
node_list[2]-> node_id:1, neighbor_size:1, neighbor_offset:4
node_list[3]-> node_id:2, neighbor_size:1, neighbor_offset:5
node_list[4]-> node_id:7, neighbor_size:3, neighbor_offset:6
node_list[5]-> node_id:3, neighbor_size:4, neighbor_offset:9
node_list[6]-> node_id:8, neighbor_size:1, neighbor_offset:13
node_list[7]-> node_id:9, neighbor_size:1, neighbor_offset:14
node_list[8]-> node_id:17, neighbor_size:1, neighbor_offset:15
*/
struct NeighborSampleResult {
int64_t *val;
int *actual_sample_size, sample_size, key_size;
NeighborSampleResult(int _sample_size, int _key_size)
: sample_size(_sample_size), key_size(_key_size) {
actual_sample_size = NULL;
val = NULL;
};
~NeighborSampleResult() {
if (val != NULL) cudaFree(val);
if (actual_sample_size != NULL) cudaFree(actual_sample_size);
}
};
struct NodeQueryResult {
int64_t *val;
int actual_sample_size;
NodeQueryResult() {
val = NULL;
actual_sample_size = 0;
};
~NodeQueryResult() {
if (val != NULL) cudaFree(val);
}
};
}
};
#endif
...@@ -14,114 +14,25 @@ ...@@ -14,114 +14,25 @@
#pragma once #pragma once
#include "heter_comm.h" #include "heter_comm.h"
#include "paddle/fluid/distributed/ps/table/common_graph_table.h"
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
namespace paddle { namespace paddle {
namespace framework { namespace framework {
struct GpuPsGraphNode {
int64_t node_id;
int neighbor_size, neighbor_offset;
// this node's neighbor is stored on [neighbor_offset,neighbor_offset +
// neighbor_size) of int64_t *neighbor_list;
};
struct GpuPsCommGraph {
int64_t *neighbor_list;
GpuPsGraphNode *node_list;
int neighbor_size, node_size;
// the size of neighbor array and graph_node_list array
GpuPsCommGraph()
: neighbor_list(NULL), node_list(NULL), neighbor_size(0), node_size(0) {}
GpuPsCommGraph(int64_t *neighbor_list_, GpuPsGraphNode *node_list_,
int neighbor_size_, int node_size_)
: neighbor_list(neighbor_list_),
node_list(node_list_),
neighbor_size(neighbor_size_),
node_size(node_size_) {}
};
/*
suppose we have a graph like this
0----3-----5----7
\ |\ |\
17 8 9 1 2
we save the nodes in arbitrary order,
in this example,the order is
[0,5,1,2,7,3,8,9,17]
let us name this array u_id;
we record each node's neighbors:
0:3,17
5:3,7
1:7
2:7
7:1,2,5
3:0,5,8,9
8:3
9:3
17:0
by concatenating each node's neighbor_list in the order we save the node id.
we get [3,17,3,7,7,7,1,2,5,0,5,8,9,3,3,0]
this is the neighbor_list of GpuPsCommGraph
given this neighbor_list and the order to save node id,
we know,
node 0's neighbors are in the range [0,1] of neighbor_list
node 5's neighbors are in the range [2,3] of neighbor_list
node 1's neighbors are in the range [4,4] of neighbor_list
node 2:[5,5]
node 7:[6,6]
node 3:[9,12]
node 8:[13,13]
node 9:[14,14]
node 17:[15,15]
...
by the above information,
we generate a node_list:GpuPsGraphNode *graph_node_list in GpuPsCommGraph
of size 9,
where node_list[i].id = u_id[i]
then we have:
node_list[0]-> node_id:0, neighbor_size:2, neighbor_offset:0
node_list[1]-> node_id:5, neighbor_size:2, neighbor_offset:2
node_list[2]-> node_id:1, neighbor_size:1, neighbor_offset:4
node_list[3]-> node_id:2, neighbor_size:1, neighbor_offset:5
node_list[4]-> node_id:7, neighbor_size:3, neighbor_offset:6
node_list[5]-> node_id:3, neighbor_size:4, neighbor_offset:9
node_list[6]-> node_id:8, neighbor_size:1, neighbor_offset:13
node_list[7]-> node_id:9, neighbor_size:1, neighbor_offset:14
node_list[8]-> node_id:17, neighbor_size:1, neighbor_offset:15
*/
struct NeighborSampleResult {
int64_t *val;
int *actual_sample_size, sample_size, key_size;
NeighborSampleResult(int _sample_size, int _key_size)
: sample_size(_sample_size), key_size(_key_size) {
actual_sample_size = NULL;
val = NULL;
};
~NeighborSampleResult() {
if (val != NULL) cudaFree(val);
if (actual_sample_size != NULL) cudaFree(actual_sample_size);
}
};
struct NodeQueryResult {
int64_t *val;
int actual_sample_size;
NodeQueryResult() {
val = NULL;
actual_sample_size = 0;
};
~NodeQueryResult() {
if (val != NULL) cudaFree(val);
}
};
class GpuPsGraphTable : public HeterComm<int64_t, int, int> { class GpuPsGraphTable : public HeterComm<int64_t, int, int> {
public: public:
GpuPsGraphTable(std::shared_ptr<HeterPsResource> resource) GpuPsGraphTable(std::shared_ptr<HeterPsResource> resource)
: HeterComm<int64_t, int, int>(1, resource) { : HeterComm<int64_t, int, int>(1, resource) {
load_factor_ = 0.25; load_factor_ = 0.25;
rw_lock.reset(new pthread_rwlock_t());
cpu_table_status = -1;
}
~GpuPsGraphTable() {
if (cpu_table_status != -1) {
end_graph_sampling();
}
} }
void build_graph_from_cpu(std::vector<GpuPsCommGraph> &cpu_node_list); void build_graph_from_cpu(std::vector<GpuPsCommGraph> &cpu_node_list);
NodeQueryResult *graph_node_sample(int gpu_id, int sample_size); NodeQueryResult *graph_node_sample(int gpu_id, int sample_size);
...@@ -134,9 +45,19 @@ class GpuPsGraphTable : public HeterComm<int64_t, int, int> { ...@@ -134,9 +45,19 @@ class GpuPsGraphTable : public HeterComm<int64_t, int, int> {
int *h_right, int *h_right,
int64_t *src_sample_res, int64_t *src_sample_res,
int *actual_sample_size); int *actual_sample_size);
int init_cpu_table(const paddle::distributed::GraphParameter &graph);
int load(const std::string &path, const std::string &param);
virtual int32_t end_graph_sampling() {
return cpu_graph_table->end_graph_sampling();
}
private: private:
std::vector<GpuPsCommGraph> gpu_graph_list; std::vector<GpuPsCommGraph> gpu_graph_list;
std::shared_ptr<paddle::distributed::GraphTable> cpu_graph_table;
std::shared_ptr<pthread_rwlock_t> rw_lock;
mutable std::mutex mutex_;
std::condition_variable cv_;
int cpu_table_status;
}; };
} }
}; };
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
//#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
/* /*
...@@ -45,6 +46,33 @@ __global__ void neighbor_sample_example(GpuPsCommGraph graph, int* index, ...@@ -45,6 +46,33 @@ __global__ void neighbor_sample_example(GpuPsCommGraph graph, int* index,
} }
} }
int GpuPsGraphTable::init_cpu_table(
const paddle::distributed::GraphParameter& graph) {
cpu_graph_table.reset(new paddle::distributed::GraphTable);
cpu_table_status = cpu_graph_table->initialize(graph);
if (cpu_table_status != 0) return cpu_table_status;
std::function<void(std::vector<GpuPsCommGraph>&)> callback =
[this](std::vector<GpuPsCommGraph>& res) {
pthread_rwlock_wrlock(this->rw_lock.get());
this->clear_graph_info();
this->build_graph_from_cpu(res);
pthread_rwlock_unlock(this->rw_lock.get());
cv_.notify_one();
};
cpu_graph_table->set_graph_sample_callback(callback);
return cpu_table_status;
}
int GpuPsGraphTable::load(const std::string& path, const std::string& param) {
int status = cpu_graph_table->load(path, param);
if (status != 0) {
return status;
}
std::unique_lock<std::mutex> lock(mutex_);
cpu_graph_table->start_graph_sampling();
cv_.wait(lock);
return 0;
}
/* /*
comment 1 comment 1
...@@ -68,6 +96,7 @@ __global__ void neighbor_sample_example(GpuPsCommGraph graph, int* index, ...@@ -68,6 +96,7 @@ __global__ void neighbor_sample_example(GpuPsCommGraph graph, int* index,
that's what fill_dvals does. that's what fill_dvals does.
*/ */
void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu( void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu(
int gpu_id, int gpu_num, int sample_size, int* h_left, int* h_right, int gpu_id, int gpu_num, int sample_size, int* h_left, int* h_right,
int64_t* src_sample_res, int* actual_sample_size) { int64_t* src_sample_res, int* actual_sample_size) {
...@@ -258,7 +287,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, ...@@ -258,7 +287,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
auto d_shard_keys = memory::Alloc(place, len * sizeof(int64_t)); auto d_shard_keys = memory::Alloc(place, len * sizeof(int64_t));
int64_t* d_shard_keys_ptr = reinterpret_cast<int64_t*>(d_shard_keys->ptr()); int64_t* d_shard_keys_ptr = reinterpret_cast<int64_t*>(d_shard_keys->ptr());
auto d_shard_vals = memory::Alloc(place, len * sizeof(int64_t)); auto d_shard_vals = memory::Alloc(place, sample_size * len * sizeof(int64_t));
int64_t* d_shard_vals_ptr = reinterpret_cast<int64_t*>(d_shard_vals->ptr()); int64_t* d_shard_vals_ptr = reinterpret_cast<int64_t*>(d_shard_vals->ptr());
auto d_shard_actual_sample_size = memory::Alloc(place, len * sizeof(int)); auto d_shard_actual_sample_size = memory::Alloc(place, len * sizeof(int));
int* d_shard_actual_sample_size_ptr = int* d_shard_actual_sample_size_ptr =
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
//#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h"
#include <queue> #include <queue>
namespace paddle { namespace paddle {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
using namespace paddle::framework;
void prepare_file(char file_name[], std::vector<std::string> data) {
std::ofstream ofile;
ofile.open(file_name);
for (auto x : data) {
ofile << x << std::endl;
}
ofile.close();
}
char edge_file_name[] = "edges.txt";
TEST(TEST_FLEET, graph_sample) {
std::vector<std::string> edges;
int gpu_count = 3;
std::vector<int> dev_ids;
dev_ids.push_back(0);
dev_ids.push_back(1);
dev_ids.push_back(2);
std::shared_ptr<HeterPsResource> resource =
std::make_shared<HeterPsResource>(dev_ids);
resource->enable_p2p();
GpuPsGraphTable g(resource);
int node_count = 10;
std::vector<std::vector<int64_t>> neighbors(node_count);
int ind = 0;
int64_t node_id = 0;
// std::vector<GpuPsCommGraph> graph_list(gpu_count);
while (ind < node_count) {
int neighbor_size = ind + 1;
while (neighbor_size--) {
edges.push_back(std::to_string(ind) + "\t" + std::to_string(node_id) +
"\t1.0");
node_id++;
}
ind++;
}
/*
gpu 0:
0,3,6,9
gpu 1:
1,4,7
gpu 2:
2,5,8
query(2,6) returns nodes [6,9,1,4,7,2]
*/
::paddle::distributed::GraphParameter table_proto;
table_proto.set_gpups_mode(true);
table_proto.set_gpups_mode_shard_num(127);
table_proto.set_gpu_num(3);
table_proto.set_gpups_graph_sample_class("BasicBfsGraphSampler");
table_proto.set_gpups_graph_sample_args("5,5,1,1");
prepare_file(edge_file_name, edges);
g.init_cpu_table(table_proto);
g.load(std::string(edge_file_name), std::string("e>"));
/*
node x's neighbor list = [(1+x)*x/2,(1+x)*x/2 + 1,.....,(1+x)*x/2 + x]
so node 6's neighbors are [21,22...,27]
node 7's neighbors are [28,29,..35]
node 0's neighbors are [0]
query([7,0,6],sample_size=3) should return [28,29,30,0,x,x,21,22,23]
6 --index-->2
0 --index--->0
7 --index-->2
*/
int64_t cpu_key[3] = {7, 0, 6};
void *key;
cudaMalloc((void **)&key, 3 * sizeof(int64_t));
cudaMemcpy(key, cpu_key, 3 * sizeof(int64_t), cudaMemcpyHostToDevice);
auto neighbor_sample_res = g.graph_neighbor_sample(0, (int64_t *)key, 3, 3);
int64_t *res = new int64_t[9];
cudaMemcpy(res, neighbor_sample_res->val, 72, cudaMemcpyDeviceToHost);
std::sort(res, res + 3);
std::sort(res + 6, res + 9);
int64_t expected_sample_val[] = {28, 29, 30, 0, -1, -1, 21, 22, 23};
for (int i = 0; i < 9; i++) {
if (expected_sample_val[i] != -1) {
ASSERT_EQ(res[i], expected_sample_val[i]);
}
}
delete[] res;
delete neighbor_sample_res;
}
...@@ -78,6 +78,11 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext { ...@@ -78,6 +78,11 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
return var_types[0] == proto::VarType::SELECTED_ROWS; return var_types[0] == proto::VarType::SELECTED_ROWS;
} }
bool IsDenseTensorVectorInput(const std::string& name) const override {
auto var_types = ctx_.GetInputsVarType(name);
return var_types[0] == proto::VarType::LOD_TENSOR_ARRAY;
}
bool IsDenseTensorOutput(const std::string& name) const override { bool IsDenseTensorOutput(const std::string& name) const override {
auto var_types = ctx_.GetOutputsVarType(name); auto var_types = ctx_.GetOutputsVarType(name);
return var_types[0] == proto::VarType::LOD_TENSOR; return var_types[0] == proto::VarType::LOD_TENSOR;
...@@ -125,9 +130,14 @@ class CompatMetaTensor : public phi::MetaTensor { ...@@ -125,9 +130,14 @@ class CompatMetaTensor : public phi::MetaTensor {
return var->Get<phi::DenseTensor>().dims(); return var->Get<phi::DenseTensor>().dims();
} else if (var->IsType<phi::SelectedRows>()) { } else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().dims(); return var->Get<phi::SelectedRows>().dims();
} else if (var->IsType<framework::LoDTensorArray>()) {
// use tensor array size as dims
auto& tensor_array = var->Get<framework::LoDTensorArray>();
return phi::make_ddim({static_cast<int64_t>(tensor_array.size())});
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can get dims from DenseTensor or SelectedRows.")); "Currently, only can get dims from DenseTensor or SelectedRows or "
"DenseTensorArray."));
} }
} else { } else {
auto* var = BOOST_GET_CONST(VarDesc*, var_); auto* var = BOOST_GET_CONST(VarDesc*, var_);
...@@ -144,6 +154,10 @@ class CompatMetaTensor : public phi::MetaTensor { ...@@ -144,6 +154,10 @@ class CompatMetaTensor : public phi::MetaTensor {
return var->Get<phi::DenseTensor>().dtype(); return var->Get<phi::DenseTensor>().dtype();
} else if (var->IsType<phi::SelectedRows>()) { } else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().dtype(); return var->Get<phi::SelectedRows>().dtype();
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported get dtype from LoDTensorArray now
return phi::DataType::UNDEFINED;
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can get dtype from DenseTensor or SelectedRows.")); "Currently, only can get dtype from DenseTensor or SelectedRows."));
...@@ -157,7 +171,19 @@ class CompatMetaTensor : public phi::MetaTensor { ...@@ -157,7 +171,19 @@ class CompatMetaTensor : public phi::MetaTensor {
DataLayout layout() const override { DataLayout layout() const override {
if (is_runtime_) { if (is_runtime_) {
auto* var = BOOST_GET_CONST(Variable*, var_); auto* var = BOOST_GET_CONST(Variable*, var_);
return var->Get<LoDTensor>().layout(); if (var->IsType<phi::DenseTensor>()) {
return var->Get<phi::DenseTensor>().layout();
} else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().layout();
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported get layout from LoDTensorArray now
return phi::DataLayout::UNDEFINED;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can get layout from DenseTensor or "
"SelectedRows."));
}
} else { } else {
// NOTE(chenweihang): do nothing // NOTE(chenweihang): do nothing
// Unsupported get layout for VarDesc now // Unsupported get layout for VarDesc now
...@@ -174,6 +200,16 @@ class CompatMetaTensor : public phi::MetaTensor { ...@@ -174,6 +200,16 @@ class CompatMetaTensor : public phi::MetaTensor {
} else if (var->IsType<phi::SelectedRows>()) { } else if (var->IsType<phi::SelectedRows>()) {
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value(); auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims; phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
} else if (var->IsType<framework::LoDTensorArray>()) {
auto* tensor_array = var->GetMutable<framework::LoDTensorArray>();
// Note: Here I want enforce `tensor_array->size() == 0UL`, because
// inplace using on LoDTensorArray is dangerous, but the unittest
// `test_list` contains this behavior
PADDLE_ENFORCE_EQ(dims.size(), 1UL,
platform::errors::InvalidArgument(
"LoDTensorArray can only have one dimension."));
// only set the array size for LoDTensorArray input
tensor_array->resize(dims[0]);
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can set dims from DenseTensor or SelectedRows.")); "Currently, only can set dims from DenseTensor or SelectedRows."));
...@@ -193,6 +229,9 @@ class CompatMetaTensor : public phi::MetaTensor { ...@@ -193,6 +229,9 @@ class CompatMetaTensor : public phi::MetaTensor {
} else if (var->IsType<phi::SelectedRows>()) { } else if (var->IsType<phi::SelectedRows>()) {
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value(); auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype; phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype;
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported set dtype for LoDTensorArray now
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can set dtype from DenseTensor or SelectedRows.")); "Currently, only can set dtype from DenseTensor or SelectedRows."));
...@@ -206,10 +245,20 @@ class CompatMetaTensor : public phi::MetaTensor { ...@@ -206,10 +245,20 @@ class CompatMetaTensor : public phi::MetaTensor {
void set_layout(DataLayout layout) override { void set_layout(DataLayout layout) override {
if (is_runtime_) { if (is_runtime_) {
auto* var = BOOST_GET(Variable*, var_); auto* var = BOOST_GET(Variable*, var_);
LoDTensor* tensor = var->GetMutable<LoDTensor>(); if (var->IsType<phi::DenseTensor>()) {
phi::DenseTensorUtils::GetMutableMeta( auto* tensor = var->GetMutable<phi::DenseTensor>();
static_cast<phi::DenseTensor*>(tensor)) phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
->layout = layout; } else if (var->IsType<phi::SelectedRows>()) {
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported set dtype for LoDTensorArray now
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can set layout from DenseTensor or "
"SelectedRows."));
}
} else { } else {
// NOTE(chenweihang): do nothing // NOTE(chenweihang): do nothing
// Unsupported set layout for VarDesc now // Unsupported set layout for VarDesc now
...@@ -251,9 +300,7 @@ class CompatMetaTensor : public phi::MetaTensor { ...@@ -251,9 +300,7 @@ class CompatMetaTensor : public phi::MetaTensor {
void share_meta(const MetaTensor& meta_tensor) override { void share_meta(const MetaTensor& meta_tensor) override {
share_dims(meta_tensor); share_dims(meta_tensor);
set_dtype(meta_tensor.dtype()); set_dtype(meta_tensor.dtype());
// VarDesc doesn't contains layout, so we cannot share layout set_layout(meta_tensor.layout());
// set_layout(meta_tensor.layout());
// special case: share lod of LoDTensor // special case: share lod of LoDTensor
share_lod(meta_tensor); share_lod(meta_tensor);
} }
...@@ -442,6 +489,51 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -442,6 +489,51 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
attr_name, infershape_input.size())); attr_name, infershape_input.size()));
} }
} }
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<phi::Scalar>))) {
auto& attr = attr_reader.GetAttr(attr_name);
if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int32_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int32_t>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int64_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int64_t>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<float>))) {
const auto& vec = BOOST_GET_CONST(std::vector<float>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<double>))) {
const auto& vec = BOOST_GET_CONST(std::vector<double>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to vector<Scalar> when "
"construct InferMetaContext.",
attr_names[i]));
}
} else if (ctx->HasAttr(attr_name)) { } else if (ctx->HasAttr(attr_name)) {
// Emplace Back Attr according to the type of attr. // Emplace Back Attr according to the type of attr.
auto& attr = attr_reader.GetAttr(attr_name); auto& attr = attr_reader.GetAttr(attr_name);
......
...@@ -97,6 +97,7 @@ pass_library(layer_norm_fuse_pass inference) ...@@ -97,6 +97,7 @@ pass_library(layer_norm_fuse_pass inference)
pass_library(add_support_int8_pass inference) pass_library(add_support_int8_pass inference)
pass_library(matmul_scale_fuse_pass inference) pass_library(matmul_scale_fuse_pass inference)
pass_library(gpu_cpu_map_matmul_to_mul_pass inference) pass_library(gpu_cpu_map_matmul_to_mul_pass inference)
pass_library(mixed_precision_configure_pass inference)
pass_library(generate_pass DEPS pass_desc_proto) pass_library(generate_pass DEPS pass_desc_proto)
target_link_libraries(generate_pass pass_desc_proto) target_link_libraries(generate_pass pass_desc_proto)
......
...@@ -2052,18 +2052,19 @@ PDNode *patterns::Pool::operator()() { ...@@ -2052,18 +2052,19 @@ PDNode *patterns::Pool::operator()() {
return output_var; return output_var;
} }
PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) { PDNode *patterns::Elementwise::operator()(PDNode *x_var, PDNode *y_var,
auto elementwise_add_op = pattern->NewNode(elementwise_add_op_repr()) const std::string elementwise_type) {
->assert_is_op("elementwise_add"); auto elementwise_op =
pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type);
x_var->AsInput()->assert_is_op_input("elementwise_add", "X");
y_var->AsInput()->assert_is_op_input("elementwise_add", "Y"); x_var->AsInput()->assert_is_op_input(elementwise_type, "X");
auto out_var = pattern->NewNode(elementwise_add_out_repr()) y_var->AsInput()->assert_is_op_input(elementwise_type, "Y");
auto out_var = pattern->NewNode(elementwise_out_repr())
->AsOutput() ->AsOutput()
->assert_is_op_output("elementwise_add", "Out"); ->assert_is_op_output(elementwise_type, "Out");
elementwise_add_op->LinksFrom({x_var, y_var}); elementwise_op->LinksFrom({x_var, y_var});
elementwise_add_op->LinksTo({out_var}); elementwise_op->LinksTo({out_var});
return out_var; return out_var;
} }
......
...@@ -1016,20 +1016,20 @@ struct Pool : public PatternBase { ...@@ -1016,20 +1016,20 @@ struct Pool : public PatternBase {
PATTERN_DECL_NODE(pool_output); PATTERN_DECL_NODE(pool_output);
}; };
// ElementwiseAdd used in residual connections. // Elementwise ops
// y_var is used and convolution output. // Forward pass for element-wise operators (add, mul)
// The operator is removed, when residual // elementwise_mul_out is the result of the operator
// connection fusion is on. struct Elementwise : public PatternBase {
struct ElementwiseAdd : public PatternBase { Elementwise(PDPattern* pattern, const std::string& name_scope)
ElementwiseAdd(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "elementwise") {}
: PatternBase(pattern, name_scope, "elementwise_add") {}
PDNode* operator()(PDNode* x_var, PDNode* y_var,
PDNode* operator()(PDNode* x_var, PDNode* y_var); const std::string elementwise_type);
PATTERN_DECL_NODE(elementwise_add_op); PATTERN_DECL_NODE(elementwise_op);
PATTERN_DECL_NODE(elementwise_add_x); PATTERN_DECL_NODE(elementwise_x);
PATTERN_DECL_NODE(elementwise_add_y); PATTERN_DECL_NODE(elementwise_y);
PATTERN_DECL_NODE(elementwise_add_out); PATTERN_DECL_NODE(elementwise_out);
}; };
// Transpose op // Transpose op
......
...@@ -73,8 +73,10 @@ static void ShareVarInfoToCinnLaunch( ...@@ -73,8 +73,10 @@ static void ShareVarInfoToCinnLaunch(
varinfo_maps.at(cinn_launch_op->GetScopeIdx()); varinfo_maps.at(cinn_launch_op->GetScopeIdx());
// collect all MemOptVarInfos of external variables // collect all MemOptVarInfos of external variables
// that would be eager deleted after the cinn_launch subgraph executed, // that were eager deleted after the cinn_launch subgraph executed,
// and store them as attribute of the subgraph // and we will delete them in advance among eager_deletion_ops
// inside cinn_launch subgraph, so store them as attribute of the subgraph
// to pass to the inner eager_deletion_ops.
for (const auto& var_name : vars_to_delete) { for (const auto& var_name : vars_to_delete) {
auto it = src_varinfo_map.find(var_name); auto it = src_varinfo_map.find(var_name);
PADDLE_ENFORCE_NE(it, src_varinfo_map.end(), PADDLE_ENFORCE_NE(it, src_varinfo_map.end(),
...@@ -82,6 +84,8 @@ static void ShareVarInfoToCinnLaunch( ...@@ -82,6 +84,8 @@ static void ShareVarInfoToCinnLaunch(
"MemOptVarInfo of var[%s] not found", var_name)); "MemOptVarInfo of var[%s] not found", var_name));
dst_varinfo_map.emplace(var_name, it->second); dst_varinfo_map.emplace(var_name, it->second);
} }
// skip running of the followed eager_deletion_op
followed_eager_deletion_op->SetSkipRunning(true);
} }
static void TakeVarInfoFromMainGraph( static void TakeVarInfoFromMainGraph(
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mixed_precision_configure_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
void MixedPrecisionConfigurePass::InsertCastOps(
Graph* graph, const StringSet& blacklist) const {
VLOG(3) << "Insert the cast op before and after the kernel that does not "
"supports fp16 precision";
auto update_cast_desc = [&](
framework::OpDesc& desc, const std::string& x_name,
const std::string& out_name, const int in_dtype, const int out_dtype) {
desc.SetType("cast");
desc.SetInput("X", {x_name});
desc.SetOutput("Out", {out_name});
desc.SetAttr("in_dtype", in_dtype);
desc.SetAttr("out_dtype", out_dtype);
desc.SetAttr("use_mkldnn", false);
desc.SetAttr("with_quant_attr", false);
desc.Flush();
};
auto cast_input = [&](Graph* graph, Node* op_node,
const StringSet& cast_list) {
auto inlinks = op_node->inputs;
for (auto* pre_node : inlinks) {
if (pre_node->IsVar()) {
const auto is_persistable = pre_node->Var()->Persistable();
const auto is_float =
pre_node->Var()->GetDataType() == proto::VarType::FP16 ||
pre_node->Var()->GetDataType() == proto::VarType::FP32 ||
pre_node->Var()->GetDataType() == proto::VarType::FP64;
if (!is_persistable && is_float) {
int suffix = 0;
for (auto* pre_node_input : pre_node->inputs) {
if (!pre_node_input->IsOp()) continue;
const auto& type = pre_node_input->Op()->Type();
if (!cast_list.count(type) && type != "cast") {
std::string old_name = pre_node->Name();
std::string new_name =
old_name + "_cast.tmp_" + std::to_string(suffix);
suffix++;
framework::OpDesc new_op_desc(op_node->Op()->Block());
// 4 for fp16, 5 for fp32
update_cast_desc(new_op_desc, old_name, new_name, 4, 5);
auto* new_op = graph->CreateOpNode(&new_op_desc);
VarDesc out_var(new_name);
out_var.SetPersistable(false);
auto* node_var = graph->CreateVarNode(&out_var);
op_node->Op()->RenameInput(old_name, new_name);
IR_NODE_LINK_TO(pre_node, new_op);
IR_NODE_LINK_TO(new_op, node_var);
IR_NODE_LINK_TO(node_var, op_node);
}
}
}
}
}
};
auto cast_output = [&](Graph* graph, Node* op_node,
const StringSet& cast_list) {
auto outlinks = op_node->outputs;
for (auto* next_node : outlinks) {
if (next_node->IsVar()) {
const auto is_persistable = next_node->Var()->Persistable();
const auto is_float =
next_node->Var()->GetDataType() == proto::VarType::FP16 ||
next_node->Var()->GetDataType() == proto::VarType::FP32 ||
next_node->Var()->GetDataType() == proto::VarType::FP64;
if (!is_persistable && is_float) {
int suffix = 0;
for (auto* next_node_output : next_node->outputs) {
if (!next_node_output->IsOp()) continue;
const auto& type = next_node_output->Op()->Type();
if (!cast_list.count(type) && type != "cast") {
std::string old_name = next_node->Name();
std::string new_name =
old_name + "_cast.tmp_" + std::to_string(suffix);
suffix++;
framework::OpDesc new_op_desc(op_node->Op()->Block());
// 4 for fp16, 5 for fp32
update_cast_desc(new_op_desc, old_name, new_name, 5, 4);
auto* new_op = graph->CreateOpNode(&new_op_desc);
VarDesc out_var(new_name);
out_var.SetPersistable(false);
auto* node_var = graph->CreateVarNode(&out_var);
next_node_output->Op()->RenameInput(old_name, new_name);
IR_NODE_LINK_TO(next_node, new_op);
IR_NODE_LINK_TO(new_op, node_var);
IR_NODE_LINK_TO(node_var, next_node_output);
}
}
}
}
}
};
for (auto* op_node :
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
if (!op_node->IsOp() || op_node->Op()->Type() == "feed" ||
op_node->Op()->Type() == "fetch")
continue;
const auto& type = op_node->Op()->Type();
if (blacklist.count(type)) {
cast_input(graph, op_node, blacklist);
cast_output(graph, op_node, blacklist);
}
}
}
void MixedPrecisionConfigurePass::ApplyImpl(Graph* graph) const {
const auto blacklist =
Get<std::unordered_set<std::string>>("gpu_fp16_disabled_op_types");
InsertCastOps(graph, blacklist);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(mixed_precision_configure_pass,
paddle::framework::ir::MixedPrecisionConfigurePass);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
using StringSet = std::unordered_set<std::string>;
class MixedPrecisionConfigurePass : public FusePassBase {
public:
MixedPrecisionConfigurePass() = default;
virtual ~MixedPrecisionConfigurePass() {}
protected:
void ApplyImpl(Graph* graph) const override;
private:
void InsertCastOps(Graph* graph, const StringSet& blacklist) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -145,10 +145,10 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX( ...@@ -145,10 +145,10 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
patterns::Conv conv_pattern{pattern, name_scope}; patterns::Conv conv_pattern{pattern, name_scope};
auto conv_output = conv_pattern(); auto conv_output = conv_pattern();
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope}; patterns::Elementwise elementwise_pattern{pattern, name_scope};
elementwise_add_pattern( elementwise_pattern(
conv_output, conv_output, pattern->NewNode(elementwise_pattern.elementwise_y_repr()),
pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr())); "elementwise_add");
conv_output->AsIntermediate(); conv_output->AsIntermediate();
int found_conv_as_x_count = 0; int found_conv_as_x_count = 0;
...@@ -160,16 +160,16 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX( ...@@ -160,16 +160,16 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, GET_IR_NODE_FROM_SUBGRAPH(elementwise_op, elementwise_op,
elementwise_add_pattern); elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_identity, elementwise_add_y, GET_IR_NODE_FROM_SUBGRAPH(elementwise_identity, elementwise_y,
elementwise_add_pattern); elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out,
elementwise_add_pattern); elementwise_pattern);
if (FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN) return; if (FindFuseOption(*conv_op, *elementwise_op) != FUSE_MKLDNN) return;
if (!IsReachable(g, elementwise_add_identity, conv_output)) return; if (!IsReachable(g, elementwise_identity, conv_output)) return;
if (HasFusedActivation(conv_op)) return; if (HasFusedActivation(conv_op)) return;
...@@ -179,14 +179,14 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX( ...@@ -179,14 +179,14 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
return; return;
} }
conv_op->Op()->SetInput("ResidualData", {elementwise_add_identity->Name()}); conv_op->Op()->SetInput("ResidualData", {elementwise_identity->Name()});
conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()}); conv_op->Op()->SetOutput("Output", {elementwise_out->Name()});
conv_op->Op()->SetAttr("fuse_residual_connection", true); conv_op->Op()->SetAttr("fuse_residual_connection", true);
GraphSafeRemoveNodes(g, {conv_output, elementwise_add_op}); GraphSafeRemoveNodes(g, {conv_output, elementwise_op});
IR_NODE_LINK_TO(elementwise_add_identity, conv_op); IR_NODE_LINK_TO(elementwise_identity, conv_op);
IR_NODE_LINK_TO(conv_op, elementwise_add_out); IR_NODE_LINK_TO(conv_op, elementwise_out);
found_conv_as_x_count++; found_conv_as_x_count++;
}; };
...@@ -212,10 +212,10 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( ...@@ -212,10 +212,10 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
patterns::Conv conv_pattern{pattern, name_scope}; patterns::Conv conv_pattern{pattern, name_scope};
auto conv_output = conv_pattern(); auto conv_output = conv_pattern();
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope}; patterns::Elementwise elementwise_pattern{pattern, name_scope};
elementwise_add_pattern( elementwise_pattern(
pattern->NewNode(elementwise_add_pattern.elementwise_add_x_repr()), pattern->NewNode(elementwise_pattern.elementwise_x_repr()), conv_output,
conv_output); "elementwise_add");
conv_output->AsIntermediate(); conv_output->AsIntermediate();
int found_conv_as_y_count = 0; int found_conv_as_y_count = 0;
...@@ -227,16 +227,16 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( ...@@ -227,16 +227,16 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, GET_IR_NODE_FROM_SUBGRAPH(elementwise_op, elementwise_op,
elementwise_add_pattern); elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x, GET_IR_NODE_FROM_SUBGRAPH(elementwise_x, elementwise_x,
elementwise_add_pattern); elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out,
elementwise_add_pattern); elementwise_pattern);
if (FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN) return; if (FindFuseOption(*conv_op, *elementwise_op) != FUSE_MKLDNN) return;
if (!IsReachable(g, elementwise_add_x, conv_output)) return; if (!IsReachable(g, elementwise_x, conv_output)) return;
if (HasFusedActivation(conv_op)) return; if (HasFusedActivation(conv_op)) return;
...@@ -246,14 +246,14 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( ...@@ -246,14 +246,14 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
return; return;
} }
conv_op->Op()->SetInput("ResidualData", {elementwise_add_x->Name()}); conv_op->Op()->SetInput("ResidualData", {elementwise_x->Name()});
conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()}); conv_op->Op()->SetOutput("Output", {elementwise_out->Name()});
conv_op->Op()->SetAttr("fuse_residual_connection", true); conv_op->Op()->SetAttr("fuse_residual_connection", true);
GraphSafeRemoveNodes(g, {conv_output, elementwise_add_op}); GraphSafeRemoveNodes(g, {conv_output, elementwise_op});
IR_NODE_LINK_TO(elementwise_add_x, conv_op); IR_NODE_LINK_TO(elementwise_x, conv_op);
IR_NODE_LINK_TO(conv_op, elementwise_add_out); IR_NODE_LINK_TO(conv_op, elementwise_out);
found_conv_as_y_count++; found_conv_as_y_count++;
}; };
...@@ -282,8 +282,8 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( ...@@ -282,8 +282,8 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
patterns::Conv conv_y_pattern{pattern, name_scope}; patterns::Conv conv_y_pattern{pattern, name_scope};
auto conv_y_output = conv_y_pattern(); auto conv_y_output = conv_y_pattern();
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope}; patterns::Elementwise elementwise_pattern{pattern, name_scope};
elementwise_add_pattern(conv_x_output, conv_y_output); elementwise_pattern(conv_x_output, conv_y_output, "elementwise_add");
conv_x_output->AsIntermediate(); conv_x_output->AsIntermediate();
conv_y_output->AsIntermediate(); conv_y_output->AsIntermediate();
...@@ -301,10 +301,10 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( ...@@ -301,10 +301,10 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
GET_IR_NODE_FROM_SUBGRAPH(conv_y_filter, conv_filter, conv_y_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_y_filter, conv_filter, conv_y_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_y_output, conv_output, conv_y_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_y_output, conv_output, conv_y_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, GET_IR_NODE_FROM_SUBGRAPH(elementwise_op, elementwise_op,
elementwise_add_pattern); elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out,
elementwise_add_pattern); elementwise_pattern);
if (!IsCompat(subgraph, g)) { if (!IsCompat(subgraph, g)) {
LOG(WARNING) LOG(WARNING)
...@@ -312,8 +312,8 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( ...@@ -312,8 +312,8 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
return; return;
} }
if (FindFuseOption(*conv_x_op, *elementwise_add_op) != FUSE_MKLDNN) return; if (FindFuseOption(*conv_x_op, *elementwise_op) != FUSE_MKLDNN) return;
if (FindFuseOption(*conv_y_op, *elementwise_add_op) != FUSE_MKLDNN) return; if (FindFuseOption(*conv_y_op, *elementwise_op) != FUSE_MKLDNN) return;
Node* projection_node; Node* projection_node;
Node* residual_conv_op; Node* residual_conv_op;
...@@ -333,14 +333,14 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( ...@@ -333,14 +333,14 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
if (HasFusedActivation(residual_conv_op)) return; if (HasFusedActivation(residual_conv_op)) return;
residual_conv_op->Op()->SetInput("ResidualData", {projection_node->Name()}); residual_conv_op->Op()->SetInput("ResidualData", {projection_node->Name()});
residual_conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()}); residual_conv_op->Op()->SetOutput("Output", {elementwise_out->Name()});
residual_conv_op->Op()->SetAttr("fuse_residual_connection", true); residual_conv_op->Op()->SetAttr("fuse_residual_connection", true);
GraphSafeRemoveNodes(g, {residual_conv_output, elementwise_add_op}); GraphSafeRemoveNodes(g, {residual_conv_output, elementwise_op});
IR_NODE_LINK_TO(projection_node, residual_conv_op); IR_NODE_LINK_TO(projection_node, residual_conv_op);
IR_NODE_LINK_TO(residual_conv_op, elementwise_add_out); IR_NODE_LINK_TO(residual_conv_op, elementwise_out);
found_projection_conv_count++; found_projection_conv_count++;
}; };
......
...@@ -807,74 +807,74 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { ...@@ -807,74 +807,74 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
PrettyLogDetail("--- quantized %d matmul ops", quantize_matmul_count); PrettyLogDetail("--- quantized %d matmul ops", quantize_matmul_count);
} }
void CPUQuantizePass::QuantizeElementwiseAdd(Graph* graph) const { void CPUQuantizePass::QuantizeElementwise(
Graph* graph, const std::string elementwise_type) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern(); auto pattern = gpd.mutable_pattern();
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_}; patterns::Elementwise elementwise_pattern{pattern, name_scope_};
elementwise_add_pattern( elementwise_pattern(
pattern->NewNode(elementwise_add_pattern.elementwise_add_x_repr()), pattern->NewNode(elementwise_pattern.elementwise_x_repr()),
pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr())); pattern->NewNode(elementwise_pattern.elementwise_y_repr()),
elementwise_type);
int quantize_elementwise_add_count = 0; int quantize_elementwise_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "Quantize elementwise_add op"; VLOG(4) << "Quantize " + elementwise_type + " op";
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, GET_IR_NODE_FROM_SUBGRAPH(elementwise_op, elementwise_op,
elementwise_add_pattern); elementwise_pattern);
// skip if should not be quantized // skip if should not be quantized
if (!platform::HasOpINT8DataType(elementwise_add_op->Op())) { if (!platform::HasOpINT8DataType(elementwise_op->Op())) {
LogQuantizationDisabled(elementwise_add_op); LogQuantizationDisabled(elementwise_op);
return; return;
} }
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x, GET_IR_NODE_FROM_SUBGRAPH(elementwise_x, elementwise_x,
elementwise_add_pattern); elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y, GET_IR_NODE_FROM_SUBGRAPH(elementwise_y, elementwise_y,
elementwise_add_pattern); elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out,
elementwise_add_pattern); elementwise_pattern);
if (!AreScalesPresentForNodes( if (!AreScalesPresentForNodes(
{elementwise_add_x, elementwise_add_y, elementwise_add_out})) { {elementwise_x, elementwise_y, elementwise_out})) {
LogCannotQuantizeOp(elementwise_add_op, LogCannotQuantizeOp(elementwise_op,
"No scale available for the operator"); "No scale available for the operator");
return; return;
} }
bool is_x_unsigned{false}, is_y_unsigned{false}; bool is_x_unsigned{false}, is_y_unsigned{false};
auto input_x_scale = auto input_x_scale = GetScaleValueForNode(elementwise_x, &is_x_unsigned);
GetScaleValueForNode(elementwise_add_x, &is_x_unsigned); auto input_y_scale = GetScaleValueForNode(elementwise_y, &is_y_unsigned);
auto input_y_scale =
GetScaleValueForNode(elementwise_add_y, &is_y_unsigned);
// TODO(sfraczek): add support for different signness // TODO(sfraczek): add support for different signness
if (is_x_unsigned != is_y_unsigned) { if (is_x_unsigned != is_y_unsigned) {
LogCannotQuantizeOp(elementwise_add_op, LogCannotQuantizeOp(elementwise_op,
"ElementwiseAdd inputs must be of the same type."); "Elementwise inputs must be of the same type.");
return; return;
} }
QuantizeInput(g, elementwise_add_op, elementwise_add_x, "X", input_x_scale, QuantizeInput(g, elementwise_op, elementwise_x, "X", input_x_scale,
is_x_unsigned, "Scale_x"); is_x_unsigned, "Scale_x");
QuantizeInput(g, elementwise_add_op, elementwise_add_y, "Y", input_y_scale, QuantizeInput(g, elementwise_op, elementwise_y, "Y", input_y_scale,
is_y_unsigned, "Scale_y"); is_y_unsigned, "Scale_y");
bool is_output_unsigned{false}; bool is_output_unsigned{false};
auto output_scale = auto output_scale =
GetScaleValueForNode(elementwise_add_out, &is_output_unsigned); GetScaleValueForNode(elementwise_out, &is_output_unsigned);
DequantizeOutput(g, elementwise_add_op, elementwise_add_out, "Out", DequantizeOutput(g, elementwise_op, elementwise_out, "Out", output_scale,
output_scale, is_output_unsigned, "Scale_out"); is_output_unsigned, "Scale_out");
++quantize_elementwise_add_count; ++quantize_elementwise_count;
}; };
gpd(graph, handler); gpd(graph, handler);
AddStatis(quantize_elementwise_add_count); AddStatis(quantize_elementwise_count);
PrettyLogDetail("--- quantized %d elementwise_add ops", PrettyLogDetail("--- quantized %d %s ops", quantize_elementwise_count,
quantize_elementwise_add_count); elementwise_type);
} }
void CPUQuantizePass::QuantizeFusionGru(Graph* graph) const { void CPUQuantizePass::QuantizeFusionGru(Graph* graph) const {
...@@ -1146,7 +1146,8 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { ...@@ -1146,7 +1146,8 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
QuantizeFc(graph); QuantizeFc(graph);
QuantizeReshape(graph); QuantizeReshape(graph);
QuantizeMatmul(graph); QuantizeMatmul(graph);
QuantizeElementwiseAdd(graph); QuantizeElementwise(graph, "elementwise_add");
QuantizeElementwise(graph, "elementwise_mul");
QuantizeFusionGru(graph); QuantizeFusionGru(graph);
QuantizeMultiGru(graph); QuantizeMultiGru(graph);
QuantizeFusionLSTM(graph); QuantizeFusionLSTM(graph);
......
...@@ -57,7 +57,8 @@ class CPUQuantizePass : public FusePassBase { ...@@ -57,7 +57,8 @@ class CPUQuantizePass : public FusePassBase {
void QuantizeTranspose(Graph* graph) const; void QuantizeTranspose(Graph* graph) const;
void QuantizeReshape(Graph* graph) const; void QuantizeReshape(Graph* graph) const;
void QuantizeMatmul(Graph* graph) const; void QuantizeMatmul(Graph* graph) const;
void QuantizeElementwiseAdd(Graph* graph) const; void QuantizeElementwise(Graph* graph,
const std::string elementwise_type) const;
void QuantizeFusionGru(Graph* graph) const; void QuantizeFusionGru(Graph* graph) const;
void QuantizeMultiGru(Graph* graph) const; void QuantizeMultiGru(Graph* graph) const;
void QuantizeFusionLSTM(Graph* graph) const; void QuantizeFusionLSTM(Graph* graph) const;
......
...@@ -90,7 +90,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -90,7 +90,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetAttr("Scale_x", 1.0f); op->SetAttr("Scale_x", 1.0f);
op->SetAttr("Scale_y", 1.0f); op->SetAttr("Scale_y", 1.0f);
op->SetAttr("Scale_out", 1.0f); op->SetAttr("Scale_out", 1.0f);
} else if (type == "elementwise_add") { } else if (type == "elementwise_add" || type == "elementwise_mul") {
op->SetInput("X", {inputs[0]}); op->SetInput("X", {inputs[0]});
if (inputs.size() > 1) op->SetInput("Y", {inputs[1]}); if (inputs.size() > 1) op->SetInput("Y", {inputs[1]});
op->SetOutput("Out", {outputs[0]}); op->SetOutput("Out", {outputs[0]});
...@@ -167,7 +167,8 @@ void CheckScales(const OpDesc* op, float scale, float shift) { ...@@ -167,7 +167,8 @@ void CheckScales(const OpDesc* op, float scale, float shift) {
scale); scale);
scale_names.push_back("Scale_in"); scale_names.push_back("Scale_in");
scale_names.push_back("Scale_out"); scale_names.push_back("Scale_out");
} else if (type == "matmul" || type == "elementwise_add") { } else if (type == "matmul" || type == "elementwise_add" ||
type == "elementwise_mul") {
scale_names.push_back("Scale_x"); scale_names.push_back("Scale_x");
scale_names.push_back("Scale_y"); scale_names.push_back("Scale_y");
scale_names.push_back("Scale_out"); scale_names.push_back("Scale_out");
...@@ -546,46 +547,77 @@ TEST(CpuQuantizePass, matmul_not_quantized) { ...@@ -546,46 +547,77 @@ TEST(CpuQuantizePass, matmul_not_quantized) {
expected_operators, added_nodes, 1.0f); expected_operators, added_nodes, 1.0f);
} }
static const std::initializer_list<std::string> variable_names_elementwise_add = static const std::initializer_list<std::string> variable_names_elementwise = {
{"a", "b", "c", "d", "e", "f"}; "a", "b", "c", "d", "e", "f"};
ProgramDesc BuildProgramDescElementwiseAdd() { ProgramDesc BuildProgramDescElementwise(const std::string elementwise_type,
const std::string elementwise_name) {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : variable_names_elementwise_add) { for (auto& v : variable_names_elementwise) {
prog.MutableBlock(0)->Var(v); prog.MutableBlock(0)->Var(v);
} }
SetOp(&prog, "dequantize", "Dequantize1", {"a"}, {"b"}, true); SetOp(&prog, "dequantize", "Dequantize1", {"a"}, {"b"}, true);
SetOp(&prog, "dequantize", "Dequantize2", {"c"}, {"d"}, true); SetOp(&prog, "dequantize", "Dequantize2", {"c"}, {"d"}, true);
SetOp(&prog, "elementwise_add", "ElementwiseAdd", {"b", "d"}, {"e"}, true, SetOp(&prog, elementwise_type, elementwise_name, {"b", "d"}, {"e"}, true,
"int8"); "int8");
SetOp(&prog, "dropout", "Dropout", {"e"}, {"f"}, true, "float32"); SetOp(&prog, "dropout", "Dropout", {"e"}, {"f"}, true, "float32");
return prog; return prog;
} }
TEST(CpuQuantizePass, elementwise_add) { void TestElementwise(const std::string elementwise_type,
const std::string elementwise_name) {
// 2 Quant + 2 IN + 1 DeQuant + 1 OUT // 2 Quant + 2 IN + 1 DeQuant + 1 OUT
int added_nodes = 6; int added_nodes = 6;
std::unordered_map<std::string, int> expected_operators = { std::unordered_map<std::string, int> expected_operators = {
{"elementwise_add", 1}, {"quantize", 2}, {"dequantize", 3}}; {elementwise_type, 1}, {"quantize", 2}, {"dequantize", 3}};
MainTest(BuildProgramDescElementwiseAdd(), variable_names_elementwise_add, MainTest(BuildProgramDescElementwise(elementwise_type, elementwise_name),
expected_operators, added_nodes, SCALE * S8_MAX); variable_names_elementwise, expected_operators, added_nodes,
SCALE * S8_MAX);
} }
TEST(CpuQuantizePass, elementwise_add_output_scale_missing) { void TestElementwiseOutputScaleMissing(const std::string elementwise_type,
const std::string elementwise_name) {
int added_nodes = 0; int added_nodes = 0;
std::unordered_map<std::string, int> expected_operators = { std::unordered_map<std::string, int> expected_operators = {
{"elementwise_add", 1}, {"quantize", 0}, {"dequantize", 2}}; {elementwise_type, 1}, {"quantize", 0}, {"dequantize", 2}};
MainTest(BuildProgramDescElementwiseAdd(), variable_names_elementwise_add, MainTest(BuildProgramDescElementwise(elementwise_type, elementwise_name),
expected_operators, added_nodes, 1.f, 1.f, "e"); variable_names_elementwise, expected_operators, added_nodes, 1.f,
1.f, "e");
} }
TEST(CpuQuantizePass, elementwise_add_unsigned_and_signed_input) { void TestElementwiseUnsignedAndSignedInput(const std::string elementwise_type,
const std::string elementwise_name) {
int added_nodes = 0; int added_nodes = 0;
std::unordered_map<std::string, int> expected_operators = { std::unordered_map<std::string, int> expected_operators = {
{"elementwise_add", 1}, {"quantize", 0}, {"dequantize", 2}}; {elementwise_type, 1}, {"quantize", 0}, {"dequantize", 2}};
MainTest(BuildProgramDescElementwiseAdd(), variable_names_elementwise_add, MainTest(BuildProgramDescElementwise(elementwise_type, elementwise_name),
expected_operators, added_nodes, 1.f, 1.f, "", "b"); variable_names_elementwise, expected_operators, added_nodes, 1.f,
1.f, "", "b");
}
TEST(CpuQuantizePass, elementwise_add) {
TestElementwise("elementwise_add", "ElementwiseAdd");
}
TEST(CpuQuantizePass, elementwise_add_output_scale_missing) {
TestElementwiseOutputScaleMissing("elementwise_add", "ElementwiseAdd");
}
TEST(CpuQuantizePass, elementwise_add_unsigned_and_signed_input) {
TestElementwiseUnsignedAndSignedInput("elementwise_add", "ElementwiseAdd");
}
TEST(CpuQuantizePass, elementwise_mul) {
TestElementwise("elementwise_mul", "ElementwiseMul");
}
TEST(CpuQuantizePass, elementwise_mul_output_scale_missing) {
TestElementwiseOutputScaleMissing("elementwise_mul", "ElementwiseMul");
}
TEST(CpuQuantizePass, elementwise_mul_unsigned_and_signed_input) {
TestElementwiseUnsignedAndSignedInput("elementwise_mul", "ElementwiseMul");
} }
const std::vector<std::string> churn_out_vars(ProgramDesc* prog, const std::vector<std::string> churn_out_vars(ProgramDesc* prog,
......
...@@ -26,10 +26,10 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const { ...@@ -26,10 +26,10 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Marks operators which are to be quantized."; VLOG(3) << "Marks operators which are to be quantized.";
std::unordered_set<std::string> supported_op_types = std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>( std::unordered_set<std::string>(
{"concat", "conv2d", "depthwise_conv2d", "elementwise_add", "fc", {"concat", "conv2d", "depthwise_conv2d", "elementwise_add",
"matmul", "nearest_interp", "nearest_interp_v2", "pool2d", "elementwise_mul", "fc", "matmul", "nearest_interp",
"prior_box", "reshape2", "transpose2", "fusion_gru", "fusion_lstm", "nearest_interp_v2", "pool2d", "prior_box", "reshape2", "transpose2",
"multi_gru", "slice"}); "fusion_gru", "fusion_lstm", "multi_gru", "slice"});
const auto& excluded_ids_list = const auto& excluded_ids_list =
Get<std::unordered_set<int>>("quantize_excluded_op_ids"); Get<std::unordered_set<int>>("quantize_excluded_op_ids");
const auto& op_types_list = const auto& op_types_list =
......
...@@ -31,7 +31,7 @@ USE_OP(slice); ...@@ -31,7 +31,7 @@ USE_OP(slice);
USE_OP(concat); USE_OP(concat);
USE_OP(matmul); USE_OP(matmul);
USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP(sigmoid); USE_OP_ITSELF(sigmoid);
USE_OP_ITSELF(tanh); USE_OP_ITSELF(tanh);
USE_OP(elementwise_mul); USE_OP(elementwise_mul);
USE_OP(softmax_with_cross_entropy); USE_OP(softmax_with_cross_entropy);
...@@ -47,7 +47,7 @@ USE_OP(square); ...@@ -47,7 +47,7 @@ USE_OP(square);
USE_OP(transpose2_grad); USE_OP(transpose2_grad);
USE_OP(concat_grad); USE_OP(concat_grad);
USE_OP_ITSELF(elementwise_mul_grad); USE_OP_ITSELF(elementwise_mul_grad);
USE_OP(sigmoid_grad); USE_OP_ITSELF(sigmoid_grad);
USE_OP_ITSELF(tanh_grad); USE_OP_ITSELF(tanh_grad);
USE_OP(sum); USE_OP(sum);
USE_OP(slice_grad); USE_OP(slice_grad);
......
...@@ -2103,16 +2103,25 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2103,16 +2103,25 @@ void OperatorWithKernel::BuildPhiKernelContext(
auto* var = ins_vector[offset]; auto* var = ins_vector[offset];
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
tensor_in = &(var->Get<framework::LoDTensor>()); tensor_in = &(var->Get<framework::LoDTensor>());
pt_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in);
} else if (var->IsType<phi::SelectedRows>()) { } else if (var->IsType<phi::SelectedRows>()) {
tensor_in = &(var->Get<phi::SelectedRows>()); tensor_in = &(var->Get<phi::SelectedRows>());
pt_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in);
} else if (var->IsType<framework::LoDTensorArray>()) {
paddle::SmallVector<const phi::TensorBase*> tensor_vector;
auto& tensor_array = var->Get<framework::LoDTensorArray>();
for (auto& t : tensor_array) {
tensor_vector.emplace_back(&t);
}
pt_kernel_context->EmplaceBackInputsWithoutSetRange(tensor_vector);
end_idx += tensor_array.size() - 1;
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported input `%s` type when call pt kernel.", "Unsupported input `%s` type when call pt kernel.",
framework::ToTypeName(var->Type()))); framework::ToTypeName(var->Type())));
} }
pt_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in);
} }
// Note: here cannot deal with vector<LoDTensorArray> input
pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx), i); pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx), i);
} }
VLOG(4) << "Done inputs"; VLOG(4) << "Done inputs";
...@@ -2140,22 +2149,33 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2140,22 +2149,33 @@ void OperatorWithKernel::BuildPhiKernelContext(
for (size_t offset = 0; offset < outs_vector.size(); ++offset) { for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
phi::TensorBase* tensor_out = nullptr; phi::TensorBase* tensor_out = nullptr;
auto* var = outs_vector[offset]; auto* var = outs_vector[offset];
if (var) { if (var) {
if (var->template IsType<framework::LoDTensor>()) { if (var->template IsType<framework::LoDTensor>()) {
tensor_out = var->template GetMutable<framework::LoDTensor>(); tensor_out = var->template GetMutable<framework::LoDTensor>();
pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
} else if (var->template IsType<phi::SelectedRows>()) { } else if (var->template IsType<phi::SelectedRows>()) {
tensor_out = var->template GetMutable<phi::SelectedRows>(); tensor_out = var->template GetMutable<phi::SelectedRows>();
pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
} else if (var->template IsType<framework::LoDTensorArray>()) {
paddle::SmallVector<phi::TensorBase*> tensor_vector;
auto* tensor_array =
var->template GetMutable<framework::LoDTensorArray>();
// Note: If the input LoDTensorArray size is 0, the output
// LoDTensorArray is also 0
for (auto& t : *tensor_array) {
tensor_vector.emplace_back(&t);
}
pt_kernel_context->EmplaceBackOutputsWithoutSetRange(tensor_vector);
end_idx += tensor_array->size() - 1;
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported output `%s` type when call pt kernel.", "Unsupported output `%s` type when call pt kernel.",
framework::ToTypeName(var->Type()))); framework::ToTypeName(var->Type())));
} }
} else {
pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
} }
pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
} }
pt_kernel_context->AssignOutputRange(std::make_pair(start_idx, end_idx), i); pt_kernel_context->AssignOutputRange(std::make_pair(start_idx, end_idx), i);
} }
VLOG(4) << "Done outputs"; VLOG(4) << "Done outputs";
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册