未验证 提交 ededcda2 编写于 作者: Y Yuang Liu 提交者: GitHub

[fleet_executor] framework for big model inference (#38795)

上级 31b1f707
......@@ -12,7 +12,7 @@ endif()
cc_library(task_loop_thread_pool SRCS task_loop_thread_pool.cc task_loop_thread.cc task_loop.cc DEPS enforce glog)
cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc dist_model.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto task_loop_thread_pool collective_helper
op_registry executor_gc_helper gflags glog ${BRPC_DEPS})
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
#include "paddle/fluid/distributed/fleet_executor/dist_model.h"
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
namespace paddle {
namespace distributed {
namespace {
bool IsPersistable(const framework::VarDesc *var) {
if (var->Persistable() &&
var->GetType() != framework::proto::VarType::FEED_MINIBATCH &&
var->GetType() != framework::proto::VarType::FETCH_LIST &&
var->GetType() != framework::proto::VarType::RAW) {
return true;
}
return false;
}
} // namespace
bool DistModel::Init() {
/* TODO(fleet exe dev): implement this funct */
place_ = paddle::platform::CUDAPlace(config_.device_id);
if (!PrepareScope()) {
return false;
}
if (!PrepareProgram()) {
return false;
}
if (!CommInit()) {
return false;
}
return true;
}
bool DistModel::CommInit() {
// TODO(fleet executor): init the comm
return true;
}
bool DistModel::PrepareScope() {
scope_.reset(new framework::Scope());
return true;
}
bool DistModel::PrepareProgram() {
if (!LoadProgram()) {
return false;
}
if (!LoadParameters()) {
return false;
}
return true;
}
bool DistModel::LoadProgram() {
VLOG(3) << "Loading program from " << config_.model_dir;
PADDLE_ENFORCE_NE(config_.model_dir, "", platform::errors::InvalidArgument(
"Model dir must be provided."));
std::string model_path = config_.model_dir + ".pdmodel";
framework::proto::ProgramDesc program_proto;
std::string pb_content;
// Read binary
std::ifstream fin(model_path, std::ios::in | std::ios::binary);
PADDLE_ENFORCE_EQ(
static_cast<bool>(fin.is_open()), true,
platform::errors::NotFound(
"Cannot open file %s, please confirm whether the file is normal.",
model_path));
fin.seekg(0, std::ios::end);
pb_content.resize(fin.tellg());
fin.seekg(0, std::ios::beg);
fin.read(&(pb_content.at(0)), pb_content.size());
fin.close();
program_proto.ParseFromString(pb_content);
VLOG(5) << pb_content;
program_.reset(new framework::ProgramDesc(program_proto));
return true;
}
bool DistModel::LoadParameters() {
VLOG(3) << "Loading parameters from " << config_.model_dir;
PADDLE_ENFORCE_NOT_NULL(program_.get(),
platform::errors::PreconditionNotMet(
"The program should be loaded first."));
const auto &global_block = program_->MutableBlock(0);
// create a temporary program to load parameters.
std::unique_ptr<framework::ProgramDesc> load_program(
new framework::ProgramDesc());
framework::BlockDesc *load_block = load_program->MutableBlock(0);
std::vector<std::string> params;
for (auto *var : global_block->AllVars()) {
if (IsPersistable(var)) {
VLOG(3) << "persistable variable's name: " << var->Name();
framework::VarDesc *new_var = load_block->Var(var->Name());
new_var->SetShape(var->GetShape());
new_var->SetDataType(var->GetDataType());
new_var->SetType(var->GetType());
new_var->SetLoDLevel(var->GetLoDLevel());
new_var->SetPersistable(true);
params.push_back(new_var->Name());
}
}
std::string param_path = config_.model_dir + ".pdiparams";
// sort paramlist to have consistent ordering
std::sort(params.begin(), params.end());
// append just the load_combine op
framework::OpDesc *op = load_block->AppendOp();
op->SetType("load_combine");
op->SetOutput("Out", params);
op->SetAttr("file_path", {param_path});
op->CheckAttrs();
framework::NaiveExecutor e(place_);
// Create all persistable variables in root scope to load them from ckpt.
// Other non-persistable variables will be created in the micro scope
// managed by fleet executor.
e.CreateVariables(*program_, 0, true, scope_.get());
e.Prepare(scope_.get(), *load_program, 0, false);
e.Run();
VLOG(3) << "After loading there are " << scope_->LocalVarNames().size()
<< " vars.";
return true;
}
void DistModel::Run(const std::vector<framework::Tensor> &input_data,
std::vector<framework::Tensor> *output_data) {
/* TODO(fleet exe dev): implement this funct */
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
class ProgramDesc;
class Scope;
class Tensor;
}
namespace distributed {
struct DistModelConfig {
std::string model_dir{};
std::vector<std::string> trainer_endpoints{};
std::string current_endpoint{};
int64_t nranks{1};
int64_t local_rank{0};
int64_t device_id{0};
int64_t mp_degree{1};
int64_t pp_degree{1};
};
class DistModel {
public:
explicit DistModel(const DistModelConfig& config) : config_(config) {}
bool Init();
void Run(const std::vector<framework::Tensor>& input_data,
std::vector<framework::Tensor>* output_data);
~DistModel() = default;
private:
DISABLE_COPY_AND_ASSIGN(DistModel);
bool PrepareScope();
bool PrepareProgram();
bool LoadProgram();
bool LoadParameters();
bool CommInit();
DistModelConfig config_;
FleetExecutorDesc executor_desc_;
platform::Place place_;
std::shared_ptr<framework::Scope> scope_;
std::shared_ptr<framework::ProgramDesc> program_;
};
} // namespace distributed
} // namespace paddle
......@@ -14,6 +14,7 @@
#include "paddle/fluid/pybind/bind_fleet_executor.h"
#include <pybind11/stl.h>
#include "paddle/fluid/distributed/fleet_executor/dist_model.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/operator.h"
......@@ -28,6 +29,8 @@ namespace pybind {
using paddle::distributed::FleetExecutor;
using paddle::distributed::TaskNode;
using paddle::distributed::DistModelConfig;
using paddle::distributed::DistModel;
using paddle::framework::OpDesc;
using paddle::framework::ProgramDesc;
......@@ -51,6 +54,22 @@ void BindFleetExecutor(py::module* m) {
.def("role", &TaskNode::role)
.def("init", &TaskNode::Init)
.def("set_program", &TaskNode::SetProgram);
py::class_<DistModelConfig>(*m, "DistModelConfig")
.def(py::init<>())
.def_readwrite("model_dir", &DistModelConfig::model_dir)
.def_readwrite("trainer_endpoints", &DistModelConfig::trainer_endpoints)
.def_readwrite("current_endpoint", &DistModelConfig::current_endpoint)
.def_readwrite("nranks", &DistModelConfig::nranks)
.def_readwrite("local_rank", &DistModelConfig::local_rank)
.def_readwrite("device_id", &DistModelConfig::device_id)
.def_readwrite("mp_degree", &DistModelConfig::mp_degree)
.def_readwrite("pp_degree", &DistModelConfig::pp_degree);
py::class_<DistModel>(*m, "DistModel")
.def(py::init<const DistModelConfig&>())
.def("init", &DistModel::Init)
.def("run", &DistModel::Run, py::call_guard<py::gil_scoped_release>());
}
} // namespace pybind
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册