未验证 提交 3be36a82 编写于 作者: W WangZhen 提交者: GitHub

[JitLayer]Split layer_utils.cc and polish interface BaseFunction (#43754)

* Split layer_utils.cc and polish interface BaseFunction

* Polish include

* Fix cmake for funciton_schema

* Using unordered_map replace map and rename VariableNameMap

* Polish Layer Constructor and remove useless include

* Refine compilation_unit
上级 cefbf800
......@@ -3,15 +3,25 @@ cc_library(
SRCS serializer.cc
DEPS lod_tensor device_context)
cc_library(
jit_layer_utils
SRCS layer_utils.cc
DEPS scope proto_desc)
cc_library(
jit_compilation_unit
SRCS compilation_unit.cc
DEPS proto_desc executor parallel_executor executor_cache)
cc_library(
jit_layer
SRCS layer.cc
DEPS executor parallel_executor executor_cache)
DEPS jit_compilation_unit)
cc_library(
jit_base_function
SRCS base_function.cc
DEPS scope proto_desc)
jit_function_schema
SRCS function_schema.cc
DEPS jit_layer_utils)
if(WITH_TESTING
AND NOT WIN32
......@@ -31,7 +41,9 @@ if(WITH_TESTING
scale_op
jit_serializer
jit_layer
jit_base_function)
jit_layer_utils
jit_function_schema
jit_compilation_unit)
cc_test(
layer_test
SRCS layer_test.cc
......
......@@ -17,77 +17,20 @@
#include <ostream>
#include <string>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/common/place.h"
#include "paddle/fluid/framework/variable.h"
namespace paddle {
namespace jit {
using Variable = paddle::framework::Variable;
using VariableNameMap = std::map<std::string, Variable>;
using DenseTensor = phi::DenseTensor;
class Argument {
public:
explicit Argument(const std::string &name, bool is_out = false);
const std::string &Name() const;
private:
std::string name_;
// paddle::optional<Variable> default_val_;
bool is_output_;
};
class FunctionSchema {
public:
FunctionSchema() = default;
std::vector<std::string> GetInputArgNames();
std::vector<std::string> GetOutputArgNames();
void AddInputArg(std::string name);
void AddOutputArg(std::string name);
private:
// input_args and output_args are ordered
std::vector<Argument> input_args;
std::vector<Argument> output_args;
};
// TODO(dev): make it as abstract class
class BaseFunction {
public:
BaseFunction(const framework::ProgramDesc &program_desc,
const std::vector<std::string> &param_names,
const VariableNameMap &params_dict,
const phi::Place &place);
virtual ~BaseFunction() {}
virtual std::vector<Variable> operator()(
const std::vector<Variable> &inputs) = 0;
protected:
void FetchOutput(std::vector<Variable> *outs);
void ShareInputsIntoScope(const std::vector<Variable> &vars);
void ShareParamsIntoScope(const std::vector<std::string> &param_names,
const VariableNameMap &params_dict);
void RemoveFeedFetch();
protected:
framework::ProgramDesc program_desc_;
FunctionSchema schema_;
// global_scope place params
framework::Scope scope_;
phi::Place place_;
virtual ~BaseFunction() {}
// virtual void SetPalce(const phi::Place &place);
};
} // namespace jit
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/jit/compilation_unit.h"
namespace paddle {
namespace jit {
void CompilationUnit::AddExecutorFunction(
const std::string &func_name,
const std::shared_ptr<FunctionInfo> &info,
const Name2VariableMap &params_dict,
const phi::Place &place) {
function_dict_[func_name] =
std::make_shared<ExecutorFunction>(info, params_dict, place);
}
void CompilationUnit::AddPEFunction(const std::string &func_name,
const std::shared_ptr<FunctionInfo> &info,
const Name2VariableMap &params_dict,
const phi::Place &place) {
function_dict_[func_name] =
std::make_shared<PEFunction>(info, params_dict, place);
}
std::shared_ptr<BaseFunction> CompilationUnit::GetFunction(
const std::string &name) const {
return function_dict_.at(name);
}
} // namespace jit
} // namespace paddle
......@@ -14,23 +14,35 @@
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/jit/executor_function.h"
#include "paddle/fluid/jit/function_schema.h"
#include "paddle/fluid/jit/pe_function.h"
namespace paddle {
namespace jit {
class BaseFunction;
class CompilationUnit {
public:
CompilationUnit() = default;
~CompilationUnit() {}
void AddExecutorFunction(const std::string &func_name,
const std::shared_ptr<FunctionInfo> &info,
const Name2VariableMap &params_dict,
const phi::Place &place);
void AddPEFunction(const std::string &func_name,
const std::shared_ptr<FunctionInfo> &info,
const Name2VariableMap &params_dict,
const phi::Place &place);
std::shared_ptr<BaseFunction> GetFunction(const std::string &name) const;
private:
std::vector<std::unique_ptr<BaseFunction>> functions_;
std::unordered_map<std::string, size_t> functions_idx_;
std::unordered_map<std::string, std::shared_ptr<BaseFunction>> function_dict_;
};
} // namespace jit
......
......@@ -14,40 +14,52 @@
#pragma once
#include <iostream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/jit/base_function.h"
#include "paddle/fluid/jit/function_schema.h"
#include "paddle/fluid/jit/layer_utils.h"
namespace paddle {
namespace jit {
class ExectorFunction : public BaseFunction {
class ExecutorFunction : public BaseFunction {
public:
ExectorFunction(const framework::ProgramDesc &program_desc,
const std::vector<std::string> param_names,
const VariableNameMap &params_dict,
ExecutorFunction(const std::shared_ptr<FunctionInfo> &info,
const Name2VariableMap &params_dict,
const phi::Place &place)
: BaseFunction(program_desc, param_names, params_dict, place),
inner_exe_(place_) {}
: info_(info), place_(place), inner_exe_(place_) {
ShareParamsIntoScope(info_->GetParamNames(), params_dict, &scope_);
VLOG(6) << framework::GenScopeTreeDebugInfo(&scope_);
}
~ExectorFunction() {}
~ExecutorFunction() noexcept {}
std::vector<Variable> operator()(const std::vector<Variable> &inputs) {
// share input into scope
ShareInputsIntoScope(inputs);
// run program
inner_exe_.Run(program_desc_,
ShareInputsIntoScope(info_->GetInputArgNames(), inputs, &scope_);
inner_exe_.Run(info_->GetProgramDesc(),
&scope_,
/*blockID=*/0,
false,
true,
schema_.GetOutputArgNames());
info_->GetOutputArgNames());
VLOG(6) << framework::GenScopeTreeDebugInfo(&scope_);
// fetch outputs
std::vector<Variable> res;
FetchOutput(&res);
FetchVarsByNames(info_->GetOutputArgNames(), scope_, &res);
return res;
}
private:
std::shared_ptr<FunctionInfo> info_;
framework::Scope scope_;
phi::Place place_;
framework::Executor inner_exe_;
};
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/jit/function_schema.h"
namespace paddle {
namespace jit {
Argument::Argument(const std::string& name, bool is_out)
: name_(name), is_output_(is_out) {}
const std::string& Argument::Name() const { return name_; }
const std::vector<std::string> FunctionSchema::GetInputArgNames() const {
std::vector<std::string> input_arg_names;
for (auto& arg : input_args) {
input_arg_names.emplace_back(arg.Name());
}
return input_arg_names;
}
const std::vector<std::string> FunctionSchema::GetOutputArgNames() const {
std::vector<std::string> output_arg_names;
for (auto& arg : output_args) {
output_arg_names.emplace_back(arg.Name());
}
return output_arg_names;
}
void FunctionSchema::AddInputArg(const std::string& name) {
input_args.emplace_back(name, false);
}
void FunctionSchema::AddOutputArg(const std::string& name) {
output_args.emplace_back(name, true);
}
FunctionInfo::FunctionInfo(const std::string& func_name,
const std::vector<std::string>& param_names,
const framework::ProgramDesc& program_desc)
: func_name_(func_name),
param_names_(param_names),
program_desc_(program_desc) {
// Parse FunctionSchema
for (auto& in_name : program_desc_.GetFeedTargetNames()) {
schema_.AddInputArg(in_name);
}
for (auto& out_name : program_desc_.GetFetchTargetNames()) {
schema_.AddOutputArg(out_name);
}
// remove feed fetch op
RemoveFeedFetch(&program_desc_);
}
const std::string& FunctionInfo::GetFunctionName() const { return func_name_; }
const framework::ProgramDesc& FunctionInfo::GetProgramDesc() const {
return program_desc_;
}
const std::vector<std::string>& FunctionInfo::GetParamNames() const {
return param_names_;
}
const std::vector<std::string> FunctionInfo::GetInputArgNames() const {
return schema_.GetInputArgNames();
}
const std::vector<std::string> FunctionInfo::GetOutputArgNames() const {
return schema_.GetOutputArgNames();
}
} // namespace jit
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ostream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/fluid/jit/layer_utils.h"
namespace paddle {
namespace jit {
class Argument {
public:
explicit Argument(const std::string& name, bool is_out = false);
const std::string& Name() const;
private:
std::string name_;
// paddle::optional<Variable> default_val_;
bool is_output_;
};
class FunctionSchema {
public:
FunctionSchema() = default;
const std::vector<std::string> GetInputArgNames() const;
const std::vector<std::string> GetOutputArgNames() const;
void AddInputArg(const std::string& name);
void AddOutputArg(const std::string& name);
private:
// input_args and output_args are ordered
std::vector<Argument> input_args;
std::vector<Argument> output_args;
};
class FunctionInfo {
public:
FunctionInfo(const std::string& func_name,
const std::vector<std::string>& param_names,
const framework::ProgramDesc& program_desc);
const std::string& GetFunctionName() const;
const framework::ProgramDesc& GetProgramDesc() const;
const std::vector<std::string>& GetParamNames() const;
const std::vector<std::string> GetInputArgNames() const;
const std::vector<std::string> GetOutputArgNames() const;
private:
std::string func_name_;
std::vector<std::string> param_names_;
framework::ProgramDesc program_desc_;
FunctionSchema schema_;
};
} // namespace jit
} // namespace paddle
......@@ -19,25 +19,22 @@ namespace jit {
// TODO(dev): Make vector<string>, num_slot as in argument
// Layer(const std::shared_ptr<ClassType>& type) : obj_(type, /*num_slot*/ 0U)
// {}
Layer::Layer(
const std::vector<std::string>& func_names,
const std::vector<framework::ProgramDesc>& program_descs,
const std::vector<std::vector<std::string>>& param_names_for_each_program,
const VariableNameMap& params_dict,
const phi::Place& place) {
VLOG(3) << "program size: " << program_descs.size();
Layer::Layer(const std::vector<std::shared_ptr<FunctionInfo>>& infos,
const Name2VariableMap& params_dict,
const phi::Place& place)
: params_dict_(params_dict) {
VLOG(3) << "infos size: " << infos.size();
// Layer manage the life time of all parameter.
for (size_t i = 0; i < func_names.size(); ++i) {
for (size_t i = 0; i < infos.size(); ++i) {
// TODO(dev): choose exector or pe by flag
function_dict[func_names[i]] = std::make_shared<ExectorFunction>(
program_descs[i], param_names_for_each_program[i], params_dict, place);
unit_.AddExecutorFunction(
infos[i]->GetFunctionName(), infos[i], params_dict_, place);
}
}
std::shared_ptr<BaseFunction> Layer::GetFunction(
const std::string& name) const {
VLOG(3) << "funcs_ size: " << function_dict.size();
return function_dict.at(name);
return unit_.GetFunction(name);
}
std::vector<Variable> Layer::forward(const std::vector<Variable>& inputs) {
......
......@@ -13,47 +13,44 @@
// limitations under the License.
#pragma once
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/jit/ast.h"
#include "paddle/fluid/jit/base_function.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/phi/common/place.h"
#include "paddle/fluid/jit/compilation_unit.h"
#include "paddle/fluid/jit/exector_function.h"
#include "paddle/fluid/jit/object.h"
#include "paddle/fluid/jit/pe_function.h"
#include "paddle/fluid/jit/function_schema.h"
namespace paddle {
namespace jit {
using Variable = paddle::framework::Variable;
using VariableNameMap = std::map<std::string, Variable>;
using DenseTensor = phi::DenseTensor;
using Name2VariableMap = std::unordered_map<std::string, Variable>;
class Layer {
public:
// TODO(dev): Make vector<string>, num_slot as in argument
// Layer(const std::shared_ptr<ClassType>& type) : obj_(type, /*num_slot*/ 0U)
// {}
// TODO(dev): consider make `func_name, program_desc, param_nams` as a class
Layer(
const std::vector<std::string>& func_names,
const std::vector<framework::ProgramDesc>& program_descs,
const std::vector<std::vector<std::string>>& param_names_for_each_program,
const VariableNameMap& params_dict,
Layer(const std::vector<std::shared_ptr<FunctionInfo>>& infos,
const Name2VariableMap& params_dict,
const phi::Place& place);
std::shared_ptr<BaseFunction> GetFunction(const std::string& name) const;
Variable GetAttribute(const std::string& name) const;
std::vector<Variable> forward(const std::vector<Variable>& inputs);
void to(const phi::Place& place);
private:
// internal::Object obj_;
// std::vector<framework::ProgramDesc> all_program_desc_;
// std::vector<std::vector<std::string>> param_name_for_each_program_;
// std::vector<Variable> all_param_;
std::map<std::string, std::shared_ptr<BaseFunction>> function_dict;
Name2VariableMap params_dict_;
Name2VariableMap attrs_dict_;
CompilationUnit unit_;
};
} // namespace jit
......
......@@ -12,26 +12,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/jit/layer.h"
#include <algorithm>
#include <fstream>
#include <iterator>
#include <string>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/jit/serializer.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/fluid/jit/layer.h"
#include "paddle/fluid/jit/serializer.h"
USE_OP_ITSELF(elementwise_add);
USE_OP_ITSELF(matmul_v2);
USE_OP_ITSELF(relu);
......
......@@ -12,66 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/jit/base_function.h"
#include "paddle/fluid/jit/layer_utils.h"
namespace paddle {
namespace jit {
Argument::Argument(const std::string &name, bool is_out)
: name_(name), is_output_(is_out) {}
const std::string &Argument::Name() const { return name_; }
std::vector<std::string> FunctionSchema::GetInputArgNames() {
std::vector<std::string> input_arg_names;
for (auto &arg : input_args) {
input_arg_names.emplace_back(arg.Name());
}
return input_arg_names;
}
std::vector<std::string> FunctionSchema::GetOutputArgNames() {
std::vector<std::string> output_arg_names;
for (auto &arg : output_args) {
output_arg_names.emplace_back(arg.Name());
}
return output_arg_names;
}
void FunctionSchema::AddInputArg(std::string name) {
input_args.emplace_back(name, false);
}
void FunctionSchema::AddOutputArg(std::string name) {
output_args.emplace_back(name, true);
}
BaseFunction::BaseFunction(const framework::ProgramDesc &program_desc,
const std::vector<std::string> &param_names,
const VariableNameMap &params_dict,
const phi::Place &place)
: program_desc_(program_desc), place_(place) {
// Parse FunctionSchema
for (auto &in_name : program_desc_.GetFeedTargetNames()) {
schema_.AddInputArg(in_name);
}
for (auto &out_name : program_desc_.GetFetchTargetNames()) {
schema_.AddOutputArg(out_name);
}
// share params into scope
ShareParamsIntoScope(param_names, params_dict);
VLOG(6) << framework::GenScopeTreeDebugInfo(&scope_);
// remove feed fetch op
RemoveFeedFetch();
}
void BaseFunction::FetchOutput(std::vector<Variable> *outs) {
for (auto &out_name : schema_.GetOutputArgNames()) {
void FetchVarsByNames(const std::vector<std::string> &names,
const framework::Scope &scope,
std::vector<Variable> *outs) {
for (auto &out_name : names) {
VLOG(3) << "fetch out: " << out_name;
auto *var = scope_.FindVar(out_name);
VLOG(3) << "after scope_.FindVar(out_name);";
auto &src_tensor = var->Get<phi::DenseTensor>();
VLOG(3) << "var->Get<phi::DenseTensor>();";
auto *var = scope.FindVar(out_name);
auto &src_tensor = var->Get<DenseTensor>();
Variable v;
auto *p = v.GetMutable<DenseTensor>();
*p = src_tensor;
......@@ -79,9 +31,10 @@ void BaseFunction::FetchOutput(std::vector<Variable> *outs) {
}
}
void BaseFunction::ShareInputsIntoScope(const std::vector<Variable> &vars) {
void ShareInputsIntoScope(const std::vector<std::string> &ordered_input_names,
const std::vector<Variable> &vars,
framework::Scope *scope) {
VLOG(3) << "vars size: " << vars.size();
std::vector<std::string> ordered_input_names = schema_.GetInputArgNames();
PADDLE_ENFORCE_EQ(
vars.size(),
ordered_input_names.size(),
......@@ -91,30 +44,30 @@ void BaseFunction::ShareInputsIntoScope(const std::vector<Variable> &vars) {
for (size_t i = 0; i < vars.size(); i++) {
VLOG(3) << "share into scope: " << ordered_input_names[i];
auto &dense_tensor = vars[i].Get<DenseTensor>();
auto *var = scope_.Var(ordered_input_names[i]);
auto *var = scope->Var(ordered_input_names[i]);
auto *dst_tensor = var->GetMutable<DenseTensor>();
*dst_tensor = dense_tensor;
}
}
void BaseFunction::ShareParamsIntoScope(
const std::vector<std::string> &param_names,
const VariableNameMap &params_dict) {
void ShareParamsIntoScope(const std::vector<std::string> &param_names,
const Name2VariableMap &params_dict,
framework::Scope *scope) {
VLOG(3) << "param_names size: " << param_names.size();
for (size_t i = 0; i < param_names.size(); ++i) {
std::string name = param_names[i];
Variable val = params_dict.find(name)->second;
auto &dense_tensor = val.Get<DenseTensor>();
auto &param = params_dict.find(name)->second;
auto &dense_tensor = param.Get<DenseTensor>();
VLOG(3) << "share into scope: " << name;
auto *var = scope_.Var(name);
auto *var = scope->Var(name);
auto *dst_tensor = var->GetMutable<DenseTensor>();
*dst_tensor = dense_tensor;
}
}
void BaseFunction::RemoveFeedFetch() {
for (size_t i = 0; i < program_desc_.Size(); ++i) {
auto *block = program_desc_.MutableBlock(i);
void RemoveFeedFetch(framework::ProgramDesc *program_desc) {
for (size_t i = 0; i < program_desc->Size(); ++i) {
auto *block = program_desc->MutableBlock(i);
const auto &all_ops = block->AllOps();
size_t op_size = all_ops.size();
VLOG(3) << "op_size: " << op_size;
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
namespace paddle {
namespace jit {
using Variable = paddle::framework::Variable;
using Name2VariableMap = std::unordered_map<std::string, Variable>;
using DenseTensor = phi::DenseTensor;
void FetchVarsByNames(const std::vector<std::string> &names,
const framework::Scope &scope,
std::vector<Variable> *outs);
void ShareInputsIntoScope(const std::vector<std::string> &ordered_input_names,
const std::vector<Variable> &vars,
framework::Scope *scope);
void ShareParamsIntoScope(const std::vector<std::string> &param_names,
const Name2VariableMap &params_dict,
framework::Scope *scope);
void RemoveFeedFetch(framework::ProgramDesc *program_desc);
} // namespace jit
} // namespace paddle
......@@ -15,42 +15,55 @@
#pragma once
#include <iostream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/executor_cache.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/jit/base_function.h"
#include "paddle/fluid/jit/function_schema.h"
#include "paddle/fluid/jit/layer_utils.h"
namespace paddle {
namespace jit {
class PEFunction : public BaseFunction {
public:
PEFunction(const framework::ProgramDesc &program_desc,
const std::vector<std::string> param_names,
const VariableNameMap &params_dict,
PEFunction(const std::shared_ptr<FunctionInfo> &info,
const Name2VariableMap &params_dict,
const phi::Place &place)
: BaseFunction(program_desc, param_names, params_dict, place) {}
: info_(info), place_(place) {
ShareParamsIntoScope(info_->GetParamNames(), params_dict, &scope_);
VLOG(6) << framework::GenScopeTreeDebugInfo(&scope_);
}
~PEFunction() {}
~PEFunction() noexcept {}
std::vector<Variable> operator()(const std::vector<Variable> &inputs) {
// bool is_test = true;
std::string prog_string;
std::hash<std::string> string_hash;
program_desc_.Proto()->SerializePartialToString(&prog_string);
auto &program_desc = info_->GetProgramDesc();
const_cast<framework::ProgramDesc *>(&program_desc)
->Proto()
->SerializePartialToString(&prog_string);
// program_desc.Proto()->SerializePartialToString(&prog_string);
int64_t program_id = static_cast<int64_t>(string_hash(prog_string));
const framework::BlockDesc &global_block = program_desc_.Block(0);
const framework::BlockDesc &global_block = program_desc.Block(0);
int64_t start_op_index = 0;
int64_t end_op_index = static_cast<int64_t>(global_block.OpSize());
ShareInputsIntoScope(inputs);
std::vector<std::string> input_var_names = schema_.GetInputArgNames();
std::vector<std::string> output_var_names = schema_.GetOutputArgNames();
ShareInputsIntoScope(info_->GetInputArgNames(), inputs, &scope_);
std::vector<std::string> input_var_names = info_->GetInputArgNames();
std::vector<std::string> output_var_names = info_->GetOutputArgNames();
std::vector<std::string> dout_var_names;
if (end_op_index > start_op_index) {
// TODO(dev): support other devices
auto cache_info = framework::GetExecutorInfoFromCache(program_desc_,
auto cache_info = framework::GetExecutorInfoFromCache(program_desc,
place_,
start_op_index,
end_op_index,
......@@ -70,7 +83,7 @@ class PEFunction : public BaseFunction {
dout_var_names.begin(),
dout_var_names.end());
framework::details::ParseSafeEagerDeletionSkipVars(
program_desc_,
program_desc,
end_op_index,
output_var_names,
&skip_eager_delete_vars);
......@@ -79,9 +92,14 @@ class PEFunction : public BaseFunction {
}
VLOG(6) << framework::GenScopeTreeDebugInfo(&scope_);
std::vector<Variable> res;
FetchOutput(&res);
FetchVarsByNames(info_->GetOutputArgNames(), scope_, &res);
return res;
}
private:
std::shared_ptr<FunctionInfo> info_;
framework::Scope scope_;
phi::Place place_;
};
} // namespace jit
......
......@@ -19,30 +19,26 @@ namespace jit {
Layer Deserializer::operator()(const std::string& dir_path) {
const auto& file_name_prefixs = GetPdmodelFileNamePrefix(dir_path);
std::vector<std::string> func_names;
std::vector<framework::ProgramDesc> program_descs;
std::vector<std::vector<std::string>> param_names_for_each_program;
// set is ordered
std::set<std::string> param_names_set;
VariableNameMap params_dict;
std::vector<std::shared_ptr<FunctionInfo>> infos;
Name2VariableMap params_dict;
for (auto& it : file_name_prefixs) {
func_names.emplace_back(it.first);
auto& func_name = it.first;
auto program_desc = LoadProgram(dir_path + it.second + PDMODEL_SUFFIX);
auto program = LoadProgram(dir_path + it.second + PDMODEL_SUFFIX);
program_descs.emplace_back(program);
// TODO(dev): load int/float params
std::vector<std::string> persistable_var_names;
auto all_var_desc = program.Block(0).AllVars();
// TODO(dev): load int/float attrs
std::vector<std::string> persist_var_names;
auto all_var_desc = program_desc.Block(0).AllVars();
for (auto* desc_ptr : all_var_desc) {
if (IsPersistable(desc_ptr)) {
persistable_var_names.emplace_back(desc_ptr->Name());
persist_var_names.emplace_back(desc_ptr->Name());
}
}
param_names_for_each_program.emplace_back(persistable_var_names);
param_names_set.insert(persistable_var_names.begin(),
persistable_var_names.end());
param_names_set.insert(persist_var_names.begin(), persist_var_names.end());
infos.emplace_back(std::make_shared<FunctionInfo>(
func_name, persist_var_names, program_desc));
}
auto default_place = imperative::GetCurrentTracer()->ExpectedPlace();
......@@ -52,11 +48,7 @@ Layer Deserializer::operator()(const std::string& dir_path) {
default_place,
&params_dict);
return Layer(func_names,
program_descs,
param_names_for_each_program,
params_dict,
default_place);
return Layer(infos, params_dict, default_place);
}
bool Deserializer::IsPersistable(framework::VarDesc* desc_ptr) {
......@@ -100,7 +92,7 @@ Deserializer::GetPdmodelFileNamePrefix(const std::string& path) {
void Deserializer::ReadTensorData(const std::string& file_name,
const std::set<std::string>& var_name,
const phi::Place& place,
VariableNameMap* params_dict) const {
Name2VariableMap* params_dict) const {
VLOG(3) << "ReadTensorData from: " << file_name;
std::ifstream fin(file_name, std::ios::binary);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
......
......@@ -15,19 +15,18 @@
#pragma once
#include <dirent.h>
#include <algorithm>
#include <fstream>
#include <set>
#include <string>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/jit/layer.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/fluid/jit/function_schema.h"
#include "paddle/fluid/jit/layer.h"
namespace paddle {
namespace jit {
static const char PDMODEL_SUFFIX[] = ".pdmodel";
......@@ -62,7 +61,7 @@ class Deserializer {
void ReadTensorData(const std::string& file_name,
const std::set<std::string>& var_name,
const phi::Place& place,
VariableNameMap* params_dict) const;
Name2VariableMap* params_dict) const;
// void ReadExtraInfo(const std::string& file_name) const;
// void ReadByteCode(const std::string& file_name) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册