未验证 提交 8ca86d72 编写于 作者: P pangengzheng 提交者: GitHub

Speedup worker (#51760)

* support run haokanctr model in heterps-models

* polish setup.py

* polish JVM_LIB in evn_dict

* align infer auc with DistPsArch pre-stable

* async and multi thread data feed

* rewrite dense tensor intialization

* async infer shape and reuse memory
上级 16ec22c4
......@@ -2510,10 +2510,22 @@ bool SlotRecordInMemoryDataFeed::ParseOneInstance(const std::string& line,
void SlotRecordInMemoryDataFeed::AssignFeedVar(const Scope& scope) {
CheckInit();
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
if (scpoe_feed_vec_.count(&scope) > 0) {
return;
}
auto& feed_vec = scpoe_feed_vec_[&scope];
feed_vec.resize(used_slots_info_.size());
for (int i = 0; i < use_slot_size_; ++i) {
feed_vec[i] =
scope.FindVar(used_slots_info_[i].slot)->GetMutable<phi::DenseTensor>();
}
#else
for (int i = 0; i < use_slot_size_; ++i) {
feed_vec_[i] =
scope.FindVar(used_slots_info_[i].slot)->GetMutable<phi::DenseTensor>();
}
#endif
}
void SlotRecordInMemoryDataFeed::PutToFeedVec(const SlotRecord* ins_vec,
......@@ -2985,6 +2997,29 @@ void SlotRecordInMemoryDataFeed::PackToScope(MiniBatchGpuPack* pack,
}
}
MiniBatchGpuPack* SlotRecordInMemoryDataFeed::get_pack(
MiniBatchGpuPack* last_pack) {
if (last_pack != nullptr) {
free_pack_queue_.Push(last_pack);
return nullptr;
}
std::unique_lock<std::mutex> lock(pack_mutex_);
while (true) {
if (using_pack_queue_.Size() != 0) {
auto* pack = using_pack_queue_.Pop();
return pack;
}
bool is_end = pack_is_end_.load();
if (is_end) {
if (using_pack_queue_.Size() == 0) {
return nullptr;
}
}
std::this_thread::sleep_for(std::chrono::microseconds(200));
}
}
MiniBatchGpuPack::MiniBatchGpuPack(const paddle::platform::Place& place,
const std::vector<UsedSlotInfo>& infos,
phi::StreamId stream_id) {
......
......@@ -1154,6 +1154,10 @@ class DataFeed {
// This function is used for binding feed_vec memory in a given scope
virtual void AssignFeedVar(const Scope& scope);
virtual std::vector<std::string> GetInputVarNames() {
return std::vector<std::string>();
}
// This function will do nothing at default
virtual void SetInputPvChannel(void* channel) {}
// This function will do nothing at default
......@@ -1201,6 +1205,9 @@ class DataFeed {
virtual const std::vector<std::string>& GetInsContentVec() const {
return ins_content_vec_;
}
virtual void SetCurBatchSize(const int batch_size) {
batch_size_ = batch_size;
}
virtual int GetCurBatchSize() { return batch_size_; }
virtual int GetGraphPathNum() {
#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS)
......@@ -1248,10 +1255,15 @@ class DataFeed {
virtual const paddle::platform::Place& GetPlace() const { return place_; }
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
virtual MiniBatchGpuPack* get_pack(MiniBatchGpuPack* last_pack) {
return nullptr;
}
virtual void PackToScope(MiniBatchGpuPack* pack, const Scope* scope) {
PADDLE_THROW(platform::errors::Unimplemented(
"This function(PackToScope) is not implemented."));
}
virtual void SetInsIdVec(MiniBatchGpuPack* pack) {}
#endif
virtual void DumpWalkPath(std::string dump_path, size_t dump_rate) {
......@@ -1809,32 +1821,41 @@ class SlotRecordInMemoryDataFeed : public InMemoryDataFeed<SlotRecord> {
public:
SlotRecordInMemoryDataFeed() {}
virtual ~SlotRecordInMemoryDataFeed();
virtual void Init(const DataFeedDesc& data_feed_desc);
virtual void LoadIntoMemory();
void Init(const DataFeedDesc& data_feed_desc) override;
void LoadIntoMemory() override;
void ExpandSlotRecord(SlotRecord* ins);
protected:
virtual bool Start();
virtual int Next();
virtual bool ParseOneInstance(SlotRecord* instance) { return false; }
virtual bool ParseOneInstanceFromPipe(SlotRecord* instance) { return false; }
bool Start() override;
int Next() override;
bool ParseOneInstance(SlotRecord* instance) override { return false; }
bool ParseOneInstanceFromPipe(SlotRecord* instance) override { return false; }
// virtual void ParseOneInstanceFromSo(const char* str, T* instance,
// CustomParser* parser) {}
virtual void PutToFeedVec(const std::vector<SlotRecord>& ins_vec) {}
void PutToFeedVec(const std::vector<SlotRecord>& ins_vec) override {}
virtual void LoadIntoMemoryByCommand(void);
virtual void LoadIntoMemoryByLib(void);
virtual void LoadIntoMemoryByLine(void);
virtual void LoadIntoMemoryByFile(void);
virtual void SetInputChannel(void* channel) {
void SetInputChannel(void* channel) override {
input_channel_ = static_cast<ChannelObject<SlotRecord>*>(channel);
}
bool ParseOneInstance(const std::string& line, SlotRecord* rec);
virtual void PutToFeedVec(const SlotRecord* ins_vec, int num);
virtual void AssignFeedVar(const Scope& scope);
void PutToFeedVec(const SlotRecord* ins_vec, int num) override;
void AssignFeedVar(const Scope& scope) override;
std::vector<std::string> GetInputVarNames() override {
std::vector<std::string> var_names;
for (int i = 0; i < use_slot_size_; ++i) {
var_names.push_back(used_slots_info_[i].slot);
}
return var_names;
}
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
void BuildSlotBatchGPU(const int ins_num, MiniBatchGpuPack* pack);
virtual MiniBatchGpuPack* get_pack(MiniBatchGpuPack* last_pack);
virtual void PackToScope(MiniBatchGpuPack* pack,
const Scope* scope = nullptr);
......@@ -1867,8 +1888,18 @@ class SlotRecordInMemoryDataFeed : public InMemoryDataFeed<SlotRecord> {
virtual void InitGraphResource(void);
virtual void InitGraphTrainResource(void);
virtual void DoWalkandSage();
void SetInsIdVec(MiniBatchGpuPack* pack) override {
if (parse_ins_id_) {
size_t ins_num = pack->ins_num();
ins_id_vec_.clear();
ins_id_vec_.resize(ins_num);
for (size_t i = 0; i < ins_num; i++) {
ins_id_vec_[i] = pack->get_lineid(i);
}
}
}
#endif
virtual void DumpWalkPath(std::string dump_path, size_t dump_rate);
void DumpWalkPath(std::string dump_path, size_t dump_rate) override;
float sample_rate_ = 1.0f;
int use_slot_size_ = 0;
......
......@@ -550,7 +550,7 @@ class HeterCpuWorker : public HogwildWorker {
class PSGPUWorker : public HogwildWorker {
public:
PSGPUWorker() {}
virtual ~PSGPUWorker() {}
virtual ~PSGPUWorker();
virtual void Initialize(const TrainerDesc& desc);
virtual void TrainFiles();
virtual void TrainFilesWithProfiler();
......@@ -566,12 +566,27 @@ class PSGPUWorker : public HogwildWorker {
#endif
void ResetStat();
// async infershape
virtual void CreateDeviceResource(const ProgramDesc& main_prog);
virtual void BindingDataFeedMemory();
protected:
void PushGradients();
void CopySparseTable();
void CopyDenseTable();
void CopyDenseVars();
struct InferShapeCheckData {
std::vector<std::vector<DDim>> pre_dims;
std::vector<std::vector<LoD>> pre_lods;
std::vector<std::vector<DDim>> after_dims;
std::vector<std::vector<LoD>> after_lods;
};
int OpRunAndShapeCheck(OperatorBase& op, // NOLINT
const Scope& scope,
const platform::Place& place);
private:
int mpi_rank_;
std::mutex mutex_;
......@@ -634,6 +649,28 @@ class PSGPUWorker : public HogwildWorker {
double gpu_2_cpu_time_;
double cpu_2_gpu_time_;
uint64_t total_inst_;
// async infershape
int task_threads_num_{6};
int scope_num_{task_threads_num_ + 1};
std::atomic<int> thread_count_{0};
std::atomic<bool> stop_token_{false};
std::atomic<bool> pack_is_end_{false};
std::vector<std::thread> task_threads_;
std::vector<Scope*> thread_scope_vec_;
std::map<Scope*, std::vector<Variable*>> need_reuse_var_vec_;
std::vector<Variable*> need_reuse_var_;
struct TaskData {
int ins_num;
Scope* scope;
MiniBatchGpuPack* pack;
};
paddle::framework::BlockingQueue<TaskData> free_task_queue_;
paddle::framework::BlockingQueue<TaskData> using_task_queue_;
static std::atomic<int> shape_check_count_;
static std::atomic<bool> shape_check_flag_;
};
#endif
......
......@@ -610,6 +610,34 @@ RuntimeInferShapeContext::GetPhiDefaultKernelSignature() const {
void RuntimeInferShapeContext::SetSkipLoD(bool skip) { can_skip_lod_ = skip; }
std::vector<LoD> RuntimeInferShapeContext::GetOutputsLod(
const std::string& out) const {
auto out_it = ctx_.outputs.find(out);
auto& out_var_list = out_it->second;
std::vector<LoD> ret;
for (size_t i = 0; i < out_var_list.size(); ++i) {
Variable* out_var = out_var_list[i];
if (out_var != nullptr) {
auto* out_tensor = out_var->GetMutable<phi::DenseTensor>();
ret.push_back(out_tensor->lod());
}
}
return ret;
}
std::vector<DDim> RuntimeInferShapeContext::GetOutputsDim(
const std::string& name) const {
const std::vector<Variable*>& vars = OutputVars(name);
std::vector<Variable*> vars_res;
for (auto var : vars) {
if (var != nullptr) {
vars_res.push_back(var);
}
}
return GetDims(vars_res);
}
DDim RuntimeInferShapeContext::GetDim(Variable* var) const {
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument("Input variable is nullptr."));
......
......@@ -224,6 +224,10 @@ class RuntimeInferShapeContext : public InferShapeContext {
void SetSkipLoD(bool skip);
std::vector<LoD> GetOutputsLod(const std::string& out) const;
std::vector<DDim> GetOutputsDim(const std::string& name) const;
protected:
DDim GetDim(Variable* var) const;
......@@ -351,6 +355,8 @@ class OperatorBase {
void SetIsCalledByExecutor(bool x) { run_by_executor_ = x; }
virtual void SetIsRuntimeInferShape(bool x) {}
virtual void RuntimeInferShape(const Scope& scope,
const platform::Place& place,
const RuntimeContext& ctx) const {}
......@@ -775,6 +781,10 @@ class OperatorWithKernel : public OperatorBase {
virtual void InferShape(InferShapeContext* ctx) const;
void SetIsRuntimeInferShape(bool x) override {
all_kernels_must_compute_runtime_shape_ = x;
}
void RuntimeInferShape(const Scope& scope,
const platform::Place& place,
const RuntimeContext& ctx) const override;
......
......@@ -276,14 +276,14 @@ void PSGPUTrainer::InitTrainerEnv(const ProgramDesc& main_program,
for (auto& var : block.AllVars()) {
if (var->Persistable()) {
auto name = var->Name();
auto* ptr = scope->Var(name);
InitializeVariable(ptr, proto::VarType::LOD_TENSOR);
Variable* root_var = root_scope_->FindVar(name);
if (!root_var) {
continue;
}
phi::DenseTensor* root_tensor =
root_var->GetMutable<phi::DenseTensor>();
auto* ptr = scope->Var(name);
InitializeVariable(ptr, proto::VarType::LOD_TENSOR);
phi::DenseTensor* thread_tensor = ptr->GetMutable<phi::DenseTensor>();
TensorCopy(*root_tensor, place, thread_tensor);
}
......@@ -300,6 +300,19 @@ void PSGPUTrainer::InitTrainerEnv(const ProgramDesc& main_program,
}
}
}
for (size_t num = 0; num < places_.size(); ++num) {
Scope* scope = workers_[num]->GetThreadScope();
for (size_t i = 0; i < need_merge_var_names_.size(); i++) {
Variable* thread_var = scope->FindVar(need_merge_var_names_[i]);
if (thread_var != nullptr) {
continue;
}
auto* ptr = scope->Var(need_merge_var_names_[i]);
InitializeVariable(ptr, proto::VarType::LOD_TENSOR);
}
}
place_ = place;
return;
}
......
......@@ -34,6 +34,83 @@ limitations under the License. */
namespace paddle {
namespace framework {
std::atomic<int> PSGPUWorker::shape_check_count_(16);
std::atomic<bool> PSGPUWorker::shape_check_flag_(true);
void PSGPUWorker::CreateDeviceResource(const ProgramDesc& main_prog) {
this->HogwildWorker::CreateDeviceResource(main_prog);
if (scope_num_ != 1) {
auto& block = main_prog.Block(0);
for (int i = 0; i < scope_num_; i++) {
auto thread_tmp = &thread_scope_->NewScope();
thread_scope_vec_.push_back(thread_tmp);
}
for (auto& scope : thread_scope_vec_) {
for (auto& var : block.AllVars()) {
std::string name = var->Name();
if (!var->Persistable()) {
auto* ptr = scope->Var(var->Name());
InitializeVariable(ptr, var->GetType());
}
}
}
VLOG(1) << "ops_ size:" << ops_.size();
for (auto& op : ops_) {
op->SetIsRuntimeInferShape(true);
}
// reusing memory
auto input_names = device_reader_->GetInputVarNames();
std::set<std::string> input_names_set(input_names.begin(),
input_names.end());
for (auto& scope : thread_scope_vec_) {
std::vector<Variable*> need_reuse;
for (auto& var : block.AllVars()) {
std::string name = var->Name();
if (!var->Persistable()) {
if (input_names_set.find(var->Name()) != input_names_set.end()) {
continue;
}
auto* ptr = scope->FindLocalVar(var->Name());
PADDLE_ENFORCE_NE(
ptr,
nullptr,
phi::errors::NotFound("The var %s is not found.", var->Name()));
need_reuse.push_back(ptr);
}
}
need_reuse_var_vec_[scope] = std::move(need_reuse);
}
{
need_reuse_var_.clear();
for (auto& var : block.AllVars()) {
std::string name = var->Name();
if (!var->Persistable()) {
if (input_names_set.find(var->Name()) != input_names_set.end()) {
continue;
}
auto* ptr = thread_scope_->FindLocalVar(var->Name());
PADDLE_ENFORCE_NE(
ptr,
nullptr,
phi::errors::NotFound("The var %s is not found.", var->Name()));
need_reuse_var_.push_back(ptr);
}
}
}
}
}
void PSGPUWorker::BindingDataFeedMemory() {
if (scope_num_ == 1) {
this->HogwildWorker::BindingDataFeedMemory();
} else {
for (auto& scope : thread_scope_vec_) {
device_reader_->AssignFeedVar(*scope);
}
}
}
void PSGPUWorker::Initialize(const TrainerDesc& desc) {
param_ = desc.downpour_param();
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
......@@ -122,6 +199,86 @@ void PSGPUWorker::SetChannelWriter(ChannelObject<std::string>* queue) {
writer_.Reset(queue);
}
PSGPUWorker::~PSGPUWorker() {
stop_token_.store(true);
for (auto& thread : task_threads_) {
if (thread.joinable()) {
thread.join();
}
}
}
int PSGPUWorker::OpRunAndShapeCheck(OperatorBase& op,
const Scope& scope,
const platform::Place& place) {
if (shape_check_flag_.load()) {
// before op run
InferShapeCheckData check_data;
auto& pre_dims = check_data.pre_dims;
auto& pre_lods = check_data.pre_lods;
auto& after_dims = check_data.after_dims;
auto& after_lods = check_data.after_lods;
RuntimeContext ctx(op.Inputs(), op.Outputs(), scope);
RuntimeInferShapeContext infer_shape_ctx(op, ctx);
auto outnames = op.Outputs();
for (auto& var_name_item : outnames) {
pre_dims.push_back(infer_shape_ctx.GetOutputsDim(var_name_item.first));
pre_lods.push_back(infer_shape_ctx.GetOutputsLod(var_name_item.first));
}
// op run
op.Run(scope, place);
// after op run
for (auto& var_name_item : outnames) {
after_dims.push_back(infer_shape_ctx.GetOutputsDim(var_name_item.first));
after_lods.push_back(infer_shape_ctx.GetOutputsLod(var_name_item.first));
}
std::string op_name = "unknow_op";
if (op.Info().HasOpProtoAndChecker()) {
op_name = op.Info().Proto().type();
}
#define SHAPE_CHECK_EQ(__VAL0, __VAL1) \
PADDLE_ENFORCE_EQ( \
__VAL0, \
__VAL1, \
platform::errors::Fatal("Shape check dims/lods error, op name: %s .", \
op_name))
SHAPE_CHECK_EQ(pre_dims.size(), after_dims.size());
for (size_t i = 0; i < pre_dims.size(); i++) {
SHAPE_CHECK_EQ(pre_dims[i].size(), after_dims[i].size());
for (size_t j = 0; j < pre_dims[i].size(); j++) {
SHAPE_CHECK_EQ(pre_dims[i][j], after_dims[i][j]);
}
}
SHAPE_CHECK_EQ(pre_lods.size(), after_lods.size());
for (size_t i = 0; i < pre_lods.size(); i++) {
SHAPE_CHECK_EQ(pre_lods[i].size(), after_lods[i].size());
for (size_t j = 0; j < pre_lods[i].size(); j++) {
auto& x = pre_lods[i][j];
auto& y = after_lods[i][j];
SHAPE_CHECK_EQ(x.size(), y.size());
for (size_t i = 0; i < x.size(); i++) {
const auto& x_level = x[i];
const auto& y_level = y[i];
SHAPE_CHECK_EQ(x_level.size(), y_level.size());
for (size_t j = 0; j < x_level.size(); j++) {
SHAPE_CHECK_EQ(x_level[j], y_level[j]);
}
}
}
}
#undef SHAPE_CHECK_EQ
} else {
op.Run(scope, place);
}
return 0;
}
void PSGPUWorker::TrainFiles() {
VLOG(0) << "Begin to train files";
platform::SetNumThreads(1);
......@@ -139,8 +296,107 @@ void PSGPUWorker::TrainFiles() {
device_reader_->Start();
int cur_batch;
int batch_cnt = 0;
while ((cur_batch = device_reader_->Next()) > 0) {
// async infershape
pack_is_end_.store(false);
if (scope_num_ != 1) {
for (size_t i = 0; i < thread_scope_vec_.size(); i++) {
TaskData task;
task.scope = thread_scope_vec_[i];
free_task_queue_.Push(task);
}
thread_count_.store(task_threads_num_);
task_threads_.reserve(task_threads_num_);
for (int i = 0; i < task_threads_num_; i++) {
task_threads_.emplace_back(std::thread([this]() -> void {
while (true) {
auto pack = device_reader_->get_pack(nullptr);
if (pack == nullptr) {
int thread_num = thread_count_.fetch_sub(1);
if (thread_num == 1) {
pack_is_end_.store(true);
}
return;
}
auto task = free_task_queue_.Pop();
task.pack = pack;
task.ins_num = pack->ins_num();
device_reader_->PackToScope(task.pack, task.scope);
for (size_t i = 0; i < ops_.size(); i++) {
auto& op = ops_[i];
bool need_skip = false;
for (auto t = 0u; t < skip_ops_.size(); ++t) {
if (op->Type().find(skip_ops_[t]) != std::string::npos) {
need_skip = true;
break;
}
}
if (!need_skip) {
paddle::framework::RuntimeContext ctx(
op->Inputs(), op->Outputs(), *task.scope);
op->RuntimeInferShape(*task.scope, place_, ctx);
}
}
using_task_queue_.Push(task);
}
}));
}
}
while (true) {
auto thread_scope = thread_scope_;
TaskData cur_task;
if (scope_num_ == 1) {
cur_batch = device_reader_->Next();
} else {
while (true) {
if (using_task_queue_.Size() != 0) {
cur_task = using_task_queue_.Pop();
cur_batch = cur_task.ins_num;
break;
}
bool is_end = pack_is_end_.load();
if (is_end) {
if (using_task_queue_.Size() == 0) {
cur_batch = 0;
break;
}
}
std::this_thread::sleep_for(std::chrono::microseconds(100));
}
thread_scope = cur_task.scope;
auto pack = cur_task.pack;
device_reader_->SetInsIdVec(pack);
// tensor share buffer
std::vector<Variable*>& cur_scope_vars =
need_reuse_var_vec_[thread_scope];
PADDLE_ENFORCE_EQ(
cur_scope_vars.size(),
need_reuse_var_.size(),
platform::errors::Fatal("reuse vars size must be same."));
for (size_t i = 0; i < need_reuse_var_.size(); i++) {
Variable* child = cur_scope_vars[i];
Variable* parent = need_reuse_var_[i];
if (child->IsType<phi::DenseTensor>()) {
child->GetMutable<phi::DenseTensor>()->ShareBufferWith(
*(parent->GetMutable<phi::DenseTensor>()));
}
}
}
if (cur_batch <= 0) {
break;
}
device_reader_->SetCurBatchSize(cur_batch);
total_ins_num += cur_batch;
if (shape_check_flag_.load()) {
if (scope_num_ == 1 || shape_check_count_.fetch_sub(1) <= 0) {
shape_check_flag_ = false;
}
}
for (auto& op : ops_) {
bool need_skip = false;
for (auto t = 0u; t < skip_ops_.size(); ++t) {
......@@ -150,18 +406,19 @@ void PSGPUWorker::TrainFiles() {
}
}
if (!need_skip) {
op->Run(*thread_scope_, place_);
OpRunAndShapeCheck(*op, *thread_scope, place_);
}
}
if (need_dump_field_) {
DumpField(*thread_scope_, dump_mode_, dump_interval_);
DumpField(*thread_scope, dump_mode_, dump_interval_);
}
if (need_dump_param_ && thread_id_ == 0) {
DumpParam(*thread_scope_, batch_cnt);
DumpParam(*thread_scope, batch_cnt);
}
for (std::string& var_name : check_nan_var_names_) {
Variable* var = thread_scope_->FindVar(var_name);
Variable* var = thread_scope->FindVar(var_name);
if (var == nullptr) {
continue;
}
......@@ -176,11 +433,11 @@ void PSGPUWorker::TrainFiles() {
std::lock_guard<std::mutex> lock(mutex);
VLOG(0) << "worker " << thread_id_ << ": " << var_name
<< " cantains inf or nan";
auto all_vars = thread_scope_->LocalVarNames();
auto all_vars = thread_scope->LocalVarNames();
std::stringstream ss;
ss << "====== worker " << thread_id_ << "======\n";
for (auto& local_var : all_vars) {
platform::PrintVar(thread_scope_, local_var, local_var, &ss);
platform::PrintVar(thread_scope, local_var, local_var, &ss);
ss << "\n";
}
std::cout << ss.str() << std::endl;
......@@ -193,9 +450,28 @@ void PSGPUWorker::TrainFiles() {
dev_ctx_->Wait();
PrintFetchVars();
thread_scope_->DropKids();
thread_scope->DropKids();
++batch_cnt;
if (scope_num_ != 1) {
std::vector<Variable*>& cur_scope_vars =
need_reuse_var_vec_[thread_scope];
PADDLE_ENFORCE_EQ(
cur_scope_vars.size(),
need_reuse_var_.size(),
platform::errors::Fatal("reuse vars size must be same."));
for (size_t i = 0; i < need_reuse_var_.size(); i++) {
Variable* child = cur_scope_vars[i];
Variable* parent = need_reuse_var_[i];
if (child->IsType<phi::DenseTensor>()) {
parent->GetMutable<phi::DenseTensor>()->ShareBufferWith(
*(child->GetMutable<phi::DenseTensor>()));
}
}
device_reader_->get_pack(cur_task.pack);
free_task_queue_.Push(cur_task);
}
}
if (need_dump_field_ || need_dump_param_) {
writer_.Flush();
}
......
......@@ -365,9 +365,6 @@ class CoalesceTensorOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
if (ctx->IsRuntime()) {
return;
}
auto use_align = ctx->Attrs().Get<bool>("use_align");
auto align_size = ctx->Attrs().Get<int>("align_size");
auto size_of_dtype = ctx->Attrs().Get<int>("user_defined_size_of_dtype");
......@@ -377,30 +374,50 @@ class CoalesceTensorOp : public framework::OperatorWithKernel {
if (size_of_dtype == -1) {
size_of_dtype = framework::SizeOfType(dtype);
}
auto alignment = [](size_t size, size_t align_size) {
size_t remaining = size % align_size;
auto aligned_size =
remaining == 0 ? size : size + (align_size - remaining);
VLOG(4) << remaining << " " << size << " " << align_size << " "
<< aligned_size;
return aligned_size;
};
VLOG(4) << "align_size: " << align_size;
if (use_align && align_size > 0) {
if (ctx->IsRuntime()) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
int64_t numel = 0;
auto dims = ctx->GetInputsDim("Input");
for (const auto &dim : dims) {
auto size = phi::product(dim);
auto len = use_align
? alignment(static_cast<size_t>(size) * size_of_dtype,
auto len = use_align ? phi::Alignment(
static_cast<size_t>(size) * size_of_dtype,
phi::GPUPlace(),
align_size) /
size_of_dtype
: static_cast<size_t>(size);
size_of_dtype
: static_cast<size_t>(size);
numel += len;
}
ctx->SetOutputDim("FusedOutput", phi::make_ddim({numel}));
VLOG(4) << "FusedOutput size:" << phi::make_ddim({numel});
#else
return;
#endif
} else {
auto alignment = [](size_t size, size_t align_size) {
size_t remaining = size % align_size;
auto aligned_size =
remaining == 0 ? size : size + (align_size - remaining);
VLOG(4) << remaining << " " << size << " " << align_size << " "
<< aligned_size;
return aligned_size;
};
VLOG(4) << "align_size: " << align_size;
if (use_align && align_size > 0) {
int64_t numel = 0;
auto dims = ctx->GetInputsDim("Input");
for (const auto &dim : dims) {
auto size = phi::product(dim);
auto len = use_align
? alignment(static_cast<size_t>(size) * size_of_dtype,
align_size) /
size_of_dtype
: static_cast<size_t>(size);
numel += len;
}
ctx->SetOutputDim("FusedOutput", phi::make_ddim({numel}));
VLOG(4) << "FusedOutput size:" << phi::make_ddim({numel});
}
}
}
......
......@@ -67,26 +67,64 @@ class FusedSeqpoolCVMOp : public framework::OperatorWithKernel {
"but received value is %d.",
ins_dims[0].size()));
for (size_t i = 0; i < num_inputs; ++i) {
const auto dims = ins_dims[i];
int rank = dims.size();
if (use_cvm) {
PADDLE_ENFORCE_GT(
dims[rank - 1],
2,
platform::errors::InvalidArgument(
"Shape error in %lu id, the last dimension(embedding) of the "
"'X' tensor must be larger than 2.",
i));
if (ctx->IsRuntime()) {
int batch_size = -1;
auto inputs_tensor = ctx->GetInputVarPtrs("X");
for (size_t i = 0; i < num_inputs; ++i) {
const auto dims = ins_dims[i];
int rank = dims.size();
int cur_batch_size = 0;
framework::Variable* x_var =
PADDLE_GET(framework::Variable*, inputs_tensor[i]);
const auto& x_tensor = x_var->Get<phi::DenseTensor>();
const auto& x_lod = x_tensor.lod();
if (x_lod.size() > 0) {
cur_batch_size = x_lod[0].size() - 1;
} else {
cur_batch_size = x_tensor.dims()[0];
}
if (batch_size == -1) {
batch_size = cur_batch_size;
} else {
PADDLE_ENFORCE_EQ(batch_size,
cur_batch_size,
platform::errors::PreconditionNotMet(
"The batch size of all input should be same, "
"please check, last batch_size is %d, current "
"batch_size is %d",
batch_size,
cur_batch_size));
}
std::vector<int64_t> out_dim;
if (use_cvm) {
out_dim = {batch_size, dims[rank - 1]};
} else {
out_dim = {batch_size, dims[rank - 1] - cvm_offset};
}
outs_dims[i] = phi::make_ddim(out_dim);
}
// input lod is not accessible here
std::vector<int64_t> out_dim;
if (use_cvm) {
out_dim = {-1, dims[rank - 1]};
} else {
out_dim = {-1, dims[rank - 1] - cvm_offset};
} else {
for (size_t i = 0; i < num_inputs; ++i) {
const auto dims = ins_dims[i];
int rank = dims.size();
if (use_cvm) {
PADDLE_ENFORCE_GT(
dims[rank - 1],
2,
platform::errors::InvalidArgument(
"Shape error in %lu id, the last dimension(embedding) of the "
"'X' tensor must be larger than 2.",
i));
}
// input lod is not accessible here
std::vector<int64_t> out_dim;
if (use_cvm) {
out_dim = {-1, dims[rank - 1]};
} else {
out_dim = {-1, dims[rank - 1] - cvm_offset};
}
outs_dims[i] = phi::make_ddim(out_dim);
}
outs_dims[i] = phi::make_ddim(out_dim);
}
ctx->SetOutputsDim("Out", outs_dims);
ctx->ShareLoD("X", /*->*/ "Out");
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <vector>
#include "paddle/phi/backends/device_memory_aligment.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/infermeta_utils.h"
......@@ -849,30 +850,19 @@ void CoalesceTensorInferMeta(const std::vector<const MetaTensor*>& input,
std::vector<MetaTensor*> output,
MetaTensor* fused_output,
MetaConfig config) {
if (config.is_runtime) {
return;
}
if (size_of_dtype == -1) {
size_of_dtype = phi::SizeOf(dtype);
}
auto alignment = [](size_t size, size_t align_size) {
size_t remaining = size % align_size;
auto aligned_size = remaining == 0 ? size : size + (align_size - remaining);
VLOG(4) << remaining << " " << size << " " << align_size << " "
<< aligned_size;
return aligned_size;
};
VLOG(4) << "align_size: " << align_size;
if (use_align && align_size > 0) {
if (config.is_runtime) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
int64_t numel = 0;
for (size_t i = 0; i < input.size(); ++i) {
const auto& dim = input[i]->dims();
auto size = phi::product(dim);
auto len = use_align
? alignment(static_cast<size_t>(size) * size_of_dtype,
align_size) /
? phi::Alignment(static_cast<size_t>(size) * size_of_dtype,
phi::GPUPlace(),
align_size) /
size_of_dtype
: static_cast<size_t>(size);
numel += len;
......@@ -882,6 +872,38 @@ void CoalesceTensorInferMeta(const std::vector<const MetaTensor*>& input,
fused_output->set_dtype(dtype);
VLOG(4) << "fused_output size:" << phi::make_ddim({numel});
}
#else
return;
#endif
} else {
auto alignment = [](size_t size, size_t align_size) {
size_t remaining = size % align_size;
auto aligned_size =
remaining == 0 ? size : size + (align_size - remaining);
VLOG(4) << remaining << " " << size << " " << align_size << " "
<< aligned_size;
return aligned_size;
};
VLOG(4) << "align_size: " << align_size;
if (use_align && align_size > 0) {
int64_t numel = 0;
for (size_t i = 0; i < input.size(); ++i) {
const auto& dim = input[i]->dims();
auto size = phi::product(dim);
auto len = use_align
? alignment(static_cast<size_t>(size) * size_of_dtype,
align_size) /
size_of_dtype
: static_cast<size_t>(size);
numel += len;
}
if (fused_output) {
fused_output->set_dims(phi::make_ddim({numel}));
fused_output->set_dtype(dtype);
VLOG(4) << "fused_output size:" << phi::make_ddim({numel});
}
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册