未验证 提交 1a7962be 编写于 作者: H hutuxian 提交者: GitHub

Paddlebox about box_wrapper (#22497)

Refine PaddleBox Framework, Main functions: 
* Add MetricMsg util class, which can calculate metrics like AUC, bucket_error, COPC.
* Replace FeedPass with new interface: BeginFeedPass & EndFeedPass
* Refactor Pull/Push Sparse Function in box_wrapper.
* Use CUDA Kernel to copy keys and copy feasign between tensor and boxps struct.
* Cache copied keys in pull sparse in order to reuse it in push period.
上级 9e29d3eb
......@@ -12,7 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_BOX_PS
#include "paddle/fluid/framework/fleet/box_wrapper.h"
#include <algorithm>
#include <ctime>
#include <memory>
#include <numeric>
......@@ -23,43 +25,101 @@ namespace paddle {
namespace framework {
std::shared_ptr<BoxWrapper> BoxWrapper::s_instance_ = nullptr;
#ifdef PADDLE_WITH_BOX_PS
std::shared_ptr<paddle::boxps::BoxPSBase> BoxWrapper::boxps_ptr_ = nullptr;
#endif
cudaStream_t BoxWrapper::stream_list_[8];
std::shared_ptr<boxps::BoxPSBase> BoxWrapper::boxps_ptr_ = nullptr;
int BoxWrapper::GetDate() const {
time_t now = time(0);
tm t;
#ifdef _WIN32
localtime_s(&t, &now);
#else
localtime_r(&now, &t);
#endif
char buf[10];
snprintf(buf, sizeof(buf), "%04d%02d%02d", (1900 + t.tm_year), (1 + t.tm_mon),
t.tm_mday);
return atoi(buf);
void BasicAucCalculator::compute() {
double* table[2] = {&_table[0][0], &_table[1][0]};
double area = 0;
double fp = 0;
double tp = 0;
for (int i = _table_size - 1; i >= 0; i--) {
double newfp = fp + table[0][i];
double newtp = tp + table[1][i];
area += (newfp - fp) * (tp + newtp) / 2;
fp = newfp;
tp = newtp;
}
if (fp < 1e-3 || tp < 1e-3) {
_auc = -0.5; // which means all nonclick or click
} else {
_auc = area / (fp * tp);
}
_mae = _local_abserr / (fp + tp);
_rmse = sqrt(_local_sqrerr / (fp + tp));
_actual_ctr = tp / (fp + tp);
_predicted_ctr = _local_pred / (fp + tp);
_size = fp + tp;
}
void BoxWrapper::FeedPass(const std::vector<uint64_t>& feasgin_to_box) const {
#ifdef PADDLE_WITH_BOX_PS
int ret = boxps_ptr_->FeedPass(GetDate(), feasgin_to_box);
PADDLE_ENFORCE_EQ(ret, 0, "FeedPass failed in BoxPS.");
#endif
void BasicAucCalculator::calculate_bucket_error() {
double last_ctr = -1;
double impression_sum = 0;
double ctr_sum = 0.0;
double click_sum = 0.0;
double error_sum = 0.0;
double error_count = 0;
double* table[2] = {&_table[0][0], &_table[1][0]};
for (int i = 0; i < _table_size; i++) {
double click = table[1][i];
double show = table[0][i] + table[1][i];
double ctr = static_cast<double>(i) / _table_size;
if (fabs(ctr - last_ctr) > kMaxSpan) {
last_ctr = ctr;
impression_sum = 0.0;
ctr_sum = 0.0;
click_sum = 0.0;
}
impression_sum += show;
ctr_sum += ctr * show;
click_sum += click;
double adjust_ctr = ctr_sum / impression_sum;
double relative_error =
sqrt((1 - adjust_ctr) / (adjust_ctr * impression_sum));
if (relative_error < kRelativeErrorBound) {
double actual_ctr = click_sum / impression_sum;
double relative_ctr_error = fabs(actual_ctr / adjust_ctr - 1);
error_sum += relative_ctr_error * impression_sum;
error_count += impression_sum;
last_ctr = -1;
}
}
_bucket_error = error_count > 0 ? error_sum / error_count : 0.0;
}
void BoxWrapper::FeedPass(int date,
const std::vector<uint64_t>& feasgin_to_box) const {
int ret = boxps_ptr_->FeedPass(date, feasgin_to_box);
PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
"FeedPass failed in BoxPS."));
}
void BoxWrapper::BeginFeedPass(int date, boxps::PSAgentBase** agent) const {
int ret = boxps_ptr_->BeginFeedPass(date, *agent);
PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
"BeginFeedPass failed in BoxPS."));
}
void BoxWrapper::EndFeedPass(boxps::PSAgentBase* agent) const {
int ret = boxps_ptr_->EndFeedPass(agent);
PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
"EndFeedPass failed in BoxPS."));
}
void BoxWrapper::BeginPass() const {
#ifdef PADDLE_WITH_BOX_PS
int ret = boxps_ptr_->BeginPass();
PADDLE_ENFORCE_EQ(ret, 0, "BeginPass failed in BoxPS.");
#endif
PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
"BeginPass failed in BoxPS."));
}
void BoxWrapper::EndPass() const {
#ifdef PADDLE_WITH_BOX_PS
int ret = boxps_ptr_->EndPass();
PADDLE_ENFORCE_EQ(ret, 0, "EndPass failed in BoxPS.");
#endif
PADDLE_ENFORCE_EQ(
ret, 0, platform::errors::PreconditionNotMet("EndPass failed in BoxPS."));
}
void BoxWrapper::PullSparse(const paddle::platform::Place& place,
......@@ -67,181 +127,202 @@ void BoxWrapper::PullSparse(const paddle::platform::Place& place,
const std::vector<float*>& values,
const std::vector<int64_t>& slot_lengths,
const int hidden_size) {
#ifdef PADDLE_WITH_BOX_PS
if (platform::is_cpu_place(place) || platform::is_gpu_place(place)) {
VLOG(3) << "Begin PullSparse";
platform::Timer all_timer;
platform::Timer pull_boxps_timer;
all_timer.Start();
int64_t total_length =
std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL);
auto buf =
memory::AllocShared(place, total_length * sizeof(boxps::FeatureValueGpu));
boxps::FeatureValueGpu* total_values_gpu =
reinterpret_cast<boxps::FeatureValueGpu*>(buf->ptr());
if (platform::is_cpu_place(place)) {
// Note: Only GPU is supported in paddlebox now, and following code have not
// be tested fully yet
LoDTensor total_keys_tensor;
int64_t* total_keys =
total_keys_tensor.mutable_data<int64_t>({total_length, 1}, place);
uint64_t* total_keys = reinterpret_cast<uint64_t*>(
total_keys_tensor.mutable_data<int64_t>({total_length, 1}, place));
int64_t offset = 0;
VLOG(3) << "Begin copy keys, key_num[" << total_length << "]";
for (size_t i = 0; i < keys.size(); ++i) {
if (platform::is_cpu_place(place)) {
memory::Copy(boost::get<platform::CPUPlace>(place), total_keys + offset,
boost::get<platform::CPUPlace>(place), keys[i],
slot_lengths[i] * sizeof(uint64_t));
} else {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
memory::Copy(boost::get<platform::CUDAPlace>(place),
total_keys + offset,
boost::get<platform::CUDAPlace>(place), keys[i],
slot_lengths[i] * sizeof(uint64_t), nullptr);
#else
PADDLE_THROW(
"Please compile WITH_GPU option, and NCCL doesn't support "
"windows.");
#endif
}
offset += slot_lengths[i];
}
PADDLE_ENFORCE_EQ(offset, total_length,
"BoxWrapper::PullSparse: total feasign keys length "
"should be equal to the sum of length of all input "
"tensors.");
// Space allocation for FeatureValue is left for boxps
paddle::boxps::FeatureValue* total_values;
if (platform::is_cpu_place(place)) {
int ret = boxps_ptr_->PullSparseCPU(
reinterpret_cast<uint64_t*>(total_keys), &total_values,
VLOG(3) << "Begin call PullSparseCPU in BoxPS";
pull_boxps_timer.Start();
// TODO(hutuxian): should use boxps::FeatureValue in the future
int ret = boxps_ptr_->PullSparseCPU(total_keys, total_values_gpu,
static_cast<int>(total_length));
PADDLE_ENFORCE_EQ(ret, 0, "PullSparseCPU failed in BoxPS.");
} else {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
int ret = boxps_ptr_->PullSparseGPU(
reinterpret_cast<uint64_t*>(total_keys), &total_values,
static_cast<int>(total_length),
boost::get<platform::CUDAPlace>(place).GetDeviceId());
PADDLE_ENFORCE_EQ(ret, 0, "PullSparseGPU failed in BoxPS.");
#endif
}
PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
"PullSparseCPU failed in BoxPS."));
pull_boxps_timer.Pause();
VLOG(3) << "Begin Copy result to tensor, total_length[" << total_length
<< "]";
offset = 0;
for (size_t i = 0; i < values.size(); ++i) {
int64_t fea_num = slot_lengths[i];
VLOG(3) << "Begin Copy slot[" << i << "] fea_num[" << fea_num << "]";
for (auto j = 0; j < fea_num; ++j) {
// Copy the emb from BoxPS to paddle tensor. Since 'show','click','emb'
// Copy the emb from BoxPS to paddle tensor. Since
// 'show','click','emb'
// are continuous in memory, so we copy here using the 'show' address
if (platform::is_cpu_place(place)) {
memory::Copy(
boost::get<platform::CPUPlace>(place), values[i] + j * hidden_size,
boost::get<platform::CPUPlace>(place),
values[i] + j * hidden_size,
boost::get<platform::CPUPlace>(place),
reinterpret_cast<float*>(&((total_values + offset)->show)),
reinterpret_cast<float*>(&((total_values_gpu + offset)->show)),
sizeof(float) * hidden_size);
} else {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
memory::Copy(
boost::get<platform::CUDAPlace>(place),
values[i] + j * hidden_size,
boost::get<platform::CUDAPlace>(place),
reinterpret_cast<float*>(&((total_values + offset)->show)),
sizeof(float) * hidden_size, nullptr);
#endif
}
++offset;
}
}
PADDLE_ENFORCE_EQ(offset, total_length,
"BoxWrapper::PullSparse: total emb values length should "
"be equal to the sum of length of all input tensors.");
} else if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
VLOG(3) << "Begin copy keys, key_num[" << total_length << "]";
int device_id = boost::get<platform::CUDAPlace>(place).GetDeviceId();
LoDTensor& total_keys_tensor = keys_tensor[device_id];
uint64_t* total_keys = reinterpret_cast<uint64_t*>(
total_keys_tensor.mutable_data<int64_t>({total_length, 1}, place));
} else {
PADDLE_THROW(
"PaddleBox: PullSparse Only Support CPUPlace and CUDAPlace Now.");
// construct slot_level lod info
auto slot_lengths_lod = slot_lengths;
for (size_t i = 1; i < slot_lengths_lod.size(); i++) {
slot_lengths_lod[i] += slot_lengths_lod[i - 1];
}
auto buf_key = memory::AllocShared(place, keys.size() * sizeof(uint64_t*));
auto buf_length =
memory::AllocShared(place, slot_lengths.size() * sizeof(int64_t));
uint64_t** gpu_keys = reinterpret_cast<uint64_t**>(buf_key->ptr());
int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length->ptr());
cudaMemcpy(gpu_keys, keys.data(), keys.size() * sizeof(uint64_t*),
cudaMemcpyHostToDevice);
cudaMemcpy(gpu_len, slot_lengths_lod.data(),
slot_lengths.size() * sizeof(int64_t), cudaMemcpyHostToDevice);
this->CopyKeys(place, gpu_keys, total_keys, gpu_len,
static_cast<int>(slot_lengths.size()),
static_cast<int>(total_length));
VLOG(3) << "Begin call PullSparseGPU in BoxPS";
pull_boxps_timer.Start();
int ret =
boxps_ptr_->PullSparseGPU(total_keys, total_values_gpu,
static_cast<int>(total_length), device_id);
PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
"PullSparseGPU failed in BoxPS."));
pull_boxps_timer.Pause();
VLOG(3) << "Begin Copy result to tensor, total_length[" << total_length
<< "]";
this->CopyForPull(place, gpu_keys, values, total_values_gpu, gpu_len,
static_cast<int>(slot_lengths.size()), hidden_size,
total_length);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"Please compile WITH_GPU option, because NCCL doesn't support "
"windows."));
#endif
} else {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddleBox: PullSparse Only Support CPUPlace or CUDAPlace Now."));
}
all_timer.Pause();
VLOG(1) << "PullSparse total costs: " << all_timer.ElapsedSec()
<< " s, of which BoxPS costs: " << pull_boxps_timer.ElapsedSec()
<< " s";
VLOG(3) << "End PullSparse";
}
void BoxWrapper::PushSparseGrad(const paddle::platform::Place& place,
const std::vector<const uint64_t*>& keys,
const std::vector<const float*>& grad_values,
const std::vector<int64_t>& slot_lengths,
const int hidden_size) {
#ifdef PADDLE_WITH_BOX_PS
if (platform::is_cpu_place(place) || platform::is_gpu_place(place)) {
const int hidden_size, const int batch_size) {
VLOG(3) << "Begin PushSparseGrad";
platform::Timer all_timer;
platform::Timer push_boxps_timer;
all_timer.Start();
int64_t total_length =
std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL);
auto buf = memory::AllocShared(
place, total_length * sizeof(boxps::FeaturePushValueGpu));
boxps::FeaturePushValueGpu* total_grad_values_gpu =
reinterpret_cast<boxps::FeaturePushValueGpu*>(buf->ptr());
if (platform::is_cpu_place(place)) {
// Note: only GPU is supported in paddlebox now, and following code have not
// be tested fully yet
LoDTensor total_keys_tensor;
int64_t* total_keys =
total_keys_tensor.mutable_data<int64_t>({total_length, 1}, place);
uint64_t* total_keys = reinterpret_cast<uint64_t*>(
total_keys_tensor.mutable_data<int64_t>({total_length, 1}, place));
int64_t offset = 0;
VLOG(3) << "Begin copy keys, key_num[" << total_length << "]";
for (size_t i = 0; i < keys.size(); ++i) {
if (platform::is_cpu_place(place)) {
memory::Copy(boost::get<platform::CPUPlace>(place), total_keys + offset,
boost::get<platform::CPUPlace>(place), keys[i],
slot_lengths[i] * sizeof(uint64_t));
} else {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
memory::Copy(boost::get<platform::CUDAPlace>(place),
total_keys + offset,
boost::get<platform::CUDAPlace>(place), keys[i],
slot_lengths[i] * sizeof(uint64_t), nullptr);
#else
PADDLE_THROW(
"Please compile WITH_GPU option, and for now NCCL doesn't support "
"windows.");
#endif
}
offset += slot_lengths[i];
}
PADDLE_ENFORCE_EQ(offset, total_length,
"BoxWrapper::PushSparseGrad: total feasign keys length "
"should be equal to the sum of length of all input "
"tensors.");
auto buf = memory::AllocShared(
place, total_length * sizeof(paddle::boxps::FeaturePushValue));
paddle::boxps::FeaturePushValue* total_grad_values =
reinterpret_cast<paddle::boxps::FeaturePushValue*>(buf->ptr());
offset = 0;
VLOG(3) << "Begin copy grad tensor to BoxPS struct";
for (size_t i = 0; i < grad_values.size(); ++i) {
int64_t fea_num = slot_lengths[i];
for (auto j = 0; j < fea_num; ++j) {
// Copy the emb grad from paddle tensor to BoxPS. Since
// 'show','click','emb' are continuous in memory, so we copy here using
// the 'show' address
if (platform::is_cpu_place(place)) {
// 'show','click','emb' are continuous in memory, here we copy
// using 'show' address
memory::Copy(
boost::get<platform::CPUPlace>(place),
reinterpret_cast<float*>(&((total_grad_values + offset)->show)),
reinterpret_cast<float*>(&((total_grad_values_gpu + offset)->show)),
boost::get<platform::CPUPlace>(place),
grad_values[i] + j * hidden_size, sizeof(float) * hidden_size);
} else {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
memory::Copy(
boost::get<platform::CUDAPlace>(place),
reinterpret_cast<float*>(&((total_grad_values + offset)->show)),
boost::get<platform::CUDAPlace>(place),
grad_values[i] + j * hidden_size, sizeof(float) * hidden_size,
nullptr);
#endif
}
++offset;
}
}
PADDLE_ENFORCE_EQ(offset, total_length,
"BoxWrapper::PushSparseGrad: total emb grad values "
"length should be equal to the sum of length of all "
"input tensors.");
if (platform::is_cpu_place(place)) {
int ret = boxps_ptr_->PushSparseCPU(
reinterpret_cast<uint64_t*>(total_keys), total_grad_values,
VLOG(3) << "Begin call PushSparseCPU in BoxPS";
push_boxps_timer.Start();
int ret = boxps_ptr_->PushSparseCPU(total_keys, total_grad_values_gpu,
static_cast<int>(total_length));
PADDLE_ENFORCE_EQ(ret, 0, "PushSparseCPU failed in BoxPS.");
} else {
PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
"PushSparseCPU failed in BoxPS."));
push_boxps_timer.Pause();
} else if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
int device_id = boost::get<platform::CUDAPlace>(place).GetDeviceId();
LoDTensor& cached_total_keys_tensor = keys_tensor[device_id];
uint64_t* total_keys =
reinterpret_cast<uint64_t*>(cached_total_keys_tensor.data<int64_t>());
VLOG(3) << "Begin copy grad tensor to boxps struct";
this->CopyForPush(place, grad_values, total_grad_values_gpu, slot_lengths,
hidden_size, total_length, batch_size);
VLOG(3) << "Begin call PushSparseGPU in BoxPS";
push_boxps_timer.Start();
int ret = boxps_ptr_->PushSparseGPU(
reinterpret_cast<uint64_t*>(total_keys), total_grad_values,
static_cast<int>(total_length),
total_keys, total_grad_values_gpu, static_cast<int>(total_length),
boost::get<platform::CUDAPlace>(place).GetDeviceId());
PADDLE_ENFORCE_EQ(ret, 0, "PushSparseGPU failed in BoxPS.");
PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
"PushSparseGPU failed in BoxPS."));
push_boxps_timer.Pause();
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"Please compile WITH_GPU option, because NCCL doesn't support "
"windows."));
#endif
}
} else {
PADDLE_THROW(
"PaddleBox: PushSparse Only Support CPUPlace and CUDAPlace Now.");
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddleBox: PushSparseGrad Only Support CPUPlace or CUDAPlace Now."));
}
#endif
all_timer.Pause();
VLOG(1) << "PushSparseGrad total cost: " << all_timer.ElapsedSec()
<< " s, of which BoxPS cost: " << push_boxps_timer.ElapsedSec()
<< " s";
VLOG(3) << "End PushSparseGrad";
}
} // end namespace framework
} // end namespace paddle
#endif
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_BOX_PS
#include <algorithm>
#include <ctime>
#include <memory>
#include <numeric>
#include "paddle/fluid/framework/fleet/box_wrapper.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/gpu_info.h"
namespace paddle {
namespace framework {
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
__global__ void PullCopy(float** dest, const boxps::FeatureValueGpu* src,
const int64_t* len, int hidden, int slot_num,
int total_len, uint64_t** keys) {
CUDA_KERNEL_LOOP(i, total_len) {
int low = 0;
int high = slot_num - 1;
while (low < high) {
int mid = (low + high) / 2;
if (i < len[mid])
high = mid;
else
low = mid + 1;
}
int x = low;
int y = i - (x ? len[x - 1] : 0);
if (*(keys[x] + y) == 0) {
*(dest[x] + y * hidden) = 0;
*(dest[x] + y * hidden + 1) = 0;
*(dest[x] + y * hidden + 2) = 0;
} else {
*(dest[x] + y * hidden) = (src + i)->show;
*(dest[x] + y * hidden + 1) = (src + i)->clk;
*(dest[x] + y * hidden + 2) = (src + i)->embed_w;
}
if ((src + i)->embedding_size == 0 || *(keys[x] + y) == 0) {
for (int j = 0; j < 8; j++) {
*(dest[x] + y * hidden + 3 + j) = 0;
}
} else {
for (int j = 0; j < 8; j++) {
*(dest[x] + y * hidden + 3 + j) = (src + i)->embedx[1 + j];
}
}
}
}
__global__ void CopyKeysKernel(uint64_t** src_keys, uint64_t* dest_total_keys,
const int64_t* len, int slot_num,
int total_len) {
CUDA_KERNEL_LOOP(i, total_len) {
int low = 0;
int high = slot_num - 1;
while (low < high) {
int mid = (low + high) / 2;
if (i < len[mid])
high = mid;
else
low = mid + 1;
}
int x = low;
int y = i - (x ? len[x - 1] : 0);
dest_total_keys[i] = src_keys[x][y];
}
}
__global__ void PushCopy(boxps::FeaturePushValueGpu* dest, float** src,
int64_t* len, int hidden, int slot_num, int total_len,
int bs, int* slot_vector) {
CUDA_KERNEL_LOOP(i, total_len) {
int low = 0;
int high = slot_num - 1;
while (low < high) {
int mid = (low + high) / 2;
if (i < len[mid])
high = mid;
else
low = mid + 1;
}
int x = low;
int y = i - (x ? len[low - 1] : 0);
(dest + i)->slot = slot_vector[x];
(dest + i)->show = *(src[x] + y * hidden);
(dest + i)->clk = *(src[x] + y * hidden + 1);
(dest + i)->embed_g = *(src[x] + y * hidden + 2) * -1. * bs;
for (int j = 0; j < 8; j++) {
(dest + i)->embedx_g[j] = *(src[x] + y * hidden + 3 + j) * -1. * bs;
}
}
}
void BoxWrapper::CopyForPull(const paddle::platform::Place& place,
uint64_t** gpu_keys,
const std::vector<float*>& values,
const boxps::FeatureValueGpu* total_values_gpu,
const int64_t* gpu_len, const int slot_num,
const int hidden_size,
const int64_t total_length) {
auto stream = dynamic_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(
boost::get<platform::CUDAPlace>(place)))
->stream();
auto buf_value = memory::AllocShared(place, values.size() * sizeof(float*));
float** gpu_values = reinterpret_cast<float**>(buf_value->ptr());
cudaMemcpy(gpu_values, values.data(), values.size() * sizeof(float*),
cudaMemcpyHostToDevice);
PullCopy<<<(total_length + 512 - 1) / 512, 512, 0, stream>>>(
gpu_values, total_values_gpu, gpu_len, hidden_size, slot_num,
total_length, gpu_keys);
cudaStreamSynchronize(stream);
}
void BoxWrapper::CopyKeys(const paddle::platform::Place& place,
uint64_t** origin_keys, uint64_t* total_keys,
const int64_t* gpu_len, int slot_num, int total_len) {
auto stream = dynamic_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(
boost::get<platform::CUDAPlace>(place)))
->stream();
CopyKeysKernel<<<(total_len + 512 - 1) / 512, 512, 0, stream>>>(
origin_keys, total_keys, gpu_len, slot_num, total_len);
cudaStreamSynchronize(stream);
}
void BoxWrapper::CopyForPush(const paddle::platform::Place& place,
const std::vector<const float*>& grad_values,
boxps::FeaturePushValueGpu* total_grad_values_gpu,
const std::vector<int64_t>& slot_lengths,
const int hidden_size, const int64_t total_length,
const int batch_size) {
auto stream = dynamic_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(
boost::get<platform::CUDAPlace>(place)))
->stream();
auto slot_lengths_lod = slot_lengths;
for (int i = 1; i < slot_lengths_lod.size(); i++) {
slot_lengths_lod[i] += slot_lengths_lod[i - 1];
}
auto buf_grad_value =
memory::AllocShared(place, grad_values.size() * sizeof(float*));
auto buf_length =
memory::AllocShared(place, slot_lengths.size() * sizeof(int64_t));
auto buf_slot_vector =
memory::AllocShared(place, slot_lengths_lod.size() * sizeof(int));
float** gpu_values = reinterpret_cast<float**>(buf_grad_value->ptr());
int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length->ptr());
int* d_slot_vector = reinterpret_cast<int*>(buf_slot_vector->ptr());
cudaMemcpy(gpu_values, grad_values.data(),
grad_values.size() * sizeof(float*), cudaMemcpyHostToDevice);
cudaMemcpy(gpu_len, slot_lengths_lod.data(),
slot_lengths.size() * sizeof(int64_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_slot_vector, slot_vector_.data(),
slot_lengths_lod.size() * sizeof(int), cudaMemcpyHostToDevice);
PushCopy<<<(total_length + 512 - 1) / 512, 512, 0, stream>>>(
total_grad_values_gpu, gpu_values, gpu_len, hidden_size,
slot_lengths.size(), total_length, batch_size, d_slot_vector);
cudaStreamSynchronize(stream);
}
} // end namespace framework
} // end namespace paddle
#endif
......@@ -14,27 +14,117 @@ limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_BOX_PS
#include <boxps_public.h>
#endif
#include <glog/logging.h>
#include <algorithm>
#include <atomic>
#include <ctime>
#include <deque>
#include <map>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/data_set.h"
#ifdef PADDLE_WITH_BOX_PS
#include <boxps.h>
#endif
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/timer.h"
namespace paddle {
namespace framework {
#ifdef PADDLE_WITH_BOX_PS
class BasicAucCalculator {
public:
BasicAucCalculator() {}
void init(int table_size) { set_table_size(table_size); }
void reset() {
for (int i = 0; i < 2; i++) {
_table[i].assign(_table_size, 0.0);
}
_local_abserr = 0;
_local_sqrerr = 0;
_local_pred = 0;
}
void add_data(double pred, int label) {
PADDLE_ENFORCE_GE(pred, 0.0, platform::errors::PreconditionNotMet(
"pred should be greater than 0"));
PADDLE_ENFORCE_LE(pred, 1.0, platform::errors::PreconditionNotMet(
"pred should be lower than 1"));
PADDLE_ENFORCE_EQ(
label * label, label,
platform::errors::PreconditionNotMet(
"label must be equal to 0 or 1, but its value is: %d", label));
int pos = std::min(static_cast<int>(pred * _table_size), _table_size - 1);
PADDLE_ENFORCE_GE(
pos, 0,
platform::errors::PreconditionNotMet(
"pos must be equal or greater than 0, but its value is: %d", pos));
PADDLE_ENFORCE_LT(
pos, _table_size,
platform::errors::PreconditionNotMet(
"pos must be less than table_size, but its value is: %d", pos));
std::lock_guard<std::mutex> lock(_table_mutex);
_local_abserr += fabs(pred - label);
_local_sqrerr += (pred - label) * (pred - label);
_local_pred += pred;
_table[label][pos]++;
}
void compute();
int table_size() const { return _table_size; }
double bucket_error() const { return _bucket_error; }
double auc() const { return _auc; }
double mae() const { return _mae; }
double actual_ctr() const { return _actual_ctr; }
double predicted_ctr() const { return _predicted_ctr; }
double size() const { return _size; }
double rmse() const { return _rmse; }
std::vector<double>& get_negative() { return _table[0]; }
std::vector<double>& get_postive() { return _table[1]; }
double& local_abserr() { return _local_abserr; }
double& local_sqrerr() { return _local_sqrerr; }
double& local_pred() { return _local_pred; }
void calculate_bucket_error();
protected:
double _local_abserr = 0;
double _local_sqrerr = 0;
double _local_pred = 0;
double _auc = 0;
double _mae = 0;
double _rmse = 0;
double _actual_ctr = 0;
double _predicted_ctr = 0;
double _size;
double _bucket_error = 0;
private:
void set_table_size(int table_size) {
_table_size = table_size;
for (int i = 0; i < 2; i++) {
_table[i] = std::vector<double>();
}
reset();
}
int _table_size;
std::vector<double> _table[2];
static constexpr double kRelativeErrorBound = 0.05;
static constexpr double kMaxSpan = 0.01;
std::mutex _table_mutex;
};
class BoxWrapper {
public:
virtual ~BoxWrapper() {}
BoxWrapper() {}
void FeedPass(const std::vector<uint64_t>& feasgin_to_box) const;
void FeedPass(int date, const std::vector<uint64_t>& feasgin_to_box) const;
void BeginFeedPass(int date, boxps::PSAgentBase** agent) const;
void EndFeedPass(boxps::PSAgentBase* agent) const;
void BeginPass() const;
void EndPass() const;
void PullSparse(const paddle::platform::Place& place,
......@@ -46,7 +136,74 @@ class BoxWrapper {
const std::vector<const uint64_t*>& keys,
const std::vector<const float*>& grad_values,
const std::vector<int64_t>& slot_lengths,
const int hidden_size);
const int hidden_size, const int batch_size);
void CopyForPull(const paddle::platform::Place& place, uint64_t** gpu_keys,
const std::vector<float*>& values,
const boxps::FeatureValueGpu* total_values_gpu,
const int64_t* gpu_len, const int slot_num,
const int hidden_size, const int64_t total_length);
void CopyForPush(const paddle::platform::Place& place,
const std::vector<const float*>& grad_values,
boxps::FeaturePushValueGpu* total_grad_values_gpu,
const std::vector<int64_t>& slot_lengths,
const int hidden_size, const int64_t total_length,
const int batch_size);
void CopyKeys(const paddle::platform::Place& place, uint64_t** origin_keys,
uint64_t* total_keys, const int64_t* gpu_len, int slot_num,
int total_len);
boxps::PSAgentBase* GetAgent() { return p_agent_; }
void InitializeGPU(const char* conf_file, const std::vector<int>& slot_vector,
const std::vector<std::string>& slot_omit_in_feedpass) {
if (nullptr != s_instance_) {
VLOG(3) << "Begin InitializeGPU";
std::vector<cudaStream_t*> stream_list;
for (int i = 0; i < platform::GetCUDADeviceCount(); ++i) {
VLOG(3) << "before get context i[" << i << "]";
platform::CUDADeviceContext* context =
dynamic_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(
platform::CUDAPlace(i)));
stream_list_[i] = context->stream();
stream_list.push_back(&stream_list_[i]);
}
VLOG(2) << "Begin call InitializeGPU in BoxPS";
// the second parameter is useless
s_instance_->boxps_ptr_->InitializeGPU(conf_file, -1, stream_list);
p_agent_ = boxps::PSAgentBase::GetIns(feedpass_thread_num_);
p_agent_->Init();
for (const auto& slot_name : slot_omit_in_feedpass) {
slot_name_omited_in_feedpass_.insert(slot_name);
}
slot_vector_ = slot_vector;
keys_tensor.resize(platform::GetCUDADeviceCount());
}
}
int GetFeedpassThreadNum() const { return feedpass_thread_num_; }
void Finalize() {
VLOG(3) << "Begin Finalize";
if (nullptr != s_instance_) {
s_instance_->boxps_ptr_->Finalize();
}
}
void SaveBase(const char* batch_model_path, const char* xbox_model_path,
boxps::SaveModelStat& stat) { // NOLINT
VLOG(3) << "Begin SaveBase";
if (nullptr != s_instance_) {
s_instance_->boxps_ptr_->SaveBase(batch_model_path, xbox_model_path,
stat);
}
}
void SaveDelta(const char* xbox_model_path,
boxps::SaveModelStat& stat) { // NOLINT
VLOG(3) << "Begin SaveDelta";
if (nullptr != s_instance_) {
s_instance_->boxps_ptr_->SaveDelta(xbox_model_path, stat);
}
}
static std::shared_ptr<BoxWrapper> GetInstance() {
if (nullptr == s_instance_) {
......@@ -54,22 +211,92 @@ class BoxWrapper {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
if (nullptr == s_instance_) {
VLOG(3) << "s_instance_ is null";
s_instance_.reset(new paddle::framework::BoxWrapper());
#ifdef PADDLE_WITH_BOX_PS
s_instance_->boxps_ptr_.reset(new paddle::boxps::FakeBoxPS());
#endif
s_instance_->boxps_ptr_.reset(boxps::BoxPSBase::GetIns());
}
}
return s_instance_;
}
const std::unordered_set<std::string>& GetOmitedSlot() const {
return slot_name_omited_in_feedpass_;
}
struct MetricMsg {
public:
MetricMsg() {}
MetricMsg(const std::string& label_varname, const std::string& pred_varname,
int is_join, int bucket_size = 1000000)
: label_varname_(label_varname),
pred_varname_(pred_varname),
is_join_(is_join) {
calculator = new BasicAucCalculator();
calculator->init(bucket_size);
}
const std::string& LabelVarname() const { return label_varname_; }
const std::string& PredVarname() const { return pred_varname_; }
int IsJoin() const { return is_join_; }
BasicAucCalculator* GetCalculator() { return calculator; }
private:
#ifdef PADDLE_WITH_BOX_PS
static std::shared_ptr<paddle::boxps::BoxPSBase> boxps_ptr_;
#endif
std::string label_varname_;
std::string pred_varname_;
int is_join_;
BasicAucCalculator* calculator;
};
int PassFlag() const { return pass_flag_; }
void FlipPassFlag() { pass_flag_ = 1 - pass_flag_; }
bool NeedMetric() const { return need_metric_; }
std::map<std::string, MetricMsg>& GetMetricList() { return metric_lists_; }
void InitMetric(const std::string& name, const std::string& label_varname,
const std::string& pred_varname, bool is_join,
int bucket_size = 1000000) {
metric_lists_.emplace(name, MetricMsg(label_varname, pred_varname,
is_join ? 1 : 0, bucket_size));
need_metric_ = true;
}
const std::vector<float> GetMetricMsg(const std::string& name) {
const auto iter = metric_lists_.find(name);
PADDLE_ENFORCE_NE(iter, metric_lists_.end(),
platform::errors::InvalidArgument(
"The metric name you provided is not registered."));
std::vector<float> metric_return_values_(8, 0.0);
auto* auc_cal_ = iter->second.GetCalculator();
auc_cal_->calculate_bucket_error();
auc_cal_->compute();
metric_return_values_[0] = auc_cal_->auc();
metric_return_values_[1] = auc_cal_->bucket_error();
metric_return_values_[2] = auc_cal_->mae();
metric_return_values_[3] = auc_cal_->rmse();
metric_return_values_[4] = auc_cal_->actual_ctr();
metric_return_values_[5] = auc_cal_->predicted_ctr();
metric_return_values_[6] =
auc_cal_->actual_ctr() / auc_cal_->predicted_ctr();
metric_return_values_[7] = auc_cal_->size();
auc_cal_->reset();
return metric_return_values_;
}
private:
static cudaStream_t stream_list_[8];
static std::shared_ptr<boxps::BoxPSBase> boxps_ptr_;
boxps::PSAgentBase* p_agent_ = nullptr;
const int feedpass_thread_num_ = 30; // magic number
static std::shared_ptr<BoxWrapper> s_instance_;
int GetDate() const;
std::unordered_set<std::string> slot_name_omited_in_feedpass_;
// Metric Related
int pass_flag_ = 1; // join: 1, update: 0
bool need_metric_ = false;
std::map<std::string, MetricMsg> metric_lists_;
std::vector<int> slot_vector_;
std::vector<LoDTensor> keys_tensor; // Cache for pull_sparse
};
#endif
class BoxHelper {
public:
......@@ -77,13 +304,17 @@ class BoxHelper {
virtual ~BoxHelper() {}
void BeginPass() {
#ifdef PADDLE_WITH_BOX_PS
auto box_ptr = BoxWrapper::GetInstance();
box_ptr->BeginPass();
#endif
}
void EndPass() {
#ifdef PADDLE_WITH_BOX_PS
auto box_ptr = BoxWrapper::GetInstance();
box_ptr->EndPass();
#endif
}
void LoadIntoMemory() {
dataset_->LoadIntoMemory();
......@@ -103,6 +334,7 @@ class BoxHelper {
std::shared_ptr<std::thread> feed_data_thread_;
// notify boxps to feed this pass feasigns from SSD to memory
void FeedPass() {
#ifdef PADDLE_WITH_BOX_PS
auto box_ptr = BoxWrapper::GetInstance();
auto input_channel_ =
dynamic_cast<MultiSlotDataset*>(dataset_)->GetInputChannel();
......@@ -119,6 +351,7 @@ class BoxHelper {
input_channel_->Write(pass_data);
input_channel_->Close();
box_ptr->FeedPass(feasign_to_box);
#endif
}
};
......
......@@ -26,7 +26,6 @@ template <typename T>
static void PullBoxSparseFunctor(const framework::ExecutionContext &ctx) {
auto inputs = ctx.MultiInput<framework::Tensor>("Ids");
auto outputs = ctx.MultiOutput<framework::Tensor>("Out");
auto hidden_size = ctx.Attr<int>("size");
const auto slot_size = inputs.size();
std::vector<const uint64_t *> all_keys(slot_size);
// BoxPS only supports float now
......@@ -41,33 +40,49 @@ static void PullBoxSparseFunctor(const framework::ExecutionContext &ctx) {
auto *output = outputs[i]->mutable_data<T>(ctx.GetPlace());
all_values[i] = output;
}
#ifdef PADDLE_WITH_BOX_PS
auto hidden_size = ctx.Attr<int>("size");
auto box_ptr = paddle::framework::BoxWrapper::GetInstance();
box_ptr->PullSparse(ctx.GetPlace(), all_keys, all_values, slot_lengths,
hidden_size);
#endif
}
template <typename T>
static void PushBoxSparseFunctor(const framework::ExecutionContext &ctx) {
auto inputs = ctx.MultiInput<framework::Tensor>("Ids");
auto inputs = ctx.MultiInput<framework::LoDTensor>("Ids");
auto d_output =
ctx.MultiInput<framework::Tensor>(framework::GradVarName("Out"));
auto hidden_size = ctx.Attr<int>("size");
const auto slot_size = inputs.size();
std::vector<const uint64_t *> all_keys(slot_size);
std::vector<const float *> all_grad_values(slot_size);
std::vector<int64_t> slot_lengths(slot_size);
int batch_size = -1;
for (size_t i = 0; i < slot_size; i++) {
const auto *slot = inputs[i];
const uint64_t *single_slot_keys =
reinterpret_cast<const uint64_t *>(slot->data<int64_t>());
all_keys[i] = single_slot_keys;
slot_lengths[i] = slot->numel();
int cur_batch_size =
slot->lod().size() ? slot->lod()[0].size() - 1 : slot->dims()[0];
if (batch_size == -1) {
batch_size = cur_batch_size;
} else {
PADDLE_ENFORCE_EQ(batch_size, cur_batch_size,
platform::errors::PreconditionNotMet(
"The batch size of all input slots should be same, "
"please cheack"));
}
const float *grad_value = d_output[i]->data<float>();
all_grad_values[i] = grad_value;
}
#ifdef PADDLE_WITH_BOX_PS
auto hidden_size = ctx.Attr<int>("size");
auto box_ptr = paddle::framework::BoxWrapper::GetInstance();
box_ptr->PushSparseGrad(ctx.GetPlace(), all_keys, all_grad_values,
slot_lengths, hidden_size);
slot_lengths, hidden_size, batch_size);
#endif
}
using LoDTensor = framework::LoDTensor;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册