未验证 提交 7673b39a 编写于 作者: W WangZhen 提交者: GitHub

[JITLayer]Polish Layer implement and refine interface code (#43607)

* Fix some TODO and polish code

* Support GPU place

* Fix layer_test ci error

* Polish some code

* Make GetFunction as const function

* Remove COMMAND tar to fix CI error

* Fix comments

* Merge develop to fix codestyle error
上级 ad88cbb8
...@@ -16,8 +16,11 @@ cc_library( ...@@ -16,8 +16,11 @@ cc_library(
if(WITH_TESTING AND NOT WIN32) if(WITH_TESTING AND NOT WIN32)
add_custom_target( add_custom_target(
jit_download_program jit_download_program
COMMAND wget -nc https://paddle-ci.gz.bcebos.com/dy2st/Testing.tar.gz COMMAND wget -nc
COMMAND tar zxvf Testing.tar.gz) https://paddle-ci.gz.bcebos.com/dy2st/export.forward.pdiparams
COMMAND wget -nc
https://paddle-ci.gz.bcebos.com/dy2st/export.forward.pdmodel
COMMAND wget -nc https://paddle-ci.gz.bcebos.com/dy2st/export.infer.pdmodel)
set(JIT_DEPS set(JIT_DEPS
phi phi
elementwise_add_op elementwise_add_op
......
...@@ -38,29 +38,28 @@ std::vector<std::string> FunctionSchema::GetOutputArgNames() { ...@@ -38,29 +38,28 @@ std::vector<std::string> FunctionSchema::GetOutputArgNames() {
return output_arg_names; return output_arg_names;
} }
void FunctionSchema::AddInputArg(std::string name, bool is_output) { void FunctionSchema::AddInputArg(std::string name) {
input_args.emplace_back(name, is_output); input_args.emplace_back(name, false);
} }
void FunctionSchema::AddOutputArg(std::string name, bool is_output) { void FunctionSchema::AddOutputArg(std::string name) {
output_args.emplace_back(name, is_output); output_args.emplace_back(name, true);
} }
BaseFunction::BaseFunction( BaseFunction::BaseFunction(const framework::ProgramDesc &program_desc,
const framework::ProgramDesc &program_desc, const std::vector<std::string> &param_names,
const std::vector<std::string> param_names_for_program, const VariableNameMap &params_dict,
const VariableNameMap &params_dict) const phi::Place &place)
: program_desc_(program_desc) { : program_desc_(program_desc), place_(place) {
// Parse FunctionSchema // Parse FunctionSchema
// skip_var_name_ = program_desc_.GetFetchTargetNames();
for (auto &in_name : program_desc_.GetFeedTargetNames()) { for (auto &in_name : program_desc_.GetFeedTargetNames()) {
schema_.AddInputArg(in_name, false); schema_.AddInputArg(in_name);
} }
for (auto &out_name : program_desc_.GetFetchTargetNames()) { for (auto &out_name : program_desc_.GetFetchTargetNames()) {
schema_.AddOutputArg(out_name, true); schema_.AddOutputArg(out_name);
} }
// share params into scope // share params into scope
SharePartialIntoScope(param_names_for_program, params_dict); ShareParamsIntoScope(param_names, params_dict);
VLOG(6) << framework::GenScopeTreeDebugInfo(&scope_); VLOG(6) << framework::GenScopeTreeDebugInfo(&scope_);
// remove feed fetch op // remove feed fetch op
RemoveFeedFetch(); RemoveFeedFetch();
...@@ -70,7 +69,9 @@ void BaseFunction::FetchOutput(std::vector<Variable> *outs) { ...@@ -70,7 +69,9 @@ void BaseFunction::FetchOutput(std::vector<Variable> *outs) {
for (auto &out_name : schema_.GetOutputArgNames()) { for (auto &out_name : schema_.GetOutputArgNames()) {
VLOG(3) << "fetch out: " << out_name; VLOG(3) << "fetch out: " << out_name;
auto *var = scope_.FindVar(out_name); auto *var = scope_.FindVar(out_name);
VLOG(3) << "after scope_.FindVar(out_name);";
auto &src_tensor = var->Get<phi::DenseTensor>(); auto &src_tensor = var->Get<phi::DenseTensor>();
VLOG(3) << "var->Get<phi::DenseTensor>();";
Variable v; Variable v;
auto *p = v.GetMutable<DenseTensor>(); auto *p = v.GetMutable<DenseTensor>();
*p = src_tensor; *p = src_tensor;
...@@ -78,23 +79,30 @@ void BaseFunction::FetchOutput(std::vector<Variable> *outs) { ...@@ -78,23 +79,30 @@ void BaseFunction::FetchOutput(std::vector<Variable> *outs) {
} }
} }
void BaseFunction::ShareIntoScope(const VariableNameMap &ivals) { void BaseFunction::ShareInputsIntoScope(const std::vector<Variable> &vars) {
VLOG(3) << "ivals size: " << ivals.size(); VLOG(3) << "vars size: " << vars.size();
for (auto it = ivals.begin(); it != ivals.end(); ++it) { std::vector<std::string> ordered_input_names = schema_.GetInputArgNames();
VLOG(3) << "share into scope: " << it->first; PADDLE_ENFORCE_EQ(
DenseTensor dense_tensor = it->second.Get<DenseTensor>(); vars.size(),
auto *var = scope_.Var(it->first); ordered_input_names.size(),
platform::errors::InvalidArgument(
"vars.size() should be equal to ordered_input_names.size()."));
for (size_t i = 0; i < vars.size(); i++) {
VLOG(3) << "share into scope: " << ordered_input_names[i];
auto &dense_tensor = vars[i].Get<DenseTensor>();
auto *var = scope_.Var(ordered_input_names[i]);
auto *dst_tensor = var->GetMutable<DenseTensor>(); auto *dst_tensor = var->GetMutable<DenseTensor>();
*dst_tensor = dense_tensor; *dst_tensor = dense_tensor;
} }
} }
void BaseFunction::SharePartialIntoScope( void BaseFunction::ShareParamsIntoScope(
const std::vector<std::string> param_names_for_program, const std::vector<std::string> &param_names,
const VariableNameMap &params_dict) { const VariableNameMap &params_dict) {
VLOG(3) << "ivals size: " << param_names_for_program.size(); VLOG(3) << "param_names size: " << param_names.size();
for (size_t i = 0; i < param_names_for_program.size(); ++i) { for (size_t i = 0; i < param_names.size(); ++i) {
std::string name = param_names_for_program[i]; std::string name = param_names[i];
Variable val = params_dict.find(name)->second; Variable val = params_dict.find(name)->second;
auto &dense_tensor = val.Get<DenseTensor>(); auto &dense_tensor = val.Get<DenseTensor>();
VLOG(3) << "share into scope: " << name; VLOG(3) << "share into scope: " << name;
...@@ -112,8 +120,15 @@ void BaseFunction::RemoveFeedFetch() { ...@@ -112,8 +120,15 @@ void BaseFunction::RemoveFeedFetch() {
VLOG(3) << "op_size: " << op_size; VLOG(3) << "op_size: " << op_size;
for (int i = op_size - 1; i >= 0; i--) { for (int i = op_size - 1; i >= 0; i--) {
auto op = all_ops[i]; auto op = all_ops[i];
if (op->Type() == "feed" || op->Type() == "fetch") { if (op->Type() == "feed") {
VLOG(3) << "remove op type: " << op->Type() << ", index: " << i; VLOG(3) << "remove op type: " << op->Type() << ", index: " << i
<< ", var name: " << op->Input("X")[0];
block->RemoveVar(op->Input("X")[0]);
block->RemoveOp(i, i + 1);
} else if (op->Type() == "fetch") {
VLOG(3) << "remove op type: " << op->Type() << ", index: " << i
<< ", var name: " << op->Output("Out")[0];
block->RemoveVar(op->Output("Out")[0]);
block->RemoveOp(i, i + 1); block->RemoveOp(i, i + 1);
} }
} }
......
...@@ -20,8 +20,7 @@ ...@@ -20,8 +20,7 @@
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/utils/none.h" #include "paddle/phi/core/enforce.h"
#include "paddle/utils/optional.h"
namespace paddle { namespace paddle {
namespace jit { namespace jit {
...@@ -50,11 +49,12 @@ class FunctionSchema { ...@@ -50,11 +49,12 @@ class FunctionSchema {
std::vector<std::string> GetOutputArgNames(); std::vector<std::string> GetOutputArgNames();
void AddInputArg(std::string name, bool is_output); void AddInputArg(std::string name);
void AddOutputArg(std::string name, bool is_output); void AddOutputArg(std::string name);
private: private:
// input_args and output_args are ordered
std::vector<Argument> input_args; std::vector<Argument> input_args;
std::vector<Argument> output_args; std::vector<Argument> output_args;
}; };
...@@ -63,33 +63,31 @@ class FunctionSchema { ...@@ -63,33 +63,31 @@ class FunctionSchema {
class BaseFunction { class BaseFunction {
public: public:
BaseFunction(const framework::ProgramDesc &program_desc, BaseFunction(const framework::ProgramDesc &program_desc,
const std::vector<std::string> param_names_for_program, const std::vector<std::string> &param_names,
const VariableNameMap &params_dict); const VariableNameMap &params_dict,
const phi::Place &place);
virtual ~BaseFunction() {} virtual ~BaseFunction() {}
virtual std::vector<Variable> operator()(const VariableNameMap &inputs) = 0; virtual std::vector<Variable> operator()(
const std::vector<Variable> &inputs) = 0;
protected: protected:
void FetchOutput(std::vector<Variable> *outs); void FetchOutput(std::vector<Variable> *outs);
void ShareIntoScope(const VariableNameMap &ivals); void ShareInputsIntoScope(const std::vector<Variable> &vars);
void SharePartialIntoScope( void ShareParamsIntoScope(const std::vector<std::string> &param_names,
const std::vector<std::string> param_names_for_program, const VariableNameMap &params_dict);
const VariableNameMap &params_dict);
void RemoveFeedFetch(); void RemoveFeedFetch();
protected: protected:
framework::ProgramDesc program_desc_; framework::ProgramDesc program_desc_;
// TODO(dev): need a better way to share params
// std::vector<Variable> &param_for_program_;
// std::vector<std::string> skip_var_name_;
FunctionSchema schema_; FunctionSchema schema_;
// global_scope place params // global_scope place params
framework::Scope scope_; framework::Scope scope_;
// framework::Executor inner_exe_; phi::Place place_;
}; };
} // namespace jit } // namespace jit
......
...@@ -22,18 +22,23 @@ namespace jit { ...@@ -22,18 +22,23 @@ namespace jit {
class ExectorFunction : public BaseFunction { class ExectorFunction : public BaseFunction {
public: public:
ExectorFunction(const framework::ProgramDesc &program_desc, ExectorFunction(const framework::ProgramDesc &program_desc,
const std::vector<std::string> param_names_for_program, const std::vector<std::string> param_names,
const VariableNameMap &params_dict) const VariableNameMap &params_dict,
: BaseFunction(program_desc, param_names_for_program, params_dict), const phi::Place &place)
inner_exe_(phi::CPUPlace()) {} : BaseFunction(program_desc, param_names, params_dict, place),
inner_exe_(place_) {}
~ExectorFunction() {} ~ExectorFunction() {}
std::vector<Variable> operator()(const VariableNameMap &inputs) { std::vector<Variable> operator()(const std::vector<Variable> &inputs) {
// share input into scope // share input into scope
ShareIntoScope(inputs); ShareInputsIntoScope(inputs);
// run program // run program
inner_exe_.Run(program_desc_, &scope_, /*blockID=*/0, false, true, inner_exe_.Run(program_desc_,
&scope_,
/*blockID=*/0,
false,
true,
schema_.GetOutputArgNames()); schema_.GetOutputArgNames());
VLOG(6) << framework::GenScopeTreeDebugInfo(&scope_); VLOG(6) << framework::GenScopeTreeDebugInfo(&scope_);
// fetch outputs // fetch outputs
...@@ -43,7 +48,6 @@ class ExectorFunction : public BaseFunction { ...@@ -43,7 +48,6 @@ class ExectorFunction : public BaseFunction {
} }
private: private:
// TODO(dev): support other devices exe
framework::Executor inner_exe_; framework::Executor inner_exe_;
}; };
......
...@@ -23,23 +23,24 @@ Layer::Layer( ...@@ -23,23 +23,24 @@ Layer::Layer(
const std::vector<std::string>& func_names, const std::vector<std::string>& func_names,
const std::vector<framework::ProgramDesc>& program_descs, const std::vector<framework::ProgramDesc>& program_descs,
const std::vector<std::vector<std::string>>& param_names_for_each_program, const std::vector<std::vector<std::string>>& param_names_for_each_program,
const VariableNameMap& params_dict) { const VariableNameMap& params_dict,
const phi::Place& place) {
VLOG(3) << "program size: " << program_descs.size(); VLOG(3) << "program size: " << program_descs.size();
// Layer manage the life time of all parameter. // Layer manage the life time of all parameter.
for (size_t i = 0; i < func_names.size(); ++i) { for (size_t i = 0; i < func_names.size(); ++i) {
// TODO(dev): choose exector or pe by flag // TODO(dev): choose exector or pe by flag
function_dict[func_names[i]] = std::make_shared<ExectorFunction>( function_dict[func_names[i]] = std::make_shared<ExectorFunction>(
program_descs[i], param_names_for_each_program[i], params_dict); program_descs[i], param_names_for_each_program[i], params_dict, place);
} }
} }
// TODO(dev): make it as const function std::shared_ptr<BaseFunction> Layer::GetFunction(
std::shared_ptr<BaseFunction> Layer::GetFunction(const std::string& name) { const std::string& name) const {
VLOG(3) << "funcs_ size: " << function_dict.size(); VLOG(3) << "funcs_ size: " << function_dict.size();
return function_dict[name]; return function_dict.at(name);
} }
std::vector<Variable> Layer::forward(const VariableNameMap& inputs) { std::vector<Variable> Layer::forward(const std::vector<Variable>& inputs) {
auto func = GetFunction("forward"); auto func = GetFunction("forward");
return (*func)(inputs); return (*func)(inputs);
} }
......
...@@ -36,16 +36,17 @@ class Layer { ...@@ -36,16 +36,17 @@ class Layer {
// TODO(dev): Make vector<string>, num_slot as in argument // TODO(dev): Make vector<string>, num_slot as in argument
// Layer(const std::shared_ptr<ClassType>& type) : obj_(type, /*num_slot*/ 0U) // Layer(const std::shared_ptr<ClassType>& type) : obj_(type, /*num_slot*/ 0U)
// {} // {}
// TODO(dev): consider make `func_name, program_desc, param_nams` as a class
Layer( Layer(
const std::vector<std::string>& func_names, const std::vector<std::string>& func_names,
const std::vector<framework::ProgramDesc>& program_descs, const std::vector<framework::ProgramDesc>& program_descs,
const std::vector<std::vector<std::string>>& param_names_for_each_program, const std::vector<std::vector<std::string>>& param_names_for_each_program,
const VariableNameMap& params_dict); const VariableNameMap& params_dict,
const phi::Place& place);
// TODO(dev): make it as const function std::shared_ptr<BaseFunction> GetFunction(const std::string& name) const;
std::shared_ptr<BaseFunction> GetFunction(const std::string& name);
std::vector<Variable> forward(const VariableNameMap& inputs); std::vector<Variable> forward(const std::vector<Variable>& inputs);
private: private:
// internal::Object obj_; // internal::Object obj_;
......
...@@ -23,11 +23,13 @@ ...@@ -23,11 +23,13 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/jit/serializer.h" #include "paddle/fluid/jit/serializer.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h" #include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(elementwise_add);
...@@ -44,28 +46,37 @@ PD_DECLARE_KERNEL(relu, CPU, ALL_LAYOUT); ...@@ -44,28 +46,37 @@ PD_DECLARE_KERNEL(relu, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(mean, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(mean, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(scale, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(scale, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA)
PD_DECLARE_KERNEL(add, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(matmul, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(relu, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(mean, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(scale, GPU, ALL_LAYOUT);
#endif
namespace paddle { namespace paddle {
namespace jit { namespace jit {
VariableNameMap PrepareInputs() { std::vector<Variable> PrepareInputs() {
auto temp = DenseTensor(); auto default_place = imperative::GetCurrentTracer()->ExpectedPlace();
temp.Resize(phi::make_ddim({2, 4})); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
phi::CPUContext cpu_ctx; auto& dev_ctx = *pool.Get(default_place);
cpu_ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
cpu_ctx.Init();
cpu_ctx.Alloc<float>(&temp);
phi::funcs::set_constant(cpu_ctx, &temp, 2.);
Variable v; Variable v;
auto *p = v.GetMutable<DenseTensor>(); auto* dense_tensor = v.GetMutable<DenseTensor>();
*p = temp; dense_tensor->Resize(phi::make_ddim({2, 4}));
// TODO(dev): associate the input name dense_tensor->mutable_data<float>(default_place);
return {{"x", v}}; phi::funcs::set_constant(dev_ctx, dense_tensor, 2.);
return {v};
} }
TEST(layer, Construct) { TEST(CpuLayerTest, Construct) {
std::string path = "./Testing/"; auto tracer = std::make_shared<paddle::imperative::Tracer>();
paddle::imperative::SetCurrentTracer(tracer);
imperative::GetCurrentTracer()->SetExpectedPlace(phi::CPUPlace());
std::string path = "./";
auto layer = jit::Load(path); auto layer = jit::Load(path);
auto inputs = PrepareInputs(); auto inputs = PrepareInputs();
...@@ -83,5 +94,39 @@ TEST(layer, Construct) { ...@@ -83,5 +94,39 @@ TEST(layer, Construct) {
EXPECT_NEAR(out_data[0], 1.41562390, 1e-6); EXPECT_NEAR(out_data[0], 1.41562390, 1e-6);
} }
#if defined(PADDLE_WITH_CUDA)
TEST(GpuLayerTest, Construct) {
auto tracer = std::make_shared<paddle::imperative::Tracer>();
paddle::imperative::SetCurrentTracer(tracer);
imperative::GetCurrentTracer()->SetExpectedPlace(phi::GPUPlace(0));
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(imperative::GetCurrentTracer()->ExpectedPlace());
const auto* dev_ctx_gpu = static_cast<const phi::GPUContext*>(&dev_ctx);
DenseTensor cpu_dense_tensor;
std::string path = "./";
auto layer = jit::Load(path);
auto inputs = PrepareInputs();
auto outs = layer.forward(inputs);
auto out_vars = outs[0];
auto out_dense_tensor = out_vars.Get<DenseTensor>();
phi::Copy(
*dev_ctx_gpu, out_dense_tensor, phi::CPUPlace(), true, &cpu_dense_tensor);
auto out_data = cpu_dense_tensor.data<float>();
EXPECT_NEAR(out_data[0], 0.02194316, 1e-6);
auto func = layer.GetFunction("infer");
outs = (*func)(inputs);
out_vars = outs[0];
out_dense_tensor = out_vars.Get<DenseTensor>();
phi::Copy(
*dev_ctx_gpu, out_dense_tensor, phi::CPUPlace(), true, &cpu_dense_tensor);
out_data = cpu_dense_tensor.data<float>();
EXPECT_NEAR(out_data[0], 1.41562390, 1e-6);
}
#endif
} // namespace jit } // namespace jit
} // namespace paddle } // namespace paddle
...@@ -27,13 +27,14 @@ namespace jit { ...@@ -27,13 +27,14 @@ namespace jit {
class PEFunction : public BaseFunction { class PEFunction : public BaseFunction {
public: public:
PEFunction(const framework::ProgramDesc &program_desc, PEFunction(const framework::ProgramDesc &program_desc,
const std::vector<std::string> param_names_for_program, const std::vector<std::string> param_names,
const VariableNameMap &params_dict) const VariableNameMap &params_dict,
: BaseFunction(program_desc, param_names_for_program, params_dict) {} const phi::Place &place)
: BaseFunction(program_desc, param_names, params_dict, place) {}
~PEFunction() {} ~PEFunction() {}
std::vector<Variable> operator()(const VariableNameMap &inputs) { std::vector<Variable> operator()(const std::vector<Variable> &inputs) {
// bool is_test = true; // bool is_test = true;
std::string prog_string; std::string prog_string;
std::hash<std::string> string_hash; std::hash<std::string> string_hash;
...@@ -43,15 +44,19 @@ class PEFunction : public BaseFunction { ...@@ -43,15 +44,19 @@ class PEFunction : public BaseFunction {
int64_t start_op_index = 0; int64_t start_op_index = 0;
int64_t end_op_index = static_cast<int64_t>(global_block.OpSize()); int64_t end_op_index = static_cast<int64_t>(global_block.OpSize());
ShareIntoScope(inputs); ShareInputsIntoScope(inputs);
std::vector<std::string> input_var_names = schema_.GetInputArgNames(); std::vector<std::string> input_var_names = schema_.GetInputArgNames();
std::vector<std::string> output_var_names = schema_.GetOutputArgNames(); std::vector<std::string> output_var_names = schema_.GetOutputArgNames();
std::vector<std::string> dout_var_names; std::vector<std::string> dout_var_names;
if (end_op_index > start_op_index) { if (end_op_index > start_op_index) {
// TODO(dev): support other devices // TODO(dev): support other devices
auto cache_info = framework::GetExecutorInfoFromCache( auto cache_info = framework::GetExecutorInfoFromCache(program_desc_,
program_desc_, phi::CPUPlace(), start_op_index, end_op_index, place_,
/*is_grad=*/false, program_id, &scope_); start_op_index,
end_op_index,
/*is_grad=*/false,
program_id,
&scope_);
auto &parallel_executor = cache_info.first; auto &parallel_executor = cache_info.first;
auto &skip_eager_delete_vars = auto &skip_eager_delete_vars =
framework::ExecutorInfoCache::Instance().SkipEagerDeleteVars( framework::ExecutorInfoCache::Instance().SkipEagerDeleteVars(
...@@ -65,7 +70,9 @@ class PEFunction : public BaseFunction { ...@@ -65,7 +70,9 @@ class PEFunction : public BaseFunction {
dout_var_names.begin(), dout_var_names.begin(),
dout_var_names.end()); dout_var_names.end());
framework::details::ParseSafeEagerDeletionSkipVars( framework::details::ParseSafeEagerDeletionSkipVars(
program_desc_, end_op_index, output_var_names, program_desc_,
end_op_index,
output_var_names,
&skip_eager_delete_vars); &skip_eager_delete_vars);
} }
parallel_executor->RunWithoutFetch(skip_eager_delete_vars); parallel_executor->RunWithoutFetch(skip_eager_delete_vars);
......
...@@ -45,14 +45,18 @@ Layer Deserializer::operator()(const std::string& dir_path) { ...@@ -45,14 +45,18 @@ Layer Deserializer::operator()(const std::string& dir_path) {
persistable_var_names.end()); persistable_var_names.end());
} }
auto default_place = imperative::GetCurrentTracer()->ExpectedPlace();
// Read from one pdiparams file, refine here // Read from one pdiparams file, refine here
auto params_for_all_program = ReadTensorData(dir_path + "export.forward.pdiparams",
ReadTensorData(dir_path + "export.forward.pdiparams", param_names_set); param_names_set,
params_dict.insert(params_for_all_program.begin(), default_place,
params_for_all_program.end()); &params_dict);
return Layer(func_names, program_descs, param_names_for_each_program, return Layer(func_names,
params_dict); program_descs,
param_names_for_each_program,
params_dict,
default_place);
} }
bool Deserializer::IsPersistable(framework::VarDesc* desc_ptr) { bool Deserializer::IsPersistable(framework::VarDesc* desc_ptr) {
...@@ -74,6 +78,7 @@ bool Deserializer::EndsWith(const std::string& str, const std::string& suffix) { ...@@ -74,6 +78,7 @@ bool Deserializer::EndsWith(const std::string& str, const std::string& suffix) {
0; 0;
} }
// process filename like `export.forward.pdmodel` and `export.infer.pdmodel`
const std::vector<std::pair<std::string, std::string>> const std::vector<std::pair<std::string, std::string>>
Deserializer::GetPdmodelFileNamePrefix(const std::string& path) { Deserializer::GetPdmodelFileNamePrefix(const std::string& path) {
std::vector<std::pair<std::string, std::string>> file_name_prefixs; std::vector<std::pair<std::string, std::string>> file_name_prefixs;
...@@ -92,23 +97,22 @@ Deserializer::GetPdmodelFileNamePrefix(const std::string& path) { ...@@ -92,23 +97,22 @@ Deserializer::GetPdmodelFileNamePrefix(const std::string& path) {
return file_name_prefixs; return file_name_prefixs;
} }
VariableNameMap Deserializer::ReadTensorData( void Deserializer::ReadTensorData(const std::string& file_name,
const std::string& file_name, const std::set<std::string>& var_name) const { const std::set<std::string>& var_name,
const phi::Place& place,
VariableNameMap* params_dict) const {
VLOG(3) << "ReadTensorData from: " << file_name; VLOG(3) << "ReadTensorData from: " << file_name;
std::ifstream fin(file_name, std::ios::binary); std::ifstream fin(file_name, std::ios::binary);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
// TODO(dev): Support other devices auto& dev_ctx = *pool.Get(place);
auto& dev_ctx = *pool.Get(phi::CPUPlace());
VariableNameMap res;
for (auto it = var_name.begin(); it != var_name.end(); it++) { for (auto it = var_name.begin(); it != var_name.end(); it++) {
VLOG(3) << "load Tensor: " << *it; VLOG(3) << "load Tensor: " << *it;
Variable v; Variable v;
// TODO(dev): Support framework::Vocab // TODO(dev): Support framework::Vocab
DenseTensor* dense_tesnor = v.GetMutable<DenseTensor>(); DenseTensor* dense_tesnor = v.GetMutable<DenseTensor>();
framework::DeserializeFromStream(fin, dense_tesnor, dev_ctx); framework::DeserializeFromStream(fin, dense_tesnor, dev_ctx);
res[*it] = v; (*params_dict)[*it] = v;
} }
return res;
} }
framework::ProgramDesc Deserializer::LoadProgram(const std::string& file_name) { framework::ProgramDesc Deserializer::LoadProgram(const std::string& file_name) {
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/jit/layer.h" #include "paddle/fluid/jit/layer.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
...@@ -58,8 +59,10 @@ class Deserializer { ...@@ -58,8 +59,10 @@ class Deserializer {
const std::vector<std::pair<std::string, std::string>> const std::vector<std::pair<std::string, std::string>>
GetPdmodelFileNamePrefix(const std::string& path); GetPdmodelFileNamePrefix(const std::string& path);
VariableNameMap ReadTensorData(const std::string& file_name, void ReadTensorData(const std::string& file_name,
const std::set<std::string>& var_name) const; const std::set<std::string>& var_name,
const phi::Place& place,
VariableNameMap* params_dict) const;
// void ReadExtraInfo(const std::string& file_name) const; // void ReadExtraInfo(const std::string& file_name) const;
// void ReadByteCode(const std::string& file_name) const; // void ReadByteCode(const std::string& file_name) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册