diff --git a/paddle/fluid/distributed/ps/service/ps_local_client.cc b/paddle/fluid/distributed/ps/service/ps_local_client.cc index ecc8819102ed324abe24ead304539505cad422f3..3e7396c36a6363374cbfa494b98359b9eed04521 100644 --- a/paddle/fluid/distributed/ps/service/ps_local_client.cc +++ b/paddle/fluid/distributed/ps/service/ps_local_client.cc @@ -306,10 +306,10 @@ int32_t PsLocalClient::Initialize() { size_t threshold) { auto* table_ptr = GetTable(table_id); std::pair ret = table_ptr->PrintTableStat(); - VLOG(0) << "table id: " << table_id << ", feasign size: " << ret.first + VLOG(1) << "table id: " << table_id << ", feasign size: " << ret.first << ", mf size: " << ret.second; if (ret.first > (int64_t)threshold) { - VLOG(0) << "run cache table"; + VLOG(1) << "run cache table"; table_ptr->CacheTable(pass_id); } return done(); diff --git a/paddle/fluid/distributed/ps/table/common_graph_table.cc b/paddle/fluid/distributed/ps/table/common_graph_table.cc index 2b545047c3bfebc071678fb3d488ecd047fc9653..ec35fd3db4ff41dd66396c91e60aa20f72d88c9b 100644 --- a/paddle/fluid/distributed/ps/table/common_graph_table.cc +++ b/paddle/fluid/distributed/ps/table/common_graph_table.cc @@ -124,11 +124,13 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea( } for (size_t i = 0; i < tasks.size(); i++) tasks[i].get(); - std::stringstream ss; - for (int k = 0; k < slot_num; ++k) { - ss << slot_feature_num_map_[k] << " "; + if (FLAGS_v > 0) { + std::stringstream ss; + for (int k = 0; k < slot_num; ++k) { + ss << slot_feature_num_map_[k] << " "; + } + VLOG(1) << "slot_feature_num_map: " << ss.str(); } - VLOG(0) << "slot_feature_num_map: " << ss.str(); tasks.clear(); @@ -137,7 +139,7 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea( for (size_t i = 0; i < shard_num; i++) { tot_len += feature_array[i].size(); } - VLOG(0) << "Loaded feature table on cpu, feature_list_size[" << tot_len + VLOG(1) << "Loaded feature table on cpu, feature_list_size[" << tot_len << "] node_ids_size[" << node_ids.size() << "]"; res.init_on_cpu(tot_len, (unsigned int)node_ids.size(), slot_num); unsigned int offset = 0, ind = 0; @@ -494,6 +496,8 @@ void GraphTable::export_partition_files(int idx, std::string file_path) { for (size_t i = 0; i < tasks.size(); i++) tasks[i].get(); } +#endif + void GraphTable::clear_graph(int idx) { for (auto p : edge_shards[idx]) { p->clear(); @@ -506,6 +510,7 @@ void GraphTable::clear_graph(int idx) { } } +#ifdef PADDLE_WITH_HETERPS void GraphTable::release_graph() { // Before releasing graph, prepare for sampling ids and embedding keys. build_graph_type_keys(); @@ -545,6 +550,7 @@ void GraphTable::release_graph_node() { feature_shrink_to_fit(); } } +#endif void GraphTable::clear_edge_shard() { VLOG(0) << "begin clear edge shard"; @@ -590,6 +596,7 @@ void GraphTable::clear_feature_shard() { VLOG(0) << "finish clear feature shard"; } +#ifdef PADDLE_WITH_HETERPS void GraphTable::feature_shrink_to_fit() { std::vector> tasks; for (auto &type_shards : feature_shards) { @@ -619,6 +626,8 @@ void GraphTable::merge_feature_shard() { feature_shards.resize(1); } +#endif + void GraphTable::clear_graph() { VLOG(0) << "begin clear_graph"; clear_edge_shard(); @@ -626,6 +635,7 @@ void GraphTable::clear_graph() { VLOG(0) << "finish clear_graph"; } +#ifdef PADDLE_WITH_HETERPS int32_t GraphTable::load_next_partition(int idx) { if (next_partition >= static_cast(partitions[idx].size())) { VLOG(0) << "partition iteration is done"; @@ -1203,11 +1213,21 @@ int32_t GraphTable::Load(const std::string &path, const std::string ¶m) { if (load_edge) { bool reverse_edge = (param[1] == '<'); std::string edge_type = param.substr(2); - return this->load_edges(path, reverse_edge, edge_type); + int ret = this->load_edges(path, reverse_edge, edge_type); + if (ret != 0) { + VLOG(0) << "Fail to load edges, path[" << path << "] edge_type[" + << edge_type << "]"; + return -1; + } } if (load_node) { std::string node_type = param.substr(1); - return this->load_nodes(path, node_type); + int ret = this->load_nodes(path, node_type); + if (ret != 0) { + VLOG(0) << "Fail to load nodes, path[" << path << "] node_type[" + << node_type << "]"; + return -1; + } } return 0; } @@ -1319,10 +1339,19 @@ int32_t GraphTable::parse_node_and_load(std::string ntype2files, return 0; } if (FLAGS_graph_load_in_parallel) { - this->load_nodes(npath_str, ""); + int ret = this->load_nodes(npath_str, ""); + if (ret != 0) { + VLOG(0) << "Fail to load nodes, path[" << npath << "]"; + return -1; + } } else { for (size_t j = 0; j < ntypes.size(); j++) { - this->load_nodes(npath_str, ntypes[j]); + int ret = this->load_nodes(npath_str, ntypes[j]); + if (ret != 0) { + VLOG(0) << "Fail to load nodes, path[" << npath << "], ntypes[" + << ntypes[j] << "]"; + return -1; + } } } return 0; @@ -1397,10 +1426,19 @@ int32_t GraphTable::load_node_and_edge_file(std::string etype2files, return 0; } if (FLAGS_graph_load_in_parallel) { - this->load_nodes(npath_str, ""); + int ret = this->load_nodes(npath_str, ""); + if (ret != 0) { + VLOG(0) << "Fail to load nodes, path[" << npath_str << "]"; + return -1; + } } else { for (size_t j = 0; j < ntypes.size(); j++) { - this->load_nodes(npath_str, ntypes[j]); + int ret = this->load_nodes(npath_str, ntypes[j]); + if (ret != 0) { + VLOG(0) << "Fail to load nodes, path[" << npath_str + << "], ntypes[" << ntypes[j] << "]"; + return -1; + } } } } @@ -1408,6 +1446,10 @@ int32_t GraphTable::load_node_and_edge_file(std::string etype2files, })); } for (size_t i = 0; i < tasks.size(); i++) tasks[i].get(); + if (is_parse_node_fail_) { + VLOG(0) << "Fail to load node_and_edge_file"; + return -1; + } return 0; } @@ -1499,7 +1541,12 @@ std::pair GraphTable::parse_node_file( node->set_feature_size(feat_name[idx].size()); for (int i = 1; i < num; ++i) { auto &v = vals[i]; - parse_feature(idx, v.ptr, v.len, node); + int ret = parse_feature(idx, v.ptr, v.len, node); + if (ret != 0) { + VLOG(0) << "Fail to parse feature, node_id[" << id << "]"; + is_parse_node_fail_ = true; + return {0, 0}; + } } } local_valid_count++; @@ -1551,7 +1598,12 @@ std::pair GraphTable::parse_node_file( if (node != NULL) { for (int i = 2; i < num; ++i) { auto &v = vals[i]; - parse_feature(idx, v.ptr, v.len, node); + int ret = parse_feature(idx, v.ptr, v.len, node); + if (ret != 0) { + VLOG(0) << "Fail to parse feature, node_id[" << id << "]"; + is_parse_node_fail_ = true; + return {0, 0}; + } } } local_valid_count++; @@ -1603,6 +1655,11 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) { valid_count += res.second; } } + if (is_parse_node_fail_) { + VLOG(0) << "Fail to load nodes, path[" << paths[0] << ".." + << paths[paths.size() - 1] << "] node_type[" << node_type << "]"; + return -1; + } VLOG(0) << valid_count << "/" << count << " nodes in node_type[ " << node_type << "] are loaded successfully!"; @@ -2103,28 +2160,48 @@ int GraphTable::parse_feature(int idx, if (dtype == "feasign") { // string_vector_2_string(fields.begin() + 1, fields.end(), ' ', // fea_ptr); - FeatureNode::parse_value_to_bytes( + int ret = FeatureNode::parse_value_to_bytes( fea_fields.begin(), fea_fields.end(), fea_ptr); + if (ret != 0) { + VLOG(0) << "Fail to parse value"; + return -1; + } return 0; } else if (dtype == "string") { string_vector_2_string( fea_fields.begin(), fea_fields.end(), ' ', fea_ptr); return 0; } else if (dtype == "float32") { - FeatureNode::parse_value_to_bytes( + int ret = FeatureNode::parse_value_to_bytes( fea_fields.begin(), fea_fields.end(), fea_ptr); + if (ret != 0) { + VLOG(0) << "Fail to parse value"; + return -1; + } return 0; } else if (dtype == "float64") { - FeatureNode::parse_value_to_bytes( + int ret = FeatureNode::parse_value_to_bytes( fea_fields.begin(), fea_fields.end(), fea_ptr); + if (ret != 0) { + VLOG(0) << "Fail to parse value"; + return -1; + } return 0; } else if (dtype == "int32") { - FeatureNode::parse_value_to_bytes( + int ret = FeatureNode::parse_value_to_bytes( fea_fields.begin(), fea_fields.end(), fea_ptr); + if (ret != 0) { + VLOG(0) << "Fail to parse value"; + return -1; + } return 0; } else if (dtype == "int64") { - FeatureNode::parse_value_to_bytes( + int ret = FeatureNode::parse_value_to_bytes( fea_fields.begin(), fea_fields.end(), fea_ptr); + if (ret != 0) { + VLOG(0) << "Fail to parse value"; + return -1; + } return 0; } } else { @@ -2132,7 +2209,7 @@ int GraphTable::parse_feature(int idx, << idx << "] feat_id_map_size[" << feat_id_map.size() << "]"; } - return -1; + return 0; } // thread safe shard vector merge class MergeShardVector { diff --git a/paddle/fluid/distributed/ps/table/common_graph_table.h b/paddle/fluid/distributed/ps/table/common_graph_table.h index b988ffa5fc3b56d0d1cb5f16119416db27104d91..79aef444d3555c74788d57ff6d3e57382b1fa966 100644 --- a/paddle/fluid/distributed/ps/table/common_graph_table.h +++ b/paddle/fluid/distributed/ps/table/common_graph_table.h @@ -789,6 +789,7 @@ class GraphTable : public Table { std::string slot_feature_separator_ = std::string(" "); std::string feature_separator_ = std::string(" "); std::vector slot_feature_num_map_; + bool is_parse_node_fail_ = false; }; /* diff --git a/paddle/fluid/distributed/ps/table/graph/graph_node.h b/paddle/fluid/distributed/ps/table/graph/graph_node.h index fc26b20da93907610e935996ab1f7a104f8741cb..e1b5143a5d876818455b16acad8bf287b2276c1e 100644 --- a/paddle/fluid/distributed/ps/table/graph/graph_node.h +++ b/paddle/fluid/distributed/ps/table/graph/graph_node.h @@ -255,7 +255,7 @@ class FeatureNode : public Node { } template - static void parse_value_to_bytes( + static int parse_value_to_bytes( std::vector::iterator feat_str_begin, std::vector::iterator feat_str_end, std::string *output) { @@ -269,8 +269,14 @@ class FeatureNode : public Node { thread_local paddle::string::str_ptr_stream ss; for (size_t i = 0; i < feat_str_size; i++) { ss.reset(*(feat_str_begin + i)); + int len = ss.end - ss.ptr; + char *old_ptr = ss.ptr; ss >> fea_ptrs[i]; + if (ss.ptr - old_ptr != len) { + return -1; + } } + return 0; } protected: diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 55ad228470cdfe37de4a148cf7112ec74d4a4ae6..b86d921b9c533fa11cb6145ea247acfeaf4864c1 100755 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -866,7 +866,8 @@ if(WITH_DISTRIBUTE) fleet heter_server brpc - fleet_executor) + fleet_executor + flags) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=parentheses" ) diff --git a/paddle/fluid/framework/barrier.h b/paddle/fluid/framework/barrier.h index e7aa976cc9eda6a075e62948f92ab9909745db55..d7d55853f54264966a5dc1f8df235244132e53cc 100644 --- a/paddle/fluid/framework/barrier.h +++ b/paddle/fluid/framework/barrier.h @@ -14,6 +14,11 @@ #pragma once +#if defined _WIN32 || defined __APPLE__ +#else +#define __LINUX__ +#endif + #ifdef __LINUX__ #include #include @@ -48,7 +53,7 @@ class Barrier { void wait() { #ifdef __LINUX__ int err = pthread_barrier_wait(&_barrier); - if (err != 0 && err != PTHREAD_BARRIER_SERIAL_THREAD)) { + if (err != 0 && err != PTHREAD_BARRIER_SERIAL_THREAD) { CHECK_EQ(1, 0); } #endif diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index f1b7d696a4ec034641a0503fdbe0d1778ef658db..2ff06916b41c4af32c7e4c2a62dd9c5a1b6c20c3 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -2112,15 +2112,24 @@ void SlotRecordInMemoryDataFeed::Init(const DataFeedDesc& data_feed_desc) { #if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS) gpu_graph_data_generator_.SetConfig(data_feed_desc); #endif + if (gpu_graph_mode_) { + train_mode_ = true; + } else { + train_mode_ = data_feed_desc.graph_config().gpu_graph_training(); + } } #if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS) void SlotRecordInMemoryDataFeed::InitGraphResource() { +#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS) gpu_graph_data_generator_.AllocResource(thread_id_, feed_vec_); +#endif } void SlotRecordInMemoryDataFeed::InitGraphTrainResource() { +#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS) gpu_graph_data_generator_.AllocTrainResource(thread_id_); +#endif } #endif diff --git a/paddle/fluid/framework/data_feed.cu b/paddle/fluid/framework/data_feed.cu index aec9fee25573a5268c8bbc1362a5f6000b3e2a76..f13f58f4ed4593e9d09450e2090b8ff0f049c580 100644 --- a/paddle/fluid/framework/data_feed.cu +++ b/paddle/fluid/framework/data_feed.cu @@ -435,7 +435,6 @@ __global__ void CopyDuplicateKeys(int64_t *dist_tensor, } int GraphDataGenerator::AcquireInstance(BufState *state) { - // if (state->GetNextStep()) { DEBUG_STATE(state); return state->len; @@ -449,66 +448,21 @@ int GraphDataGenerator::AcquireInstance(BufState *state) { return 0; } -// TODO(fengdanlei): opt -__global__ void GraphFillFeatureKernel(uint64_t *id_tensor, - int *fill_ins_num, - uint64_t *walk, - uint64_t *feature, - int *row, - int central_word, - int step, - int len, - int col_num, - int slot_num) { - __shared__ int32_t local_key[CUDA_NUM_THREADS * 16]; - __shared__ int local_num; - __shared__ int global_num; - - size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (threadIdx.x == 0) { - local_num = 0; - } - __syncthreads(); - if (idx < len) { - int src = row[idx] * col_num + central_word; - if (walk[src] != 0 && walk[src + step] != 0) { - size_t dst = atomicAdd(&local_num, 1); - for (int i = 0; i < slot_num; ++i) { - local_key[dst * 2 * slot_num + i * 2] = feature[src * slot_num + i]; - local_key[dst * 2 * slot_num + i * 2 + 1] = - feature[(src + step) * slot_num + i]; - } - } - } - - __syncthreads(); - - if (threadIdx.x == 0) { - global_num = atomicAdd(fill_ins_num, local_num); - } - __syncthreads(); - - if (threadIdx.x < local_num) { - for (int i = 0; i < slot_num; ++i) { - id_tensor[(global_num * 2 + 2 * threadIdx.x) * slot_num + i] = - local_key[(2 * threadIdx.x) * slot_num + i]; - id_tensor[(global_num * 2 + 2 * threadIdx.x + 1) * slot_num + i] = - local_key[(2 * threadIdx.x + 1) * slot_num + i]; - } - } -} - __global__ void GraphFillIdKernel(uint64_t *id_tensor, int *fill_ins_num, uint64_t *walk, + uint8_t *walk_ntype, int *row, int central_word, int step, int len, - int col_num) { + int col_num, + uint8_t *excluded_train_pair, + int excluded_train_pair_len) { __shared__ uint64_t local_key[CUDA_NUM_THREADS * 2]; __shared__ int local_num; __shared__ int global_num; + bool need_filter = false; size_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (threadIdx.x == 0) { @@ -521,9 +475,19 @@ __global__ void GraphFillIdKernel(uint64_t *id_tensor, if (idx < len) { int src = row[idx] * col_num + central_word; if (walk[src] != 0 && walk[src + step] != 0) { - size_t dst = atomicAdd(&local_num, 1); - local_key[dst * 2] = walk[src]; - local_key[dst * 2 + 1] = walk[src + step]; + for (int i = 0; i < excluded_train_pair_len; i += 2) { + if (walk_ntype[src] == excluded_train_pair[i] && + walk_ntype[src + step] == excluded_train_pair[i + 1]) { + // filter this pair + need_filter = true; + break; + } + } + if (!need_filter) { + size_t dst = atomicAdd(&local_num, 1); + local_key[dst * 2] = walk[src]; + local_key[dst * 2 + 1] = walk[src + step]; + } } } @@ -651,6 +615,10 @@ int GraphDataGenerator::FillGraphIdShowClkTensor(int uniq_instance, int index_offset = 3 + slot_num_ * 2 + 5 * samples_.size(); index_tensor_ptr_ = feed_vec_[index_offset]->mutable_data( {total_instance}, this->place_); + if (get_degree_) { + degree_tensor_ptr_ = feed_vec_[index_offset + 1]->mutable_data( + {uniq_instance * edge_to_id_len_}, this->place_); + } int len_samples = samples_.size(); int *num_nodes_tensor_ptr_[len_samples]; @@ -717,6 +685,13 @@ int GraphDataGenerator::FillGraphIdShowClkTensor(int uniq_instance, sizeof(int) * total_instance, cudaMemcpyDeviceToDevice, train_stream_); + if (get_degree_) { + cudaMemcpyAsync(degree_tensor_ptr_, + node_degree_vec_[index]->ptr(), + sizeof(int) * uniq_instance * edge_to_id_len_, + cudaMemcpyDeviceToDevice, + train_stream_); + } GraphFillCVMKernel<<(d_walk_->ptr()); + uint8_t *walk_ntype = NULL; + uint8_t *excluded_train_pair = NULL; + if (excluded_train_pair_len_ > 0) { + walk_ntype = reinterpret_cast(d_walk_ntype_->ptr()); + excluded_train_pair = + reinterpret_cast(d_excluded_train_pair_->ptr()); + } uint64_t *ins_buf = reinterpret_cast(d_ins_buf_->ptr()); int *random_row = reinterpret_cast(d_random_row_->ptr()); int *d_pair_num = reinterpret_cast(d_pair_num_->ptr()); @@ -763,11 +745,14 @@ int GraphDataGenerator::MakeInsPair(cudaStream_t stream) { ins_buf + ins_buf_pair_len_ * 2, d_pair_num, walk, + walk_ntype, random_row + buf_state_.cursor, buf_state_.central_word, window_step_[buf_state_.step], len, - walk_len_); + walk_len_, + excluded_train_pair, + excluded_train_pair_len_); int h_pair_num; cudaMemcpyAsync( &h_pair_num, d_pair_num, sizeof(int), cudaMemcpyDeviceToHost, stream); @@ -782,8 +767,9 @@ int GraphDataGenerator::MakeInsPair(cudaStream_t stream) { cudaMemcpyDeviceToHost); VLOG(2) << "h_pair_num = " << h_pair_num << ", ins_buf_pair_len = " << ins_buf_pair_len_; - for (int xx = 0; xx < 2 * ins_buf_pair_len_; xx++) { - VLOG(2) << "h_ins_buf[" << xx << "]: " << h_ins_buf[xx]; + for (int xx = 0; xx < ins_buf_pair_len_; xx++) { + VLOG(2) << "h_ins_buf: " << h_ins_buf[xx * 2] << ", " + << h_ins_buf[xx * 2 + 1]; } } return ins_buf_pair_len_; @@ -809,6 +795,7 @@ int GraphDataGenerator::GenerateBatch() { platform::CUDADeviceGuard guard(gpuid_); int res = 0; if (!gpu_graph_training_) { + // infer if (!sage_mode_) { total_instance = (infer_node_start_ + batch_size_ <= infer_node_end_) ? batch_size_ @@ -829,6 +816,7 @@ int GraphDataGenerator::GenerateBatch() { sage_batch_count_); } } else { + // train if (!sage_mode_) { while (ins_buf_pair_len_ < batch_size_) { res = FillInsBuf(train_stream_); @@ -908,6 +896,7 @@ __global__ void GraphFillSampleKeysKernel(uint64_t *neighbors, __global__ void GraphDoWalkKernel(uint64_t *neighbors, uint64_t *walk, + uint8_t *walk_ntype, int *d_prefix_sum, int *actual_sample_size, int cur_degree, @@ -915,7 +904,8 @@ __global__ void GraphDoWalkKernel(uint64_t *neighbors, int len, int *id_cnt, int *sampleidx2row, - int col_size) { + int col_size, + uint8_t edge_dst_id) { CUDA_KERNEL_LOOP(i, len) { for (int k = 0; k < actual_sample_size[i]; k++) { // int idx = sampleidx2row[i]; @@ -924,6 +914,9 @@ __global__ void GraphDoWalkKernel(uint64_t *neighbors, size_t col = step; size_t offset = (row * col_size + col); walk[offset] = neighbors[i * cur_degree + k]; + if (walk_ntype != NULL) { + walk_ntype[offset] = edge_dst_id; + } } } } @@ -932,7 +925,10 @@ __global__ void GraphDoWalkKernel(uint64_t *neighbors, __global__ void GraphFillFirstStepKernel(int *prefix_sum, int *sampleidx2row, uint64_t *walk, + uint8_t *walk_ntype, uint64_t *keys, + uint8_t edge_src_id, + uint8_t edge_dst_id, int len, int walk_degree, int col_size, @@ -948,6 +944,10 @@ __global__ void GraphFillFirstStepKernel(int *prefix_sum, size_t offset = col_size * row; walk[offset] = keys[idx]; walk[offset + 1] = neighbors[idx * walk_degree + k]; + if (walk_ntype != NULL) { + walk_ntype[offset] = edge_src_id; + walk_ntype[offset + 1] = edge_dst_id; + } } } } @@ -1071,12 +1071,18 @@ __global__ void UniqueFeature(uint64_t *d_in, } // Fill sample_res to the stepth column of walk void GraphDataGenerator::FillOneStep(uint64_t *d_start_ids, + int etype_id, uint64_t *walk, + uint8_t *walk_ntype, int len, NeighborSampleResult &sample_res, int cur_degree, int step, int *len_per_row) { + auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); + uint64_t node_id = gpu_graph_ptr->edge_to_node_map_[etype_id]; + uint8_t edge_src_id = node_id >> 32; + uint8_t edge_dst_id = node_id; size_t temp_storage_bytes = 0; int *d_actual_sample_size = sample_res.actual_sample_size; uint64_t *d_neighbors = sample_res.val; @@ -1114,7 +1120,10 @@ void GraphDataGenerator::FillOneStep(uint64_t *d_start_ids, sample_stream_>>>(d_prefix_sum, d_tmp_sampleidx2row, walk, + walk_ntype, d_start_ids, + edge_src_id, + edge_dst_id, len, walk_degree_, walk_len_, @@ -1138,6 +1147,7 @@ void GraphDataGenerator::FillOneStep(uint64_t *d_start_ids, GraphDoWalkKernel<<>>( d_neighbors, walk, + walk_ntype, d_prefix_sum, d_actual_sample_size, cur_degree, @@ -1145,7 +1155,8 @@ void GraphDataGenerator::FillOneStep(uint64_t *d_start_ids, len, len_per_row, d_tmp_sampleidx2row, - walk_len_); + walk_len_, + edge_dst_id); } if (debug_mode_) { size_t once_max_sample_keynum = walk_degree_ * once_sample_startid_len_; @@ -1844,6 +1855,22 @@ std::shared_ptr GraphDataGenerator::GenerateSampleGraph( return final_nodes_vec[len_samples - 1]; } +std::shared_ptr GraphDataGenerator::GetNodeDegree( + uint64_t *node_ids, int len) { + auto node_degree = memory::AllocShared( + place_, + len * edge_to_id_len_ * sizeof(int), + phi::Stream(reinterpret_cast(sample_stream_))); + auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); + auto edge_to_id = gpu_graph_ptr->edge_to_id; + for (auto &iter : edge_to_id) { + int edge_idx = iter.second; + gpu_graph_ptr->get_node_degree( + gpuid_, edge_idx, node_ids, len, node_degree); + } + return node_degree; +} + uint64_t GraphDataGenerator::CopyUniqueNodes() { if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) { uint64_t h_uniq_node_num = 0; @@ -1892,6 +1919,7 @@ void GraphDataGenerator::DoWalkandSage() { debug_gpu_memory_info(device_id, "DoWalkandSage start"); platform::CUDADeviceGuard guard(gpuid_); if (gpu_graph_training_) { + // train bool train_flag; if (FLAGS_graph_metapath_split_opt) { train_flag = FillWalkBufMultiPath(); @@ -1933,6 +1961,14 @@ void GraphDataGenerator::DoWalkandSage() { phi::Stream(reinterpret_cast(sample_stream_))); auto final_sage_nodes = GenerateSampleGraph( ins_cursor, total_instance, &uniq_instance, inverse); + uint64_t *final_sage_nodes_ptr = + reinterpret_cast(final_sage_nodes->ptr()); + if (get_degree_) { + auto node_degrees = + GetNodeDegree(final_sage_nodes_ptr, uniq_instance); + node_degree_vec_.emplace_back(node_degrees); + } + cudaStreamSynchronize(sample_stream_); if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) { uint64_t *final_sage_nodes_ptr = reinterpret_cast(final_sage_nodes->ptr()); @@ -1946,10 +1982,11 @@ void GraphDataGenerator::DoWalkandSage() { sage_batch_num_ += 1; } uint64_t h_uniq_node_num = CopyUniqueNodes(); - VLOG(0) << "train sage_batch_num: " << sage_batch_num_; + VLOG(1) << "train sage_batch_num: " << sage_batch_num_; } } } else { + // infer bool infer_flag = FillInferBuf(); if (sage_mode_) { sage_batch_num_ = 0; @@ -1982,6 +2019,13 @@ void GraphDataGenerator::DoWalkandSage() { phi::Stream(reinterpret_cast(sample_stream_))); auto final_sage_nodes = GenerateSampleGraph( node_buf_ptr_, total_instance, &uniq_instance, inverse); + uint64_t *final_sage_nodes_ptr = + reinterpret_cast(final_sage_nodes->ptr()); + if (get_degree_) { + auto node_degrees = + GetNodeDegree(final_sage_nodes_ptr, uniq_instance); + node_degree_vec_.emplace_back(node_degrees); + } cudaStreamSynchronize(sample_stream_); if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) { uint64_t *final_sage_nodes_ptr = @@ -2001,7 +2045,7 @@ void GraphDataGenerator::DoWalkandSage() { } uint64_t h_uniq_node_num = CopyUniqueNodes(); - VLOG(0) << "infer sage_batch_num: " << sage_batch_num_; + VLOG(1) << "infer sage_batch_num: " << sage_batch_num_; } } } @@ -2042,6 +2086,23 @@ int GraphDataGenerator::FillInferBuf() { return 0; } } + if (!infer_node_type_index_set_.empty()) { + while (infer_cursor < h_device_keys_len_.size()) { + if (infer_node_type_index_set_.find(infer_cursor) == + infer_node_type_index_set_.end()) { + VLOG(2) << "Skip cursor[" << infer_cursor << "]"; + infer_cursor++; + continue; + } else { + VLOG(2) << "Not skip cursor[" << infer_cursor << "]"; + break; + } + } + if (infer_cursor >= h_device_keys_len_.size()) { + return 0; + } + } + size_t device_key_size = h_device_keys_len_[infer_cursor]; total_row_ = (global_infer_node_type_start[infer_cursor] + infer_table_cap_ <= @@ -2104,6 +2165,11 @@ int GraphDataGenerator::FillWalkBuf() { int *len_per_row = reinterpret_cast(d_len_per_row_->ptr()); uint64_t *d_sample_keys = reinterpret_cast(d_sample_keys_->ptr()); cudaMemsetAsync(walk, 0, buf_size_ * sizeof(uint64_t), sample_stream_); + uint8_t *walk_ntype = NULL; + if (excluded_train_pair_len_ > 0) { + walk_ntype = reinterpret_cast(d_walk_ntype_->ptr()); + cudaMemsetAsync(walk_ntype, 0, buf_size_ * sizeof(uint8_t), sample_stream_); + } // cudaMemsetAsync( // len_per_row, 0, once_max_sample_keynum * sizeof(int), sample_stream_); int sample_times = 0; @@ -2157,6 +2223,10 @@ int GraphDataGenerator::FillWalkBuf() { VLOG(2) << "gpuid = " << gpuid_ << " path[0] = " << path[0]; uint64_t *cur_walk = walk + i; + uint8_t *cur_walk_ntype = NULL; + if (excluded_train_pair_len_ > 0) { + cur_walk_ntype = walk_ntype + i; + } NeighborSampleQuery q; q.initialize(gpuid_, @@ -2197,7 +2267,9 @@ int GraphDataGenerator::FillWalkBuf() { } } FillOneStep(d_type_keys + start, + path[0], cur_walk, + cur_walk_ntype, tmp_len, sample_res, walk_degree_, @@ -2248,7 +2320,9 @@ int GraphDataGenerator::FillWalkBuf() { } } FillOneStep(d_type_keys + start, + edge_type_id, cur_walk, + cur_walk_ntype, sample_key_len, sample_res, 1, @@ -2311,11 +2385,11 @@ int GraphDataGenerator::FillWalkBuf() { if (!sage_mode_) { uint64_t h_uniq_node_num = CopyUniqueNodes(); - VLOG(0) << "sample_times:" << sample_times << ", d_walk_size:" << buf_size_ + VLOG(1) << "sample_times:" << sample_times << ", d_walk_size:" << buf_size_ << ", d_walk_offset:" << i << ", total_rows:" << total_row_ << ", total_samples:" << total_samples; } else { - VLOG(0) << "sample_times:" << sample_times << ", d_walk_size:" << buf_size_ + VLOG(1) << "sample_times:" << sample_times << ", d_walk_size:" << buf_size_ << ", d_walk_offset:" << i << ", total_rows:" << total_row_ << ", total_samples:" << total_samples; } @@ -2341,6 +2415,10 @@ int GraphDataGenerator::FillWalkBufMultiPath() { /////// auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); uint64_t *walk = reinterpret_cast(d_walk_->ptr()); + uint8_t *walk_ntype = NULL; + if (excluded_train_pair_len_ > 0) { + walk_ntype = reinterpret_cast(d_walk_ntype_->ptr()); + } int *len_per_row = reinterpret_cast(d_len_per_row_->ptr()); uint64_t *d_sample_keys = reinterpret_cast(d_sample_keys_->ptr()); cudaMemsetAsync(walk, 0, buf_size_ * sizeof(uint64_t), sample_stream_); @@ -2359,7 +2437,7 @@ int GraphDataGenerator::FillWalkBufMultiPath() { size_t node_type_len = first_node_type.size(); std::string first_node = paddle::string::split_string(cur_metapath, "2")[0]; - auto it = gpu_graph_ptr->feature_to_id.find(first_node); + auto it = gpu_graph_ptr->node_to_id.find(first_node); auto node_type = it->second; int remain_size = @@ -2383,6 +2461,10 @@ int GraphDataGenerator::FillWalkBufMultiPath() { VLOG(2) << "gpuid = " << gpuid_ << " path[0] = " << path[0]; uint64_t *cur_walk = walk + i; + uint8_t *cur_walk_ntype = NULL; + if (excluded_train_pair_len_ > 0) { + cur_walk_ntype = walk_ntype + i; + } NeighborSampleQuery q; q.initialize(gpuid_, @@ -2421,7 +2503,9 @@ int GraphDataGenerator::FillWalkBufMultiPath() { } FillOneStep(d_type_keys + start, + path[0], cur_walk, + cur_walk_ntype, tmp_len, sample_res, walk_degree_, @@ -2472,7 +2556,9 @@ int GraphDataGenerator::FillWalkBufMultiPath() { } } FillOneStep(d_type_keys + start, + edge_type_id, cur_walk, + cur_walk_ntype, sample_key_len, sample_res, 1, @@ -2639,6 +2725,27 @@ void GraphDataGenerator::AllocResource( phi::Stream(reinterpret_cast(sample_stream_))); cudaMemsetAsync( d_walk_->ptr(), 0, buf_size_ * sizeof(uint64_t), sample_stream_); + + excluded_train_pair_len_ = gpu_graph_ptr->excluded_train_pair_.size(); + if (excluded_train_pair_len_ > 0) { + d_excluded_train_pair_ = memory::AllocShared( + place_, + excluded_train_pair_len_ * sizeof(uint8_t), + phi::Stream(reinterpret_cast(sample_stream_))); + CUDA_CHECK(cudaMemcpyAsync(d_excluded_train_pair_->ptr(), + gpu_graph_ptr->excluded_train_pair_.data(), + excluded_train_pair_len_ * sizeof(uint8_t), + cudaMemcpyHostToDevice, + sample_stream_)); + + d_walk_ntype_ = memory::AllocShared( + place_, + buf_size_ * sizeof(uint8_t), + phi::Stream(reinterpret_cast(sample_stream_))); + cudaMemsetAsync( + d_walk_ntype_->ptr(), 0, buf_size_ * sizeof(uint8_t), sample_stream_); + } + d_sample_keys_ = memory::AllocShared( place_, once_max_sample_keynum * sizeof(uint64_t), @@ -2735,6 +2842,29 @@ void GraphDataGenerator::AllocResource( phi::Stream(reinterpret_cast(sample_stream_))); } + // parse infer_node_type + auto &type_to_index = gpu_graph_ptr->get_graph_type_to_index(); + if (!gpu_graph_training_) { + auto node_types = + paddle::string::split_string(infer_node_type_, ";"); + auto node_to_id = gpu_graph_ptr->node_to_id; + for (auto &type : node_types) { + auto iter = node_to_id.find(type); + PADDLE_ENFORCE_NE( + iter, + node_to_id.end(), + platform::errors::NotFound("(%s) is not found in node_to_id.", type)); + int node_type = iter->second; + int type_index = type_to_index[node_type]; + VLOG(2) << "add node[" << type + << "] into infer_node_type, type_index(cursor)[" << type_index + << "]"; + infer_node_type_index_set_.insert(type_index); + } + VLOG(2) << "infer_node_type_index_set_num: " + << infer_node_type_index_set_.size(); + } + cudaStreamSynchronize(sample_stream_); debug_gpu_memory_info(gpuid_, "AllocResource end"); @@ -2774,6 +2904,7 @@ void GraphDataGenerator::SetConfig( once_sample_startid_len_ * walk_len_ * walk_degree_ * repeat_time_; train_table_cap_ = graph_config.train_table_cap(); infer_table_cap_ = graph_config.infer_table_cap(); + get_degree_ = graph_config.get_degree(); epoch_finish_ = false; VLOG(1) << "Confirm GraphConfig, walk_degree : " << walk_degree_ << ", walk_len : " << walk_len_ << ", window : " << window_ @@ -2788,7 +2919,8 @@ void GraphDataGenerator::SetConfig( std::string str_samples = graph_config.samples(); auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); debug_gpu_memory_info("init_conf start"); - gpu_graph_ptr->init_conf(first_node_type, meta_path); + gpu_graph_ptr->init_conf( + first_node_type, meta_path, graph_config.excluded_train_pair()); debug_gpu_memory_info("init_conf end"); auto edge_to_id = gpu_graph_ptr->edge_to_id; @@ -2800,6 +2932,10 @@ void GraphDataGenerator::SetConfig( samples_.emplace_back(sample_size); } copy_unique_len_ = 0; + + if (!gpu_graph_training_) { + infer_node_type_ = graph_config.infer_node_type(); + } } } // namespace framework diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 77bce7933816195ae59b4e38fc180334713d5759..47c703402131b0883f0e2a3b4bad87d86ebd59b1 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -914,7 +914,9 @@ class GraphDataGenerator { int FillFeatureBuf(std::shared_ptr d_walk, std::shared_ptr d_feature); void FillOneStep(uint64_t* start_ids, + int etype_id, uint64_t* walk, + uint8_t* walk_ntype, int len, NeighborSampleResult& sample_res, // NOLINT int cur_degree, @@ -966,6 +968,7 @@ class GraphDataGenerator { int len, int* uniq_len, std::shared_ptr& inverse); // NOLINT + std::shared_ptr GetNodeDegree(uint64_t* node_ids, int len); int InsertTable(const uint64_t* d_keys, uint64_t len, std::shared_ptr d_uniq_node_num); @@ -988,6 +991,7 @@ class GraphDataGenerator { int* index_tensor_ptr_; int64_t* show_tensor_ptr_; int64_t* clk_tensor_ptr_; + int* degree_tensor_ptr_; cudaStream_t train_stream_; cudaStream_t sample_stream_; @@ -999,6 +1003,8 @@ class GraphDataGenerator { std::shared_ptr d_train_metapath_keys_; std::shared_ptr d_walk_; + std::shared_ptr d_walk_ntype_; + std::shared_ptr d_excluded_train_pair_; std::shared_ptr d_feature_list_; std::shared_ptr d_feature_; std::shared_ptr d_len_per_row_; @@ -1033,11 +1039,13 @@ class GraphDataGenerator { // sage mode batch data std::vector> inverse_vec_; std::vector> final_sage_nodes_vec_; + std::vector> node_degree_vec_; std::vector uniq_instance_vec_; std::vector total_instance_vec_; std::vector>> graph_edges_vec_; std::vector>> edges_split_num_vec_; + int excluded_train_pair_len_; int64_t reindex_table_size_; int sage_batch_count_; int sage_batch_num_; @@ -1067,6 +1075,9 @@ class GraphDataGenerator { int total_row_; size_t infer_node_start_; size_t infer_node_end_; + std::set infer_node_type_index_set_; + std::string infer_node_type_; + bool get_degree_; }; class DataFeed { diff --git a/paddle/fluid/framework/data_feed.proto b/paddle/fluid/framework/data_feed.proto index e7ecf06e1551a914449f539d10b6d1715ed78658..7f81711b7c8e57d4d263d93c4aa105fe6585374b 100644 --- a/paddle/fluid/framework/data_feed.proto +++ b/paddle/fluid/framework/data_feed.proto @@ -42,6 +42,9 @@ message GraphConfig { optional string samples = 12; optional int64 train_table_cap = 13 [ default = 80000 ]; optional int64 infer_table_cap = 14 [ default = 80000 ]; + optional string excluded_train_pair = 15; + optional string infer_node_type = 16; + optional bool get_degree = 17 [ default = false ]; } message DataFeedDesc { diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index baa12950d0d6c929b41ed78ffc2a5b5b046ffdac..3e9224b96304c65b9e4ee5fffa2281fceda0c994 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -457,7 +457,7 @@ void DatasetImpl::LoadIntoMemory() { timeline.Start(); std::vector load_threads; if (gpu_graph_mode_) { - VLOG(0) << "in gpu_graph_mode"; + VLOG(1) << "in gpu_graph_mode"; #if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS) for (size_t i = 0; i < readers_.size(); i++) { readers_[i]->SetGpuGraphMode(gpu_graph_mode_); @@ -470,7 +470,6 @@ void DatasetImpl::LoadIntoMemory() { readers_[i]->ResetPathNum(); readers_[i]->ResetEpochFinish(); } - return; } for (int64_t i = 0; i < thread_num_; ++i) { diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index 921df0452c7ae77d9a0bfc3decffb7806b7a4986..349996aee3b8ae71816e3371a9798815b1cb6b2f 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -30,7 +30,7 @@ limitations under the License. */ #if defined(PADDLE_WITH_PSCORE) #include "paddle/fluid/distributed/ps/wrapper/fleet.h" #endif - +#include "paddle/fluid/framework/barrier.h" #include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/executor_gc_helper.h" #include "paddle/fluid/framework/heter_util.h" @@ -212,6 +212,9 @@ class DeviceWorker { virtual void SetDeviceContext(platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; } + + virtual void SetThreadNum(int thread_num) { thread_num_ = thread_num; } + virtual Scope* GetThreadScope() { return thread_scope_; } DataFeed* device_reader_ = nullptr; @@ -290,7 +293,8 @@ class HogwildWorker : public CPUWorkerBase { HogwildWorkerParameter param_; std::vector skip_ops_; std::map stat_var_name_map_; - static std::atomic worker_num_stat_; + static std::atomic quit_flag_; + // static bool quit_flag_2; }; class DownpourWorker : public HogwildWorker { @@ -724,7 +728,7 @@ class HeterSectionWorker : public DeviceWorker { const platform::Place& place() const { return place_; } void SetDeviceIndex(int tid) override { thread_id_ = tid; } - void SetThreadNum(int thread_num) { thread_num_ = thread_num; } + // void SetThreadNum(int thread_num) { thread_num_ = thread_num; } void SetMicrobatchNum(int num) { num_microbatches_ = num; } void SetPipelineStageNum(int num) { num_pipeline_stages_ = num; } void SetPipelineStage(int stage) { pipeline_stage_ = stage; } @@ -767,7 +771,7 @@ class HeterSectionWorker : public DeviceWorker { protected: int trainer_id_; int trainers_; - int thread_num_; + // int thread_num_; int thread_id_; int num_microbatches_; int num_pipeline_stages_; diff --git a/paddle/fluid/framework/dist_multi_trainer_test.cc b/paddle/fluid/framework/dist_multi_trainer_test.cc index ae88dbab057b624f8a49745137c1520a64862590..e95ec8f1517b4df5f85c66d2c3150e0f3f6a27cb 100644 --- a/paddle/fluid/framework/dist_multi_trainer_test.cc +++ b/paddle/fluid/framework/dist_multi_trainer_test.cc @@ -14,6 +14,7 @@ #include "gtest/gtest.h" #include "paddle/fluid/framework/trainer.h" +#include "paddle/phi/core/flags.h" #ifdef PADDLE_WITH_GLOO #include "paddle/fluid/framework/fleet/gloo_wrapper.h" #endif @@ -21,6 +22,7 @@ #else #define _LINUX #endif +DECLARE_bool(enable_exit_when_partial_worker); namespace paddle { namespace framework { @@ -82,5 +84,92 @@ TEST(DisMultiTrainerTest, testforgpugraph) { #endif } +TEST(DisMultiTrainerTest, test2) { +#ifdef _LINUX + FLAGS_enable_exit_when_partial_worker = true; + std::shared_ptr tmp1 = std::make_shared(); + TrainerDesc t; + t.set_class_name("MultiTrainer"); + t.set_device_worker_name("HogwildWorker"); + t.set_thread_num(1); + auto* m = t.mutable_downpour_param()->add_program_config(); + m->set_program_id("123"); + std::string str; + // str += "name: \"MultiSlotDataFeed\"\nbatch_size: 2\nmulti_slot_desc {\n"; + str += + "name: \"SlotRecordInMemoryDataFeed\"\nbatch_size: 2\nmulti_slot_desc " + "{\n"; + str += "slots {\nname: \"words\"\ntype: \"uint64\"\nis_dense: false\n"; + str += "is_used: true\n}\nslots {\nname: \"label\"\ntype: \"uint64\"\n"; + str += "is_dense: false\nis_used: true\n}\n}\n"; + str += "graph_config {\n"; + str += "gpu_graph_training: true\n}"; + // std::shared_ptr dataset = + // std::make_shared(); + std::shared_ptr dataset = + std::make_shared(); + + dataset->SetFileList(std::vector()); + dataset->SetThreadNum(1); + dataset->SetTrainerNum(1); + dataset->SetDataFeedDesc(str); + dataset->CreateChannel(); + dataset->CreateReaders(); + Scope root_scope; + tmp1->SetScope(&root_scope); + tmp1->Initialize(t, dataset.get()); + tmp1->SetDebug(false); + ProgramDesc p; + tmp1->InitOtherEnv(p); + tmp1->Run(); + tmp1->Finalize(); +#endif +} + +TEST(DisMultiTrainerTest, test3) { +#ifdef _LINUX + FLAGS_enable_exit_when_partial_worker = true; + std::shared_ptr tmp1 = std::make_shared(); + TrainerDesc t; + t.set_class_name("MultiTrainer"); + t.set_device_worker_name("HogwildWorker"); + t.set_thread_num(1); + auto* m = t.mutable_downpour_param()->add_program_config(); + m->set_program_id("123"); + std::string str; + // str += "name: \"MultiSlotDataFeed\"\nbatch_size: 2\nmulti_slot_desc {\n"; + str += + "name: \"SlotRecordInMemoryDataFeed\"\nbatch_size: 2\nmulti_slot_desc " + "{\n"; + str += "slots {\nname: \"words\"\ntype: \"uint64\"\nis_dense: false\n"; + str += "is_used: true\n}\nslots {\nname: \"label\"\ntype: \"uint64\"\n"; + str += "is_dense: false\nis_used: true\n}\n}\n"; + str += "graph_config {\n"; + str += "gpu_graph_training: true\n}"; + // std::shared_ptr dataset = + // std::make_shared(); + std::shared_ptr dataset = + std::make_shared(); + + dataset->SetFileList(std::vector()); + dataset->SetThreadNum(1); + dataset->SetTrainerNum(1); + dataset->SetDataFeedDesc(str); + dataset->CreateChannel(); + dataset->SetGpuGraphMode(true); + dataset->CreateReaders(); + auto readers = dataset->GetReaders(); + readers[0]->SetGpuGraphMode(true); + Scope root_scope; + tmp1->SetScope(&root_scope); + tmp1->Initialize(t, dataset.get()); + tmp1->SetDebug(true); + ProgramDesc p; + tmp1->InitOtherEnv(p); + tmp1->Run(); + tmp1->Finalize(); +#endif +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h index 50895c2645853e564d17a7fb8321070153c692a4..10093d4cc22cab2cce940f768127224b0c64276b 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h @@ -109,6 +109,11 @@ class GpuPsGraphTable std::vector> edge_type_graphs); std::vector> get_edge_type_graph( int gpu_id, int edge_type_len); + void get_node_degree(int gpu_id, + int edge_idx, + uint64_t *key, + int len, + std::shared_ptr node_degree); int get_feature_of_nodes(int gpu_id, uint64_t *d_walk, uint64_t *d_offset, @@ -146,6 +151,8 @@ class GpuPsGraphTable uint32_t *actual_feature_size, uint64_t *feature_list, uint8_t *slot_list); + void move_degree_to_source_gpu( + int gpu_id, int gpu_num, int *h_left, int *h_right, int *node_degree); void move_result_to_source_gpu_all_edge_type(int gpu_id, int gpu_num, int sample_size, diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu index f8684af98f20346aeea9b3878f506b839a4471f9..d4bd392622d033423d4a05908bea5dd9b927dea9 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu @@ -155,6 +155,15 @@ __global__ void get_features_kernel(GpuPsCommGraphFea graph, } } +__global__ void get_node_degree_kernel(GpuPsNodeInfo* node_info_list, + int* node_degree, + int n) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) { + node_degree[i] = node_info_list[i].neighbor_size; + } +} + template __global__ void neighbor_sample_kernel_walking(GpuPsCommGraph graph, GpuPsNodeInfo* node_info_list, @@ -455,6 +464,41 @@ void GpuPsGraphTable::move_result_to_source_gpu(int start_index, } } +void GpuPsGraphTable::move_degree_to_source_gpu( + int start_index, int gpu_num, int* h_left, int* h_right, int* node_degree) { + int shard_len[gpu_num]; + for (int i = 0; i < gpu_num; i++) { + if (h_left[i] == -1 || h_right[i] == -1) { + continue; + } + shard_len[i] = h_right[i] - h_left[i] + 1; + int cur_step = (int)path_[start_index][i].nodes_.size() - 1; + for (int j = cur_step; j > 0; j--) { + CUDA_CHECK( + cudaMemcpyAsync(path_[start_index][i].nodes_[j - 1].val_storage, + path_[start_index][i].nodes_[j].val_storage, + path_[start_index][i].nodes_[j - 1].val_bytes_len, + cudaMemcpyDefault, + path_[start_index][i].nodes_[j - 1].out_stream)); + } + auto& node = path_[start_index][i].nodes_.front(); + CUDA_CHECK( + cudaMemcpyAsync(reinterpret_cast(node_degree + h_left[i]), + node.val_storage + sizeof(int64_t) * shard_len[i], + sizeof(int) * shard_len[i], + cudaMemcpyDefault, + node.out_stream)); + } + + for (int i = 0; i < gpu_num; ++i) { + if (h_left[i] == -1 || h_right[i] == -1) { + continue; + } + auto& node = path_[start_index][i].nodes_.front(); + CUDA_CHECK(cudaStreamSynchronize(node.out_stream)); + } +} + void GpuPsGraphTable::move_result_to_source_gpu_all_edge_type( int start_index, int gpu_num, @@ -570,6 +614,16 @@ __global__ void fill_dvalues(uint64_t* d_shard_vals, } } +__global__ void fill_dvalues(int* d_shard_degree, + int* d_degree, + int* idx, + int len) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < len) { + d_degree[idx[i]] = d_shard_degree[i]; + } +} + __global__ void fill_dvalues_with_edge_type(uint64_t* d_shard_vals, uint64_t* d_vals, int* d_shard_actual_sample_size, @@ -756,7 +810,7 @@ void GpuPsGraphTable::build_graph_fea_on_single_gpu(const GpuPsCommGraphFea& g, } else { gpu_graph_fea_list_[offset].feature_size = 0; } - VLOG(0) << "gpu node_feature info card :" << gpu_id << " ,node_size is " + VLOG(1) << "gpu node_feature info card :" << gpu_id << " ,node_size is " << gpu_graph_fea_list_[offset].node_size << ", feature_size is " << gpu_graph_fea_list_[offset].feature_size; } @@ -1340,6 +1394,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( return result; } +// only for graphsage NeighborSampleResultV2 GpuPsGraphTable::graph_neighbor_sample_all_edge_type( int gpu_id, int edge_type_len, @@ -1537,6 +1592,125 @@ NeighborSampleResultV2 GpuPsGraphTable::graph_neighbor_sample_all_edge_type( return result; } +void GpuPsGraphTable::get_node_degree( + int gpu_id, + int edge_idx, + uint64_t* key, + int len, + std::shared_ptr node_degree) { + int* node_degree_ptr = + reinterpret_cast(node_degree->ptr()) + edge_idx * len; + int total_gpu = resource_->total_device(); + platform::CUDAPlace place = platform::CUDAPlace(resource_->dev_id(gpu_id)); + platform::CUDADeviceGuard guard(resource_->dev_id(gpu_id)); + auto stream = resource_->local_stream(gpu_id, 0); + int grid_size = (len - 1) / block_size_ + 1; + int h_left[total_gpu]; // NOLINT + int h_right[total_gpu]; // NOLINT + auto d_left = + memory::Alloc(place, + total_gpu * sizeof(int), + phi::Stream(reinterpret_cast(stream))); + auto d_right = + memory::Alloc(place, + total_gpu * sizeof(int), + phi::Stream(reinterpret_cast(stream))); + int* d_left_ptr = reinterpret_cast(d_left->ptr()); + int* d_right_ptr = reinterpret_cast(d_right->ptr()); + CUDA_CHECK(cudaMemsetAsync(d_left_ptr, -1, total_gpu * sizeof(int), stream)); + CUDA_CHECK(cudaMemsetAsync(d_right_ptr, -1, total_gpu * sizeof(int), stream)); + auto d_idx = + memory::Alloc(place, + len * sizeof(int), + phi::Stream(reinterpret_cast(stream))); + int* d_idx_ptr = reinterpret_cast(d_idx->ptr()); + auto d_shard_keys = + memory::Alloc(place, + len * sizeof(uint64_t), + phi::Stream(reinterpret_cast(stream))); + uint64_t* d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); + auto d_shard_degree = + memory::Alloc(place, + len * sizeof(int), + phi::Stream(reinterpret_cast(stream))); + int* d_shard_degree_ptr = reinterpret_cast(d_shard_degree->ptr()); + split_input_to_shard( + (uint64_t*)(key), d_idx_ptr, len, d_left_ptr, d_right_ptr, gpu_id); + heter_comm_kernel_->fill_shard_key( + d_shard_keys_ptr, key, d_idx_ptr, len, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + CUDA_CHECK(cudaMemcpyAsync(h_left, + d_left_ptr, + total_gpu * sizeof(int), + cudaMemcpyDeviceToHost, + stream)); + CUDA_CHECK(cudaMemcpyAsync(h_right, + d_right_ptr, + total_gpu * sizeof(int), + cudaMemcpyDeviceToHost, + stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + device_mutex_[gpu_id]->lock(); + for (int i = 0; i < total_gpu; ++i) { + int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; + if (shard_len == 0) { + continue; + } + create_storage( + gpu_id, + i, + shard_len * sizeof(uint64_t), + shard_len * sizeof(uint64_t) + sizeof(int) * shard_len + shard_len % 2); + } + walk_to_dest( + gpu_id, total_gpu, h_left, h_right, (uint64_t*)(d_shard_keys_ptr), NULL); + for (int i = 0; i < total_gpu; ++i) { + if (h_left[i] == -1) { + continue; + } + int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; + auto& node = path_[gpu_id][i].nodes_.back(); + CUDA_CHECK(cudaMemsetAsync( + node.val_storage, 0, shard_len * sizeof(uint64_t), node.in_stream)); + CUDA_CHECK(cudaStreamSynchronize(node.in_stream)); + platform::CUDADeviceGuard guard(resource_->dev_id(i)); + int table_offset = + get_table_offset(i, GraphTableType::EDGE_TABLE, edge_idx); + tables_[table_offset]->get(reinterpret_cast(node.key_storage), + reinterpret_cast(node.val_storage), + (size_t)(h_right[i] - h_left[i] + 1), + resource_->remote_stream(i, gpu_id)); + GpuPsNodeInfo* node_info_list = + reinterpret_cast(node.val_storage); + int* node_degree_array = (int*)(node_info_list + shard_len); + int grid_size_ = (shard_len - 1) / block_size_ + 1; + get_node_degree_kernel<<remote_stream(i, gpu_id)>>>( + node_info_list, node_degree_array, shard_len); + } + for (int i = 0; i < total_gpu; ++i) { + if (h_left[i] == -1) { + continue; + } + CUDA_CHECK(cudaStreamSynchronize(resource_->remote_stream(i, gpu_id))); + } + move_degree_to_source_gpu( + gpu_id, total_gpu, h_left, h_right, d_shard_degree_ptr); + fill_dvalues<<>>( + d_shard_degree_ptr, node_degree_ptr, d_idx_ptr, len); + CUDA_CHECK(cudaStreamSynchronize(stream)); + for (int i = 0; i < total_gpu; i++) { + int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; + if (shard_len == 0) { + continue; + } + destroy_storage(gpu_id, i); + } + device_mutex_[gpu_id]->unlock(); +} + NodeQueryResult GpuPsGraphTable::graph_node_sample(int gpu_id, int sample_size) { return NodeQueryResult(); diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu index 40d69e13d57b72cdc9bcebf543b29bc7712ff059..a863cb40552598d5f21d2baa36ae8700636dca36 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu @@ -32,7 +32,8 @@ void GraphGpuWrapper::set_device(std::vector ids) { } void GraphGpuWrapper::init_conf(const std::string &first_node_type, - const std::string &meta_path) { + const std::string &meta_path, + const std::string &excluded_train_pair) { static std::mutex mutex; { std::lock_guard lock(mutex); @@ -45,12 +46,12 @@ void GraphGpuWrapper::init_conf(const std::string &first_node_type, paddle::string::split_string(first_node_type, ";"); VLOG(2) << "node_types: " << first_node_type; for (auto &type : node_types) { - auto iter = feature_to_id.find(type); - PADDLE_ENFORCE_NE(iter, - feature_to_id.end(), - platform::errors::NotFound( - "(%s) is not found in feature_to_id.", type)); - VLOG(2) << "feature_to_id[" << type << "] = " << iter->second; + auto iter = node_to_id.find(type); + PADDLE_ENFORCE_NE( + iter, + node_to_id.end(), + platform::errors::NotFound("(%s) is not found in node_to_id.", type)); + VLOG(2) << "node_to_id[" << type << "] = " << iter->second; first_node_type_.push_back(iter->second); } meta_path_.resize(first_node_type_.size()); @@ -58,17 +59,40 @@ void GraphGpuWrapper::init_conf(const std::string &first_node_type, for (size_t i = 0; i < meta_paths.size(); i++) { auto path = meta_paths[i]; - auto nodes = paddle::string::split_string(path, "-"); + auto edges = paddle::string::split_string(path, "-"); + for (auto &edge : edges) { + auto iter = edge_to_id.find(edge); + PADDLE_ENFORCE_NE(iter, + edge_to_id.end(), + platform::errors::NotFound( + "(%s) is not found in edge_to_id.", edge)); + VLOG(2) << "edge_to_id[" << edge << "] = " << iter->second; + meta_path_[i].push_back(iter->second); + if (edge_to_node_map_.find(iter->second) == edge_to_node_map_.end()) { + auto nodes = paddle::string::split_string(edge, "2"); + uint64_t src_node_id = node_to_id.find(nodes[0])->second; + uint64_t dst_node_id = node_to_id.find(nodes[1])->second; + edge_to_node_map_[iter->second] = src_node_id << 32 | dst_node_id; + } + } + } + + auto paths = + paddle::string::split_string(excluded_train_pair, ";"); + VLOG(2) << "excluded_train_pair[" << excluded_train_pair << "]"; + for (auto &path : paths) { + auto nodes = paddle::string::split_string(path, "2"); for (auto &node : nodes) { - auto iter = edge_to_id.find(node); + auto iter = node_to_id.find(node); PADDLE_ENFORCE_NE(iter, edge_to_id.end(), platform::errors::NotFound( "(%s) is not found in edge_to_id.", node)); VLOG(2) << "edge_to_id[" << node << "] = " << iter->second; - meta_path_[i].push_back(iter->second); + excluded_train_pair_.push_back(iter->second); } } + int max_dev_id = 0; for (size_t i = 0; i < device_id_mapping.size(); i++) { if (device_id_mapping[i] > max_dev_id) { @@ -85,11 +109,11 @@ void GraphGpuWrapper::init_conf(const std::string &first_node_type, auto &finish_node_type = finish_node_type_[i]; finish_node_type.clear(); - for (size_t idx = 0; idx < feature_to_id.size(); idx++) { + for (size_t idx = 0; idx < node_to_id.size(); idx++) { infer_node_type_start[idx] = 0; } for (auto &type : node_types) { - auto iter = feature_to_id.find(type); + auto iter = node_to_id.find(type); node_type_start[iter->second] = 0; infer_node_type_start[iter->second] = 0; } @@ -188,7 +212,7 @@ void GraphGpuWrapper::init_metapath(std::string cur_metapath, int first_node_idx; std::string first_node = paddle::string::split_string(cur_metapath_, "2")[0]; - auto it = feature_to_id.find(first_node); + auto it = node_to_id.find(first_node); first_node_idx = it->second; d_graph_train_total_keys_.resize(thread_num); h_graph_train_keys_len_.resize(thread_num); @@ -309,8 +333,8 @@ void GraphGpuWrapper::set_up_types(const std::vector &edge_types, } id_to_feature = node_types; for (size_t table_id = 0; table_id < node_types.size(); table_id++) { - int res = feature_to_id.size(); - feature_to_id[node_types[table_id]] = res; + int res = node_to_id.size(); + node_to_id[node_types[table_id]] = res; } table_feat_mapping.resize(node_types.size()); this->table_feat_conf_feat_name.resize(node_types.size()); @@ -389,21 +413,22 @@ void GraphGpuWrapper::load_edge_file(std::string etype2files, etype2files, graph_data_local_path, part_num, reverse); } -void GraphGpuWrapper::load_node_file(std::string name, std::string filepath) { +int GraphGpuWrapper::load_node_file(std::string name, std::string filepath) { // 'n' means load nodes and 'node_type' follows std::string params = "n" + name; - if (feature_to_id.find(name) != feature_to_id.end()) { - reinterpret_cast(graph_table) + if (node_to_id.find(name) != node_to_id.end()) { + return reinterpret_cast(graph_table) ->cpu_graph_table_->Load(std::string(filepath), params); } + return 0; } -void GraphGpuWrapper::load_node_file(std::string ntype2files, - std::string graph_data_local_path, - int part_num) { - reinterpret_cast(graph_table) +int GraphGpuWrapper::load_node_file(std::string ntype2files, + std::string graph_data_local_path, + int part_num) { + return reinterpret_cast(graph_table) ->cpu_graph_table_->parse_node_and_load( ntype2files, graph_data_local_path, part_num); } @@ -422,8 +447,8 @@ void GraphGpuWrapper::add_table_feat_conf(std::string table_name, std::string feat_name, std::string feat_dtype, int feat_shape) { - if (feature_to_id.find(table_name) != feature_to_id.end()) { - int idx = feature_to_id[table_name]; + if (node_to_id.find(table_name) != node_to_id.end()) { + int idx = node_to_id[table_name]; if (table_feat_mapping[idx].find(feat_name) == table_feat_mapping[idx].end()) { int res = table_feat_mapping[idx].size(); @@ -512,7 +537,7 @@ void GraphGpuWrapper::upload_batch(int type, g->cpu_graph_table_->make_gpu_ps_graph(idx, ids[i]); g->build_graph_on_single_gpu(sub_graph, i, idx); sub_graph.release_on_cpu(); - VLOG(0) << "sub graph on gpu " << i << " is built"; + VLOG(1) << "sub graph on gpu " << i << " is built"; return 0; })); } @@ -579,7 +604,7 @@ void GraphGpuWrapper::build_gpu_graph_fea(GpuPsCommGraphFea &sub_graph_fea, GpuPsGraphTable *g = reinterpret_cast(graph_table); g->build_graph_fea_on_single_gpu(sub_graph_fea, i); sub_graph_fea.release_on_cpu(); - VLOG(0) << "sub graph fea on gpu " << i << " is built"; + VLOG(1) << "sub graph fea on gpu " << i << " is built"; return; } @@ -607,6 +632,16 @@ GraphGpuWrapper::get_edge_type_graph(int gpu_id, int edge_type_len) { ->get_edge_type_graph(gpu_id, edge_type_len); } +void GraphGpuWrapper::get_node_degree( + int gpu_id, + int edge_idx, + uint64_t *key, + int len, + std::shared_ptr node_degree) { + return ((GpuPsGraphTable *)graph_table) + ->get_node_degree(gpu_id, edge_idx, key, len, node_degree); +} + int GraphGpuWrapper::get_feature_info_of_nodes( int gpu_id, uint64_t *d_nodes, @@ -776,7 +811,7 @@ std::string &GraphGpuWrapper::get_node_type_size(std::string first_node_type) { auto &type_to_index = get_graph_type_to_index(); std::vector node_type_size; for (auto node : uniq_first_node_) { - auto it = feature_to_id.find(node); + auto it = node_to_id.find(node); auto first_node_idx = it->second; size_t f_idx = type_to_index[first_node_idx]; int type_total_key_size = graph_all_type_total_keys[f_idx].size(); diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h index 52d7132a0e460e633502abaa7a8592219828b67c..ccfff6e999d05f52409139c0f58df41821b732c5 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h @@ -41,7 +41,8 @@ class GraphGpuWrapper { } static std::shared_ptr s_instance_; void init_conf(const std::string& first_node_type, - const std::string& meta_path); + const std::string& meta_path, + const std::string& excluded_train_pair); void initialize(); void finalize(); void set_device(std::vector ids); @@ -66,10 +67,10 @@ class GraphGpuWrapper { int part_num, bool reverse); - void load_node_file(std::string name, std::string filepath); - void load_node_file(std::string ntype2files, - std::string graph_data_local_path, - int part_num); + int load_node_file(std::string name, std::string filepath); + int load_node_file(std::string ntype2files, + std::string graph_data_local_path, + int part_num); void load_node_and_edge(std::string etype2files, std::string ntype2files, std::string graph_data_local_path, @@ -120,6 +121,11 @@ class GraphGpuWrapper { int sample_size, int len, std::vector> edge_type_graphs); + void get_node_degree(int gpu_id, + int edge_idx, + uint64_t* key, + int len, + std::shared_ptr node_degree); gpuStream_t get_local_stream(int gpuid); std::vector graph_neighbor_sample( int gpu_id, @@ -160,7 +166,7 @@ class GraphGpuWrapper { std::string& get_node_type_size(std::string first_node_type); std::string& get_edge_type_size(); - std::unordered_map edge_to_id, feature_to_id; + std::unordered_map edge_to_id, node_to_id; std::vector id_to_feature, id_to_edge; std::vector> table_feat_mapping; std::vector> table_feat_conf_feat_name; @@ -175,6 +181,7 @@ class GraphGpuWrapper { std::string feature_separator_ = std::string(" "); bool conf_initialized_ = false; std::vector first_node_type_; + std::vector excluded_train_pair_; std::vector> meta_path_; std::vector> finish_node_type_; @@ -187,6 +194,11 @@ class GraphGpuWrapper { std::vector h_graph_train_keys_len_; std::vector>> d_graph_all_type_total_keys_; + std::map + edge_to_node_map_; + std::vector> h_graph_all_type_keys_len_; std::string slot_feature_separator_ = std::string(" "); diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index 159287a2c3775578454dea22cb4feb64c39fab2f..e094df929211fd8e3b762742a26d64f0655a774c 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -193,17 +193,17 @@ void PSGPUWrapper::add_key_to_gputask(std::shared_ptr gpu_task) { } timeline.Pause(); - VLOG(0) << "GpuPs task add keys cost " << timeline.ElapsedSec() + VLOG(1) << "GpuPs task add keys cost " << timeline.ElapsedSec() << " seconds."; timeline.Start(); - size_t slot_num = slot_vector_.size() - 1; + size_t slot_num = (size_t)slot_num_for_pull_feature_; // no slot_fea mode and whole_hbm mode, only keep one unique_sort action if (slot_num > 0 && FLAGS_gpugraph_storage_mode != paddle::framework::GpuGraphStorageMode::WHOLE_HBM) { gpu_task->UniqueKeys(); } timeline.Pause(); - VLOG(0) << "GpuPs task unique cost " << timeline.ElapsedSec() << " seconds."; + VLOG(1) << "GpuPs task unique cost " << timeline.ElapsedSec() << " seconds."; } void PSGPUWrapper::resize_gputask(std::shared_ptr gpu_task) { @@ -218,7 +218,8 @@ void PSGPUWrapper::resize_gputask(std::shared_ptr gpu_task) { } } -void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task) { +void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task, + Dataset* dataset_for_pull) { VLOG(3) << "PSGPUWrapper::BuildGPUPSTask begin"; platform::Timer timeline; timeline.Start(); @@ -341,12 +342,13 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task) { << " seconds."; } } else { - SlotRecordDataset* dataset = reinterpret_cast(dataset_); + SlotRecordDataset* dataset = + reinterpret_cast(dataset_for_pull); const std::vector& vec_data = dataset->GetGpuGraphTotalKeys(); timeline.Start(); add_key_to_local(vec_data); timeline.Pause(); - VLOG(0) << "GpuGraphTotalKeys: " << vec_data.size() + VLOG(1) << "GpuGraphTotalKeys: " << vec_data.size() << ", add_key_to_local cost " << timeline.ElapsedSec() << " seconds."; } @@ -361,7 +363,8 @@ void PSGPUWrapper::add_slot_feature(std::shared_ptr gpu_task) { // 8卡数据分片 size_t device_num = heter_devices_.size(); std::vector threads; - size_t slot_num = slot_vector_.size() - 1; // node slot 9008 in slot_vector + size_t slot_num = + (size_t)slot_num_for_pull_feature_; // node slot 9008 in slot_vector auto& local_dim_keys = gpu_task->feature_dim_keys_; // [shard_num, 0, keys]] double divide_nodeid_cost = 0; double get_feature_id_cost = 0; @@ -532,8 +535,7 @@ void PSGPUWrapper::add_slot_feature(std::shared_ptr gpu_task) { for (size_t i = 0; i < device_num; i++) { feature_num += feature_list_size[i]; } - VLOG(0) << "feature_num is " << feature_num << " node_num num is " - << node_num; + VLOG(1) << "feature_num is " << feature_num << " node_num is " << node_num; size_t set_num = thread_keys_shard_num_; std::vector> feature_id_set(set_num); @@ -635,7 +637,7 @@ void PSGPUWrapper::add_slot_feature(std::shared_ptr gpu_task) { add_feature_to_key_cost = time_stage.ElapsedSec(); threads.clear(); timeline.Pause(); - VLOG(0) << " add_slot_feature costs: " << timeline.ElapsedSec() << " s." + VLOG(1) << " add_slot_feature costs: " << timeline.ElapsedSec() << " s." << " divide_nodeid_cost " << divide_nodeid_cost << " get_feature_id_cost " << get_feature_id_cost << " add_feature_to_set_cost " << add_feature_to_set_cost @@ -644,9 +646,9 @@ void PSGPUWrapper::add_slot_feature(std::shared_ptr gpu_task) { void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { platform::Timer timeline; - size_t slot_num = slot_vector_.size() - 1; // node slot 9008 in slot_vector - if (slot_num > 0 && FLAGS_gpugraph_storage_mode != - paddle::framework::GpuGraphStorageMode::WHOLE_HBM) { + if (slot_num_for_pull_feature_ > 0 && + FLAGS_gpugraph_storage_mode != + paddle::framework::GpuGraphStorageMode::WHOLE_HBM) { add_slot_feature(gpu_task); } @@ -656,7 +658,7 @@ void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { time_stage.Start(); gpu_task->UniqueKeys(); time_stage.Pause(); - VLOG(0) << "BuildPull slot feature uniq and sort cost time: " + VLOG(1) << "BuildPull slot feature uniq and sort cost time: " << time_stage.ElapsedSec(); auto& local_dim_keys = gpu_task->feature_dim_keys_; @@ -795,7 +797,7 @@ void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { } task_futures.clear(); timeline.Pause(); - VLOG(0) << "pull sparse from CpuPS into GpuPS total keys " << total_key + VLOG(1) << "pull sparse from CpuPS into GpuPS total keys " << total_key << ", cost " << timeline.ElapsedSec() << " seconds."; if (multi_node_) { auto gloo_wrapper = paddle::framework::GlooWrapper::GetInstance(); @@ -879,7 +881,7 @@ void PSGPUWrapper::divide_to_device(std::shared_ptr gpu_task) { } } timeline.Pause(); - VLOG(0) << "GpuPs prepare for build hbm cost " << timeline.ElapsedSec() + VLOG(1) << "GpuPs prepare for build hbm cost " << timeline.ElapsedSec() << " seconds."; } @@ -1052,7 +1054,7 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { << " dim index: " << j << " contains feasign nums: " << gpu_task->device_dim_ptr_[i][j].size(); } - VLOG(0) << i << " card with dynamic mf contains feasign nums total: " + VLOG(1) << i << " card with dynamic mf contains feasign nums total: " << feature_keys_count[i]; size_max = std::max(size_max, feature_keys_count[i]); } @@ -1074,7 +1076,7 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { #endif } stagetime.Pause(); - VLOG(0) << "card: " + VLOG(1) << "card: " << " BuildGPUTask create HeterPs_ costs: " << stagetime.ElapsedSec() << " s."; stagetime.Start(); @@ -1191,8 +1193,7 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { threads[j] = std::thread(build_ps_thread, i, j, len, feature_value_size); } // build feature table - size_t slot_num = slot_vector_.size() - 1; // node slot 9008 in slot_vector - if (slot_num > 0 && + if (slot_num_for_pull_feature_ > 0 && (FLAGS_gpugraph_storage_mode == paddle::framework::GpuGraphStorageMode:: MEM_EMB_FEATURE_AND_GPU_GRAPH || FLAGS_gpugraph_storage_mode == @@ -1228,7 +1229,7 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { t.join(); } stagetime.Pause(); - VLOG(0) << "card: " << i + VLOG(1) << "card: " << i << " BuildGPUTask build_ps async costs: " << stagetime.ElapsedSec() << " s."; }; @@ -1262,7 +1263,7 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { } cpu_task_futures.clear(); stagetime.Pause(); - VLOG(0) << " BuildGPUTask build_dynamic_mf_func " + VLOG(1) << " BuildGPUTask build_dynamic_mf_func " << " cost " << stagetime.ElapsedSec() << " s."; for (int i = 0; i < device_num; i++) { cpu_reday_channels_[i]->Close(); @@ -1282,7 +1283,7 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { gpu_task->sub_graph_feas = NULL; } stagetime.Pause(); - VLOG(0) << " build_dymf_hbm_pool " + VLOG(1) << " build_dymf_hbm_pool " << " cost " << stagetime.ElapsedSec() << " s."; } @@ -1292,7 +1293,7 @@ void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) { timer.Start(); dataset_->LoadIntoMemory(); timer.Pause(); - VLOG(0) << "LoadIntoMemory cost: " << timer.ElapsedSec() << "s"; + VLOG(1) << "LoadIntoMemory cost: " << timer.ElapsedSec() << "s"; gpu_graph_mode_ = dataset_->GetGpuGraphMode(); if (dataset_->GetMemoryDataSize() == 0) { VLOG(0) << "GetMemoryDataSize == 0"; @@ -1309,7 +1310,7 @@ void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) { std::shared_ptr gpu_task = gpu_task_pool_.Get(); gpu_task->Reset(); gpu_task->pass_id_ = (uint16_t)(dataset_->GetPassID()); - data_ready_channel_->Put(gpu_task); + data_ready_channel_->Put(std::make_pair(gpu_task, dataset_)); } else if (hbm_sparse_table_initialized_ == false) { SparseTableToHbm(); } @@ -1317,7 +1318,7 @@ void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) { std::shared_ptr gpu_task = gpu_task_pool_.Get(); gpu_task->Reset(); gpu_task->pass_id_ = (uint16_t)(dataset_->GetPassID()); - data_ready_channel_->Put(gpu_task); + data_ready_channel_->Put(std::make_pair(gpu_task, dataset_)); #endif VLOG(3) << "End LoadIntoMemory(), dataset[" << dataset_ << "]"; } @@ -1332,17 +1333,20 @@ void PSGPUWrapper::start_build_thread() { void PSGPUWrapper::pre_build_thread() { // prebuild: process load_data while (running_) { + std::pair, Dataset*> task = + std::make_pair(nullptr, nullptr); std::shared_ptr gpu_task = nullptr; - if (!data_ready_channel_->Get(gpu_task)) { + if (!data_ready_channel_->Get(task)) { continue; } + gpu_task = task.first; VLOG(3) << "thread PreBuildTask start."; platform::Timer timer; timer.Start(); // build cpu ps data process - PreBuildTask(gpu_task); + PreBuildTask(gpu_task, task.second); timer.Pause(); - VLOG(0) << "thread PreBuildTask end, cost time: " << timer.ElapsedSec() + VLOG(1) << "thread PreBuildTask end, cost time: " << timer.ElapsedSec() << " s"; buildcpu_ready_channel_->Put(gpu_task); } @@ -1382,7 +1386,7 @@ void PSGPUWrapper::build_task() { return; } - VLOG(0) << "PrepareGPUTask start."; + VLOG(1) << "PrepareGPUTask start."; platform::Timer timer; timer.Start(); if (!multi_mf_dim_) { @@ -1390,7 +1394,7 @@ void PSGPUWrapper::build_task() { } BuildGPUTask(gpu_task); timer.Pause(); - VLOG(0) << "PrepareGPUTask + BuildGPUTask end, cost time: " + VLOG(1) << "PrepareGPUTask + BuildGPUTask end, cost time: " << timer.ElapsedSec() << "s"; current_task_ = gpu_task; @@ -1419,11 +1423,11 @@ void PSGPUWrapper::BeginPass() { "[BeginPass] after build_task, current task is not null.")); } if (FLAGS_gpugraph_dedup_pull_push_mode) { - VLOG(0) << "BeginPass end, cost time: " << timer.ElapsedSec() + VLOG(1) << "BeginPass end, cost time: " << timer.ElapsedSec() << "s, enable pull push dedup mode=" << FLAGS_gpugraph_dedup_pull_push_mode; } else { - VLOG(0) << "BeginPass end, cost time: " << timer.ElapsedSec() << "s"; + VLOG(1) << "BeginPass end, cost time: " << timer.ElapsedSec() << "s"; } } @@ -1433,11 +1437,14 @@ void PSGPUWrapper::EndPass() { return; } #endif + if (current_task_ == nullptr) { + return; + } platform::Timer stagetime; stagetime.Start(); HbmToSparseTable(); stagetime.Pause(); - VLOG(0) << "EndPass HbmToSparseTable cost time: " << stagetime.ElapsedSec() + VLOG(1) << "EndPass HbmToSparseTable cost time: " << stagetime.ElapsedSec() << "s"; gpu_task_pool_.Push(current_task_); @@ -1453,7 +1460,7 @@ void PSGPUWrapper::SparseTableToHbm() { gpu_task->init(thread_keys_shard_num_, device_num, multi_mf_dim_); gpu_task->pass_id_ = (uint16_t)(dataset_->GetPassID()); auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); - auto node_to_id = gpu_graph_ptr->feature_to_id; + auto node_to_id = gpu_graph_ptr->node_to_id; auto edge_to_id = gpu_graph_ptr->edge_to_id; std::vector vec_data = gpu_graph_ptr->get_graph_total_keys(); @@ -1569,7 +1576,7 @@ void PSGPUWrapper::HbmToSparseTable() { size_t thread_num = 16; size_t device_num = heter_devices_.size(); if (multi_mf_dim_) { - VLOG(0) << "psgpu wrapper dump pool: multi_mf_dim_: " << multi_mf_dim_; + VLOG(1) << "psgpu wrapper dump pool: multi_mf_dim_: " << multi_mf_dim_; for (size_t i = 0; i < device_num; i++) { cpu_reday_channels_[i]->Open(); for (int j = 0; j < multi_mf_dim_; j++) { @@ -1593,7 +1600,7 @@ void PSGPUWrapper::HbmToSparseTable() { f.wait(); } timer.Pause(); - VLOG(0) << " EndPass dump_pool_to_cpu_func " + VLOG(1) << " EndPass dump_pool_to_cpu_func " << " cost " << timer.ElapsedSec() << " s."; for (size_t i = 0; i < device_num; i++) { cpu_reday_channels_[i]->Close(); @@ -1605,7 +1612,7 @@ void PSGPUWrapper::HbmToSparseTable() { } cpu_task_futures.clear(); timer.Pause(); - VLOG(0) << " EndPass cpu_func " + VLOG(1) << " EndPass cpu_func " << " cost " << timer.ElapsedSec() << " s."; if (keysize_max != 0) { HeterPs_->end_pass(); diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index 119aebf4b2d69ab4bcc303d16fc836798fe697b4..87b0765f95502cc50414713ef75f27723ea6a2ee 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -24,6 +24,7 @@ limitations under the License. */ #include #include #include +#include #include #ifdef PADDLE_WITH_GLOO #include @@ -202,7 +203,8 @@ class PSGPUWrapper { void divide_to_device(std::shared_ptr gpu_task); void add_slot_feature(std::shared_ptr gpu_task); void BuildGPUTask(std::shared_ptr gpu_task); - void PreBuildTask(std::shared_ptr gpu_task); + void PreBuildTask(std::shared_ptr gpu_task, + Dataset* dataset_for_pull); void BuildPull(std::shared_ptr gpu_task); void PrepareGPUTask(std::shared_ptr gpu_task); void LoadIntoMemory(bool is_shuffle); @@ -251,8 +253,8 @@ class PSGPUWrapper { buildpull_threads_.join(); s_instance_ = nullptr; VLOG(3) << "PSGPUWrapper Finalize Finished."; - HeterPs_->show_table_collisions(); if (HeterPs_ != NULL) { + HeterPs_->show_table_collisions(); delete HeterPs_; HeterPs_ = NULL; } @@ -605,7 +607,10 @@ class PSGPUWrapper { slot_vector_ = slot_vector; VLOG(0) << "slot_vector size is " << slot_vector_.size(); } - + void SetPullFeatureSlotNum(int slot_num) { + slot_num_for_pull_feature_ = slot_num; + VLOG(0) << "slot_num_for_pull_feature_ is " << slot_num_for_pull_feature_; + } void SetSlotOffsetVector(const std::vector& slot_offset_vector) { slot_offset_vector_ = slot_offset_vector; std::cout << "yxf set: "; @@ -734,6 +739,7 @@ class PSGPUWrapper { std::vector index_dim_vec_; int multi_mf_dim_{0}; int max_mf_dim_{0}; + int slot_num_for_pull_feature_{0}; size_t val_type_size_{0}; size_t grad_type_size_{0}; size_t pull_type_size_{0}; @@ -781,10 +787,10 @@ class PSGPUWrapper { // hbm pools of totol dims number #endif - std::shared_ptr< - paddle::framework::ChannelObject>> - data_ready_channel_ = - paddle::framework::MakeChannel>(); + std::shared_ptr, Dataset*>>> + data_ready_channel_ = paddle::framework::MakeChannel< + std::pair, Dataset*>>(); std::shared_ptr< paddle::framework::ChannelObject>> buildcpu_ready_channel_ = diff --git a/paddle/fluid/framework/hogwild_worker.cc b/paddle/fluid/framework/hogwild_worker.cc index ef2ef8f8596c999239c3eff8e5b2700ea08fdca2..8c5512c14c566ac8be9b75178b9f97368e708741 100644 --- a/paddle/fluid/framework/hogwild_worker.cc +++ b/paddle/fluid/framework/hogwild_worker.cc @@ -31,7 +31,7 @@ DECLARE_bool(enable_exit_when_partial_worker); namespace paddle { namespace framework { -std::atomic HogwildWorker::worker_num_stat_(0); +std::atomic HogwildWorker::quit_flag_(false); Barrier g_barrier; void HogwildWorker::Initialize(const TrainerDesc &desc) { @@ -148,7 +148,7 @@ void HogwildWorker::TrainFilesWithProfiler() { int cur_batch; int batch_cnt = 0; if (thread_id_ == 0) { - worker_num_stat_.store(0); + quit_flag_.store(false); } g_barrier.wait(); bool train_mode = device_reader_->IsTrainMode(); @@ -160,11 +160,11 @@ void HogwildWorker::TrainFilesWithProfiler() { while (1) { cur_batch = device_reader_->Next(); if (FLAGS_enable_exit_when_partial_worker && train_mode) { - if (cur_batch > 0) { - worker_num_stat_.fetch_add(1, std::memory_order_relaxed); + if (cur_batch <= 0) { + quit_flag_.store(true, std::memory_order_relaxed); } g_barrier.wait(); - if (worker_num_stat_.load(std::memory_order_relaxed) % thread_num_ != 0) { + if (quit_flag_.load(std::memory_order_relaxed) == true) { break; } } @@ -265,7 +265,8 @@ void HogwildWorker::TrainFiles() { int cur_batch; int batch_cnt = 0; if (thread_id_ == 0) { - worker_num_stat_.store(0); + quit_flag_.store(false); + // quit_flag_2 = false; } g_barrier.wait(); @@ -280,11 +281,11 @@ void HogwildWorker::TrainFiles() { while (1) { cur_batch = device_reader_->Next(); if (FLAGS_enable_exit_when_partial_worker && train_mode) { - if (cur_batch > 0) { - worker_num_stat_.fetch_add(1, std::memory_order_relaxed); + if (cur_batch <= 0) { + quit_flag_.store(true, std::memory_order_relaxed); } g_barrier.wait(); - if (worker_num_stat_.load(std::memory_order_relaxed) % thread_num_ != 0) { + if (quit_flag_.load(std::memory_order_relaxed) == true) { break; } } @@ -320,7 +321,7 @@ void HogwildWorker::TrainFiles() { #endif } timeline.Pause(); - VLOG(0) << "worker " << thread_id_ << " train cost " << timeline.ElapsedSec() + VLOG(1) << "worker " << thread_id_ << " train cost " << timeline.ElapsedSec() << " seconds, batch_num: " << total_batch_num; if (need_dump_field_ || need_dump_param_) { diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index 258da58ba7885a31d9c5af3e51a6baa11e999ed5..2b423730084d70104590d7a10e354ed27dfab5a5 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -25,6 +25,8 @@ limitations under the License. */ namespace paddle { namespace framework { +extern Barrier g_barrier; + void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, Dataset* dataset) { thread_num_ = trainer_desc.thread_num(); @@ -62,7 +64,7 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, thread_num_); } #endif - + g_barrier.reset(thread_num_); for (int i = 0; i < thread_num_; ++i) { workers_[i] = DeviceWorkerFactory::CreateDeviceWorker( trainer_desc.device_worker_name()); @@ -74,6 +76,7 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, workers_[i]->Initialize(trainer_desc); workers_[i]->SetDeviceIndex(i); workers_[i]->SetDataFeed(readers[i]); + workers_[i]->SetThreadNum(thread_num_); } // set debug here @@ -177,7 +180,7 @@ void MultiTrainer::InitOtherEnv(const ProgramDesc& main_program) { // for unittest which call train_from_dataset but does not call // fleet.init_worker() first if (communicator == nullptr) { - VLOG(0) << "MultiTrainer::InitOtherEnv Communicator is null!"; + VLOG(1) << "MultiTrainer::InitOtherEnv Communicator is null!"; } else { auto& recv_ctx = communicator->GetRecvCtxMap(); communicator->PullDense(recv_ctx); @@ -299,13 +302,13 @@ void MultiTrainer::Finalize() { auto communicator = paddle::distributed::Communicator::GetInstance(); // for unittest which does not call fleet.init_worker() first if (communicator == nullptr) { - VLOG(0) << "MultiTrainer::Finalize communicator is null!"; + VLOG(1) << "MultiTrainer::Finalize communicator is null!"; } else { if (communicator->_worker_ptr != nullptr) { communicator->_worker_ptr->Flush(); VLOG(1) << "MultiTrainer::Finalize ps client flush done"; } else { - VLOG(0) << "communicator->_worker_ptr is null"; + VLOG(1) << "communicator->_worker_ptr is null"; } } #endif diff --git a/paddle/fluid/operators/unity_build_rule.cmake b/paddle/fluid/operators/unity_build_rule.cmake index 891cb40ab28dfa79b1fee657032a3294370e5ad8..f1d629b5c86bd4b9c4b9b246120728be57aedc11 100644 --- a/paddle/fluid/operators/unity_build_rule.cmake +++ b/paddle/fluid/operators/unity_build_rule.cmake @@ -64,6 +64,7 @@ register_unity_group( cudnn_lstm_op.cc cumsum_op.cc cvm_op.cc + unzip_op.cc data_norm_op.cc deformable_conv_op.cc deformable_conv_v1_op.cc @@ -402,6 +403,7 @@ register_unity_group( ctc_align_op.cu cumsum_op.cu cvm_op.cu + unzip_op.cu data_norm_op.cu deformable_conv_op.cu deformable_conv_v1_op.cu @@ -579,3 +581,5 @@ register_unity_group(cu expand_op.cu) register_unity_group(cu matmul_v2_op.cu) register_unity_group(cu top_k_v2_op.cu) register_unity_group(cu set_value_op.cu) +register_unity_group(cu unzip.cu) +register_unity_group(cc unzip.cc) diff --git a/paddle/fluid/operators/unzip_op.cc b/paddle/fluid/operators/unzip_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..ffb46c2f4b56c1816428443c3f8ff8fdac856e77 --- /dev/null +++ b/paddle/fluid/operators/unzip_op.cc @@ -0,0 +1,171 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/unzip_op.h" + +#include + +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace paddle { +namespace operators { + +class unzipOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "lod"); + OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "lod"); + + auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_EQ( + x_dims.size(), + 2UL, + platform::errors::InvalidArgument( + "Input(X)'s rank should be 2, but got %d", x_dims.size())); + + auto lod_dims = ctx->GetInputDim("lod"); + PADDLE_ENFORCE_EQ( + lod_dims.size(), + 1UL, + platform::errors::InvalidArgument( + "Input(X)'s rank should be 1, but got %d", lod_dims.size())); + + ctx->SetOutputDim("Y", {lod_dims[0] - 1, x_dims[1]}); + } + + protected: + // Explicitly set that the data type of computation kernel of + // unzip + // is determined by its input "X". + phi::KernelKey GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); + } +}; + +class unzipGradientOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "unzipGradient"); + OP_INOUT_CHECK(ctx->HasInput("lod"), "Input", "unzip", "unzipGradient"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Y")), + "Input", + framework::GradVarName("Y"), + "unzipGradient"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), + "Output", + framework::GradVarName("X"), + "unzipGradient"); + + auto x_dims = ctx->GetInputDim("X"); + auto lod_dims = ctx->GetInputDim("lod"); + auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y")); + PADDLE_ENFORCE_EQ( + x_dims.size(), + 2, + platform::errors::InvalidArgument( + "Expect Input(X)'s rank == 2, but got %d", x_dims.size())); + PADDLE_ENFORCE_EQ( + dy_dims.size(), + 2, + platform::errors::InvalidArgument( + "Expect Input(X)'s rank == 2, but got %d", dy_dims.size())); + PADDLE_ENFORCE_EQ( + lod_dims.size(), + 1, + platform::errors::InvalidArgument( + "Expect Input(X)'s rank == 1, but got %d", lod_dims.size())); + + PADDLE_ENFORCE_EQ( + x_dims[1], + dy_dims[1], + platform::errors::InvalidArgument( + "The 1st dimension of Input(X) and Input(Y@Grad) should " + "be equal, X is %d, Y@Grad is %d", + x_dims[1], + dy_dims[1])); + + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + ctx->ShareLoD("X", framework::GradVarName("X")); + } + + protected: + // Explicitly set that the data type of computation kernel of + // unzip + // is determined by its input "X". + phi::KernelKey GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Y")), + ctx.device_context().GetPlace()); + } +}; + +class unzipOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(LodTensor, default LodTensor), a 2-D tensor with shape " + "[M x N]," + " where N is the batch size and D is the emebdding dim. "); + AddInput("lod", "(Tensor), a 1-D Tensor with shape [K]"); + AddOutput("Y", + "(LodTensor, default LodTensor), a 2-D tensor with shape " + "[K-1 x N]."); + AddComment(R"DOC( +unzip Operator. +)DOC"); + } +}; + +template +class unzipGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("unzip_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput("lod", this->Input("lod")); + op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(unzip, + ops::unzipOp, + ops::unzipOpMaker, + ops::unzipGradOpMaker, + ops::unzipGradOpMaker); + +REGISTER_OPERATOR(unzip_grad, ops::unzipGradientOp); + +REGISTER_OP_CPU_KERNEL(unzip, + ops::unzipOpKernel, + ops::unzipOpKernel); + +REGISTER_OP_CPU_KERNEL(unzip_grad, + ops::unzipGradOpKernel, + ops::unzipGradOpKernel); diff --git a/paddle/fluid/operators/unzip_op.cu b/paddle/fluid/operators/unzip_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..0605ce4ab91191aa9a1f80805092eaac68de1e23 --- /dev/null +++ b/paddle/fluid/operators/unzip_op.cu @@ -0,0 +1,105 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/operators/unzip_op.h" +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/backends/gpu/gpu_primitives.h" +namespace paddle { +namespace operators { + +using phi::PADDLE_CUDA_NUM_THREADS; + +template +__global__ void unzipKernel( + const T* X, const LodType* lod, T* Y, size_t col_size, size_t n) { + CUDA_KERNEL_LOOP(i, n) { + int lod_idx = i / col_size; + if ((lod[lod_idx + 1] - lod[lod_idx]) > 0) { + assert((lod[lod_idx + 1] - lod[lod_idx]) == col_size); + int x_idx = 0; + for (int j = 0; j < lod_idx; ++j) { + if ((lod[j + 1] - lod[j]) > 0) { + x_idx++; + } + } + Y[i] = X[x_idx * col_size + (i % col_size)]; + } else { + Y[i] = 0; + } + } +} + +template +class unzipCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const auto* x = context.Input("X"); + const T* x_data = x->data(); + + const auto* lod = context.Input("lod"); + const LodType* lod_data = lod->data(); + + auto col_size = x->dims()[1]; + auto row_size = lod->dims()[0] - 1; + auto y_numel = col_size * row_size; + + auto* y = context.Output("Y"); + T* y_data = y->mutable_data(context.GetPlace()); + + // for Input X do not have lod Information. + auto stream = context.template device_context().stream(); + unzipKernel<<<(y_numel + PADDLE_CUDA_NUM_THREADS - 1) / + PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, + 0, + stream>>>(x_data, lod_data, y_data, col_size, y_numel); + } +}; + +template +class unzipGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_THROW(phi::errors::Unimplemented("unzip_grad is unimplemented")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + unzip, + ops::unzipCUDAKernel, + ops::unzipCUDAKernel, + ops::unzipCUDAKernel, + ops::unzipCUDAKernel, + ops::unzipCUDAKernel, + ops::unzipCUDAKernel, + ops::unzipCUDAKernel, + ops::unzipCUDAKernel, + ops::unzipCUDAKernel, + ops::unzipCUDAKernel, + ops::unzipCUDAKernel, + ops::unzipCUDAKernel); + +REGISTER_OP_CUDA_KERNEL(unzip_grad, + ops::unzipGradCUDAKernel, + ops::unzipGradCUDAKernel, + ops::unzipGradCUDAKernel, + ops::unzipGradCUDAKernel, + ops::unzipGradCUDAKernel, + ops::unzipGradCUDAKernel); diff --git a/paddle/fluid/operators/unzip_op.h b/paddle/fluid/operators/unzip_op.h new file mode 100644 index 0000000000000000000000000000000000000000..f177f69476f1ed623a7941b0683f20c536ff2f52 --- /dev/null +++ b/paddle/fluid/operators/unzip_op.h @@ -0,0 +1,39 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class unzipOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_THROW(phi::errors::Unimplemented("unzip is unimplemented")); + } +}; + +template +class unzipGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_THROW(phi::errors::Unimplemented("unzip_grad is unimplemented")); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/pybind/data_set_py.cc b/paddle/fluid/pybind/data_set_py.cc index 4ed5f32ff3088d7d45765066741f7813824725c1..f42c179b0a1970d595b3e4fe13a52f0d8fc3a910 100644 --- a/paddle/fluid/pybind/data_set_py.cc +++ b/paddle/fluid/pybind/data_set_py.cc @@ -290,6 +290,9 @@ void BindDataset(py::module *m) { .def("get_memory_data_size", &framework::Dataset::GetMemoryDataSize, py::call_guard()) + .def("get_epoch_finish", + &framework::Dataset::GetEpochFinish, + py::call_guard()) .def("get_pv_data_size", &framework::Dataset::GetPvDataSize, py::call_guard()) diff --git a/paddle/fluid/pybind/ps_gpu_wrapper_py.cc b/paddle/fluid/pybind/ps_gpu_wrapper_py.cc index 7c02a02aff775cbde43d210ad5853664dc58df7a..7f0026580af376ebe6e30686cafa5b198c4e1bb0 100644 --- a/paddle/fluid/pybind/ps_gpu_wrapper_py.cc +++ b/paddle/fluid/pybind/ps_gpu_wrapper_py.cc @@ -41,6 +41,9 @@ void BindPSGPUWrapper(py::module* m) { .def("set_slot_vector", &framework::PSGPUWrapper::SetSlotVector, py::call_guard()) + .def("set_slot_num_for_pull_feature", + &framework::PSGPUWrapper::SetPullFeatureSlotNum, + py::call_guard()) #ifdef PADDLE_WITH_CUDA .def("set_slot_dim_vector", &framework::PSGPUWrapper::SetSlotDimVector, diff --git a/python/paddle/distributed/passes/ps_trainer_pass.py b/python/paddle/distributed/passes/ps_trainer_pass.py index 5bb15e40854a40ab58cf28865e26ca6fe108b578..2a60b0df5f5eb2bf1c6a81636a3693036e1bb6e6 100755 --- a/python/paddle/distributed/passes/ps_trainer_pass.py +++ b/python/paddle/distributed/passes/ps_trainer_pass.py @@ -19,10 +19,10 @@ from _collections import defaultdict import paddle import paddle.fluid.framework as framework from paddle.distributed.passes.pass_base import PassBase, register_pass +from paddle.fluid.transpiler.collective import SingleProcessMultiThread from paddle.framework import core from paddle.static import Parameter, Program -from ..ps.utils.collective_transpiler import SingleProcessMultiThread from ..ps.utils.public import * # noqa: F403 diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index 650cf3756dabb8500bdf5e063b286692e73cdff4..b1b3afd730db4160da41e3166a04978634cf6345 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -967,6 +967,9 @@ class InMemoryDataset(DatasetBase): """ return self.dataset.get_pv_data_size() + def get_epoch_finish(self): + return self.dataset.get_epoch_finish() + @deprecated( since="2.0.0", update_to="paddle.distributed.InMemoryDataset.get_memory_data_size", @@ -1121,6 +1124,15 @@ class InMemoryDataset(DatasetBase): self.proto_desc.graph_config.infer_table_cap = config.get( "infer_table_cap", 800000 ) + self.proto_desc.graph_config.excluded_train_pair = config.get( + "excluded_train_pair", "" + ) + self.proto_desc.graph_config.infer_node_type = config.get( + "infer_node_type", "" + ) + self.proto_desc.graph_config.get_degree = config.get( + "get_degree", False + ) self.dataset.set_gpu_graph_mode(True) def set_pass_id(self, pass_id): diff --git a/python/paddle/fluid/tests/unittests/test_dataset.py b/python/paddle/fluid/tests/unittests/test_dataset.py index f98193ea64b1b4ef909c00ae3e79284241c82ec7..ba5d5a2c858b64f3bc13ab228d1838349b3adc33 100644 --- a/python/paddle/fluid/tests/unittests/test_dataset.py +++ b/python/paddle/fluid/tests/unittests/test_dataset.py @@ -313,6 +313,7 @@ class TestDataset(unittest.TestCase): dataset.set_graph_config(graph_config) dataset.set_pass_id(0) dataset.get_pass_id() + dataset.get_epoch_finish() def test_in_memory_dataset_masterpatch(self): """ diff --git a/python/paddle/fluid/tests/unittests/test_unzip_op.py b/python/paddle/fluid/tests/unittests/test_unzip_op.py new file mode 100644 index 0000000000000000000000000000000000000000..65a353822b12fbd185dce4e9e5bdbefcd2370e5a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_unzip_op.py @@ -0,0 +1,68 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core + + +class TestUnzipOp(unittest.TestCase): + def test_result(self): + """ + For unzip op + """ + paddle.enable_static() + if core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + x = fluid.data(name='X', shape=[3, 4], dtype='float64') + lod = fluid.data(name='lod', shape=[11], dtype='int64') + output = paddle.incubate.unzip(x, lod) + + input = [ + [1.0, 2.0, 3.0, 4.0], + [10.0, 20.0, 30.0, 40.0], + [100.0, 200.0, 300.0, 400.0], + ] + lod = [0, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12] + + feed = { + 'X': np.array(input).astype("float64"), + 'lod': np.array(lod).astype("int64"), + } + + exe = fluid.Executor(place=place) + exe.run(fluid.default_startup_program()) + res = exe.run(feed=feed, fetch_list=[output]) + out = [ + [1.0, 2.0, 3.0, 4.0], + [0.0, 0.0, 0.0, 0.0], + [10.0, 20.0, 30.0, 40.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [100.0, 200.0, 300.0, 400.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ] + out_np = np.array(out, dtype="float64") + assert (res == out_np).all(), "output is not right" + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/incubate/__init__.py b/python/paddle/incubate/__init__.py index 97dd8353be50c800e03e9ee947368241b49cf44e..69d4ca2b0c8ae5c7edb607c954a887bf256074e4 100644 --- a/python/paddle/incubate/__init__.py +++ b/python/paddle/incubate/__init__.py @@ -23,6 +23,7 @@ from .operators import graph_send_recv from .operators import graph_khop_sampler from .operators import graph_sample_neighbors from .operators import graph_reindex +from .operators import unzip from .tensor import segment_sum from .tensor import segment_mean from .tensor import segment_max @@ -55,4 +56,5 @@ __all__ = [ 'segment_max', 'segment_min', 'identity_loss', + 'unzip', ] diff --git a/python/paddle/incubate/operators/__init__.py b/python/paddle/incubate/operators/__init__.py index eb105a12e1ab6429a46596387ee80c28e99166ad..e96c3641196574b859e8c891e50eafc0c320255c 100644 --- a/python/paddle/incubate/operators/__init__.py +++ b/python/paddle/incubate/operators/__init__.py @@ -21,3 +21,4 @@ from .graph_send_recv import graph_send_recv # noqa: F401 from .graph_khop_sampler import graph_khop_sampler # noqa: F401 from .graph_sample_neighbors import graph_sample_neighbors # noqa: F401 from .graph_reindex import graph_reindex # noqa: F401 +from .unzip import unzip # noqa: F401 diff --git a/python/paddle/incubate/operators/unzip.py b/python/paddle/incubate/operators/unzip.py new file mode 100644 index 0000000000000000000000000000000000000000..c8ed13ed0142a449ec86759095bd79cc9fcddd75 --- /dev/null +++ b/python/paddle/incubate/operators/unzip.py @@ -0,0 +1,77 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.fluid.data_feeder import check_variable_and_dtype +from paddle.fluid.layer_helper import LayerHelper + + +def unzip(input, lod): + r""" + + **unzip layers** + + unzip 'input' accroding to 'lod' + + Args: + input (Variable): The zipped input, 2-D LodTensor with shape [N, M]. + lod (Variable): The original lod of unzipped input, 1-D LodTensor with shape[K]. + + Returns: + Variable: The original unzipped tensor, 2-D LodTensor with shape[K-1, M]. + + Examples: + + .. code-block:: python + import numpy as np + import paddle + import paddle.fluid as fluid + paddle.enable_static() + input_np = np.array([ + [1.0, 2.0, 3.0, 4.0], + [10.0, 20.0, 30.0, 40.0], + [100.0, 200.0, 300.0, 400.0] + ]) + lod_np = np.array([0, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12]) + input = paddle.to_tensor(input_np, "int64") + lod = paddle.to_tensor(lod_np, "int64") + + unzipped_input = paddle.incubate.unzip(input, lod) + ''' + unzipped_input is [ + [1.0, 2.0, 3.0, 4.0], + [0.0, 0.0, 0.0, 0.0], + [10.0, 20.0, 30.0, 40.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [100.0, 200.0, 300.0, 400.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0] + ] + ''' + """ + helper = LayerHelper('unzip', **locals()) + out = helper.create_variable(dtype=input.dtype) + check_variable_and_dtype( + input, + 'input', + ['float16', 'float32', 'float64', 'int', 'bool', 'int64'], + 'unzip', + ) + check_variable_and_dtype(lod, 'lod', ['int', 'int64'], 'unzip') + helper.append_op( + type='unzip', inputs={'X': [input], 'lod': [lod]}, outputs={'Y': [out]} + ) + return out