未验证 提交 b8d106e1 编写于 作者: D danleifeng 提交者: GitHub

【GPUPS】Adam accessor (#43919)

* add adam/sharedadam optimzier for gpups;edit optimizer struct;test=develop
上级 1882ffd5
......@@ -31,6 +31,7 @@ int CtrDymfAccessor::Initialize() {
_embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embedx_sgd_rule->LoadConfig(_config.embedx_sgd_param(),
_config.embedx_dim());
common_feature_value.optimizer_name = name;
common_feature_value.embed_sgd_dim = _embed_sgd_rule->Dim();
common_feature_value.embedx_dim = _config.embedx_dim();
......@@ -42,7 +43,10 @@ int CtrDymfAccessor::Initialize() {
if (_config.ctr_accessor_param().show_scale()) {
_show_scale = true;
}
VLOG(0) << " INTO CtrDymfAccessor::Initialize()";
VLOG(0) << " INTO CtrDymfAccessor::Initialize(); embed_sgd_dim:"
<< common_feature_value.embed_sgd_dim
<< " embedx_dim:" << common_feature_value.embedx_dim
<< " embedx_sgd_dim:" << common_feature_value.embedx_sgd_dim;
InitAccessorInfo();
return 0;
}
......@@ -53,9 +57,9 @@ void CtrDymfAccessor::InitAccessorInfo() {
auto embedx_dim = _config.embedx_dim();
VLOG(0) << "InitAccessorInfo embedx_dim:" << embedx_dim;
_accessor_info.select_dim = 3 + embedx_dim;
_accessor_info.select_dim = 4 + embedx_dim;
_accessor_info.select_size = _accessor_info.select_dim * sizeof(float);
_accessor_info.update_dim = 4 + embedx_dim;
_accessor_info.update_dim = 5 + embedx_dim;
_accessor_info.update_size = _accessor_info.update_dim * sizeof(float);
_accessor_info.mf_size =
(embedx_dim + common_feature_value.embedx_sgd_dim) * sizeof(float);
......@@ -179,8 +183,10 @@ int32_t CtrDymfAccessor::Create(float** values, size_t num) {
value[common_feature_value.ClickIndex()] = 0;
value[common_feature_value.SlotIndex()] = -1;
value[common_feature_value.MfDimIndex()] = -1;
_embed_sgd_rule->InitValue(value + common_feature_value.EmbedWIndex(),
value + common_feature_value.EmbedG2SumIndex());
_embed_sgd_rule->InitValue(
value + common_feature_value.EmbedWIndex(),
value + common_feature_value.EmbedG2SumIndex(),
false); // adam embed init not zero, adagrad embed init zero
_embedx_sgd_rule->InitValue(value + common_feature_value.EmbedxWIndex(),
value + common_feature_value.EmbedxG2SumIndex(),
false);
......@@ -293,22 +299,14 @@ std::string CtrDymfAccessor::ParseToString(const float* v, int param) {
i++) {
os << " " << v[i];
}
// os << " " << common_feature_value.Slot(const_cast<float*>(v)) << " "
// << common_feature_value.MfDim(const_cast<float*>(v));
auto show = common_feature_value.Show(const_cast<float*>(v));
auto click = common_feature_value.Click(const_cast<float*>(v));
auto score = ShowClickScore(show, click);
auto mf_dim = int(common_feature_value.MfDim(const_cast<float*>(v)));
if (score >= _config.embedx_threshold() &&
param > common_feature_value.EmbedxG2SumIndex()) {
// VLOG(1) << "common_feature_value.EmbedxG2SumIndex():"
// << common_feature_value.EmbedxG2SumIndex();
// VLOG(1) << "common_feature_value.EmbedxWIndex():"
// << common_feature_value.EmbedxWIndex();
// VLOG(1) << "common_feature_value.MfDim():"
// << common_feature_value.MfDim(const_cast<float*>(v));
for (auto i = common_feature_value.EmbedxG2SumIndex();
i < common_feature_value.EmbedxWIndex() +
common_feature_value.MfDim(const_cast<float*>(v));
i < common_feature_value.Dim(mf_dim);
++i) {
os << " " << v[i];
}
......
......@@ -54,10 +54,24 @@ class CtrDymfAccessor : public ValueAccessor {
int ClickIndex() { return ShowIndex() + 1; }
int EmbedWIndex() { return ClickIndex() + 1; }
int EmbedG2SumIndex() { return EmbedWIndex() + 1; }
int SlotIndex() { return EmbedG2SumIndex() + 1; }
int SlotIndex() { return EmbedG2SumIndex() + embed_sgd_dim; }
int MfDimIndex() { return SlotIndex() + 1; }
int EmbedxG2SumIndex() { return MfDimIndex() + 1; }
int EmbedxWIndex() { return EmbedxG2SumIndex() + 1; }
int EmbedxWIndex() { return EmbedxG2SumIndex() + embedx_sgd_dim; }
// 根据mf_dim计算的总长度
int Dim(int& mf_dim) {
int tmp_embedx_sgd_dim = 1;
if (optimizer_name == "SparseAdamSGDRule") { // adam
tmp_embedx_sgd_dim = mf_dim * 2 + 2;
} else if (optimizer_name == "SparseSharedAdamSGDRule") { // shared_adam
tmp_embedx_sgd_dim = 4;
}
return 7 + embed_sgd_dim + tmp_embedx_sgd_dim + mf_dim;
}
// 根据mf_dim计算的总byte数
int Size(int& mf_dim) { return (Dim(mf_dim)) * sizeof(float); }
float& UnseenDays(float* val) { return val[UnseenDaysIndex()]; }
float& DeltaScore(float* val) { return val[DeltaScoreIndex()]; }
......@@ -73,6 +87,7 @@ class CtrDymfAccessor : public ValueAccessor {
int embed_sgd_dim;
int embedx_dim;
int embedx_sgd_dim;
std::string optimizer_name;
};
struct CtrDymfPushValue {
......
......@@ -213,7 +213,6 @@ void SparseAdamSGDRule::UpdateValueWork(float* w,
float beta1_pow_ = *beta1_pow;
float beta2_pow_ = *beta2_pow;
// lr not change in one update
lr *= sqrt(1 - beta2_pow_) / (1 - beta1_pow_);
for (size_t i = 0; i < _embedding_dim; i++) {
// Calculation
......@@ -252,5 +251,88 @@ void SparseAdamSGDRule::InitValueWork(float* value,
*(sgd + Beta1PowIndex()) = _beta1_decay_rate;
*(sgd + Beta2PowIndex()) = _beta2_decay_rate;
}
void SparseSharedAdamSGDRule::LoadConfig(
const SparseCommonSGDRuleParameter& param, size_t emb_dim) {
_embedding_dim = emb_dim;
auto adam_param = param.adam();
learning_rate_ = adam_param.learning_rate();
_initial_range = adam_param.initial_range();
_beta1_decay_rate = adam_param.beta1_decay_rate();
_beta2_decay_rate = adam_param.beta2_decay_rate();
_ada_epsilon = adam_param.ada_epsilon();
if (adam_param.weight_bounds_size() == 0) {
_min_bound = -std::numeric_limits<float>::max();
_max_bound = std::numeric_limits<float>::max();
} else {
CHECK(adam_param.weight_bounds_size() >= 2)
<< "invalid repeated size for weight_bounds:"
<< adam_param.weight_bounds_size();
_min_bound = adam_param.weight_bounds(0);
_max_bound = adam_param.weight_bounds(1);
}
}
void SparseSharedAdamSGDRule::UpdateValueWork(float* w,
float* sgd,
const float* grad,
float scale) {
float* gsum = sgd + GSumIndex();
float* g2sum = sgd + G2SumIndex();
float* beta1_pow = sgd + Beta1PowIndex();
float* beta2_pow = sgd + Beta2PowIndex();
const float* g = grad;
float lr = learning_rate_;
float beta1_pow_ = *beta1_pow;
float beta2_pow_ = *beta2_pow;
float gsum_ = *gsum;
float g2sum_ = *g2sum;
lr *= sqrt(1 - beta2_pow_) / (1 - beta1_pow_);
double sum_gsum = 0.0;
double sum_g2sum = 0.0;
for (int i = 0; i < _embedding_dim; i++) {
// Calculation
double new_gsum =
_beta1_decay_rate * gsum_ + (1 - _beta1_decay_rate) * g[i];
double new_g2sum =
_beta2_decay_rate * g2sum_ + (1 - _beta2_decay_rate) * g[i] * g[i];
w[i] = w[i] - lr * (new_gsum / (sqrt(new_g2sum) + _ada_epsilon));
BoundValue(w[i]);
sum_gsum += new_gsum;
sum_g2sum += new_g2sum;
}
// update beta_pow_decay
(*gsum) = sum_gsum / _embedding_dim;
(*g2sum) = sum_g2sum / _embedding_dim;
(*beta1_pow) *= _beta1_decay_rate;
(*beta2_pow) *= _beta2_decay_rate;
}
void SparseSharedAdamSGDRule::InitValueWork(float* value,
float* sgd,
bool zero_init) {
for (int i = 0; i < _embedding_dim; ++i) {
if (zero_init) {
value[i] = 0.0;
BoundValue(value[i]);
} else {
value[i] =
(local_uniform_real_distribution<double>()(local_random_engine()) *
2 -
1) *
_initial_range;
BoundValue(value[i]);
}
}
// init rule gsum and g2sum
for (int i = GSumIndex(); i < Beta1PowIndex(); i++) {
sgd[i] = 0.0;
}
// init beta1_pow and beta2_pow
*(sgd + Beta1PowIndex()) = _beta1_decay_rate;
*(sgd + Beta2PowIndex()) = _beta2_decay_rate;
}
} // namespace distributed
} // namespace paddle
......@@ -144,5 +144,28 @@ class SparseAdamSGDRule : public SparseValueSGDRule {
float _beta2_decay_rate;
float _ada_epsilon;
};
class SparseSharedAdamSGDRule : public SparseValueSGDRule {
public:
virtual void LoadConfig(const SparseCommonSGDRuleParameter& param,
size_t emb_dim);
virtual void UpdateValueWork(float* w,
float* sgd,
const float* push_value,
float scale);
virtual void InitValueWork(float* value, float* sgd, bool zero_init);
virtual size_t Dim() { return 4; }
size_t GSumIndex() { return 0; }
size_t G2SumIndex() { return GSumIndex() + 1; }
size_t Beta1PowIndex() { return G2SumIndex() + 1; }
size_t Beta2PowIndex() { return Beta1PowIndex() + 1; }
protected:
float learning_rate_;
float _beta1_decay_rate;
float _beta2_decay_rate;
float _ada_epsilon;
};
} // namespace distributed
} // namespace paddle
......@@ -49,6 +49,7 @@ REGISTER_PSCORE_CLASS(SparseValueSGDRule, StdAdaGradSGDRule);
REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseAdamSGDRule);
REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseNaiveSGDRule);
REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseAdaGradSGDRule);
REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseSharedAdamSGDRule);
int32_t TableManager::Initialize() {
static bool initialized = false;
......
......@@ -13,6 +13,7 @@ cc_library(
op_registry
fs
shell
ps_gpu_wrapper
${RPC_DEPS})
target_link_libraries(fleet z)
......@@ -18,6 +18,10 @@ limitations under the License. */
#include "paddle/fluid/distributed/ps/service/communicator/communicator.h"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/distributed/ps/wrapper/fleet.h"
#if defined PADDLE_WITH_HETERPS && defined PADDLE_WITH_PSCORE
#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h"
#endif
namespace paddle {
namespace distributed {
......@@ -129,6 +133,13 @@ void FleetWrapper::InitWorker(const std::string& dist_desc,
worker_ptr_ = std::shared_ptr<paddle::distributed::PSClient>(
paddle::distributed::PSClientFactory::Create(ps_param));
worker_ptr_->Configure(ps_param, dense_pull_regions, ps_env_, index);
#if defined PADDLE_WITH_HETERPS && defined PADDLE_WITH_PSCORE
VLOG(3) << "FleetWrapper::InitWorker InitializeGPUServer";
auto* accessor = worker_ptr_->GetTableAccessor(0);
auto ps_gpu_wrapper = paddle::framework::PSGPUWrapper::GetInstance();
ps_gpu_wrapper->InitializeGPUServer(ps_param);
ps_gpu_wrapper->SetTableAccessor(accessor);
#endif
}
} else {
VLOG(3) << "Client can be initialized only once";
......@@ -525,11 +536,11 @@ void FleetWrapper::PushSparseFromTensorAsync(
int batch_size = -1;
bool batch_size_consist = true;
for (auto* input : *inputs) {
int cur_batch_size =
size_t cur_batch_size =
input->lod().size() ? input->lod()[0].size() - 1 : input->dims()[0];
if (batch_size == -1) {
batch_size = cur_batch_size;
} else if (batch_size != cur_batch_size) {
batch_size = int(cur_batch_size);
} else if (batch_size != int(cur_batch_size)) {
// CHECK(batch_size == cur_batch_size); // NOLINT
batch_size_consist = false;
break;
......@@ -537,12 +548,12 @@ void FleetWrapper::PushSparseFromTensorAsync(
}
CHECK(batch_size > 0); // NOLINT
int show_size =
size_t show_size =
shows->lod().size() ? shows->lod()[0].size() - 1 : shows->dims()[0];
CHECK(show_size == batch_size || show_size == 1);
int clk_size =
CHECK(show_size == size_t(batch_size) || show_size == 1);
size_t clk_size =
clks->lod().size() ? clks->lod()[0].size() - 1 : clks->dims()[0];
CHECK(clk_size == batch_size || clk_size == 1);
CHECK(clk_size == size_t(batch_size) || clk_size == 1);
CHECK(outputs->size() == inputs->size());
std::vector<uint64_t> push_keys;
......@@ -601,12 +612,10 @@ void FleetWrapper::PushSparseFromTensorAsync(
// in
// ctr_accessor.h
push_values.back()[0] = 2; // TODO(zhaocaibei123): slot
push_values.back()[1] = (static_cast<int>(i) >= show_size
? 1
: static_cast<float>(show_tensor[i]));
push_values.back()[2] = (static_cast<int>(i) >= clk_size
? 0
: static_cast<float>(clk_tensor[i]));
push_values.back()[1] =
(i >= show_size ? 1 : static_cast<float>(show_tensor[i]));
push_values.back()[2] =
(i >= clk_size ? 0 : static_cast<float>(clk_tensor[i]));
float* data = push_values.back().data() + 3;
memcpy(data, g + output_len, sizeof(float) * fea_dim);
}
......@@ -630,12 +639,10 @@ void FleetWrapper::PushSparseFromTensorAsync(
// slot show clk grad... consistent with CtrCommonPushValue defined in
// ctr_accessor.h
push_values.back()[0] = 2; // TODO(zhaocaibei123): slot
push_values.back()[1] = (static_cast<int>(i) >= show_size
? 1
: static_cast<float>(show_tensor[i]));
push_values.back()[2] = (static_cast<int>(i) >= clk_size
? 0
: static_cast<float>(clk_tensor[i]));
push_values.back()[1] =
(i >= show_size ? 1 : static_cast<float>(show_tensor[i]));
push_values.back()[2] =
(i >= clk_size ? 0 : static_cast<float>(clk_tensor[i]));
float* data = push_values.back().data() + 3;
memcpy(data, g + output_len, sizeof(float) * fea_dim);
}
......
......@@ -197,14 +197,14 @@ message TableParameter {
message TableAccessorParameter {
optional string accessor_class = 1;
optional SGDParameter embed_sgd_param = 2;
optional SGDParameter embedx_sgd_param = 3;
optional uint32 fea_dim = 4 [ default = 11 ]; // field size of one value
optional uint32 embedx_dim = 5 [ default = 8 ]; // embedx feature size
optional uint32 embedx_threshold = 6
[ default = 10 ]; // embedx feature create threshold
optional CtrAccessorParameter ctr_accessor_param = 7;
repeated TableAccessorSaveParameter table_accessor_save_param = 8;
optional SGDParameter embed_sgd_param = 10;
optional SGDParameter embedx_sgd_param = 11;
}
message SGDParameter {
......@@ -228,7 +228,7 @@ message
repeated float weight_bounds = 4;
}
message SparseAdamSGDParameter { // SparseAdamSGDRule
message SparseAdamSGDParameter { // SparseAdamSGDRule | SparseSharedAdamSGDRule
optional double learning_rate = 1 [ default = 0.001 ];
optional double initial_range = 2 [ default = 0.0001 ];
optional double beta1_decay_rate = 3 [ default = 0.9 ];
......
......@@ -25,10 +25,17 @@ endif()
if(WITH_HETERPS)
if(WITH_NCCL AND WITH_GPU)
nv_library(
ps_gpu_wrapper
SRCS ps_gpu_wrapper.cu ps_gpu_wrapper.cc
DEPS heter_ps gloo_wrapper ${BRPC_DEPS})
if(WITH_PSCORE)
nv_library(
ps_gpu_wrapper
SRCS ps_gpu_wrapper.cu ps_gpu_wrapper.cc
DEPS heter_ps gloo_wrapper ps_framework_proto ${BRPC_DEPS})
else()
nv_library(
ps_gpu_wrapper
SRCS ps_gpu_wrapper.cu ps_gpu_wrapper.cc
DEPS heter_ps gloo_wrapper ${BRPC_DEPS})
endif()
add_subdirectory(heter_ps)
elseif(WITH_XPU_KP)
xpu_library(
......
......@@ -81,7 +81,6 @@ class HeterContext {
std::vector<std::vector<FeatureValue>> device_values_;
std::vector<std::vector<FeatureKey>> device_keys_;
std::vector<std::vector<std::vector<FeatureKey>>> device_dim_keys_;
std::vector<std::vector<std::vector<FeatureValue>>> device_dim_values_;
std::vector<std::mutex*> mutex_;
std::vector<std::vector<std::mutex*>> dim_mutex_;
int multi_mf_dim_ = 0;
......@@ -114,7 +113,6 @@ class HeterContext {
value_dim_ptr_[i].resize(dim_num);
}
device_values_.resize(device_num);
device_dim_values_.resize(device_num);
device_keys_.resize(device_num);
device_dim_keys_.resize(device_num);
......
......@@ -9,16 +9,16 @@ if(WITH_GPU)
endif()
nv_library(
heter_comm_kernel
SRCS heter_comm_kernel.cu feature_value.h
SRCS heter_comm_kernel.cu feature_value.h feature_value.cu
DEPS ${HETERPS_DEPS})
nv_library(
hashtable_kernel
SRCS hashtable_kernel.cu feature_value.h
SRCS hashtable_kernel.cu feature_value.h feature_value.cu
DEPS ${HETERPS_DEPS})
nv_library(
heter_comm
SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h
mem_pool.h
SRCS heter_comm.h feature_value.h feature_value.cu heter_resource.cc
heter_resource.h mem_pool.h
DEPS ${HETERPS_DEPS} heter_comm_kernel hashtable_kernel)
nv_test(
test_heter_comm
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
namespace paddle {
namespace framework {
template <typename FVAccessor>
__global__ void PullCopy(float** dest,
const float* src,
const int64_t* len,
int slot_num,
int total_len,
uint64_t** keys,
uint64_t max_val_size,
int* gpu_dim,
FVAccessor feature_value_accessor) {
CUDA_KERNEL_LOOP(i, total_len) {
int low = 0;
int high = slot_num - 1;
while (low < high) {
int mid = (low + high) / 2;
if (i < len[mid])
high = mid;
else
low = mid + 1;
}
int x = low;
int y = i - (x ? len[x - 1] : 0);
float* feature_value_ptr =
(float*)((char*)src + uint64_t(i) * uint64_t(max_val_size));
int mf_dim = gpu_dim[x] - 3;
feature_value_accessor.Select(
dest[x] + y * (mf_dim + 3), feature_value_ptr, keys[x] + y, mf_dim);
}
}
template <typename FVAccessor>
__global__ void PushCopyWithPool(float* dest,
float** src,
int64_t* len,
int slot_num,
uint64_t total_len,
int bs,
int* slot_vector,
int* mf_dim_vector,
size_t grad_value_size,
FVAccessor feature_value_accessor) {
CUDA_KERNEL_LOOP(i, total_len) {
int low = 0;
int high = slot_num - 1;
while (low < high) {
int mid = (low + high) / 2;
if (i < len[mid])
high = mid;
else
low = mid + 1;
}
int x = low;
int y = i - (x ? len[low - 1] : 0);
float* cur = (float*)((char*)dest + i * grad_value_size);
cur[feature_value_accessor.common_push_value.SlotIndex()] =
(float)slot_vector[x];
int mf_dim = mf_dim_vector[x];
cur[feature_value_accessor.common_push_value.MfDimIndex()] = mf_dim;
cur[feature_value_accessor.common_push_value.ShowIndex()] =
*(src[x] + y * (mf_dim + 3));
cur[feature_value_accessor.common_push_value.ClickIndex()] =
*(src[x] + y * (mf_dim + 3) + 1);
cur[feature_value_accessor.common_push_value.EmbedGIndex()] =
*(src[x] + y * (mf_dim + 3) + 2) * -1. * bs;
for (int j = 0; j < mf_dim; j++) {
cur[feature_value_accessor.common_push_value.EmbedxGIndex() + j] =
*(src[x] + y * (mf_dim + 3) + 3 + j) * -1. * bs;
}
}
}
template <typename GPUAccessor>
void AccessorWrapper<GPUAccessor>::CopyForPullImpl(
const paddle::platform::Place& place,
uint64_t** gpu_keys,
const std::vector<float*>& values,
const float* total_values_gpu,
const int64_t* gpu_len,
const int slot_num,
const int hidden_size,
const int64_t total_length,
int* gpu_dim,
int feature_value_size) {
auto stream = dynamic_cast<paddle::platform::CUDADeviceContext*>(
paddle::platform::DeviceContextPool::Instance().Get(place))
->stream();
auto buf_value = memory::Alloc(place, values.size() * sizeof(float*));
float** gpu_values = reinterpret_cast<float**>(buf_value->ptr());
cudaMemcpy(gpu_values,
values.data(),
values.size() * sizeof(float*),
cudaMemcpyHostToDevice);
PullCopy<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>(
gpu_values,
total_values_gpu,
gpu_len,
slot_num,
total_length,
gpu_keys,
feature_value_size,
gpu_dim,
gpu_accessor_);
cudaStreamSynchronize(stream);
}
template <typename GPUAccessor>
void AccessorWrapper<GPUAccessor>::CopyForPushImpl(
const paddle::platform::Place& place,
const std::vector<const float*>& grad_values,
float* total_grad_values_gpu,
const std::vector<int64_t>& slot_lengths,
const uint64_t total_length,
const int batch_size,
size_t grad_value_size,
std::vector<int>& slot_vector,
std::vector<int>& slot_mf_dim_vector) {
auto stream = dynamic_cast<paddle::platform::CUDADeviceContext*>(
paddle::platform::DeviceContextPool::Instance().Get(place))
->stream();
auto slot_lengths_lod = slot_lengths;
for (int i = 1; i < slot_lengths_lod.size(); i++) {
slot_lengths_lod[i] += slot_lengths_lod[i - 1];
}
auto buf_grad_value =
memory::Alloc(place, grad_values.size() * sizeof(float*));
auto buf_length = memory::Alloc(place, slot_lengths.size() * sizeof(int64_t));
auto buf_slot_vector =
memory::Alloc(place, slot_lengths_lod.size() * sizeof(int));
auto buf_mf_dim_vector =
memory::Alloc(place, slot_lengths_lod.size() * sizeof(int));
float** gpu_values = reinterpret_cast<float**>(buf_grad_value->ptr());
int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length->ptr());
int* d_slot_vector = reinterpret_cast<int*>(buf_slot_vector->ptr());
int* d_mf_dim_vector = reinterpret_cast<int*>(buf_mf_dim_vector->ptr());
cudaMemcpy(gpu_values,
grad_values.data(),
grad_values.size() * sizeof(float*),
cudaMemcpyHostToDevice);
cudaMemcpy(gpu_len,
slot_lengths_lod.data(),
slot_lengths.size() * sizeof(int64_t),
cudaMemcpyHostToDevice);
cudaMemcpy(d_slot_vector,
slot_vector.data(),
slot_lengths_lod.size() * sizeof(int),
cudaMemcpyHostToDevice);
cudaMemcpy(d_mf_dim_vector,
slot_mf_dim_vector.data(),
slot_lengths_lod.size() * sizeof(int),
cudaMemcpyHostToDevice);
PushCopyWithPool<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>(
total_grad_values_gpu,
gpu_values,
gpu_len,
slot_lengths.size(),
total_length,
batch_size,
d_slot_vector,
d_mf_dim_vector,
grad_value_size,
gpu_accessor_);
cudaStreamSynchronize(stream);
}
#ifdef PADDLE_WITH_PSCORE
template class AccessorWrapper<CommonFeatureValueAccessor>;
#endif
} // namespace framework
} // namespace paddle
#endif
......@@ -25,10 +25,12 @@
#ifdef PADDLE_WITH_HETERPS
namespace paddle {
namespace framework {
class GpuPsGraphTable : public HeterComm<uint64_t, int64_t, int> {
class GpuPsGraphTable
: public HeterComm<uint64_t, int64_t, int, CommonFeatureValueAccessor> {
public:
GpuPsGraphTable(std::shared_ptr<HeterPsResource> resource, int topo_aware)
: HeterComm<uint64_t, int64_t, int>(1, resource) {
: HeterComm<uint64_t, int64_t, int, CommonFeatureValueAccessor>(
1, resource) {
load_factor_ = 0.25;
rw_lock.reset(new pthread_rwlock_t());
gpu_num = resource_->total_device();
......
......@@ -137,8 +137,12 @@ class HashTable {
size_t len,
StreamType stream);
template <typename StreamType>
void get(const KeyType* d_keys, char* d_vals, size_t len, StreamType stream);
template <typename StreamType, typename FVAccessor>
void get(const KeyType* d_keys,
char* d_vals,
size_t len,
StreamType stream,
FVAccessor& fv_accessor);
void show();
......@@ -150,9 +154,9 @@ class HashTable {
#if defined(PADDLE_WITH_CUDA)
template <typename GradType, typename Sgd, typename StreamType>
template <typename Sgd, typename StreamType>
void update(const KeyType* d_keys,
const GradType* d_grads,
const float* d_grads,
size_t len,
Sgd sgd,
StreamType stream);
......
......@@ -83,36 +83,25 @@ __global__ void search_kernel(Table* table,
}
}
template <typename Table>
template <typename Table, typename FVAccessor>
__global__ void dy_mf_search_kernel(Table* table,
const typename Table::key_type* const keys,
char* vals,
size_t len,
size_t pull_feature_value_size) {
size_t pull_feature_value_size,
FVAccessor feature_value_accessor) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
auto it = table->find(keys[i]);
if (it != table->end()) {
uint64_t offset = i * pull_feature_value_size;
FeatureValue* cur = (FeatureValue*)(vals + offset);
FeatureValue& input = *(FeatureValue*)(it->second);
cur->slot = input.slot;
cur->show = input.show;
cur->clk = input.clk;
cur->mf_dim = input.mf_dim;
cur->lr = input.lr;
cur->mf_size = input.mf_size;
cur->cpu_ptr = input.cpu_ptr;
cur->delta_score = input.delta_score;
cur->lr_g2sum = input.lr_g2sum;
for (int j = 0; j < cur->mf_dim + 1; ++j) {
cur->mf[j] = input.mf[j];
}
} else {
if (keys[i] != 0) {
printf("warning::pull miss key: %llu", keys[i]);
}
float* cur = (float*)(vals + offset);
float* input = it->second;
int mf_dim =
int(input[feature_value_accessor.common_feature_value.MfDimIndex()]);
feature_value_accessor.FeatureValueFill(cur, input, mf_dim);
}
}
}
......@@ -145,8 +134,8 @@ __global__ void dy_mf_update_kernel(Table* table,
if (i < len) {
auto it = table->find(keys[i]);
if (it != table->end()) {
FeaturePushValue* cur = (FeaturePushValue*)(grads + i * grad_value_size);
sgd.dy_mf_update_value(optimizer_config, (it.getter())->second, *cur);
float* cur = (float*)(grads + i * grad_value_size);
sgd.dy_mf_update_value(optimizer_config, (it.getter())->second, cur);
} else {
if (keys[i] != 0) {
printf("warning::push miss key: %llu", keys[i]);
......@@ -212,17 +201,18 @@ void HashTable<KeyType, ValType>::get(const KeyType* d_keys,
}
template <typename KeyType, typename ValType>
template <typename StreamType>
template <typename StreamType, typename FVAccessor>
void HashTable<KeyType, ValType>::get(const KeyType* d_keys,
char* d_vals,
size_t len,
StreamType stream) {
StreamType stream,
FVAccessor& fv_accessor) {
if (len == 0) {
return;
}
const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
dy_mf_search_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
container_, d_keys, d_vals, len, pull_feature_value_size_);
container_, d_keys, d_vals, len, pull_feature_value_size_, fv_accessor);
}
template <typename KeyType, typename ValType>
......@@ -298,27 +288,6 @@ void HashTable<KeyType, ValType>::dump_to_cpu(int devid, StreamType stream) {
cpu_val[x + 7] = gpu_val.mf[x];
}
}
#endif
#ifdef PADDLE_WITH_PSCORE
auto* downpour_value =
(paddle::distributed::FixedFeatureValue*)(gpu_val.cpu_ptr);
int downpour_value_size = downpour_value->size();
if (gpu_val.mf_size > 0 && downpour_value_size == 7) {
downpour_value->resize(gpu_val.mf_size + downpour_value_size);
}
float* cpu_val = downpour_value->data();
// cpu_val[0] = 0;
cpu_val[2] = gpu_val.delta_score;
cpu_val[3] = gpu_val.show;
cpu_val[4] = gpu_val.clk;
cpu_val[5] = gpu_val.lr;
cpu_val[6] = gpu_val.lr_g2sum;
cpu_val[0] = gpu_val.slot;
if (gpu_val.mf_size > 0) {
for (int x = 0; x < gpu_val.mf_size; x++) {
cpu_val[x + 7] = gpu_val.mf[x];
}
}
#endif
}
};
......@@ -336,9 +305,9 @@ void HashTable<KeyType, ValType>::dump_to_cpu(int devid, StreamType stream) {
}
template <typename KeyType, typename ValType>
template <typename GradType, typename Sgd, typename StreamType>
template <typename Sgd, typename StreamType>
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
const GradType* d_grads,
const float* d_grads,
size_t len,
Sgd sgd,
StreamType stream) {
......@@ -371,8 +340,8 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
push_grad_value_size_);
}
template class HashTable<unsigned long, paddle::framework::FeatureValue>;
template class HashTable<unsigned long, paddle::framework::FeatureValue*>;
template class HashTable<unsigned long, float>;
template class HashTable<unsigned long, float*>;
template class HashTable<long, int>;
template class HashTable<unsigned long, int>;
template class HashTable<unsigned long, unsigned long>;
......@@ -382,15 +351,19 @@ template class HashTable<long, long>;
template class HashTable<long, unsigned long>;
template class HashTable<long, unsigned int>;
template void HashTable<unsigned long, paddle::framework::FeatureValue>::get<
cudaStream_t>(const unsigned long* d_keys,
paddle::framework::FeatureValue* d_vals,
size_t len,
cudaStream_t stream);
template void HashTable<unsigned long, float>::get<cudaStream_t>(
const unsigned long* d_keys,
float* d_vals,
size_t len,
cudaStream_t stream);
template void
HashTable<unsigned long, paddle::framework::FeatureValue*>::get<cudaStream_t>(
const unsigned long* d_keys, char* d_vals, size_t len, cudaStream_t stream);
HashTable<unsigned long, float*>::get<cudaStream_t, CommonFeatureValueAccessor>(
const unsigned long* d_keys,
char* d_vals,
size_t len,
cudaStream_t stream,
CommonFeatureValueAccessor& fv_accessor);
template void HashTable<long, int>::get<cudaStream_t>(const long* d_keys,
int* d_vals,
......@@ -399,6 +372,12 @@ template void HashTable<long, int>::get<cudaStream_t>(const long* d_keys,
template void HashTable<unsigned long, int>::get<cudaStream_t>(
const unsigned long* d_keys, int* d_vals, size_t len, cudaStream_t stream);
template void HashTable<unsigned long, unsigned long>::get<cudaStream_t>(
const unsigned long* d_keys,
unsigned long* d_vals,
size_t len,
cudaStream_t stream);
template void HashTable<long, unsigned long>::get<cudaStream_t>(
const long* d_keys, unsigned long* d_vals, size_t len, cudaStream_t stream);
template void HashTable<long, long>::get<cudaStream_t>(const long* d_keys,
......@@ -414,19 +393,19 @@ template void HashTable<unsigned long, long>::get<cudaStream_t>(
// const unsigned long* d_keys, char* d_vals, size_t len, cudaStream_t
// stream);
template void HashTable<unsigned long, paddle::framework::FeatureValue>::insert<
cudaStream_t>(const unsigned long* d_keys,
const paddle::framework::FeatureValue* d_vals,
size_t len,
cudaStream_t stream);
template void HashTable<unsigned long, float>::insert<cudaStream_t>(
const unsigned long* d_keys,
const float* d_vals,
size_t len,
cudaStream_t stream);
template void HashTable<unsigned long, paddle::framework::FeatureValue*>::
insert<cudaStream_t>(const unsigned long* d_keys,
size_t len,
char* pool,
size_t feature_value_size,
size_t start_index,
cudaStream_t stream);
template void HashTable<unsigned long, float*>::insert<cudaStream_t>(
const unsigned long* d_keys,
size_t len,
char* pool,
size_t feature_value_size,
size_t start_index,
cudaStream_t stream);
template void HashTable<long, int>::insert<cudaStream_t>(const long* d_keys,
const int* d_vals,
......@@ -460,30 +439,37 @@ template void HashTable<unsigned long, long>::insert<cudaStream_t>(
size_t len,
cudaStream_t stream);
template void HashTable<unsigned long, paddle::framework::FeatureValue>::
dump_to_cpu<cudaStream_t>(int devid, cudaStream_t stream);
template void HashTable<unsigned long, unsigned long>::insert<cudaStream_t>(
const unsigned long* d_keys,
const unsigned long* d_vals,
size_t len,
cudaStream_t stream);
template void HashTable<unsigned long, paddle::framework::FeatureValue>::update<
paddle::framework::FeaturePushValue,
Optimizer<paddle::framework::FeatureValue,
paddle::framework::FeaturePushValue>,
template void HashTable<unsigned long, float*>::dump_to_cpu<cudaStream_t>(
int devid, cudaStream_t stream);
template void
HashTable<unsigned long, float*>::update<SparseAdagradOptimizer, cudaStream_t>(
const unsigned long* d_keys,
const char* d_grads,
size_t len,
SparseAdagradOptimizer sgd,
cudaStream_t stream);
template void
HashTable<unsigned long, float*>::update<SparseAdamOptimizer, cudaStream_t>(
const unsigned long* d_keys,
const char* d_grads,
size_t len,
SparseAdamOptimizer sgd,
cudaStream_t stream);
template void HashTable<unsigned long, float*>::update<
SparseAdamSharedOptimizer,
cudaStream_t>(const unsigned long* d_keys,
const paddle::framework::FeaturePushValue* d_grads,
const char* d_grads,
size_t len,
Optimizer<paddle::framework::FeatureValue,
paddle::framework::FeaturePushValue> sgd,
SparseAdamSharedOptimizer sgd,
cudaStream_t stream);
template void HashTable<unsigned long, paddle::framework::FeatureValue*>::
update<Optimizer<paddle::framework::FeatureValue,
paddle::framework::FeaturePushValue>,
cudaStream_t>(const unsigned long* d_keys,
const char* d_grads,
size_t len,
Optimizer<paddle::framework::FeatureValue,
paddle::framework::FeaturePushValue> sgd,
cudaStream_t stream);
// template void HashTable<unsigned long,
// paddle::framework::FeatureValue>::update<
// Optimizer<paddle::framework::FeatureValue,
......
......@@ -46,7 +46,10 @@ namespace framework {
#define TYPEALIGN(ALIGNVAL, LEN) \
(((uint64_t)(LEN) + ((ALIGNVAL)-1)) & ~((uint64_t)((ALIGNVAL)-1)))
template <typename KeyType, typename ValType, typename GradType>
template <typename KeyType,
typename ValType,
typename GradType,
typename FVAccessor>
class HeterComm {
public:
HeterComm(size_t capacity, std::shared_ptr<HeterPsResource> resource);
......@@ -65,12 +68,9 @@ class HeterComm {
GradType* d_grads,
size_t len,
int& uniq_len); // NOLINT
void dynamic_merge_grad(int gpu_num,
KeyType* d_keys,
GradType* d_grads,
size_t len,
int& uniq_len);
void pull_sparse(int num, KeyType* d_keys, ValType* d_vals, size_t len);
void dynamic_merge_grad(
int gpu_num, KeyType* d_keys, float* d_grads, size_t len, int& uniq_len);
void pull_sparse(int num, KeyType* d_keys, float* d_vals, size_t len);
void build_ps(int num,
KeyType* h_keys,
ValType* h_vals,
......@@ -92,7 +92,7 @@ class HeterComm {
template <typename Sgd>
void push_sparse(int num,
KeyType* d_keys,
GradType* d_grads,
float* d_grads,
size_t len,
Sgd& sgd); // NOLINT
#elif defined(PADDLE_WITH_XPU_KP)
......@@ -149,6 +149,13 @@ class HeterComm {
multi_mf_dim_ = multi_mf_dim;
max_mf_dim_ = max_mf_dim;
}
void set_accessor(FVAccessor& accessor) {
feature_value_accessor_ = accessor;
// for (auto& ptr_table: ptr_tables_) {
// ptr_table->set_accessor(feature_value_accessor_);
// }
}
#endif
bool need_transfer(int send_id, int receive_id) {
......@@ -282,9 +289,11 @@ class HeterComm {
char* src_val,
size_t val_size);
FVAccessor feature_value_accessor_;
protected:
using Table = HashTable<KeyType, ValType>;
using PtrTable = HashTable<KeyType, ValType*>;
using PtrTable = HashTable<KeyType, float*>;
std::vector<Table*> tables_;
std::vector<PtrTable*> ptr_tables_;
std::shared_ptr<HeterPsResource> resource_;
......
......@@ -128,22 +128,28 @@ __global__ void fill_dvals_kernel(ValType* d_shard_vals,
}
}
template <typename KeyType, typename GradType, typename T>
__global__ void dy_mf_fill_shard_grads_kernel(KeyType* d_shard_keys,
KeyType* d_keys,
GradType* d_shard_grads,
GradType* d_grads,
T* idx,
size_t len,
size_t grad_value_size) {
template <typename KeyType, typename T, typename FVAccessor>
__global__ void dy_mf_fill_shard_grads_kernel(
KeyType* d_shard_keys,
KeyType* d_keys,
float* d_shard_grads,
float* d_grads,
T* idx,
size_t len,
size_t grad_value_size,
FVAccessor feature_value_accessor) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
d_shard_keys[i] = d_keys[idx[i]];
*(GradType*)((char*)d_shard_grads + i * grad_value_size) =
*(GradType*)((char*)d_grads + uint64_t(idx[i]) * grad_value_size);
float* cur = (float*)((char*)d_shard_grads + i * grad_value_size);
float* shard_val =
(float*)((char*)d_grads + uint64_t(idx[i]) * grad_value_size);
feature_value_accessor.PushValueFill(cur, shard_val);
}
}
template <typename FVAccessor>
__global__ void merge_gradients_kernel(const uint32_t* offset,
const uint32_t* fea_num,
const uint32_t* index,
......@@ -151,36 +157,40 @@ __global__ void merge_gradients_kernel(const uint32_t* offset,
char* output,
int n,
size_t grad_value_size,
DynamicGradMerger& merger_) {
DynamicGradMerger& merger,
FVAccessor& feature_value_accessor) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) {
uint32_t start = offset[i];
uint32_t num = fea_num[i];
int ori_index = index[start];
FeaturePushValue& out = *(FeaturePushValue*)(output + i * grad_value_size);
FeaturePushValue& in =
*(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size);
merger_.update_one(out, in);
float* out = (float*)(output + i * grad_value_size);
float* in = (float*)(input + size_t(ori_index) * grad_value_size);
merger.update_one(out, in, feature_value_accessor);
for (int j = 1; j < num; ++j) {
ori_index = index[start + j];
FeaturePushValue& rhs =
*(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size);
merger_.merge_one(out, rhs);
in = (float*)(input + size_t(ori_index) * grad_value_size);
merger.merge_one(out, in, feature_value_accessor);
}
}
}
template <typename ValType, typename T>
__global__ void dy_mf_fill_dvals_kernel(ValType* d_shard_vals,
ValType* d_vals,
template <typename T, typename FVAccessor>
__global__ void dy_mf_fill_dvals_kernel(float* d_shard_vals,
float* d_vals,
T* idx,
size_t len,
size_t val_size) {
size_t val_size,
FVAccessor feature_value_accessor) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
uint64_t new_offset = uint64_t(idx[i]) * val_size;
*(ValType*)((char*)d_vals + new_offset) =
*(ValType*)((char*)d_shard_vals + i * val_size);
float* cur = (float*)((char*)d_vals + new_offset);
float* shard_val = (float*)((char*)d_shard_vals + uint64_t(i) * val_size);
int mf_dim = int(
shard_val[feature_value_accessor.common_feature_value.MfDimIndex()]);
feature_value_accessor.FeatureValueFill(cur, shard_val, mf_dim);
}
}
......@@ -312,15 +322,20 @@ void HeterCommKernel::reduce_by_key(void* d_temp_storage,
debug_synchronous));
}
template <typename KeyType, typename GradType, typename T, typename StreamType>
void HeterCommKernel::dy_mf_fill_shard_grads(KeyType* d_shard_keys,
KeyType* d_keys,
GradType* d_shard_grads,
GradType* d_grads,
T* idx,
long long len,
size_t grad_value_size,
const StreamType& stream) {
template <typename KeyType,
typename T,
typename StreamType,
typename FVAccessor>
void HeterCommKernel::dy_mf_fill_shard_grads(
KeyType* d_shard_keys,
KeyType* d_keys,
float* d_shard_grads,
float* d_grads,
T* idx,
long long len,
size_t grad_value_size,
const StreamType& stream,
FVAccessor& feature_value_accessor) {
int grid_size = (len - 1) / block_size_ + 1;
size_t c_len = (size_t)len;
dy_mf_fill_shard_grads_kernel<<<grid_size, block_size_, 0, stream>>>(
......@@ -330,10 +345,11 @@ void HeterCommKernel::dy_mf_fill_shard_grads(KeyType* d_shard_keys,
d_grads,
idx,
c_len,
grad_value_size);
grad_value_size,
feature_value_accessor);
}
template <typename StreamType>
template <typename StreamType, typename FVAccessor>
void HeterCommKernel::merge_gradient(const uint32_t* offset,
const uint32_t* fea_num,
const uint32_t* index,
......@@ -342,23 +358,33 @@ void HeterCommKernel::merge_gradient(const uint32_t* offset,
int n,
size_t grad_value_size,
DynamicGradMerger& merger_,
const StreamType& stream) {
const StreamType& stream,
FVAccessor& feature_value_accessor) {
int grid_size = (n - 1) / block_size_ + 1;
merge_gradients_kernel<<<grid_size, block_size_, 0, stream>>>(
offset, fea_num, index, input, output, n, grad_value_size, merger_);
offset,
fea_num,
index,
input,
output,
n,
grad_value_size,
merger_,
feature_value_accessor);
}
template <typename ValType, typename T, typename StreamType>
void HeterCommKernel::dy_mf_fill_dvals(ValType* d_shard_vals,
ValType* d_vals,
template <typename T, typename StreamType, typename FVAccessor>
void HeterCommKernel::dy_mf_fill_dvals(float* d_shard_vals,
float* d_vals,
T* idx,
long long len,
size_t val_size,
const StreamType& stream) {
const StreamType& stream,
FVAccessor& feature_value_accessor) {
int grid_size = (len - 1) / block_size_ + 1;
size_t c_len = (size_t)len;
dy_mf_fill_dvals_kernel<<<grid_size, block_size_, 0, stream>>>(
d_shard_vals, d_vals, idx, c_len, val_size);
d_shard_vals, d_vals, idx, c_len, val_size, feature_value_accessor);
}
template void HeterCommKernel::fill_idx<int, cudaStream_t>(
......@@ -402,17 +428,15 @@ template void HeterCommKernel::fill_shard_key<unsigned long, int, cudaStream_t>(
long long len,
const cudaStream_t& stream);
template void HeterCommKernel::fill_shard_grads<
unsigned long,
paddle::framework::FeaturePushValue,
int,
cudaStream_t>(unsigned long* d_shard_keys,
unsigned long* d_keys,
paddle::framework::FeaturePushValue* d_shard_grads,
paddle::framework::FeaturePushValue* d_grads,
int* idx,
long long len,
const cudaStream_t& stream);
template void
HeterCommKernel::fill_shard_grads<unsigned long, float, int, cudaStream_t>(
unsigned long* d_shard_keys,
unsigned long* d_keys,
float* d_shard_grads,
float* d_grads,
int* idx,
long long len,
const cudaStream_t& stream);
template void
HeterCommKernel::fill_dvals<paddle::framework::FeatureValue, int, cudaStream_t>(
......@@ -467,20 +491,23 @@ template void HeterCommKernel::reduce_by_key<
cudaStream_t stream,
bool debug_synchronous);
template void HeterCommKernel::dy_mf_fill_shard_grads<
unsigned long,
paddle::framework::FeaturePushValue,
int,
cudaStream_t>(unsigned long* d_shard_keys,
unsigned long* d_keys,
paddle::framework::FeaturePushValue* d_shard_grads,
paddle::framework::FeaturePushValue* d_grads,
int* idx,
long long len,
size_t grad_value_size,
const cudaStream_t& stream);
template void HeterCommKernel::merge_gradient<cudaStream_t>(
template void
HeterCommKernel::dy_mf_fill_shard_grads<unsigned long,
int,
cudaStream_t,
CommonFeatureValueAccessor>(
unsigned long* d_shard_keys,
unsigned long* d_keys,
float* d_shard_grads,
float* d_grads,
int* idx,
long long len,
size_t grad_value_size,
const cudaStream_t& stream,
CommonFeatureValueAccessor& feature_value_accessor);
template void
HeterCommKernel::merge_gradient<cudaStream_t, CommonFeatureValueAccessor>(
const uint32_t* offset,
const uint32_t* fea_num,
const uint32_t* index,
......@@ -489,16 +516,18 @@ template void HeterCommKernel::merge_gradient<cudaStream_t>(
int n,
size_t grad_value_size,
DynamicGradMerger& merger_,
const cudaStream_t& stream);
const cudaStream_t& stream,
CommonFeatureValueAccessor& feature_value_accessor);
template void HeterCommKernel::
dy_mf_fill_dvals<paddle::framework::FeatureValue, int, cudaStream_t>(
paddle::framework::FeatureValue* d_shard_vals,
paddle::framework::FeatureValue* d_vals,
dy_mf_fill_dvals<int, cudaStream_t, CommonFeatureValueAccessor>(
float* d_shard_vals,
float* d_vals,
int* idx,
long long len,
size_t val_size,
const cudaStream_t& stream);
const cudaStream_t& stream,
CommonFeatureValueAccessor& feature_value_accessor);
#endif
} // namespace framework
......
......@@ -41,25 +41,16 @@ struct DynamicGradMerger {
return out;
}
template <typename T>
__device__ __forceinline__ void update_one(T& output, const T& input) {
output.slot = input.slot;
output.show = input.show;
output.clk = input.clk;
output.mf_dim = input.mf_dim;
output.lr_g = input.lr_g;
for (int i = 0; i < output.mf_dim; ++i) {
output.mf_g[i] = input.mf_g[i];
}
template <typename FVAccessor>
__device__ __forceinline__ void update_one(
float* output, const float* input, FVAccessor& feature_value_accessor) {
feature_value_accessor.PushValueFill(output, input);
}
template <typename T>
__device__ __forceinline__ void merge_one(T& output, const T& input) {
output.show += input.show;
output.clk += input.clk;
output.lr_g += input.lr_g;
for (int i = 0; i < input.mf_dim; ++i) {
output.mf_g[i] += input.mf_g[i];
}
template <typename FVAccessor>
__device__ __forceinline__ void merge_one(
float* output, const float* input, FVAccessor& feature_value_accessor) {
feature_value_accessor.MergePushValue(output, input);
}
};
......@@ -146,19 +137,20 @@ class HeterCommKernel {
bool debug_synchronous = false);
template <typename KeyType,
typename GradType,
typename T,
typename StreamType>
typename StreamType,
typename FVAccessor>
void dy_mf_fill_shard_grads(KeyType* d_shard_keys,
KeyType* d_keys,
GradType* d_shard_grads,
GradType* d_grads,
float* d_shard_grads,
float* d_grads,
T* idx,
long long len,
size_t grad_value_size,
const StreamType& stream);
const StreamType& stream,
FVAccessor& feature_value_accessor);
template <typename StreamType>
template <typename StreamType, typename FVAccessor>
void merge_gradient(const uint32_t* offset,
const uint32_t* fea_num,
const uint32_t* index,
......@@ -167,15 +159,17 @@ class HeterCommKernel {
int n,
size_t grad_value_size,
DynamicGradMerger& merger_,
const StreamType& stream);
const StreamType& stream,
FVAccessor& feature_value_accessor);
template <typename ValType, typename T, typename StreamType>
void dy_mf_fill_dvals(ValType* d_shard_vals,
ValType* d_vals,
template <typename T, typename StreamType, typename FVAccessor>
void dy_mf_fill_dvals(float* d_shard_vals,
float* d_vals,
T* idx,
long long len,
size_t val_size,
const StreamType& stream);
const StreamType& stream,
FVAccessor& feature_value_accessor);
private:
int block_size_{256};
......
......@@ -22,34 +22,43 @@ namespace paddle {
namespace framework {
HeterPsBase* HeterPsBase::get_instance(
size_t capacity, std::shared_ptr<HeterPsResource> resource) {
return new HeterPs(capacity, resource);
size_t capacity,
std::shared_ptr<HeterPsResource> resource,
std::unordered_map<std::string, float> fleet_config,
std::string accessor_type,
int optimizer_type) {
if (accessor_type == "CtrDymfAccessor" &&
(optimizer_type == 1 || optimizer_type == 3 || optimizer_type == 4)) {
return new HeterPs<CommonFeatureValueAccessor>(
capacity, resource, accessor_type, fleet_config, optimizer_type);
} else {
VLOG(0) << " HeterPsBase get_instance Warning: now only support "
"CtrDymfAccessor, but get "
<< accessor_type_;
return new HeterPs<CommonFeatureValueAccessor>(
capacity, resource, accessor_type, fleet_config, optimizer_type);
}
}
HeterPs::HeterPs(size_t capacity, std::shared_ptr<HeterPsResource> resource) {
comm_ =
std::make_shared<HeterComm<FeatureKey, FeatureValue, FeaturePushValue>>(
capacity, resource);
HeterPs::HeterPs(size_t capacity,
std::shared_ptr<HeterPsResource> resource,
std::unordered_map<std::string, float> fleet_config,
std::string accessor_type,
int optimizer_type) {
comm_ = std::make_shared<HeterComm<FeatureKey, float*, float*, FVAccessor>>(
capacity, resource);
optimizer_type_ = optimizer_type;
}
HeterPs::~HeterPs() {}
void HeterPs::pull_sparse(int num,
FeatureKey* d_keys,
FeatureValue* d_vals,
float* d_vals,
size_t len) {
comm_->pull_sparse(num, d_keys, d_vals, len);
}
void HeterPs::build_ps(int num,
FeatureKey* h_keys,
FeatureValue* h_vals,
size_t len,
size_t chunk_size,
int stream_num) {
comm_->build_ps(num, h_keys, h_vals, len, chunk_size, stream_num);
}
int HeterPs::get_index_by_devid(int devid) {
return comm_->get_index_by_devid(devid);
}
......@@ -68,7 +77,7 @@ void HeterPs::show_one_table(int gpu_num) { comm_->show_one_table(gpu_num); }
void HeterPs::push_sparse(int num,
FeatureKey* d_keys,
FeaturePushValue* d_grads,
float* d_grads,
size_t len) {
comm_->push_sparse(num, d_keys, d_grads, len);
// comm_->push_sparse_multi_node(num, d_keys, d_grads, len, opt_);
......
......@@ -22,80 +22,139 @@ namespace paddle {
namespace framework {
HeterPsBase* HeterPsBase::get_instance(
size_t capacity, std::shared_ptr<HeterPsResource> resource) {
return new HeterPs(capacity, resource);
size_t capacity,
std::shared_ptr<HeterPsResource> resource,
std::unordered_map<std::string, float> fleet_config,
std::string accessor_type,
int optimizer_type) {
if (accessor_type == "CtrDymfAccessor" &&
(optimizer_type == 1 || optimizer_type == 3 || optimizer_type == 4)) {
return new HeterPs<CommonFeatureValueAccessor>(
capacity, resource, fleet_config, accessor_type, optimizer_type);
} else {
VLOG(0) << " HeterPsBase get_instance Warning: now only support "
"CtrDymfAccessor, but get "
<< accessor_type;
return new HeterPs<CommonFeatureValueAccessor>(
capacity, resource, fleet_config, accessor_type, optimizer_type);
}
}
HeterPs::HeterPs(size_t capacity, std::shared_ptr<HeterPsResource> resource) {
comm_ =
std::make_shared<HeterComm<FeatureKey, FeatureValue, FeaturePushValue>>(
capacity, resource);
opt_ = Optimizer<FeatureValue, FeaturePushValue>();
template <typename FVAccessor>
HeterPs<FVAccessor>::HeterPs(
size_t capacity,
std::shared_ptr<HeterPsResource> resource,
std::unordered_map<std::string, float> fleet_config,
std::string accessor_type,
int optimizer_type) {
comm_ = std::make_shared<HeterComm<FeatureKey, float*, float*, FVAccessor>>(
capacity, resource);
feature_value_accessor_.Configure(fleet_config);
set_accessor(feature_value_accessor_);
accessor_type_ = accessor_type;
optimizer_type_ = optimizer_type;
}
HeterPs::~HeterPs() {}
template <typename FVAccessor>
HeterPs<FVAccessor>::~HeterPs() {}
void HeterPs::pull_sparse(int num,
FeatureKey* d_keys,
FeatureValue* d_vals,
size_t len) {
template <typename FVAccessor>
void HeterPs<FVAccessor>::pull_sparse(int num,
FeatureKey* d_keys,
float* d_vals,
size_t len) {
comm_->pull_sparse(num, d_keys, d_vals, len);
}
void HeterPs::build_ps(int num,
FeatureKey* h_keys,
FeatureValue* h_vals,
size_t len,
size_t chunk_size,
int stream_num) {
comm_->build_ps(num, h_keys, h_vals, len, chunk_size, stream_num);
}
void HeterPs::build_ps(int num,
FeatureKey* h_keys,
char* pool,
size_t len,
size_t feature_value_size,
size_t chunk_size,
int stream_num) {
template <typename FVAccessor>
void HeterPs<FVAccessor>::build_ps(int num,
FeatureKey* h_keys,
char* pool,
size_t len,
size_t feature_value_size,
size_t chunk_size,
int stream_num) {
comm_->build_ps(
num, h_keys, pool, len, feature_value_size, chunk_size, stream_num);
}
int HeterPs::get_index_by_devid(int devid) {
template <typename FVAccessor>
int HeterPs<FVAccessor>::get_index_by_devid(int devid) {
return comm_->get_index_by_devid(devid);
}
void HeterPs::set_sparse_sgd(const OptimizerConfig& optimizer_config) {
template <typename FVAccessor>
void HeterPs<FVAccessor>::set_sparse_sgd(
const OptimizerConfig& optimizer_config) {
comm_->set_sparse_sgd(optimizer_config);
}
void HeterPs::set_embedx_sgd(const OptimizerConfig& optimizer_config) {
template <typename FVAccessor>
void HeterPs<FVAccessor>::set_embedx_sgd(
const OptimizerConfig& optimizer_config) {
comm_->set_embedx_sgd(optimizer_config);
}
void HeterPs::end_pass() { comm_->end_pass(); }
template <typename FVAccessor>
void HeterPs<FVAccessor>::end_pass() {
comm_->end_pass();
}
void HeterPs::show_one_table(int gpu_num) { comm_->show_one_table(gpu_num); }
template <typename FVAccessor>
void HeterPs<FVAccessor>::show_one_table(int gpu_num) {
comm_->show_one_table(gpu_num);
}
void HeterPs::push_sparse(int num,
FeatureKey* d_keys,
FeaturePushValue* d_grads,
size_t len) {
comm_->push_sparse(num, d_keys, d_grads, len, opt_);
// comm_->push_sparse_multi_node(num, d_keys, d_grads, len, opt_);
template <typename FVAccessor>
void HeterPs<FVAccessor>::push_sparse(int num,
FeatureKey* d_keys,
float* d_grads,
size_t len) {
if (accessor_type_ == "CtrDymfAccessor") {
if (optimizer_type_ == 3) { // adam
auto optimizer = SparseAdamOptimizer(feature_value_accessor_);
VLOG(5) << "INTO push_sparse SparseAdamOptimizer, EmbedDim():"
<< optimizer.EmbedDim();
comm_->push_sparse(num, d_keys, d_grads, len, optimizer);
} else if (optimizer_type_ == 4) { // shared_adam
auto optimizer = SparseAdamSharedOptimizer(feature_value_accessor_);
VLOG(5) << "INTO push_sparse SparseAdamSharedOptimizer, EmbedDim():"
<< optimizer.EmbedDim();
comm_->push_sparse(num, d_keys, d_grads, len, optimizer);
} else if (optimizer_type_ == 1) { // adagrad {
auto optimizer = SparseAdagradOptimizer(feature_value_accessor_);
VLOG(5) << "INTO push_sparse SparseAdagradOptimizer, EmbedDim():"
<< optimizer.EmbedDim();
comm_->push_sparse(num, d_keys, d_grads, len, optimizer);
} else {
VLOG(0) << " push sparse Error: CtrDymfAccessor only support adagrad(1),"
"adam(3) or shared_adam(4), bug get optimizer type:"
<< optimizer_type_;
}
} else {
VLOG(0) << " push sparse Error: now only support CtrDymfAccessor, but get "
<< accessor_type_;
}
}
void HeterPs::set_nccl_comm_and_size(const std::vector<ncclComm_t>& inner_comms,
const std::vector<ncclComm_t>& inter_comms,
int comm_size) {
template <typename FVAccessor>
void HeterPs<FVAccessor>::set_nccl_comm_and_size(
const std::vector<ncclComm_t>& inner_comms,
const std::vector<ncclComm_t>& inter_comms,
int comm_size) {
comm_->set_nccl_comm_and_size(inner_comms, inter_comms, comm_size);
}
void HeterPs::set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) {
template <typename FVAccessor>
void HeterPs<FVAccessor>::set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) {
comm_->set_multi_mf_dim(multi_mf_dim, max_mf_dim);
}
template <typename FVAccessor>
void HeterPs<FVAccessor>::set_accessor(FVAccessor& accessor) {
comm_->set_accessor(accessor);
}
} // end namespace framework
} // end namespace paddle
#endif
......@@ -26,24 +26,23 @@ limitations under the License. */
namespace paddle {
namespace framework {
template <typename FVAccessor>
class HeterPs : public HeterPsBase {
public:
HeterPs() {}
HeterPs(size_t capacity, std::shared_ptr<HeterPsResource> resource);
HeterPs(size_t capacity,
std::shared_ptr<HeterPsResource> resource,
std::unordered_map<std::string, float> fleet_config,
std::string accessor_type,
int optimizer_type);
virtual ~HeterPs();
HeterPs(const HeterPs&) = delete;
HeterPs& operator=(const HeterPs&) = delete;
void pull_sparse(int num,
FeatureKey* d_keys,
FeatureValue* d_vals,
float* d_vals,
size_t len) override;
void build_ps(int num,
FeatureKey* h_keys,
FeatureValue* h_vals,
size_t len,
size_t chunk_size,
int stream_num) override;
void build_ps(int num,
FeatureKey* h_keys,
char* pool,
......@@ -56,6 +55,8 @@ class HeterPs : public HeterPsBase {
const std::vector<ncclComm_t>& inter_comms,
int comm_size) override;
void set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) override;
void set_accessor(FVAccessor& accessor);
#endif
void set_sparse_sgd(const OptimizerConfig& optimizer_config) override;
......@@ -66,13 +67,15 @@ class HeterPs : public HeterPsBase {
void show_one_table(int gpu_num) override;
void push_sparse(int num,
FeatureKey* d_keys,
FeaturePushValue* d_grads,
float* d_grads,
size_t len) override;
private:
std::shared_ptr<HeterComm<FeatureKey, FeatureValue, FeaturePushValue>> comm_;
std::shared_ptr<HeterComm<FeatureKey, float*, float*, FVAccessor>> comm_;
#if defined(PADDLE_WITH_CUDA)
Optimizer<FeatureValue, FeaturePushValue> opt_;
FVAccessor feature_value_accessor_;
std::string accessor_type_;
int optimizer_type_;
#endif
};
......
......@@ -34,14 +34,8 @@ class HeterPsBase {
virtual void pull_sparse(int num,
FeatureKey* d_keys,
FeatureValue* d_vals,
float* d_vals,
size_t len) = 0;
virtual void build_ps(int num,
FeatureKey* h_keys,
FeatureValue* h_vals,
size_t len,
size_t chunk_size,
int stream_num) = 0;
virtual void build_ps(int num,
FeatureKey* h_keys,
char* pool,
......@@ -56,19 +50,25 @@ class HeterPsBase {
const std::vector<ncclComm_t>& inter_comms,
int comm_size) = 0;
virtual void set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) = 0;
#endif
virtual void end_pass() = 0;
virtual void show_one_table(int gpu_num) = 0;
virtual void push_sparse(int num,
FeatureKey* d_keys,
FeaturePushValue* d_grads,
float* d_grads,
size_t len) = 0;
virtual void set_sparse_sgd(const OptimizerConfig& optimizer_config) = 0;
virtual void set_embedx_sgd(const OptimizerConfig& optimizer_config) = 0;
static HeterPsBase* get_instance(size_t capacity,
std::shared_ptr<HeterPsResource> resource);
static HeterPsBase* get_instance(
size_t capacity,
std::shared_ptr<HeterPsResource> resource,
// CommonFeatureValueAccessor feature_value_accessor,
std::unordered_map<std::string, float> fleet_config,
std::string accessor_type,
int optimizer_type);
};
} // end namespace framework
......
......@@ -82,20 +82,6 @@ class HBMMemoryPool : public managed {
cudaMemset(mem_, 0, block_size_ * capacity);
}
friend std::ostream& operator<<(std::ostream& out, HBMMemoryPool& p) {
for (size_t k = 0; k < 5; k++) {
auto x = (FeatureValue*)(p.mem() + k * p.capacity());
out << "show: " << x->show << " clk: " << x->clk << " slot: " << x->slot
<< " lr: " << x->lr << " mf_dim: " << x->mf_size
<< " mf_size: " << x->mf_size << " mf:";
for (int i = 0; i < x->mf_size + 1; ++i) {
out << " " << x->mf[i];
}
out << "\n";
}
return out;
}
char* mem() { return mem_; }
size_t capacity() { return capacity_; }
......
......@@ -27,13 +27,19 @@ class OptimizerConfig {
float learning_rate = 0.05;
float initial_g2sum = 3.0;
float initial_range = 0;
float beta1_decay_rate = 0.9; // adam
float beta2_decay_rate = 0.999; // adam
float ada_epsilon = 1e-8;
float mf_create_thresholds = 10;
float mf_learning_rate = 0.05;
float mf_initial_g2sum = 3.0;
float mf_initial_range = 1e-4;
float mf_beta1_decay_rate = 0.9; // adam
float mf_beta2_decay_rate = 0.999; // adam
float mf_min_bound = -10;
float mf_max_bound = 10;
float mf_ada_epsilon = 1e-8;
void set_sparse_sgd(float nonclk_coeff,
float clk_coeff,
......@@ -41,7 +47,10 @@ class OptimizerConfig {
float max_bound,
float learning_rate,
float initial_g2sum,
float initial_range) {
float initial_range,
float beta1_decay_rate,
float beta2_decay_rate,
float ada_epsilon) {
this->nonclk_coeff = nonclk_coeff;
this->clk_coeff = clk_coeff;
this->min_bound = min_bound;
......@@ -49,6 +58,9 @@ class OptimizerConfig {
this->learning_rate = learning_rate;
this->initial_g2sum = initial_g2sum;
this->initial_range = initial_range;
this->beta1_decay_rate = beta1_decay_rate;
this->beta2_decay_rate = beta2_decay_rate;
this->ada_epsilon = ada_epsilon;
}
void set_sparse_sgd(const OptimizerConfig& optimizer_config) {
......@@ -59,6 +71,9 @@ class OptimizerConfig {
this->learning_rate = optimizer_config.learning_rate;
this->initial_g2sum = optimizer_config.initial_g2sum;
this->initial_range = optimizer_config.initial_range;
this->beta1_decay_rate = optimizer_config.beta1_decay_rate;
this->beta2_decay_rate = optimizer_config.beta2_decay_rate;
this->ada_epsilon = optimizer_config.ada_epsilon;
}
void set_embedx_sgd(float mf_create_thresholds,
......@@ -66,13 +81,19 @@ class OptimizerConfig {
float mf_initial_g2sum,
float mf_initial_range,
float mf_min_bound,
float mf_max_bound) {
float mf_max_bound,
float mf_beta1_decay_rate,
float mf_beta2_decay_rate,
float mf_ada_epsilon) {
this->mf_create_thresholds = mf_create_thresholds;
this->mf_learning_rate = mf_learning_rate;
this->mf_initial_g2sum = mf_initial_g2sum;
this->mf_initial_range = mf_initial_range;
this->mf_min_bound = mf_min_bound;
this->mf_max_bound = mf_max_bound;
this->mf_beta1_decay_rate = mf_beta1_decay_rate;
this->mf_beta2_decay_rate = mf_beta2_decay_rate;
this->mf_ada_epsilon = mf_ada_epsilon;
}
void set_embedx_sgd(const OptimizerConfig& optimizer_config) {
......@@ -82,6 +103,9 @@ class OptimizerConfig {
this->mf_initial_range = optimizer_config.mf_initial_range;
this->mf_min_bound = optimizer_config.mf_min_bound;
this->mf_max_bound = optimizer_config.mf_max_bound;
this->mf_beta1_decay_rate = optimizer_config.mf_beta1_decay_rate;
this->mf_beta2_decay_rate = optimizer_config.mf_beta2_decay_rate;
this->mf_ada_epsilon = optimizer_config.mf_ada_epsilon;
}
};
......
......@@ -26,90 +26,6 @@ limitations under the License. */
namespace paddle {
namespace framework {
__global__ void PullCopy(float** dest,
const FeatureValue* src,
const int64_t* len,
int hidden,
int slot_num,
int total_len,
uint64_t** keys) {
CUDA_KERNEL_LOOP(i, total_len) {
int low = 0;
int high = slot_num - 1;
while (low < high) {
int mid = (low + high) / 2;
if (i < len[mid])
high = mid;
else
low = mid + 1;
}
int x = low;
int y = i - (x ? len[x - 1] : 0);
if (*(keys[x] + y) == 0) {
*(dest[x] + y * hidden) = 0;
*(dest[x] + y * hidden + 1) = 0;
*(dest[x] + y * hidden + 2) = 0;
} else {
*(dest[x] + y * hidden) = (src + i)->show;
*(dest[x] + y * hidden + 1) = (src + i)->clk;
*(dest[x] + y * hidden + 2) = (src + i)->lr;
}
if ((src + i)->mf_size == 0 || *(keys[x] + y) == 0) {
for (int j = 0; j < hidden - 3; j++) {
*(dest[x] + y * hidden + 3 + j) = 0;
}
} else {
for (int j = 0; j < hidden - 3; j++) {
*(dest[x] + y * hidden + 3 + j) = (src + i)->mf[1 + j];
}
}
}
}
__global__ void PullCopy(float** dest,
const FeatureValue* src,
const int64_t* len,
int slot_num,
int total_len,
uint64_t** keys,
uint64_t max_val_size,
int* gpu_dim) {
CUDA_KERNEL_LOOP(i, total_len) {
int low = 0;
int high = slot_num - 1;
while (low < high) {
int mid = (low + high) / 2;
if (i < len[mid])
high = mid;
else
low = mid + 1;
}
int x = low;
int y = i - (x ? len[x - 1] : 0);
FeatureValue* feature_value_ptr =
(FeatureValue*)((char*)src + uint64_t(i) * uint64_t(max_val_size));
int mf_dim = gpu_dim[x] - 3;
if (*(keys[x] + y) == 0) {
*(dest[x] + y * (mf_dim + 3)) = 0;
*(dest[x] + y * (mf_dim + 3) + 1) = 0;
*(dest[x] + y * (mf_dim + 3) + 2) = 0;
} else {
*(dest[x] + y * (mf_dim + 3)) = feature_value_ptr->show;
*(dest[x] + y * (mf_dim + 3) + 1) = feature_value_ptr->clk;
*(dest[x] + y * (mf_dim + 3) + 2) = feature_value_ptr->lr;
}
if ((feature_value_ptr)->mf_size == 0 || *(keys[x] + y) == 0) {
for (int j = 0; j < mf_dim; j++) {
*(dest[x] + y * (mf_dim + 3) + 3 + j) = 0;
}
} else {
for (int j = 0; j < mf_dim; j++) {
*(dest[x] + y * (mf_dim + 3) + 3 + j) = feature_value_ptr->mf[1 + j];
}
}
}
}
__global__ void CopyKeysKernel(uint64_t** src_keys,
uint64_t* dest_total_keys,
const int64_t* len,
......@@ -161,101 +77,8 @@ __global__ void PushCopy(FeaturePushValue* dest,
}
}
__global__ void PushCopyWithPool(FeaturePushValue* dest,
float** src,
int64_t* len,
int slot_num,
uint64_t total_len,
int bs,
int* slot_vector,
int* mf_dim_vector,
size_t grad_value_size) {
CUDA_KERNEL_LOOP(i, total_len) {
int low = 0;
int high = slot_num - 1;
while (low < high) {
int mid = (low + high) / 2;
if (i < len[mid])
high = mid;
else
low = mid + 1;
}
int x = low;
int y = i - (x ? len[low - 1] : 0);
FeaturePushValue* cur =
(FeaturePushValue*)((char*)dest + i * grad_value_size);
cur->slot = slot_vector[x];
int mf_dim = mf_dim_vector[x];
cur->mf_dim = mf_dim;
cur->show = *(src[x] + y * (mf_dim + 3));
cur->clk = *(src[x] + y * (mf_dim + 3) + 1);
cur->lr_g = *(src[x] + y * (mf_dim + 3) + 2) * -1. * bs;
for (int j = 0; j < cur->mf_dim; j++) {
cur->mf_g[j] = *(src[x] + y * (mf_dim + 3) + 3 + j) * -1. * bs;
}
}
}
PSGPUWrapper::~PSGPUWrapper() { delete HeterPs_; }
void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place,
uint64_t** gpu_keys,
const std::vector<float*>& values,
const FeatureValue* total_values_gpu,
const int64_t* gpu_len,
const int slot_num,
const int hidden_size,
const int64_t total_length) {
auto stream = dynamic_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(place))
->stream();
auto buf_value = memory::Alloc(place, values.size() * sizeof(float*));
float** gpu_values = reinterpret_cast<float**>(buf_value->ptr());
cudaMemcpy(gpu_values,
values.data(),
values.size() * sizeof(float*),
cudaMemcpyHostToDevice);
PullCopy<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>(
gpu_values,
total_values_gpu,
gpu_len,
hidden_size,
slot_num,
total_length,
gpu_keys);
cudaStreamSynchronize(stream);
}
void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place,
uint64_t** gpu_keys,
const std::vector<float*>& values,
const FeatureValue* total_values_gpu,
const int64_t* gpu_len,
const int slot_num,
const int hidden_size,
const int64_t total_length,
int* gpu_dim) {
auto stream = dynamic_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(place))
->stream();
auto buf_value = memory::Alloc(place, values.size() * sizeof(float*));
float** gpu_values = reinterpret_cast<float**>(buf_value->ptr());
cudaMemcpy(gpu_values,
values.data(),
values.size() * sizeof(float*),
cudaMemcpyHostToDevice);
PullCopy<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>(
gpu_values,
total_values_gpu,
gpu_len,
slot_num,
total_length,
gpu_keys,
val_type_size_,
gpu_dim);
cudaStreamSynchronize(stream);
}
void PSGPUWrapper::CopyKeys(const paddle::platform::Place& place,
uint64_t** origin_keys,
uint64_t* total_keys,
......@@ -270,125 +93,26 @@ void PSGPUWrapper::CopyKeys(const paddle::platform::Place& place,
cudaStreamSynchronize(stream);
}
void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place,
const std::vector<const float*>& grad_values,
FeaturePushValue* total_grad_values_gpu,
const std::vector<int64_t>& slot_lengths,
const int hidden_size,
const int64_t total_length,
const int batch_size) {
auto stream = dynamic_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(place))
->stream();
auto slot_lengths_lod = slot_lengths;
for (int i = 1; i < slot_lengths_lod.size(); i++) {
slot_lengths_lod[i] += slot_lengths_lod[i - 1];
}
auto buf_grad_value =
memory::Alloc(place, grad_values.size() * sizeof(float*));
auto buf_length = memory::Alloc(place, slot_lengths.size() * sizeof(int64_t));
auto buf_slot_vector =
memory::Alloc(place, slot_lengths_lod.size() * sizeof(int));
float** gpu_values = reinterpret_cast<float**>(buf_grad_value->ptr());
int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length->ptr());
int* d_slot_vector = reinterpret_cast<int*>(buf_slot_vector->ptr());
cudaMemcpy(gpu_values,
grad_values.data(),
grad_values.size() * sizeof(float*),
cudaMemcpyHostToDevice);
cudaMemcpy(gpu_len,
slot_lengths_lod.data(),
slot_lengths.size() * sizeof(int64_t),
cudaMemcpyHostToDevice);
cudaMemcpy(d_slot_vector,
slot_vector_.data(),
slot_lengths_lod.size() * sizeof(int),
cudaMemcpyHostToDevice);
PushCopy<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>(
total_grad_values_gpu,
gpu_values,
gpu_len,
hidden_size,
slot_lengths.size(),
total_length,
batch_size,
d_slot_vector);
cudaStreamSynchronize(stream);
}
void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place,
const std::vector<const float*>& grad_values,
FeaturePushValue* total_grad_values_gpu,
const std::vector<int64_t>& slot_lengths,
const uint64_t total_length,
const int batch_size,
size_t grad_value_size) {
auto stream = dynamic_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(place))
->stream();
auto slot_lengths_lod = slot_lengths;
for (int i = 1; i < slot_lengths_lod.size(); i++) {
slot_lengths_lod[i] += slot_lengths_lod[i - 1];
}
auto buf_grad_value =
memory::Alloc(place, grad_values.size() * sizeof(float*));
auto buf_length = memory::Alloc(place, slot_lengths.size() * sizeof(int64_t));
auto buf_slot_vector =
memory::Alloc(place, slot_lengths_lod.size() * sizeof(int));
auto buf_mf_dim_vector =
memory::Alloc(place, slot_lengths_lod.size() * sizeof(int));
float** gpu_values = reinterpret_cast<float**>(buf_grad_value->ptr());
int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length->ptr());
int* d_slot_vector = reinterpret_cast<int*>(buf_slot_vector->ptr());
int* d_mf_dim_vector = reinterpret_cast<int*>(buf_mf_dim_vector->ptr());
cudaMemcpy(gpu_values,
grad_values.data(),
grad_values.size() * sizeof(float*),
cudaMemcpyHostToDevice);
cudaMemcpy(gpu_len,
slot_lengths_lod.data(),
slot_lengths.size() * sizeof(int64_t),
cudaMemcpyHostToDevice);
cudaMemcpy(d_slot_vector,
slot_vector_.data(),
slot_lengths_lod.size() * sizeof(int),
cudaMemcpyHostToDevice);
cudaMemcpy(d_mf_dim_vector,
slot_mf_dim_vector_.data(),
slot_lengths_lod.size() * sizeof(int),
cudaMemcpyHostToDevice);
PushCopyWithPool<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>(
total_grad_values_gpu,
gpu_values,
gpu_len,
slot_lengths.size(),
total_length,
batch_size,
d_slot_vector,
d_mf_dim_vector,
grad_value_size);
cudaStreamSynchronize(stream);
}
void PSGPUWrapper::SetSparseSGD(float nonclk_coeff,
float clk_coeff,
float min_bound,
float max_bound,
float learning_rate,
float initial_g2sum,
float initial_range) {
OptimizerConfig optimizer_config;
optimizer_config.set_sparse_sgd(nonclk_coeff,
clk_coeff,
min_bound,
max_bound,
learning_rate,
initial_g2sum,
initial_range);
HeterPs_->set_sparse_sgd(optimizer_config);
float initial_range,
float beta1_decay_rate,
float beta2_decay_rate,
float ada_epsilon) {
optimizer_config_.set_sparse_sgd(nonclk_coeff,
clk_coeff,
min_bound,
max_bound,
learning_rate,
initial_g2sum,
initial_range,
beta1_decay_rate,
beta2_decay_rate,
ada_epsilon);
}
void PSGPUWrapper::SetEmbedxSGD(float mf_create_thresholds,
......@@ -396,15 +120,19 @@ void PSGPUWrapper::SetEmbedxSGD(float mf_create_thresholds,
float mf_initial_g2sum,
float mf_initial_range,
float mf_min_bound,
float mf_max_bound) {
OptimizerConfig optimizer_config;
optimizer_config.set_embedx_sgd(mf_create_thresholds,
mf_learning_rate,
mf_initial_g2sum,
mf_initial_range,
mf_min_bound,
mf_max_bound);
HeterPs_->set_embedx_sgd(optimizer_config);
float mf_max_bound,
float mf_beta1_decay_rate,
float mf_beta2_decay_rate,
float mf_ada_epsilon) {
optimizer_config_.set_embedx_sgd(mf_create_thresholds,
mf_learning_rate,
mf_initial_g2sum,
mf_initial_range,
mf_min_bound,
mf_max_bound,
mf_beta1_decay_rate,
mf_beta2_decay_rate,
mf_ada_epsilon);
}
} // end namespace framework
......
......@@ -51,7 +51,10 @@ limitations under the License. */
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_PSCORE
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h"
#include "paddle/fluid/distributed/ps/wrapper/fleet.h"
#include "paddle/fluid/distributed/the_one_ps.pb.h"
#endif
#ifdef PADDLE_WITH_PSLIB
#include "afs_api.h"
......@@ -64,9 +67,6 @@ limitations under the License. */
namespace paddle {
namespace framework {
#define TYPEALIGN(ALIGNVAL, LEN) \
(((uint64_t)(LEN) + ((ALIGNVAL)-1)) & ~((uint64_t)((ALIGNVAL)-1)))
class Dataset;
#ifdef PADDLE_WITH_PSLIB
......@@ -98,7 +98,7 @@ class AfsWrapper {
class PSGPUWrapper {
public:
virtual ~PSGPUWrapper();
~PSGPUWrapper();
PSGPUWrapper() {
HeterPs_ = NULL;
......@@ -139,37 +139,6 @@ class PSGPUWrapper {
const int64_t* gpu_len,
int slot_num,
int total_len);
void CopyForPull(const paddle::platform::Place& place,
uint64_t** gpu_keys,
const std::vector<float*>& values,
const FeatureValue* total_values_gpu,
const int64_t* gpu_len,
const int slot_num,
const int hidden_size,
const int64_t total_length);
void CopyForPull(const paddle::platform::Place& place,
uint64_t** gpu_keys,
const std::vector<float*>& values,
const FeatureValue* total_values_gpu,
const int64_t* gpu_len,
const int slot_num,
const int hidden_size,
const int64_t total_length,
int* gpu_dim);
void CopyForPush(const paddle::platform::Place& place,
const std::vector<const float*>& grad_values,
FeaturePushValue* total_grad_values_gpu,
const std::vector<int64_t>& slot_lengths,
const int hidden_size,
const int64_t total_length,
const int batch_size);
void CopyForPush(const paddle::platform::Place& place,
const std::vector<const float*>& grad_values,
FeaturePushValue* total_grad_values_gpu,
const std::vector<int64_t>& slot_lengths,
const uint64_t total_length,
const int batch_size,
size_t grad_value_size);
void BuildGPUTask(std::shared_ptr<HeterContext> gpu_task);
void PreBuildTask(std::shared_ptr<HeterContext> gpu_task);
......@@ -274,13 +243,96 @@ class PSGPUWrapper {
float max_bound,
float learning_rate,
float initial_g2sum,
float initial_range);
float initial_range,
float beta1_decay_rate,
float beta2_decay_rate,
float ada_epsilon);
void SetEmbedxSGD(float mf_create_thresholds,
float mf_learning_rate,
float mf_initial_g2sum,
float mf_initial_range,
float mf_min_bound,
float mf_max_bound);
float mf_max_bound,
float mf_beta1_decay_rate,
float mf_beta2_decay_rate,
float mf_ada_epsilon);
#ifdef PADDLE_WITH_PSCORE
void add_sparse_optimizer(
std::unordered_map<std::string, float>& config, // NOLINT
const ::paddle::distributed::SparseCommonSGDRuleParameter& sgd_param,
const std::string& prefix = "") {
auto optimizer_name = sgd_param.name();
if (optimizer_name == "SparseNaiveSGDRule") {
config[prefix + "optimizer_type"] = 0;
config[prefix + "learning_rate"] = sgd_param.naive().learning_rate();
config[prefix + "initial_range"] = sgd_param.naive().initial_range();
config[prefix + "min_bound"] = sgd_param.naive().weight_bounds()[0];
config[prefix + "max_bound"] = sgd_param.naive().weight_bounds()[1];
} else if (optimizer_name == "SparseAdaGradSGDRule") {
config[prefix + "optimizer_type"] = 1;
config[prefix + "learning_rate"] = sgd_param.adagrad().learning_rate();
config[prefix + "initial_range"] = sgd_param.adagrad().initial_range();
config[prefix + "initial_g2sum"] = sgd_param.adagrad().initial_g2sum();
config[prefix + "min_bound"] = sgd_param.adagrad().weight_bounds()[0];
config[prefix + "max_bound"] = sgd_param.adagrad().weight_bounds()[1];
} else if (optimizer_name == "StdAdaGradSGDRule") {
config[prefix + "optimizer_type"] = 2;
config[prefix + "learning_rate"] = sgd_param.adagrad().learning_rate();
config[prefix + "initial_range"] = sgd_param.adagrad().initial_range();
config[prefix + "initial_g2sum"] = sgd_param.adagrad().initial_g2sum();
config[prefix + "min_bound"] = sgd_param.adagrad().weight_bounds()[0];
config[prefix + "max_bound"] = sgd_param.adagrad().weight_bounds()[1];
} else if (optimizer_name == "SparseAdamSGDRule") {
config[prefix + "optimizer_type"] = 3;
config[prefix + "learning_rate"] = sgd_param.adam().learning_rate();
config[prefix + "initial_range"] = sgd_param.adam().initial_range();
config[prefix + "beta1_decay_rate"] = sgd_param.adam().beta1_decay_rate();
config[prefix + "beta2_decay_rate"] = sgd_param.adam().beta2_decay_rate();
config[prefix + "ada_epsilon"] = sgd_param.adam().ada_epsilon();
config[prefix + "min_bound"] = sgd_param.adam().weight_bounds()[0];
config[prefix + "max_bound"] = sgd_param.adam().weight_bounds()[1];
} else if (optimizer_name == "SparseSharedAdamSGDRule") {
config[prefix + "optimizer_type"] = 4;
config[prefix + "learning_rate"] = sgd_param.adam().learning_rate();
config[prefix + "initial_range"] = sgd_param.adam().initial_range();
config[prefix + "beta1_decay_rate"] = sgd_param.adam().beta1_decay_rate();
config[prefix + "beta2_decay_rate"] = sgd_param.adam().beta2_decay_rate();
config[prefix + "ada_epsilon"] = sgd_param.adam().ada_epsilon();
config[prefix + "min_bound"] = sgd_param.adam().weight_bounds()[0];
config[prefix + "max_bound"] = sgd_param.adam().weight_bounds()[1];
}
}
void InitializeGPUServer(paddle::distributed::PSParameter ps_param) {
auto sparse_table =
ps_param.server_param().downpour_server_param().downpour_table_param(0);
auto sparse_table_accessor = sparse_table.accessor();
auto sparse_table_accessor_parameter =
sparse_table_accessor.ctr_accessor_param();
accessor_class_ = sparse_table_accessor.accessor_class();
std::unordered_map<std::string, float> config;
config["embedx_dim"] = sparse_table_accessor.embedx_dim();
config["nonclk_coeff"] = sparse_table_accessor_parameter.nonclk_coeff();
config["clk_coeff"] = sparse_table_accessor_parameter.click_coeff();
config["mf_create_thresholds"] = sparse_table_accessor.embedx_threshold();
if (accessor_class_ == "CtrDymfAccessor") {
// optimizer config for embed_w and embedx
add_sparse_optimizer(config, sparse_table_accessor.embed_sgd_param());
add_sparse_optimizer(
config, sparse_table_accessor.embedx_sgd_param(), "mf_");
}
fleet_config_ = config;
GlobalAccessorTransfor::GetInstance().Init(accessor_class_);
GlobalAccessorTransfor::GetInstance().GetAccessorWrapper()->Configure(
config);
InitializeGPUServer(config);
}
#endif
void InitializeGPUServer(std::unordered_map<std::string, float> config) {
float nonclk_coeff = (config.find("nonclk_coeff") == config.end())
? 1.0
......@@ -288,54 +340,83 @@ class PSGPUWrapper {
float clk_coeff =
(config.find("clk_coeff") == config.end()) ? 1.0 : config["clk_coeff"];
float min_bound = (config.find("min_bound") == config.end())
? -10000.0
? -10.0
: config["min_bound"];
float max_bound = (config.find("max_bound") == config.end())
? 10000.0
: config["max_bound"];
float max_bound =
(config.find("max_bound") == config.end()) ? 10.0 : config["max_bound"];
float learning_rate = (config.find("learning_rate") == config.end())
? 1.0
? 0.05
: config["learning_rate"];
float initial_g2sum = (config.find("initial_g2sum") == config.end())
? 1.0
? 3.0
: config["initial_g2sum"];
float initial_range = (config.find("initial_range") == config.end())
? 1.0
? 1e-4
: config["initial_range"];
float beta1_decay_rate = (config.find("beta1_decay_rate") == config.end())
? 0.9
: config["beta1_decay_rate"];
float beta2_decay_rate = (config.find("beta2_decay_rate") == config.end())
? 0.999
: config["beta2_decay_rate"];
float ada_epsilon = (config.find("ada_epsilon") == config.end())
? 1e-8
: config["ada_epsilon"];
// mf config settings
float mf_create_thresholds =
(config.find("mf_create_thresholds") == config.end())
? static_cast<float>(1.0)
: config["mf_create_thresholds"];
float mf_learning_rate = (config.find("mf_learning_rate") == config.end())
? 1.0
? 0.05
: config["mf_learning_rate"];
float mf_initial_g2sum = (config.find("mf_initial_g2sum") == config.end())
? 1.0
? 3.0
: config["mf_initial_g2sum"];
float mf_initial_range = (config.find("mf_initial_range") == config.end())
? 1.0
? 1e-4
: config["mf_initial_range"];
float mf_min_bound = (config.find("mf_min_bound") == config.end())
? 1.0
? -10.0
: config["mf_min_bound"];
float mf_max_bound = (config.find("mf_max_bound") == config.end())
? 1.0
? 10.0
: config["mf_max_bound"];
float mf_beta1_decay_rate =
(config.find("mf_beta1_decay_rate") == config.end())
? 0.9
: config["mf_beta1_decay_rate"];
float mf_beta2_decay_rate =
(config.find("mf_beta2_decay_rate") == config.end())
? 0.999
: config["mf_beta2_decay_rate"];
float mf_ada_epsilon = (config.find("mf_ada_epsilon") == config.end())
? 1e-8
: config["mf_ada_epsilon"];
this->SetSparseSGD(nonclk_coeff,
clk_coeff,
min_bound,
max_bound,
learning_rate,
initial_g2sum,
initial_range);
initial_range,
beta1_decay_rate,
beta2_decay_rate,
ada_epsilon);
this->SetEmbedxSGD(mf_create_thresholds,
mf_learning_rate,
mf_initial_g2sum,
mf_initial_range,
mf_min_bound,
mf_max_bound);
mf_max_bound,
mf_beta1_decay_rate,
mf_beta2_decay_rate,
mf_ada_epsilon);
// set optimizer type(naive,adagrad,std_adagrad,adam,share_adam)
optimizer_type_ = (config.find("optimizer_type") == config.end())
? 1
: static_cast<int>(config["optimizer_type"]);
}
void SetDate(int year, int month, int day) {
......@@ -380,7 +461,7 @@ class PSGPUWrapper {
if (slot_info_initialized_) {
return;
}
SlotRecordDataset* dataset = dynamic_cast<SlotRecordDataset*>(dataset_);
SlotRecordDataset* dataset = (SlotRecordDataset*)(dataset_);
auto slots_vec = dataset->GetSlots();
slot_offset_vector_.clear();
for (auto& slot : slot_vector_) {
......@@ -421,10 +502,13 @@ class PSGPUWrapper {
for (size_t i = 0; i < slot_index_vec_.size(); i++) {
slot_index_vec_[i] = dim_index_map[slot_mf_dim_vector_[i]];
}
val_type_size_ =
TYPEALIGN(8, sizeof(FeatureValue) + sizeof(float) * (max_mf_dim_ + 1));
grad_type_size_ =
TYPEALIGN(8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float)));
auto accessor_wrapper_ptr =
GlobalAccessorTransfor::GetInstance().GetAccessorWrapper();
val_type_size_ = accessor_wrapper_ptr->GetFeatureValueSize(max_mf_dim_);
grad_type_size_ = accessor_wrapper_ptr->GetPushValueSize(max_mf_dim_);
VLOG(0) << "InitSlotInfo: val_type_size_" << val_type_size_
<< " grad_type_size_:" << grad_type_size_;
slot_info_initialized_ = true;
}
#endif
......@@ -445,6 +529,12 @@ class PSGPUWrapper {
const std::string& conf);
#endif
#ifdef PADDLE_WITH_PSCORE
void SetTableAccessor(paddle::distributed::ValueAccessor* accessor) {
cpu_table_accessor_ = accessor;
}
#endif
private:
static std::shared_ptr<PSGPUWrapper> s_instance_;
Dataset* dataset_;
......@@ -497,6 +587,12 @@ class PSGPUWrapper {
int day_;
bool slot_info_initialized_ = false;
int use_afs_api_ = 0;
int optimizer_type_ = 1;
std::string accessor_class_;
std::unordered_map<std::string, float> fleet_config_;
#ifdef PADDLE_WITH_PSCORE
paddle::distributed::ValueAccessor* cpu_table_accessor_;
#endif
#ifdef PADDLE_WITH_CUDA
std::vector<MemoryPool*> mem_pools_;
......@@ -521,6 +617,7 @@ class PSGPUWrapper {
bool running_ = false;
std::vector<std::shared_ptr<ThreadPool>> pull_thread_pool_;
std::vector<std::shared_ptr<ThreadPool>> hbm_thread_pool_;
OptimizerConfig optimizer_config_;
protected:
static bool is_initialized_;
......
......@@ -594,6 +594,21 @@ class DistributedStrategy(object):
bounds = strategy.get(prefix + 'sparse_weight_bounds',
[-10, 10])
sgd.adam.weight_bounds.extend(bounds)
elif optimizer_name == "shared_adam":
sgd.name = 'SparseSharedAdamSGDRule'
sgd.adam.learning_rate = strategy.get(
prefix + 'sparse_learning_rate', 0.001)
sgd.adam.initial_range = strategy.get(
prefix + 'sparse_initial_range', 1e-4)
sgd.adam.beta1_decay_rate = strategy.get(
prefix + 'sparse_beta1_decay_rate', 0.9)
sgd.adam.beta2_decay_rate = strategy.get(
prefix + 'sparse_beta2_decay_rate', 0.999)
sgd.adam.ada_epsilon = strategy.get(
prefix + 'sparse_ada_epsilon', 1e-8)
bounds = strategy.get(prefix + 'sparse_weight_bounds',
[-10, 10])
sgd.adam.weight_bounds.extend(bounds)
def set_sparse_table_config(table_data, config):
for key in config:
......
......@@ -195,7 +195,7 @@ class Accessor:
sgd_param.naive.initial_range = 0.0001
if len(sgd_param.naive.weight_bounds) == 0:
sgd_param.naive.weight_bounds.extend([-10.0, 10.0])
if sgd_param.name == "SparseAdamSGDRule":
if sgd_param.name == "SparseAdamSGDRule" or sgd_param.name == "SparseSharedAdamSGDRule":
if not sgd_param.adam.HasField("learning_rate"):
sgd_param.adam.learning_rate = 0.001
if not sgd_param.adam.HasField("initial_range"):
......
......@@ -334,6 +334,14 @@ class TestStrategyConfig(unittest.TestCase):
strategy.sparse_table_configs[0].accessor.embed_sgd_param.adagrad.
initial_range, 0.0001)
strategy = paddle.distributed.fleet.DistributedStrategy()
configs = {}
configs['emb'] = {"sparse_optimizer": "shared_adam"}
strategy.fleet_desc_configs = configs
self.assertEqual(
strategy.sparse_table_configs[0].accessor.embed_sgd_param.adam.
beta1_decay_rate, 0.9)
def test_trainer_desc_configs(self):
strategy = paddle.distributed.fleet.DistributedStrategy()
configs = {
......
......@@ -671,7 +671,8 @@ HIGH_PARALLEL_JOB_NEW = [
'test_trt_convert_reduce_sum',
'save_quant2_model_lstm',
'test_trt_convert_slice',
'test_quant2_int8_lstm_mkldnn'
'test_quant2_int8_lstm_mkldnn',
'test_dist_fleet_ps13'
]
# mem=0 but always timeout or failed : It run 15 job each time in Single cases;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册