未验证 提交 3f619290 编写于 作者: Y yaoxuefeng 提交者: GitHub

merge dymf branch (#42714)

merge dymf branch
上级 e726960a
...@@ -129,11 +129,6 @@ class HeterContext { ...@@ -129,11 +129,6 @@ class HeterContext {
for (size_t i = 0; i < feature_dim_keys_.size(); i++) { for (size_t i = 0; i < feature_dim_keys_.size(); i++) {
feature_dim_keys_[i].resize(dim_num); feature_dim_keys_[i].resize(dim_num);
value_dim_ptr_[i].resize(dim_num); value_dim_ptr_[i].resize(dim_num);
if (i == 0) {
for (int j = 0; j < dim_num; j++) {
feature_dim_keys_[i][j].push_back(0);
}
}
} }
device_values_.resize(device_num); device_values_.resize(device_num);
device_dim_values_.resize(device_num); device_dim_values_.resize(device_num);
......
...@@ -32,17 +32,33 @@ struct FeatureValue { ...@@ -32,17 +32,33 @@ struct FeatureValue {
float lr; float lr;
float lr_g2sum; float lr_g2sum;
int mf_size; int mf_size;
float mf[MF_DIM + 1]; int mf_dim;
uint64_t cpu_ptr; uint64_t cpu_ptr;
float mf[0];
friend std::ostream& operator<<(std::ostream& out, FeatureValue& val) { friend std::ostream& operator<<(std::ostream& out, FeatureValue& val) {
out << "show: " << val.show << " clk: " << val.clk << " slot: " << val.slot out << "show: " << val.show << " clk: " << val.clk << " slot: " << val.slot
<< " lr: " << val.lr << " mf_size: " << val.mf_size << " mf:"; << " lr: " << val.lr << " mf_dim: " << val.mf_dim
for (int i = 0; i < val.mf_size; ++i) { << "cpuptr: " << val.cpu_ptr << " mf_size: " << val.mf_size << " mf:";
for (int i = 0; i < val.mf_dim + 1; ++i) {
out << " " << val.mf[i]; out << " " << val.mf[i];
} }
return out; return out;
} }
__device__ __forceinline__ void operator=(const FeatureValue& in) {
delta_score = in.delta_score;
show = in.show;
clk = in.clk;
slot = in.slot;
lr = in.lr;
lr_g2sum = in.lr_g2sum;
mf_size = in.mf_size;
mf_dim = in.mf_dim;
cpu_ptr = in.cpu_ptr;
for (int i = 0; i < mf_dim + 1; i++) {
mf[i] = in.mf[i];
}
}
}; };
struct FeaturePushValue { struct FeaturePushValue {
...@@ -50,20 +66,19 @@ struct FeaturePushValue { ...@@ -50,20 +66,19 @@ struct FeaturePushValue {
float clk; float clk;
int slot; int slot;
float lr_g; float lr_g;
float mf_g[MF_DIM]; int mf_dim;
float mf_g[0];
// __device__ __forceinline__ FeaturePushValue __device__ __forceinline__ void operator=(const FeaturePushValue& in) {
// operator+(const FeaturePushValue& a) const { show = in.show;
// FeaturePushValue out; clk = in.clk;
// out.slot = a.slot; slot = in.slot;
// out.show = a.show + show; lr_g = in.lr_g;
// out.clk = a.clk + clk; mf_dim = in.mf_dim;
// out.lr_g = a.lr_g + lr_g; for (int i = 0; i < mf_dim; i++) {
// for (int i = 0; i < MF_DIM; ++i) { mf_g[i] = in.mf_g[i];
// out.mf_g[i] = a.mf_g[i] + mf_g[i]; }
// } }
// return out;
// }
}; };
} // end namespace framework } // end namespace framework
......
...@@ -118,8 +118,8 @@ class HashTable { ...@@ -118,8 +118,8 @@ class HashTable {
StreamType stream); StreamType stream);
template <typename StreamType> template <typename StreamType>
void insert(const KeyType* d_keys, size_t len, char* pool, size_t start_index, void insert(const KeyType* d_keys, size_t len, char* pool,
StreamType stream); size_t feature_value_size, size_t start_index, StreamType stream);
template <typename StreamType> template <typename StreamType>
void get(const KeyType* d_keys, ValType* d_vals, size_t len, void get(const KeyType* d_keys, ValType* d_vals, size_t len,
......
...@@ -50,7 +50,8 @@ __global__ void insert_kernel(Table* table, ...@@ -50,7 +50,8 @@ __global__ void insert_kernel(Table* table,
template <typename Table> template <typename Table>
__global__ void insert_kernel(Table* table, __global__ void insert_kernel(Table* table,
const typename Table::key_type* const keys, const typename Table::key_type* const keys,
size_t len, char* pool, int start_index) { size_t len, char* pool, size_t feature_value_size,
int start_index) {
ReplaceOp<typename Table::mapped_type> op; ReplaceOp<typename Table::mapped_type> op;
thrust::pair<typename Table::key_type, typename Table::mapped_type> kv; thrust::pair<typename Table::key_type, typename Table::mapped_type> kv;
...@@ -58,7 +59,8 @@ __global__ void insert_kernel(Table* table, ...@@ -58,7 +59,8 @@ __global__ void insert_kernel(Table* table,
if (i < len) { if (i < len) {
kv.first = keys[i]; kv.first = keys[i];
kv.second = (Table::mapped_type)(pool + (start_index + i) * 80); uint64_t offset = uint64_t(start_index + i) * feature_value_size;
kv.second = (Table::mapped_type)(pool + offset);
auto it = table->insert(kv, op); auto it = table->insert(kv, op);
assert(it != table->end() && "error: insert fails: table is full"); assert(it != table->end() && "error: insert fails: table is full");
} }
...@@ -81,14 +83,16 @@ __global__ void search_kernel(Table* table, ...@@ -81,14 +83,16 @@ __global__ void search_kernel(Table* table,
template <typename Table> template <typename Table>
__global__ void dy_mf_search_kernel(Table* table, __global__ void dy_mf_search_kernel(Table* table,
const typename Table::key_type* const keys, const typename Table::key_type* const keys,
char* const vals, size_t len, char* vals, size_t len,
size_t pull_feature_value_size) { size_t pull_feature_value_size) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x; const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) { if (i < len) {
auto it = table->find(keys[i]); auto it = table->find(keys[i]);
if (it != table->end()) { if (it != table->end()) {
*(FeatureValue*)(vals + i * pull_feature_value_size) = *(it->second); uint64_t offset = i * pull_feature_value_size;
FeatureValue& cur = *(FeatureValue*)(vals + offset);
FeatureValue& input = *(FeatureValue*)(it->second);
} }
} }
} }
...@@ -121,7 +125,7 @@ __global__ void dy_mf_update_kernel(Table* table, ...@@ -121,7 +125,7 @@ __global__ void dy_mf_update_kernel(Table* table,
FeaturePushValue* cur = (FeaturePushValue*)(grads + i * grad_value_size); FeaturePushValue* cur = (FeaturePushValue*)(grads + i * grad_value_size);
sgd.dy_mf_update_value(optimizer_config, (it.getter())->second, *cur); sgd.dy_mf_update_value(optimizer_config, (it.getter())->second, *cur);
} else { } else {
printf("yxf::push miss key: %d", keys[i]); printf("warning: push miss key: %d", keys[i]);
} }
} }
} }
...@@ -201,7 +205,8 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys, ...@@ -201,7 +205,8 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
template <typename KeyType, typename ValType> template <typename KeyType, typename ValType>
template <typename StreamType> template <typename StreamType>
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys, size_t len, void HashTable<KeyType, ValType>::insert(const KeyType* d_keys, size_t len,
char* pool, size_t start_index, char* pool, size_t feature_value_size,
size_t start_index,
StreamType stream) { StreamType stream) {
if (len == 0) { if (len == 0) {
return; return;
...@@ -210,8 +215,8 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys, size_t len, ...@@ -210,8 +215,8 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys, size_t len,
return; return;
} }
const int grid_size = (len - 1) / BLOCK_SIZE_ + 1; const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(container_, d_keys, len, insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
pool, start_index); container_, d_keys, len, pool, feature_value_size, start_index);
} }
template <typename KeyType, typename ValType> template <typename KeyType, typename ValType>
...@@ -319,6 +324,7 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys, ...@@ -319,6 +324,7 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
} }
template class HashTable<unsigned long, paddle::framework::FeatureValue>; template class HashTable<unsigned long, paddle::framework::FeatureValue>;
template class HashTable<unsigned long, paddle::framework::FeatureValue*>;
template class HashTable<long, int>; template class HashTable<long, int>;
template class HashTable<unsigned long, int>; template class HashTable<unsigned long, int>;
template class HashTable<unsigned long, unsigned long>; template class HashTable<unsigned long, unsigned long>;
...@@ -331,6 +337,10 @@ template void HashTable<unsigned long, paddle::framework::FeatureValue>::get< ...@@ -331,6 +337,10 @@ template void HashTable<unsigned long, paddle::framework::FeatureValue>::get<
paddle::framework::FeatureValue* d_vals, size_t len, paddle::framework::FeatureValue* d_vals, size_t len,
cudaStream_t stream); cudaStream_t stream);
template void
HashTable<unsigned long, paddle::framework::FeatureValue*>::get<cudaStream_t>(
const unsigned long* d_keys, char* d_vals, size_t len, cudaStream_t stream);
template void HashTable<long, int>::get<cudaStream_t>(const long* d_keys, template void HashTable<long, int>::get<cudaStream_t>(const long* d_keys,
int* d_vals, size_t len, int* d_vals, size_t len,
cudaStream_t stream); cudaStream_t stream);
...@@ -354,6 +364,11 @@ template void HashTable<unsigned long, paddle::framework::FeatureValue>::insert< ...@@ -354,6 +364,11 @@ template void HashTable<unsigned long, paddle::framework::FeatureValue>::insert<
const paddle::framework::FeatureValue* d_vals, size_t len, const paddle::framework::FeatureValue* d_vals, size_t len,
cudaStream_t stream); cudaStream_t stream);
template void HashTable<unsigned long, paddle::framework::FeatureValue*>::
insert<cudaStream_t>(const unsigned long* d_keys, size_t len, char* pool,
size_t feature_value_size, size_t start_index,
cudaStream_t stream);
template void HashTable<long, int>::insert<cudaStream_t>(const long* d_keys, template void HashTable<long, int>::insert<cudaStream_t>(const long* d_keys,
const int* d_vals, const int* d_vals,
size_t len, size_t len,
...@@ -393,6 +408,16 @@ template void HashTable<unsigned long, paddle::framework::FeatureValue>::update< ...@@ -393,6 +408,16 @@ template void HashTable<unsigned long, paddle::framework::FeatureValue>::update<
sgd, sgd,
cudaStream_t stream); cudaStream_t stream);
template void
HashTable<unsigned long, paddle::framework::FeatureValue*>::update<
Optimizer<paddle::framework::FeatureValue,
paddle::framework::FeaturePushValue>,
cudaStream_t>(const unsigned long* d_keys, const char* d_grads, size_t len,
Optimizer<paddle::framework::FeatureValue,
paddle::framework::FeaturePushValue>
sgd,
cudaStream_t stream);
// template void HashTable<unsigned long, // template void HashTable<unsigned long,
// paddle::framework::FeatureValue>::update< // paddle::framework::FeatureValue>::update<
// Optimizer<paddle::framework::FeatureValue, // Optimizer<paddle::framework::FeatureValue,
......
...@@ -15,10 +15,13 @@ limitations under the License. */ ...@@ -15,10 +15,13 @@ limitations under the License. */
#pragma once #pragma once
#include <thread> #include <thread>
#include <vector> #include <vector>
#include "cub/cub.cuh"
#include "cub/util_allocator.cuh"
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h" #include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
#include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/dynload/nccl.h" #include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/timer.h"
#include "thrust/pair.h" #include "thrust/pair.h"
#elif defined(PADDLE_WITH_XPU_KP) #elif defined(PADDLE_WITH_XPU_KP)
// #include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h" // #include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"
...@@ -38,6 +41,9 @@ limitations under the License. */ ...@@ -38,6 +41,9 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
#define TYPEALIGN(ALIGNVAL, LEN) \
(((uint64_t)(LEN) + ((ALIGNVAL)-1)) & ~((uint64_t)((ALIGNVAL)-1)))
template <typename KeyType, typename ValType, typename GradType> template <typename KeyType, typename ValType, typename GradType>
class HeterComm { class HeterComm {
public: public:
...@@ -50,9 +56,13 @@ class HeterComm { ...@@ -50,9 +56,13 @@ class HeterComm {
int* left, int* right, int gpu_num); int* left, int* right, int gpu_num);
void merge_grad(int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len, void merge_grad(int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len,
int& uniq_len); // NOLINT int& uniq_len); // NOLINT
void dynamic_merge_grad(int gpu_num, KeyType* d_keys, GradType* d_grads,
size_t len, int& uniq_len);
void pull_sparse(int num, KeyType* d_keys, ValType* d_vals, size_t len); void pull_sparse(int num, KeyType* d_keys, ValType* d_vals, size_t len);
void build_ps(int num, KeyType* h_keys, ValType* h_vals, size_t len, void build_ps(int num, KeyType* h_keys, ValType* h_vals, size_t len,
size_t chunk_size, int stream_num); size_t chunk_size, int stream_num);
void build_ps(int num, KeyType* h_keys, char* pool, size_t len,
size_t feature_value_size, size_t chunk_size, int stream_num);
void dump(); void dump();
void show_one_table(int gpu_num); void show_one_table(int gpu_num);
int get_index_by_devid(int devid); int get_index_by_devid(int devid);
...@@ -96,6 +106,11 @@ class HeterComm { ...@@ -96,6 +106,11 @@ class HeterComm {
nccl_inter_comms_ = inter_comms; nccl_inter_comms_ = inter_comms;
node_size_ = comm_size; node_size_ = comm_size;
} }
void set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) {
multi_mf_dim_ = multi_mf_dim;
max_mf_dim_ = max_mf_dim;
}
#endif #endif
bool need_transfer(int send_id, int receive_id) { bool need_transfer(int send_id, int receive_id) {
...@@ -114,8 +129,8 @@ class HeterComm { ...@@ -114,8 +129,8 @@ class HeterComm {
char* key_storage; char* key_storage;
char* val_storage; char* val_storage;
int sync; int sync;
int key_bytes_len; size_t key_bytes_len;
int val_bytes_len; size_t val_bytes_len;
int dev_num; int dev_num;
}; };
...@@ -206,12 +221,18 @@ class HeterComm { ...@@ -206,12 +221,18 @@ class HeterComm {
void destroy_storage(int start_index, int end_index); void destroy_storage(int start_index, int end_index);
void walk_to_dest(int start_index, int gpu_num, int* h_left, int* h_right, void walk_to_dest(int start_index, int gpu_num, int* h_left, int* h_right,
KeyType* src_key, GradType* src_val); KeyType* src_key, GradType* src_val);
void walk_to_dest(int start_index, int gpu_num, int* h_left, int* h_right,
KeyType* src_key, char* src_val, size_t val_size);
void walk_to_src(int start_index, int gpu_num, int* h_left, int* h_right, void walk_to_src(int start_index, int gpu_num, int* h_left, int* h_right,
ValType* src_val); ValType* src_val);
void walk_to_src(int start_index, int gpu_num, int* h_left, int* h_right,
char* src_val, size_t val_size);
protected: protected:
using Table = HashTable<KeyType, ValType>; using Table = HashTable<KeyType, ValType>;
using PtrTable = HashTable<KeyType, ValType*>;
std::vector<Table*> tables_; std::vector<Table*> tables_;
std::vector<PtrTable*> ptr_tables_;
std::shared_ptr<HeterPsResource> resource_; std::shared_ptr<HeterPsResource> resource_;
std::vector<std::vector<Path>> path_; std::vector<std::vector<Path>> path_;
float load_factor_{0.75}; float load_factor_{0.75};
...@@ -221,6 +242,7 @@ class HeterComm { ...@@ -221,6 +242,7 @@ class HeterComm {
private: private:
int topo_aware_{0}; int topo_aware_{0};
std::vector<LocalStorage> storage_; std::vector<LocalStorage> storage_;
DynamicGradMerger merger_;
int feanum_{1800 * 2048}; int feanum_{1800 * 2048};
int multi_node_{0}; int multi_node_{0};
int node_size_; int node_size_;
...@@ -228,6 +250,8 @@ class HeterComm { ...@@ -228,6 +250,8 @@ class HeterComm {
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
std::vector<ncclComm_t> nccl_inner_comms_; std::vector<ncclComm_t> nccl_inner_comms_;
std::vector<ncclComm_t> nccl_inter_comms_; std::vector<ncclComm_t> nccl_inter_comms_;
int multi_mf_dim_{8};
int max_mf_dim_ = 8;
std::vector<std::shared_ptr<cub::CachingDeviceAllocator>> allocators_; std::vector<std::shared_ptr<cub::CachingDeviceAllocator>> allocators_;
#endif #endif
}; };
......
...@@ -117,6 +117,52 @@ __global__ void fill_dvals_kernel(ValType* d_shard_vals, ValType* d_vals, ...@@ -117,6 +117,52 @@ __global__ void fill_dvals_kernel(ValType* d_shard_vals, ValType* d_vals,
} }
} }
template <typename KeyType, typename GradType, typename T>
__global__ void dy_mf_fill_shard_grads_kernel(
KeyType* d_shard_keys, KeyType* d_keys, GradType* d_shard_grads,
GradType* d_grads, T* idx, size_t len, size_t grad_value_size) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
d_shard_keys[i] = d_keys[idx[i]];
*(GradType*)((char*)d_shard_grads + i * grad_value_size) =
*(GradType*)((char*)d_grads + uint64_t(idx[i]) * grad_value_size);
}
}
__global__ void merge_gradients_kernel(const uint32_t* offset,
const uint32_t* fea_num,
const uint32_t* index, const char* input,
char* output, int n,
size_t grad_value_size,
DynamicGradMerger& merger_) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) {
uint32_t start = offset[i];
uint32_t num = fea_num[i];
int ori_index = index[start];
FeaturePushValue& out = *(FeaturePushValue*)(output + i * grad_value_size);
FeaturePushValue& in =
*(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size);
merger_.update_one(out, in);
for (int j = 1; j < num; ++j) {
ori_index = index[start + j];
in = *(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size);
merger_.merge_one(out, in);
}
}
}
template <typename ValType, typename T>
__global__ void dy_mf_fill_dvals_kernel(ValType* d_shard_vals, ValType* d_vals,
T* idx, size_t len, size_t val_size) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
uint64_t new_offset = uint64_t(idx[i]) * val_size;
*(ValType*)((char*)d_vals + new_offset) =
*(ValType*)((char*)d_shard_vals + i * val_size);
}
}
// cuda implemention of heter_comm_kernel.h // cuda implemention of heter_comm_kernel.h
template <typename T, typename StreamType> template <typename T, typename StreamType>
void HeterCommKernel::fill_idx(T* idx, long long len, void HeterCommKernel::fill_idx(T* idx, long long len,
...@@ -207,8 +253,42 @@ void HeterCommKernel::reduce_by_key(void* d_temp_storage, ...@@ -207,8 +253,42 @@ void HeterCommKernel::reduce_by_key(void* d_temp_storage,
debug_synchronous)); debug_synchronous));
} }
template <typename KeyType, typename GradType, typename T, typename StreamType>
void HeterCommKernel::dy_mf_fill_shard_grads(
KeyType* d_shard_keys, KeyType* d_keys, GradType* d_shard_grads,
GradType* d_grads, T* idx, long long len, size_t grad_value_size,
const StreamType& stream) {
int grid_size = (len - 1) / block_size_ + 1;
size_t c_len = (size_t)len;
dy_mf_fill_shard_grads_kernel<<<grid_size, block_size_, 0, stream>>>(
d_shard_keys, d_keys, d_shard_grads, d_grads, idx, c_len,
grad_value_size);
}
template <typename StreamType>
void HeterCommKernel::merge_gradient(
const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index,
const char* input, char* output, int n, size_t grad_value_size,
DynamicGradMerger& merger_, const StreamType& stream) {
int grid_size = (n - 1) / block_size_ + 1;
merge_gradients_kernel<<<grid_size, block_size_, 0, stream>>>(
offset, fea_num, index, input, output, n, grad_value_size, merger_);
}
template <typename ValType, typename T, typename StreamType>
void HeterCommKernel::dy_mf_fill_dvals(ValType* d_shard_vals, ValType* d_vals,
T* idx, long long len, size_t val_size,
const StreamType& stream) {
int grid_size = (len - 1) / block_size_ + 1;
size_t c_len = (size_t)len;
dy_mf_fill_dvals_kernel<<<grid_size, block_size_, 0, stream>>>(
d_shard_vals, d_vals, idx, c_len, val_size);
}
template void HeterCommKernel::fill_idx<int, cudaStream_t>( template void HeterCommKernel::fill_idx<int, cudaStream_t>(
int* idx, long long len, const cudaStream_t& stream); int* idx, long long len, const cudaStream_t& stream);
template void HeterCommKernel::fill_idx<uint32_t, cudaStream_t>(
uint32_t* idx, long long len, const cudaStream_t& stream);
template void HeterCommKernel::calc_shard_offset<int, cudaStream_t>( template void HeterCommKernel::calc_shard_offset<int, cudaStream_t>(
int* idx, int* left, int* right, long long len, int total_devs, int* idx, int* left, int* right, long long len, int total_devs,
...@@ -270,6 +350,23 @@ template void HeterCommKernel::reduce_by_key< ...@@ -270,6 +350,23 @@ template void HeterCommKernel::reduce_by_key<
paddle::framework::FeaturePushValue* d_aggregates_out, int* d_num_runs_out, paddle::framework::FeaturePushValue* d_aggregates_out, int* d_num_runs_out,
int num_items, cudaStream_t stream, bool debug_synchronous); int num_items, cudaStream_t stream, bool debug_synchronous);
template void HeterCommKernel::dy_mf_fill_shard_grads<
unsigned long, paddle::framework::FeaturePushValue, int, cudaStream_t>(
unsigned long* d_shard_keys, unsigned long* d_keys,
paddle::framework::FeaturePushValue* d_shard_grads,
paddle::framework::FeaturePushValue* d_grads, int* idx, long long len,
size_t grad_value_size, const cudaStream_t& stream);
template void HeterCommKernel::merge_gradient<cudaStream_t>(
const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index,
const char* input, char* output, int n, size_t grad_value_size,
DynamicGradMerger& merger_, const cudaStream_t& stream);
template void HeterCommKernel::dy_mf_fill_dvals<paddle::framework::FeatureValue,
int, cudaStream_t>(
paddle::framework::FeatureValue* d_shard_vals,
paddle::framework::FeatureValue* d_vals, int* idx, long long len,
size_t val_size, const cudaStream_t& stream);
#endif #endif
} // namespace framework } // namespace framework
......
...@@ -27,6 +27,42 @@ limitations under the License. */ ...@@ -27,6 +27,42 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
struct DynamicGradMerger {
template <typename T>
CUB_RUNTIME_FUNCTION __forceinline__ __device__ T
operator()(const T& a, const T& b) const {
T out;
out.slot = a.slot;
out.mf_dim = a.mf_dim;
out.show = a.show + b.show;
out.clk = a.clk + b.clk;
out.lr_g = a.lr_g + b.lr_g;
return out;
}
template <typename T>
__device__ __forceinline__ void update_one(T& output, const T& input) {
output.slot = input.slot;
output.show = input.show;
output.clk = input.clk;
output.mf_dim = input.mf_dim;
output.lr_g = input.lr_g;
for (int i = 0; i < output.mf_dim; ++i) {
output.mf_g[i] = input.mf_g[i];
}
}
template <typename T>
__device__ __forceinline__ void merge_one(T& output, const T& input) {
output.show += input.show;
output.clk += input.clk;
output.lr_g += input.lr_g;
for (int i = 0; i < input.mf_dim; ++i) {
output.mf_g[i] += input.mf_g[i];
}
}
};
class HeterCommKernel { class HeterCommKernel {
public: public:
HeterCommKernel() {} HeterCommKernel() {}
...@@ -80,6 +116,24 @@ class HeterCommKernel { ...@@ -80,6 +116,24 @@ class HeterCommKernel {
StreamType stream = NULL, bool debug_synchronous = false); StreamType stream = NULL, bool debug_synchronous = false);
template <typename KeyType, typename GradType, typename T,
typename StreamType>
void dy_mf_fill_shard_grads(KeyType* d_shard_keys, KeyType* d_keys,
GradType* d_shard_grads, GradType* d_grads,
T* idx, long long len, size_t grad_value_size,
const StreamType& stream);
template <typename StreamType>
void merge_gradient(const uint32_t* offset, const uint32_t* fea_num,
const uint32_t* index, const char* input, char* output,
int n, size_t grad_value_size, DynamicGradMerger& merger_,
const StreamType& stream);
template <typename ValType, typename T, typename StreamType>
void dy_mf_fill_dvals(ValType* d_shard_vals, ValType* d_vals, T* idx,
long long len, size_t val_size,
const StreamType& stream);
private: private:
int block_size_{256}; int block_size_{256};
}; };
......
...@@ -44,6 +44,13 @@ void HeterPs::build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals, ...@@ -44,6 +44,13 @@ void HeterPs::build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals,
comm_->build_ps(num, h_keys, h_vals, len, chunk_size, stream_num); comm_->build_ps(num, h_keys, h_vals, len, chunk_size, stream_num);
} }
void HeterPs::build_ps(int num, FeatureKey* h_keys, char* pool, size_t len,
size_t feature_value_size, size_t chunk_size,
int stream_num) {
comm_->build_ps(num, h_keys, pool, len, feature_value_size, chunk_size,
stream_num);
}
int HeterPs::get_index_by_devid(int devid) { int HeterPs::get_index_by_devid(int devid) {
return comm_->get_index_by_devid(devid); return comm_->get_index_by_devid(devid);
} }
...@@ -72,6 +79,10 @@ void HeterPs::set_nccl_comm_and_size(const std::vector<ncclComm_t>& inner_comms, ...@@ -72,6 +79,10 @@ void HeterPs::set_nccl_comm_and_size(const std::vector<ncclComm_t>& inner_comms,
comm_->set_nccl_comm_and_size(inner_comms, inter_comms, comm_size); comm_->set_nccl_comm_and_size(inner_comms, inter_comms, comm_size);
} }
void HeterPs::set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) {
comm_->set_multi_mf_dim(multi_mf_dim, max_mf_dim);
}
} // end namespace framework } // end namespace framework
} // end namespace paddle } // end namespace paddle
#endif #endif
...@@ -37,11 +37,14 @@ class HeterPs : public HeterPsBase { ...@@ -37,11 +37,14 @@ class HeterPs : public HeterPsBase {
size_t len) override; size_t len) override;
void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals, size_t len, void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals, size_t len,
size_t chunk_size, int stream_num) override; size_t chunk_size, int stream_num) override;
void build_ps(int num, FeatureKey* h_keys, char* pool, size_t len,
size_t feature_value_size, size_t chunk_size,
int stream_num) override;
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
void set_nccl_comm_and_size(const std::vector<ncclComm_t>& inner_comms, void set_nccl_comm_and_size(const std::vector<ncclComm_t>& inner_comms,
const std::vector<ncclComm_t>& inter_comms, const std::vector<ncclComm_t>& inter_comms,
int comm_size) override; int comm_size) override;
void set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) override;
#endif #endif
void set_sparse_sgd(const OptimizerConfig& optimizer_config) override; void set_sparse_sgd(const OptimizerConfig& optimizer_config) override;
......
...@@ -35,11 +35,15 @@ class HeterPsBase { ...@@ -35,11 +35,15 @@ class HeterPsBase {
size_t len) = 0; size_t len) = 0;
virtual void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals, virtual void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals,
size_t len, size_t chunk_size, int stream_num) = 0; size_t len, size_t chunk_size, int stream_num) = 0;
virtual void build_ps(int num, FeatureKey* h_keys, char* pool, size_t len,
size_t feature_value_size, size_t chunk_size,
int stream_num) = 0;
virtual int get_index_by_devid(int devid) = 0; virtual int get_index_by_devid(int devid) = 0;
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
virtual void set_nccl_comm_and_size( virtual void set_nccl_comm_and_size(
const std::vector<ncclComm_t>& inner_comms, const std::vector<ncclComm_t>& inner_comms,
const std::vector<ncclComm_t>& inter_comms, int comm_size) = 0; const std::vector<ncclComm_t>& inter_comms, int comm_size) = 0;
virtual void set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) = 0;
#endif #endif
virtual void end_pass() = 0; virtual void end_pass() = 0;
virtual void show_one_table(int gpu_num) = 0; virtual void show_one_table(int gpu_num) = 0;
......
...@@ -107,6 +107,8 @@ class HeterPsResource { ...@@ -107,6 +107,8 @@ class HeterPsResource {
int get_index_by_devid(int devid); int get_index_by_devid(int devid);
int dev_id(int num); int dev_id(int num);
void set_multi_mf(int multi_mf_dim, int max_mf_dim); void set_multi_mf(int multi_mf_dim, int max_mf_dim);
int multi_mf() { return multi_mf_dim_; }
int max_mf_dim() { return max_mf_dim_; }
ppStream local_stream(int dev_num, int stream_num); ppStream local_stream(int dev_num, int stream_num);
ppStream remote_stream(int dev_num, int stream_num); ppStream remote_stream(int dev_num, int stream_num);
......
...@@ -125,20 +125,21 @@ class Optimizer { ...@@ -125,20 +125,21 @@ class Optimizer {
if (optimizer_config.mf_create_thresholds <= if (optimizer_config.mf_create_thresholds <=
optimizer_config.nonclk_coeff * (ptr->show - ptr->clk) + optimizer_config.nonclk_coeff * (ptr->show - ptr->clk) +
optimizer_config.clk_coeff * ptr->clk) { optimizer_config.clk_coeff * ptr->clk) {
// ptr->mf_size = ptr->mf_dim + 1; ptr->mf_size = ptr->mf_dim + 1;
ptr->mf_size = MF_DIM + 1; // ptr->mf_size = MF_DIM + 1;
ptr->mf[0] = 0; ptr->mf[0] = 0;
int tid_x = blockIdx.x * blockDim.x + threadIdx.x; int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
curandState state; curandState state;
curand_init(clock64(), tid_x, 0, &state); curand_init(clock64(), tid_x, 0, &state);
for (int i = 0; i < MF_DIM; ++i) { for (int i = 0; i < ptr->mf_dim; ++i) {
ptr->mf[i + 1] = ptr->mf[i + 1] =
(curand_uniform(&state)) * optimizer_config.mf_initial_range; (curand_uniform(&state)) * optimizer_config.mf_initial_range;
} }
} }
} else { } else {
update_mf(optimizer_config, MF_DIM, &(ptr->mf[1]), ptr->mf[0], grad.mf_g, update_mf(optimizer_config, ptr->mf_dim, &(ptr->mf[1]), ptr->mf[0],
grad.mf_g,
grad.show); // for local test grad.show); // for local test
} }
} }
......
...@@ -61,6 +61,45 @@ __global__ void PullCopy(float** dest, const FeatureValue* src, ...@@ -61,6 +61,45 @@ __global__ void PullCopy(float** dest, const FeatureValue* src,
} }
} }
__global__ void PullCopy(float** dest, const FeatureValue* src,
const int64_t* len, int slot_num, int total_len,
uint64_t** keys, uint64_t max_val_size, int* gpu_dim) {
CUDA_KERNEL_LOOP(i, total_len) {
int low = 0;
int high = slot_num - 1;
while (low < high) {
int mid = (low + high) / 2;
if (i < len[mid])
high = mid;
else
low = mid + 1;
}
int x = low;
int y = i - (x ? len[x - 1] : 0);
FeatureValue* feature_value_ptr =
(FeatureValue*)((char*)src + uint64_t(i) * uint64_t(max_val_size));
int mf_dim = gpu_dim[x] - 3;
if (*(keys[x] + y) == 0) {
*(dest[x] + y * (mf_dim + 3)) = 0;
*(dest[x] + y * (mf_dim + 3) + 1) = 0;
*(dest[x] + y * (mf_dim + 3) + 2) = 0;
} else {
*(dest[x] + y * (mf_dim + 3)) = feature_value_ptr->show;
*(dest[x] + y * (mf_dim + 3) + 1) = feature_value_ptr->clk;
*(dest[x] + y * (mf_dim + 3) + 2) = feature_value_ptr->lr;
}
if ((feature_value_ptr)->mf_size == 0 || *(keys[x] + y) == 0) {
for (int j = 0; j < mf_dim; j++) {
*(dest[x] + y * (mf_dim + 3) + 3 + j) = 0;
}
} else {
for (int j = 0; j < mf_dim; j++) {
*(dest[x] + y * (mf_dim + 3) + 3 + j) = feature_value_ptr->mf[1 + j];
}
}
}
}
__global__ void CopyKeysKernel(uint64_t** src_keys, uint64_t* dest_total_keys, __global__ void CopyKeysKernel(uint64_t** src_keys, uint64_t* dest_total_keys,
const int64_t* len, int slot_num, const int64_t* len, int slot_num,
int total_len) { int total_len) {
...@@ -105,6 +144,35 @@ __global__ void PushCopy(FeaturePushValue* dest, float** src, int64_t* len, ...@@ -105,6 +144,35 @@ __global__ void PushCopy(FeaturePushValue* dest, float** src, int64_t* len,
} }
} }
__global__ void PushCopyWithPool(FeaturePushValue* dest, float** src,
int64_t* len, int slot_num, uint64_t total_len,
int bs, int* slot_vector, int* mf_dim_vector,
size_t grad_value_size) {
CUDA_KERNEL_LOOP(i, total_len) {
int low = 0;
int high = slot_num - 1;
while (low < high) {
int mid = (low + high) / 2;
if (i < len[mid])
high = mid;
else
low = mid + 1;
}
int x = low;
int y = i - (x ? len[low - 1] : 0);
FeaturePushValue* cur =
(FeaturePushValue*)((char*)dest + i * grad_value_size);
cur->slot = slot_vector[x];
int mf_dim = mf_dim_vector[x];
cur->mf_dim = mf_dim;
cur->show = *(src[x] + y * (mf_dim + 3));
cur->clk = *(src[x] + y * (mf_dim + 3) + 1);
cur->lr_g = *(src[x] + y * (mf_dim + 3) + 2) * -1. * bs;
for (int j = 0; j < cur->mf_dim; j++) {
cur->mf_g[j] = *(src[x] + y * (mf_dim + 3) + 3 + j) * -1. * bs;
}
}
}
PSGPUWrapper::~PSGPUWrapper() { delete HeterPs_; } PSGPUWrapper::~PSGPUWrapper() { delete HeterPs_; }
void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place, void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place,
...@@ -128,6 +196,26 @@ void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place, ...@@ -128,6 +196,26 @@ void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place,
cudaStreamSynchronize(stream); cudaStreamSynchronize(stream);
} }
void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place,
uint64_t** gpu_keys,
const std::vector<float*>& values,
const FeatureValue* total_values_gpu,
const int64_t* gpu_len, const int slot_num,
const int hidden_size,
const int64_t total_length, int* gpu_dim) {
auto stream = dynamic_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(place))
->stream();
auto buf_value = memory::Alloc(place, values.size() * sizeof(float*));
float** gpu_values = reinterpret_cast<float**>(buf_value->ptr());
cudaMemcpy(gpu_values, values.data(), values.size() * sizeof(float*),
cudaMemcpyHostToDevice);
PullCopy<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>(
gpu_values, total_values_gpu, gpu_len, slot_num, total_length, gpu_keys,
val_type_size_, gpu_dim);
cudaStreamSynchronize(stream);
}
void PSGPUWrapper::CopyKeys(const paddle::platform::Place& place, void PSGPUWrapper::CopyKeys(const paddle::platform::Place& place,
uint64_t** origin_keys, uint64_t* total_keys, uint64_t** origin_keys, uint64_t* total_keys,
const int64_t* gpu_len, int slot_num, const int64_t* gpu_len, int slot_num,
...@@ -177,6 +265,45 @@ void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place, ...@@ -177,6 +265,45 @@ void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place,
cudaStreamSynchronize(stream); cudaStreamSynchronize(stream);
} }
void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place,
const std::vector<const float*>& grad_values,
FeaturePushValue* total_grad_values_gpu,
const std::vector<int64_t>& slot_lengths,
const uint64_t total_length,
const int batch_size, size_t grad_value_size) {
auto stream = dynamic_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(place))
->stream();
auto slot_lengths_lod = slot_lengths;
for (int i = 1; i < slot_lengths_lod.size(); i++) {
slot_lengths_lod[i] += slot_lengths_lod[i - 1];
}
auto buf_grad_value =
memory::Alloc(place, grad_values.size() * sizeof(float*));
auto buf_length = memory::Alloc(place, slot_lengths.size() * sizeof(int64_t));
auto buf_slot_vector =
memory::Alloc(place, slot_lengths_lod.size() * sizeof(int));
auto buf_mf_dim_vector =
memory::Alloc(place, slot_lengths_lod.size() * sizeof(int));
float** gpu_values = reinterpret_cast<float**>(buf_grad_value->ptr());
int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length->ptr());
int* d_slot_vector = reinterpret_cast<int*>(buf_slot_vector->ptr());
int* d_mf_dim_vector = reinterpret_cast<int*>(buf_mf_dim_vector->ptr());
cudaMemcpy(gpu_values, grad_values.data(),
grad_values.size() * sizeof(float*), cudaMemcpyHostToDevice);
cudaMemcpy(gpu_len, slot_lengths_lod.data(),
slot_lengths.size() * sizeof(int64_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_slot_vector, slot_vector_.data(),
slot_lengths_lod.size() * sizeof(int), cudaMemcpyHostToDevice);
cudaMemcpy(d_mf_dim_vector, slot_mf_dim_vector_.data(),
slot_lengths_lod.size() * sizeof(int), cudaMemcpyHostToDevice);
PushCopyWithPool<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>(
total_grad_values_gpu, gpu_values, gpu_len, slot_lengths.size(),
total_length, batch_size, d_slot_vector, d_mf_dim_vector,
grad_value_size);
cudaStreamSynchronize(stream);
}
void PSGPUWrapper::SetSparseSGD(float nonclk_coeff, float clk_coeff, void PSGPUWrapper::SetSparseSGD(float nonclk_coeff, float clk_coeff,
float min_bound, float max_bound, float min_bound, float max_bound,
float learning_rate, float initial_g2sum, float learning_rate, float initial_g2sum,
......
...@@ -27,6 +27,7 @@ limitations under the License. */ ...@@ -27,6 +27,7 @@ limitations under the License. */
#include <vector> #include <vector>
#ifdef PADDLE_WITH_GLOO #ifdef PADDLE_WITH_GLOO
#include <gloo/broadcast.h> #include <gloo/broadcast.h>
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/fleet/gloo_wrapper.h" #include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif #endif
#include "paddle/fluid/distributed/ps/thirdparty/round_robin.h" #include "paddle/fluid/distributed/ps/thirdparty/round_robin.h"
...@@ -54,6 +55,9 @@ limitations under the License. */ ...@@ -54,6 +55,9 @@ limitations under the License. */
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
#include "afs_api.h" #include "afs_api.h"
#endif #endif
#ifdef PADDLE_WITH_PSLIB
#include "downpour_accessor.h" // NOLINT
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -95,12 +99,21 @@ class PSGPUWrapper { ...@@ -95,12 +99,21 @@ class PSGPUWrapper {
PSGPUWrapper() { PSGPUWrapper() {
HeterPs_ = NULL; HeterPs_ = NULL;
sleep_seconds_before_fail_exit_ = 300; sleep_seconds_before_fail_exit_ = 300;
pull_thread_pool_.resize(thread_keys_shard_num_);
for (size_t i = 0; i < pull_thread_pool_.size(); i++) {
pull_thread_pool_[i].reset(new ::ThreadPool(1));
}
hbm_thread_pool_.resize(thread_keys_shard_num_); hbm_thread_pool_.resize(thread_keys_shard_num_);
for (size_t i = 0; i < hbm_thread_pool_.size(); i++) { for (size_t i = 0; i < hbm_thread_pool_.size(); i++) {
hbm_thread_pool_[i].reset(new ::ThreadPool(1)); hbm_thread_pool_[i].reset(new ::ThreadPool(1));
} }
} }
void PullSparse(const paddle::platform::Place& place, const int table_id,
const std::vector<const uint64_t*>& keys,
const std::vector<float*>& values,
const std::vector<int64_t>& slot_lengths,
const std::vector<int>& slot_dim, const int hidden_size);
void PullSparse(const paddle::platform::Place& place, const int table_id, void PullSparse(const paddle::platform::Place& place, const int table_id,
const std::vector<const uint64_t*>& keys, const std::vector<const uint64_t*>& keys,
const std::vector<float*>& values, const std::vector<float*>& values,
...@@ -119,13 +132,23 @@ class PSGPUWrapper { ...@@ -119,13 +132,23 @@ class PSGPUWrapper {
const FeatureValue* total_values_gpu, const int64_t* gpu_len, const FeatureValue* total_values_gpu, const int64_t* gpu_len,
const int slot_num, const int hidden_size, const int slot_num, const int hidden_size,
const int64_t total_length); const int64_t total_length);
void CopyForPull(const paddle::platform::Place& place, uint64_t** gpu_keys,
const std::vector<float*>& values,
const FeatureValue* total_values_gpu, const int64_t* gpu_len,
const int slot_num, const int hidden_size,
const int64_t total_length, int* gpu_dim);
void CopyForPush(const paddle::platform::Place& place, void CopyForPush(const paddle::platform::Place& place,
const std::vector<const float*>& grad_values, const std::vector<const float*>& grad_values,
FeaturePushValue* total_grad_values_gpu, FeaturePushValue* total_grad_values_gpu,
const std::vector<int64_t>& slot_lengths, const std::vector<int64_t>& slot_lengths,
const int hidden_size, const int64_t total_length, const int hidden_size, const int64_t total_length,
const int batch_size); const int batch_size);
void CopyForPush(const paddle::platform::Place& place,
const std::vector<const float*>& grad_values,
FeaturePushValue* total_grad_values_gpu,
const std::vector<int64_t>& slot_lengths,
const uint64_t total_length, const int batch_size,
size_t grad_value_size);
void BuildGPUTask(std::shared_ptr<HeterContext> gpu_task); void BuildGPUTask(std::shared_ptr<HeterContext> gpu_task);
void PreBuildTask(std::shared_ptr<HeterContext> gpu_task); void PreBuildTask(std::shared_ptr<HeterContext> gpu_task);
...@@ -428,6 +451,7 @@ class PSGPUWrapper { ...@@ -428,6 +451,7 @@ class PSGPUWrapper {
std::shared_ptr<HeterContext> current_task_ = nullptr; std::shared_ptr<HeterContext> current_task_ = nullptr;
std::thread pre_build_threads_; std::thread pre_build_threads_;
bool running_ = false; bool running_ = false;
std::vector<std::shared_ptr<ThreadPool>> pull_thread_pool_;
std::vector<std::shared_ptr<ThreadPool>> hbm_thread_pool_; std::vector<std::shared_ptr<ThreadPool>> hbm_thread_pool_;
protected: protected:
......
...@@ -26,6 +26,7 @@ template <typename T> ...@@ -26,6 +26,7 @@ template <typename T>
static void PullGpuPSSparseFunctor(const framework::ExecutionContext &ctx) { static void PullGpuPSSparseFunctor(const framework::ExecutionContext &ctx) {
auto inputs = ctx.MultiInput<framework::Tensor>("Ids"); auto inputs = ctx.MultiInput<framework::Tensor>("Ids");
auto outputs = ctx.MultiOutput<framework::Tensor>("Out"); auto outputs = ctx.MultiOutput<framework::Tensor>("Out");
auto embedding_size_vec = ctx.Attr<std::vector<int>>("size");
const auto slot_size = inputs.size(); const auto slot_size = inputs.size();
std::vector<const uint64_t *> all_keys(slot_size); std::vector<const uint64_t *> all_keys(slot_size);
// GpuPSPS only supports float now // GpuPSPS only supports float now
...@@ -44,7 +45,7 @@ static void PullGpuPSSparseFunctor(const framework::ExecutionContext &ctx) { ...@@ -44,7 +45,7 @@ static void PullGpuPSSparseFunctor(const framework::ExecutionContext &ctx) {
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
auto gpu_ps_ptr = paddle::framework::PSGPUWrapper::GetInstance(); auto gpu_ps_ptr = paddle::framework::PSGPUWrapper::GetInstance();
gpu_ps_ptr->PullSparse(ctx.GetPlace(), 0, all_keys, all_values, slot_lengths, gpu_ps_ptr->PullSparse(ctx.GetPlace(), 0, all_keys, all_values, slot_lengths,
0); embedding_size_vec, 0);
#endif #endif
} }
......
...@@ -737,7 +737,7 @@ def _pull_gpups_sparse(input, ...@@ -737,7 +737,7 @@ def _pull_gpups_sparse(input,
for i in range(len(inputs)) for i in range(len(inputs))
] ]
w = helper.create_parameter( w = helper.create_parameter(
attr=helper.param_attr, shape=[11], dtype=dtype, is_bias=False) attr=helper.param_attr, shape=[size[0]], dtype=dtype, is_bias=False)
helper.append_op( helper.append_op(
type='pull_gpups_sparse', type='pull_gpups_sparse',
inputs={'Ids': inputs, inputs={'Ids': inputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册