未验证 提交 1a7962be 编写于 作者: H hutuxian 提交者: GitHub

Paddlebox about box_wrapper (#22497)

Refine PaddleBox Framework, Main functions: 
* Add MetricMsg util class, which can calculate metrics like AUC, bucket_error, COPC.
* Replace FeedPass with new interface: BeginFeedPass & EndFeedPass
* Refactor Pull/Push Sparse Function in box_wrapper.
* Use CUDA Kernel to copy keys and copy feasign between tensor and boxps struct.
* Cache copied keys in pull sparse in order to reuse it in push period.
上级 9e29d3eb
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_BOX_PS
#include <algorithm>
#include <ctime>
#include <memory>
#include <numeric>
#include "paddle/fluid/framework/fleet/box_wrapper.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/gpu_info.h"
namespace paddle {
namespace framework {
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
__global__ void PullCopy(float** dest, const boxps::FeatureValueGpu* src,
const int64_t* len, int hidden, int slot_num,
int total_len, uint64_t** keys) {
CUDA_KERNEL_LOOP(i, total_len) {
int low = 0;
int high = slot_num - 1;
while (low < high) {
int mid = (low + high) / 2;
if (i < len[mid])
high = mid;
else
low = mid + 1;
}
int x = low;
int y = i - (x ? len[x - 1] : 0);
if (*(keys[x] + y) == 0) {
*(dest[x] + y * hidden) = 0;
*(dest[x] + y * hidden + 1) = 0;
*(dest[x] + y * hidden + 2) = 0;
} else {
*(dest[x] + y * hidden) = (src + i)->show;
*(dest[x] + y * hidden + 1) = (src + i)->clk;
*(dest[x] + y * hidden + 2) = (src + i)->embed_w;
}
if ((src + i)->embedding_size == 0 || *(keys[x] + y) == 0) {
for (int j = 0; j < 8; j++) {
*(dest[x] + y * hidden + 3 + j) = 0;
}
} else {
for (int j = 0; j < 8; j++) {
*(dest[x] + y * hidden + 3 + j) = (src + i)->embedx[1 + j];
}
}
}
}
__global__ void CopyKeysKernel(uint64_t** src_keys, uint64_t* dest_total_keys,
const int64_t* len, int slot_num,
int total_len) {
CUDA_KERNEL_LOOP(i, total_len) {
int low = 0;
int high = slot_num - 1;
while (low < high) {
int mid = (low + high) / 2;
if (i < len[mid])
high = mid;
else
low = mid + 1;
}
int x = low;
int y = i - (x ? len[x - 1] : 0);
dest_total_keys[i] = src_keys[x][y];
}
}
__global__ void PushCopy(boxps::FeaturePushValueGpu* dest, float** src,
int64_t* len, int hidden, int slot_num, int total_len,
int bs, int* slot_vector) {
CUDA_KERNEL_LOOP(i, total_len) {
int low = 0;
int high = slot_num - 1;
while (low < high) {
int mid = (low + high) / 2;
if (i < len[mid])
high = mid;
else
low = mid + 1;
}
int x = low;
int y = i - (x ? len[low - 1] : 0);
(dest + i)->slot = slot_vector[x];
(dest + i)->show = *(src[x] + y * hidden);
(dest + i)->clk = *(src[x] + y * hidden + 1);
(dest + i)->embed_g = *(src[x] + y * hidden + 2) * -1. * bs;
for (int j = 0; j < 8; j++) {
(dest + i)->embedx_g[j] = *(src[x] + y * hidden + 3 + j) * -1. * bs;
}
}
}
void BoxWrapper::CopyForPull(const paddle::platform::Place& place,
uint64_t** gpu_keys,
const std::vector<float*>& values,
const boxps::FeatureValueGpu* total_values_gpu,
const int64_t* gpu_len, const int slot_num,
const int hidden_size,
const int64_t total_length) {
auto stream = dynamic_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(
boost::get<platform::CUDAPlace>(place)))
->stream();
auto buf_value = memory::AllocShared(place, values.size() * sizeof(float*));
float** gpu_values = reinterpret_cast<float**>(buf_value->ptr());
cudaMemcpy(gpu_values, values.data(), values.size() * sizeof(float*),
cudaMemcpyHostToDevice);
PullCopy<<<(total_length + 512 - 1) / 512, 512, 0, stream>>>(
gpu_values, total_values_gpu, gpu_len, hidden_size, slot_num,
total_length, gpu_keys);
cudaStreamSynchronize(stream);
}
void BoxWrapper::CopyKeys(const paddle::platform::Place& place,
uint64_t** origin_keys, uint64_t* total_keys,
const int64_t* gpu_len, int slot_num, int total_len) {
auto stream = dynamic_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(
boost::get<platform::CUDAPlace>(place)))
->stream();
CopyKeysKernel<<<(total_len + 512 - 1) / 512, 512, 0, stream>>>(
origin_keys, total_keys, gpu_len, slot_num, total_len);
cudaStreamSynchronize(stream);
}
void BoxWrapper::CopyForPush(const paddle::platform::Place& place,
const std::vector<const float*>& grad_values,
boxps::FeaturePushValueGpu* total_grad_values_gpu,
const std::vector<int64_t>& slot_lengths,
const int hidden_size, const int64_t total_length,
const int batch_size) {
auto stream = dynamic_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(
boost::get<platform::CUDAPlace>(place)))
->stream();
auto slot_lengths_lod = slot_lengths;
for (int i = 1; i < slot_lengths_lod.size(); i++) {
slot_lengths_lod[i] += slot_lengths_lod[i - 1];
}
auto buf_grad_value =
memory::AllocShared(place, grad_values.size() * sizeof(float*));
auto buf_length =
memory::AllocShared(place, slot_lengths.size() * sizeof(int64_t));
auto buf_slot_vector =
memory::AllocShared(place, slot_lengths_lod.size() * sizeof(int));
float** gpu_values = reinterpret_cast<float**>(buf_grad_value->ptr());
int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length->ptr());
int* d_slot_vector = reinterpret_cast<int*>(buf_slot_vector->ptr());
cudaMemcpy(gpu_values, grad_values.data(),
grad_values.size() * sizeof(float*), cudaMemcpyHostToDevice);
cudaMemcpy(gpu_len, slot_lengths_lod.data(),
slot_lengths.size() * sizeof(int64_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_slot_vector, slot_vector_.data(),
slot_lengths_lod.size() * sizeof(int), cudaMemcpyHostToDevice);
PushCopy<<<(total_length + 512 - 1) / 512, 512, 0, stream>>>(
total_grad_values_gpu, gpu_values, gpu_len, hidden_size,
slot_lengths.size(), total_length, batch_size, d_slot_vector);
cudaStreamSynchronize(stream);
}
} // end namespace framework
} // end namespace paddle
#endif
...@@ -14,27 +14,117 @@ limitations under the License. */ ...@@ -14,27 +14,117 @@ limitations under the License. */
#pragma once #pragma once
#ifdef PADDLE_WITH_BOX_PS
#include <boxps_public.h>
#endif
#include <glog/logging.h> #include <glog/logging.h>
#include <algorithm>
#include <atomic>
#include <ctime>
#include <deque>
#include <map>
#include <memory> #include <memory>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include <string> #include <string>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/data_set.h"
#ifdef PADDLE_WITH_BOX_PS #include "paddle/fluid/framework/lod_tensor.h"
#include <boxps.h>
#endif
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/timer.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
#ifdef PADDLE_WITH_BOX_PS
class BasicAucCalculator {
public:
BasicAucCalculator() {}
void init(int table_size) { set_table_size(table_size); }
void reset() {
for (int i = 0; i < 2; i++) {
_table[i].assign(_table_size, 0.0);
}
_local_abserr = 0;
_local_sqrerr = 0;
_local_pred = 0;
}
void add_data(double pred, int label) {
PADDLE_ENFORCE_GE(pred, 0.0, platform::errors::PreconditionNotMet(
"pred should be greater than 0"));
PADDLE_ENFORCE_LE(pred, 1.0, platform::errors::PreconditionNotMet(
"pred should be lower than 1"));
PADDLE_ENFORCE_EQ(
label * label, label,
platform::errors::PreconditionNotMet(
"label must be equal to 0 or 1, but its value is: %d", label));
int pos = std::min(static_cast<int>(pred * _table_size), _table_size - 1);
PADDLE_ENFORCE_GE(
pos, 0,
platform::errors::PreconditionNotMet(
"pos must be equal or greater than 0, but its value is: %d", pos));
PADDLE_ENFORCE_LT(
pos, _table_size,
platform::errors::PreconditionNotMet(
"pos must be less than table_size, but its value is: %d", pos));
std::lock_guard<std::mutex> lock(_table_mutex);
_local_abserr += fabs(pred - label);
_local_sqrerr += (pred - label) * (pred - label);
_local_pred += pred;
_table[label][pos]++;
}
void compute();
int table_size() const { return _table_size; }
double bucket_error() const { return _bucket_error; }
double auc() const { return _auc; }
double mae() const { return _mae; }
double actual_ctr() const { return _actual_ctr; }
double predicted_ctr() const { return _predicted_ctr; }
double size() const { return _size; }
double rmse() const { return _rmse; }
std::vector<double>& get_negative() { return _table[0]; }
std::vector<double>& get_postive() { return _table[1]; }
double& local_abserr() { return _local_abserr; }
double& local_sqrerr() { return _local_sqrerr; }
double& local_pred() { return _local_pred; }
void calculate_bucket_error();
protected:
double _local_abserr = 0;
double _local_sqrerr = 0;
double _local_pred = 0;
double _auc = 0;
double _mae = 0;
double _rmse = 0;
double _actual_ctr = 0;
double _predicted_ctr = 0;
double _size;
double _bucket_error = 0;
private:
void set_table_size(int table_size) {
_table_size = table_size;
for (int i = 0; i < 2; i++) {
_table[i] = std::vector<double>();
}
reset();
}
int _table_size;
std::vector<double> _table[2];
static constexpr double kRelativeErrorBound = 0.05;
static constexpr double kMaxSpan = 0.01;
std::mutex _table_mutex;
};
class BoxWrapper { class BoxWrapper {
public: public:
virtual ~BoxWrapper() {} virtual ~BoxWrapper() {}
BoxWrapper() {} BoxWrapper() {}
void FeedPass(const std::vector<uint64_t>& feasgin_to_box) const; void FeedPass(int date, const std::vector<uint64_t>& feasgin_to_box) const;
void BeginFeedPass(int date, boxps::PSAgentBase** agent) const;
void EndFeedPass(boxps::PSAgentBase* agent) const;
void BeginPass() const; void BeginPass() const;
void EndPass() const; void EndPass() const;
void PullSparse(const paddle::platform::Place& place, void PullSparse(const paddle::platform::Place& place,
...@@ -46,7 +136,74 @@ class BoxWrapper { ...@@ -46,7 +136,74 @@ class BoxWrapper {
const std::vector<const uint64_t*>& keys, const std::vector<const uint64_t*>& keys,
const std::vector<const float*>& grad_values, const std::vector<const float*>& grad_values,
const std::vector<int64_t>& slot_lengths, const std::vector<int64_t>& slot_lengths,
const int hidden_size); const int hidden_size, const int batch_size);
void CopyForPull(const paddle::platform::Place& place, uint64_t** gpu_keys,
const std::vector<float*>& values,
const boxps::FeatureValueGpu* total_values_gpu,
const int64_t* gpu_len, const int slot_num,
const int hidden_size, const int64_t total_length);
void CopyForPush(const paddle::platform::Place& place,
const std::vector<const float*>& grad_values,
boxps::FeaturePushValueGpu* total_grad_values_gpu,
const std::vector<int64_t>& slot_lengths,
const int hidden_size, const int64_t total_length,
const int batch_size);
void CopyKeys(const paddle::platform::Place& place, uint64_t** origin_keys,
uint64_t* total_keys, const int64_t* gpu_len, int slot_num,
int total_len);
boxps::PSAgentBase* GetAgent() { return p_agent_; }
void InitializeGPU(const char* conf_file, const std::vector<int>& slot_vector,
const std::vector<std::string>& slot_omit_in_feedpass) {
if (nullptr != s_instance_) {
VLOG(3) << "Begin InitializeGPU";
std::vector<cudaStream_t*> stream_list;
for (int i = 0; i < platform::GetCUDADeviceCount(); ++i) {
VLOG(3) << "before get context i[" << i << "]";
platform::CUDADeviceContext* context =
dynamic_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(
platform::CUDAPlace(i)));
stream_list_[i] = context->stream();
stream_list.push_back(&stream_list_[i]);
}
VLOG(2) << "Begin call InitializeGPU in BoxPS";
// the second parameter is useless
s_instance_->boxps_ptr_->InitializeGPU(conf_file, -1, stream_list);
p_agent_ = boxps::PSAgentBase::GetIns(feedpass_thread_num_);
p_agent_->Init();
for (const auto& slot_name : slot_omit_in_feedpass) {
slot_name_omited_in_feedpass_.insert(slot_name);
}
slot_vector_ = slot_vector;
keys_tensor.resize(platform::GetCUDADeviceCount());
}
}
int GetFeedpassThreadNum() const { return feedpass_thread_num_; }
void Finalize() {
VLOG(3) << "Begin Finalize";
if (nullptr != s_instance_) {
s_instance_->boxps_ptr_->Finalize();
}
}
void SaveBase(const char* batch_model_path, const char* xbox_model_path,
boxps::SaveModelStat& stat) { // NOLINT
VLOG(3) << "Begin SaveBase";
if (nullptr != s_instance_) {
s_instance_->boxps_ptr_->SaveBase(batch_model_path, xbox_model_path,
stat);
}
}
void SaveDelta(const char* xbox_model_path,
boxps::SaveModelStat& stat) { // NOLINT
VLOG(3) << "Begin SaveDelta";
if (nullptr != s_instance_) {
s_instance_->boxps_ptr_->SaveDelta(xbox_model_path, stat);
}
}
static std::shared_ptr<BoxWrapper> GetInstance() { static std::shared_ptr<BoxWrapper> GetInstance() {
if (nullptr == s_instance_) { if (nullptr == s_instance_) {
...@@ -54,22 +211,92 @@ class BoxWrapper { ...@@ -54,22 +211,92 @@ class BoxWrapper {
static std::mutex mutex; static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex); std::lock_guard<std::mutex> lock(mutex);
if (nullptr == s_instance_) { if (nullptr == s_instance_) {
VLOG(3) << "s_instance_ is null";
s_instance_.reset(new paddle::framework::BoxWrapper()); s_instance_.reset(new paddle::framework::BoxWrapper());
#ifdef PADDLE_WITH_BOX_PS s_instance_->boxps_ptr_.reset(boxps::BoxPSBase::GetIns());
s_instance_->boxps_ptr_.reset(new paddle::boxps::FakeBoxPS());
#endif
} }
} }
return s_instance_; return s_instance_;
} }
const std::unordered_set<std::string>& GetOmitedSlot() const {
return slot_name_omited_in_feedpass_;
}
struct MetricMsg {
public:
MetricMsg() {}
MetricMsg(const std::string& label_varname, const std::string& pred_varname,
int is_join, int bucket_size = 1000000)
: label_varname_(label_varname),
pred_varname_(pred_varname),
is_join_(is_join) {
calculator = new BasicAucCalculator();
calculator->init(bucket_size);
}
const std::string& LabelVarname() const { return label_varname_; }
const std::string& PredVarname() const { return pred_varname_; }
int IsJoin() const { return is_join_; }
BasicAucCalculator* GetCalculator() { return calculator; }
private:
std::string label_varname_;
std::string pred_varname_;
int is_join_;
BasicAucCalculator* calculator;
};
int PassFlag() const { return pass_flag_; }
void FlipPassFlag() { pass_flag_ = 1 - pass_flag_; }
bool NeedMetric() const { return need_metric_; }
std::map<std::string, MetricMsg>& GetMetricList() { return metric_lists_; }
void InitMetric(const std::string& name, const std::string& label_varname,
const std::string& pred_varname, bool is_join,
int bucket_size = 1000000) {
metric_lists_.emplace(name, MetricMsg(label_varname, pred_varname,
is_join ? 1 : 0, bucket_size));
need_metric_ = true;
}
const std::vector<float> GetMetricMsg(const std::string& name) {
const auto iter = metric_lists_.find(name);
PADDLE_ENFORCE_NE(iter, metric_lists_.end(),
platform::errors::InvalidArgument(
"The metric name you provided is not registered."));
std::vector<float> metric_return_values_(8, 0.0);
auto* auc_cal_ = iter->second.GetCalculator();
auc_cal_->calculate_bucket_error();
auc_cal_->compute();
metric_return_values_[0] = auc_cal_->auc();
metric_return_values_[1] = auc_cal_->bucket_error();
metric_return_values_[2] = auc_cal_->mae();
metric_return_values_[3] = auc_cal_->rmse();
metric_return_values_[4] = auc_cal_->actual_ctr();
metric_return_values_[5] = auc_cal_->predicted_ctr();
metric_return_values_[6] =
auc_cal_->actual_ctr() / auc_cal_->predicted_ctr();
metric_return_values_[7] = auc_cal_->size();
auc_cal_->reset();
return metric_return_values_;
}
private: private:
#ifdef PADDLE_WITH_BOX_PS static cudaStream_t stream_list_[8];
static std::shared_ptr<paddle::boxps::BoxPSBase> boxps_ptr_; static std::shared_ptr<boxps::BoxPSBase> boxps_ptr_;
#endif boxps::PSAgentBase* p_agent_ = nullptr;
const int feedpass_thread_num_ = 30; // magic number
static std::shared_ptr<BoxWrapper> s_instance_; static std::shared_ptr<BoxWrapper> s_instance_;
int GetDate() const; std::unordered_set<std::string> slot_name_omited_in_feedpass_;
// Metric Related
int pass_flag_ = 1; // join: 1, update: 0
bool need_metric_ = false;
std::map<std::string, MetricMsg> metric_lists_;
std::vector<int> slot_vector_;
std::vector<LoDTensor> keys_tensor; // Cache for pull_sparse
}; };
#endif
class BoxHelper { class BoxHelper {
public: public:
...@@ -77,13 +304,17 @@ class BoxHelper { ...@@ -77,13 +304,17 @@ class BoxHelper {
virtual ~BoxHelper() {} virtual ~BoxHelper() {}
void BeginPass() { void BeginPass() {
#ifdef PADDLE_WITH_BOX_PS
auto box_ptr = BoxWrapper::GetInstance(); auto box_ptr = BoxWrapper::GetInstance();
box_ptr->BeginPass(); box_ptr->BeginPass();
#endif
} }
void EndPass() { void EndPass() {
#ifdef PADDLE_WITH_BOX_PS
auto box_ptr = BoxWrapper::GetInstance(); auto box_ptr = BoxWrapper::GetInstance();
box_ptr->EndPass(); box_ptr->EndPass();
#endif
} }
void LoadIntoMemory() { void LoadIntoMemory() {
dataset_->LoadIntoMemory(); dataset_->LoadIntoMemory();
...@@ -103,6 +334,7 @@ class BoxHelper { ...@@ -103,6 +334,7 @@ class BoxHelper {
std::shared_ptr<std::thread> feed_data_thread_; std::shared_ptr<std::thread> feed_data_thread_;
// notify boxps to feed this pass feasigns from SSD to memory // notify boxps to feed this pass feasigns from SSD to memory
void FeedPass() { void FeedPass() {
#ifdef PADDLE_WITH_BOX_PS
auto box_ptr = BoxWrapper::GetInstance(); auto box_ptr = BoxWrapper::GetInstance();
auto input_channel_ = auto input_channel_ =
dynamic_cast<MultiSlotDataset*>(dataset_)->GetInputChannel(); dynamic_cast<MultiSlotDataset*>(dataset_)->GetInputChannel();
...@@ -119,6 +351,7 @@ class BoxHelper { ...@@ -119,6 +351,7 @@ class BoxHelper {
input_channel_->Write(pass_data); input_channel_->Write(pass_data);
input_channel_->Close(); input_channel_->Close();
box_ptr->FeedPass(feasign_to_box); box_ptr->FeedPass(feasign_to_box);
#endif
} }
}; };
......
...@@ -26,7 +26,6 @@ template <typename T> ...@@ -26,7 +26,6 @@ template <typename T>
static void PullBoxSparseFunctor(const framework::ExecutionContext &ctx) { static void PullBoxSparseFunctor(const framework::ExecutionContext &ctx) {
auto inputs = ctx.MultiInput<framework::Tensor>("Ids"); auto inputs = ctx.MultiInput<framework::Tensor>("Ids");
auto outputs = ctx.MultiOutput<framework::Tensor>("Out"); auto outputs = ctx.MultiOutput<framework::Tensor>("Out");
auto hidden_size = ctx.Attr<int>("size");
const auto slot_size = inputs.size(); const auto slot_size = inputs.size();
std::vector<const uint64_t *> all_keys(slot_size); std::vector<const uint64_t *> all_keys(slot_size);
// BoxPS only supports float now // BoxPS only supports float now
...@@ -41,33 +40,49 @@ static void PullBoxSparseFunctor(const framework::ExecutionContext &ctx) { ...@@ -41,33 +40,49 @@ static void PullBoxSparseFunctor(const framework::ExecutionContext &ctx) {
auto *output = outputs[i]->mutable_data<T>(ctx.GetPlace()); auto *output = outputs[i]->mutable_data<T>(ctx.GetPlace());
all_values[i] = output; all_values[i] = output;
} }
#ifdef PADDLE_WITH_BOX_PS
auto hidden_size = ctx.Attr<int>("size");
auto box_ptr = paddle::framework::BoxWrapper::GetInstance(); auto box_ptr = paddle::framework::BoxWrapper::GetInstance();
box_ptr->PullSparse(ctx.GetPlace(), all_keys, all_values, slot_lengths, box_ptr->PullSparse(ctx.GetPlace(), all_keys, all_values, slot_lengths,
hidden_size); hidden_size);
#endif
} }
template <typename T> template <typename T>
static void PushBoxSparseFunctor(const framework::ExecutionContext &ctx) { static void PushBoxSparseFunctor(const framework::ExecutionContext &ctx) {
auto inputs = ctx.MultiInput<framework::Tensor>("Ids"); auto inputs = ctx.MultiInput<framework::LoDTensor>("Ids");
auto d_output = auto d_output =
ctx.MultiInput<framework::Tensor>(framework::GradVarName("Out")); ctx.MultiInput<framework::Tensor>(framework::GradVarName("Out"));
auto hidden_size = ctx.Attr<int>("size");
const auto slot_size = inputs.size(); const auto slot_size = inputs.size();
std::vector<const uint64_t *> all_keys(slot_size); std::vector<const uint64_t *> all_keys(slot_size);
std::vector<const float *> all_grad_values(slot_size); std::vector<const float *> all_grad_values(slot_size);
std::vector<int64_t> slot_lengths(slot_size); std::vector<int64_t> slot_lengths(slot_size);
int batch_size = -1;
for (size_t i = 0; i < slot_size; i++) { for (size_t i = 0; i < slot_size; i++) {
const auto *slot = inputs[i]; const auto *slot = inputs[i];
const uint64_t *single_slot_keys = const uint64_t *single_slot_keys =
reinterpret_cast<const uint64_t *>(slot->data<int64_t>()); reinterpret_cast<const uint64_t *>(slot->data<int64_t>());
all_keys[i] = single_slot_keys; all_keys[i] = single_slot_keys;
slot_lengths[i] = slot->numel(); slot_lengths[i] = slot->numel();
int cur_batch_size =
slot->lod().size() ? slot->lod()[0].size() - 1 : slot->dims()[0];
if (batch_size == -1) {
batch_size = cur_batch_size;
} else {
PADDLE_ENFORCE_EQ(batch_size, cur_batch_size,
platform::errors::PreconditionNotMet(
"The batch size of all input slots should be same, "
"please cheack"));
}
const float *grad_value = d_output[i]->data<float>(); const float *grad_value = d_output[i]->data<float>();
all_grad_values[i] = grad_value; all_grad_values[i] = grad_value;
} }
#ifdef PADDLE_WITH_BOX_PS
auto hidden_size = ctx.Attr<int>("size");
auto box_ptr = paddle::framework::BoxWrapper::GetInstance(); auto box_ptr = paddle::framework::BoxWrapper::GetInstance();
box_ptr->PushSparseGrad(ctx.GetPlace(), all_keys, all_grad_values, box_ptr->PushSparseGrad(ctx.GetPlace(), all_keys, all_grad_values,
slot_lengths, hidden_size); slot_lengths, hidden_size, batch_size);
#endif
} }
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册