未验证 提交 8b2c906a 编写于 作者: A Aurelius84 提交者: GitHub

Simplify constructor of InterpreterCore (#37072)

* Simplify constructor of InterpreterCore

* fix bool

* clean code
上级 76d2fd1d
...@@ -16,9 +16,8 @@ ...@@ -16,9 +16,8 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace interpreter {
void EventManager::WaitEvent(const Instruction& instruction, void WaitEvent(const Instruction& instruction, const platform::Place& place) {
const platform::Place& place) {
// If InterpreterCore in on CPUPlace, do nothing. // If InterpreterCore in on CPUPlace, do nothing.
if (platform::is_cpu_place(place)) return; if (platform::is_cpu_place(place)) return;
...@@ -32,8 +31,7 @@ void EventManager::WaitEvent(const Instruction& instruction, ...@@ -32,8 +31,7 @@ void EventManager::WaitEvent(const Instruction& instruction,
} }
} }
void EventManager::RecordEvent(const Instruction& instruction, void RecordEvent(const Instruction& instruction, const platform::Place& place) {
const platform::Place& place) {
// If InterpreterCore in on CPUPlace, do nothing. // If InterpreterCore in on CPUPlace, do nothing.
if (platform::is_cpu_place(place)) return; if (platform::is_cpu_place(place)) return;
...@@ -43,5 +41,6 @@ void EventManager::RecordEvent(const Instruction& instruction, ...@@ -43,5 +41,6 @@ void EventManager::RecordEvent(const Instruction& instruction,
} }
} }
} // namespace interpreter
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -17,14 +17,11 @@ ...@@ -17,14 +17,11 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace interpreter {
void RecordEvent(const Instruction& instruction, const platform::Place& place);
class EventManager { void WaitEvent(const Instruction& instruction, const platform::Place& place);
public:
void RecordEvent(const Instruction& instruction,
const platform::Place& place);
void WaitEvent(const Instruction& instruction, const platform::Place& place);
};
} // namespace interpreter
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -33,9 +33,9 @@ namespace framework { ...@@ -33,9 +33,9 @@ namespace framework {
// NOTE(Aurelius84): Need a better strategy to determine it. // NOTE(Aurelius84): Need a better strategy to determine it.
static constexpr size_t kHostNumThreads = 4; static constexpr size_t kHostNumThreads = 4;
InterpreterCore::InterpreterCore(const platform::Place& place, BlockDesc* block, InterpreterCore::InterpreterCore(const platform::Place& place,
VariableScope* global_scope, const BlockDesc& block,
const std::vector<std::string>& feed_names) VariableScope* global_scope)
: place_(place), : place_(place),
block_(block), block_(block),
global_scope_(global_scope), global_scope_(global_scope),
...@@ -45,8 +45,6 @@ InterpreterCore::InterpreterCore(const platform::Place& place, BlockDesc* block, ...@@ -45,8 +45,6 @@ InterpreterCore::InterpreterCore(const platform::Place& place, BlockDesc* block,
new interpreter::AsyncWorkQueue(kHostNumThreads, &main_thread_blocker_)); new interpreter::AsyncWorkQueue(kHostNumThreads, &main_thread_blocker_));
gc_.reset(new InterpreterCoreGarbageCollector()); gc_.reset(new InterpreterCoreGarbageCollector());
feed_names_ = feed_names;
exception_notifier_ = main_thread_blocker_.RegisterEvent( exception_notifier_ = main_thread_blocker_.RegisterEvent(
kExceptionCaught, [this]() { return exception_holder_.IsCaught(); }); kExceptionCaught, [this]() { return exception_holder_.IsCaught(); });
...@@ -65,27 +63,12 @@ InterpreterCore::~InterpreterCore() { ...@@ -65,27 +63,12 @@ InterpreterCore::~InterpreterCore() {
} }
paddle::framework::FetchList InterpreterCore::Run( paddle::framework::FetchList InterpreterCore::Run(
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors) { const std::vector<framework::LoDTensor>& feed_tensors) {
auto FeedInput = [&] { bool is_build = is_build_;
for (size_t i = 0; i < feed_names_.size(); ++i) { Prepare(feed_names, feed_tensors, is_build);
auto* feed_var = global_scope_->Var(feed_names_[i]);
auto feed_tensor = feed_var->GetMutable<framework::LoDTensor>();
feed_tensor->ShareDataWith(feed_tensors[i]);
feed_tensor->set_lod(feed_tensors[i].lod());
}
};
if (is_build_ == false) { if (is_build) {
paddle::framework::interpreter::build_variable_scope(*block_,
global_scope_);
FeedInput();
paddle::framework::interpreter::build_op_func_list(
place_, *block_, &vec_func_list_, global_scope_);
is_build_ = true;
// convert vec func_list to graph
Convert();
} else {
FeedInput();
ExecuteInstructionList(vec_instruction_); ExecuteInstructionList(vec_instruction_);
} }
...@@ -95,9 +78,9 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -95,9 +78,9 @@ paddle::framework::FetchList InterpreterCore::Run(
} }
void InterpreterCore::Convert() { void InterpreterCore::Convert() {
auto& vec_meta_info = global_scope_->MutableVecMetaInfo();
auto var_nums = global_scope_->VarSize(); auto var_nums = global_scope_->VarSize();
input_var2op_info_.resize(var_nums); input_var2op_info_.resize(var_nums);
vec_meta_info_.resize(var_nums);
auto op_nums = vec_func_list_.size(); auto op_nums = vec_func_list_.size();
vec_instruction_.reserve(op_nums); vec_instruction_.reserve(op_nums);
...@@ -136,7 +119,7 @@ void InterpreterCore::Convert() { ...@@ -136,7 +119,7 @@ void InterpreterCore::Convert() {
gc_check_input_list.erase(last, gc_check_input_list.end()); gc_check_input_list.erase(last, gc_check_input_list.end());
for (auto var_id : gc_check_input_list) { for (auto var_id : gc_check_input_list) {
vec_meta_info_[var_id].var_ref_count_++; vec_meta_info[var_id].var_ref_count_++;
instr.AddGCCheckVar(var_id); instr.AddGCCheckVar(var_id);
} }
} }
...@@ -148,7 +131,7 @@ void InterpreterCore::Convert() { ...@@ -148,7 +131,7 @@ void InterpreterCore::Convert() {
if (input_var2op_info_.at(id).size() == 0) { if (input_var2op_info_.at(id).size() == 0) {
// output var not be used by any kernel // output var not be used by any kernel
vec_instruction_[i].AddGCCheckVar(id); vec_instruction_[i].AddGCCheckVar(id);
vec_meta_info_[id].var_ref_count_++; vec_meta_info[id].var_ref_count_++;
} }
} }
} }
...@@ -180,7 +163,7 @@ void InterpreterCore::Convert() { ...@@ -180,7 +163,7 @@ void InterpreterCore::Convert() {
} }
for (size_t i = 0; i < vec_instruction_.size(); ++i) { for (size_t i = 0; i < vec_instruction_.size(); ++i) {
BuildAndCacheInstructionCtx(&vec_instruction_[i], *global_scope_, place_); BuildAndCacheInstructionCtx(&vec_instruction_[i]);
} }
BuildSkipShareLoDInfo(); BuildSkipShareLoDInfo();
...@@ -248,16 +231,14 @@ void InterpreterCore::BuildInplace() { ...@@ -248,16 +231,14 @@ void InterpreterCore::BuildInplace() {
} }
} }
void InterpreterCore::BuildAndCacheInstructionCtx( void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) {
Instruction* instr_node, const VariableScope& var_scope,
const platform::Place& place) {
VariableValueMap ins_map; VariableValueMap ins_map;
for (auto& var_name_item : instr_node->Inputs()) { for (auto& var_name_item : instr_node->Inputs()) {
std::vector<Variable*> input_vars; std::vector<Variable*> input_vars;
input_vars.reserve(var_name_item.second.size()); input_vars.reserve(var_name_item.second.size());
for (auto& id : var_name_item.second) { for (auto& id : var_name_item.second) {
input_vars.emplace_back(var_scope.Var(id)); input_vars.emplace_back(global_scope_->Var(id));
} }
ins_map.emplace(var_name_item.first, std::move(input_vars)); ins_map.emplace(var_name_item.first, std::move(input_vars));
} }
...@@ -268,7 +249,7 @@ void InterpreterCore::BuildAndCacheInstructionCtx( ...@@ -268,7 +249,7 @@ void InterpreterCore::BuildAndCacheInstructionCtx(
out_vars.reserve(var_name_item.second.size()); out_vars.reserve(var_name_item.second.size());
for (auto& id : var_name_item.second) { for (auto& id : var_name_item.second) {
out_vars.emplace_back(var_scope.Var(id)); out_vars.emplace_back(global_scope_->Var(id));
} }
outs_map.emplace(var_name_item.first, std::move(out_vars)); outs_map.emplace(var_name_item.first, std::move(out_vars));
} }
...@@ -359,7 +340,7 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -359,7 +340,7 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
void InterpreterCore::ExecuteInstructionList( void InterpreterCore::ExecuteInstructionList(
const std::vector<Instruction>& vec_instr) { const std::vector<Instruction>& vec_instr) {
async_work_queue_->PrepareAtomicDeps(dependecy_count_); async_work_queue_->PrepareAtomicDeps(dependecy_count_);
async_work_queue_->PrepareAtomicVarRef(vec_meta_info_); async_work_queue_->PrepareAtomicVarRef(global_scope_->VecMetaInfo());
op_run_number_ = 0; op_run_number_ = 0;
exception_holder_.Clear(); exception_holder_.Clear();
...@@ -452,7 +433,7 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) { ...@@ -452,7 +433,7 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) {
auto& instr_node = vec_instruction_.at(instr_id); auto& instr_node = vec_instruction_.at(instr_id);
auto* op = instr_node.OpBase(); auto* op = instr_node.OpBase();
platform::RecordEvent instruction_event(op->Type()); platform::RecordEvent instruction_event(op->Type());
event_manager_.WaitEvent(instr_node, place_); interpreter::WaitEvent(instr_node, place_);
try { try {
RunInstruction(instr_node); RunInstruction(instr_node);
...@@ -479,7 +460,7 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) { ...@@ -479,7 +460,7 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) {
return; return;
} }
event_manager_.RecordEvent(instr_node, place_); interpreter::RecordEvent(instr_node, place_);
op_run_number_.fetch_add(1, std::memory_order_relaxed); op_run_number_.fetch_add(1, std::memory_order_relaxed);
// GC infomation // GC infomation
...@@ -508,11 +489,18 @@ void InterpreterCore::CheckGC(const Instruction& instr) { ...@@ -508,11 +489,18 @@ void InterpreterCore::CheckGC(const Instruction& instr) {
} }
} }
void InterpreterCore::DryRunPrepare( void InterpreterCore::Prepare(
const std::vector<framework::LoDTensor>& feed_tensors) { const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors, bool prepare_feed) {
PADDLE_ENFORCE_EQ(feed_names.size(), feed_tensors.size(),
platform::errors::PreconditionNotMet(
"Required feed_names.size() == feed_tensors.size(), "
"but received %d != %d",
feed_names.size(), feed_tensors.size()));
auto FeedInput = [&] { auto FeedInput = [&] {
for (size_t i = 0; i < feed_names_.size(); ++i) { for (size_t i = 0; i < feed_names.size(); ++i) {
auto* feed_var = global_scope_->FindVar(feed_names_[i]); auto* feed_var = global_scope_->FindVar(feed_names[i]);
PADDLE_ENFORCE_NOT_NULL(feed_var, platform::errors::NotFound( PADDLE_ENFORCE_NOT_NULL(feed_var, platform::errors::NotFound(
"feed_var shall not be nullptr.")); "feed_var shall not be nullptr."));
...@@ -522,35 +510,33 @@ void InterpreterCore::DryRunPrepare( ...@@ -522,35 +510,33 @@ void InterpreterCore::DryRunPrepare(
} }
}; };
if (is_build_ == false) { if (!is_build_) {
paddle::framework::interpreter::build_variable_scope(*block_, paddle::framework::interpreter::build_variable_scope(block_, global_scope_);
global_scope_);
FeedInput(); FeedInput();
paddle::framework::interpreter::build_op_func_list( paddle::framework::interpreter::build_op_func_list(
place_, *block_, &vec_func_list_, global_scope_); place_, block_, &vec_func_list_, global_scope_);
is_build_ = true; is_build_ = true;
// convert vec func_list to graph // convert vec func_list to graph
Convert(); Convert();
} }
// NOTE: Because feed_tensor will be GC after // NOTE: Because feed_tensor will be GC after
// paddle::framework::build_op_func_list, so we should // paddle::framework::build_op_func_list, so we should
// call // call FeedInput again.
// FeedInput again. if (prepare_feed) FeedInput();
FeedInput();
} }
const CostInfo& InterpreterCore::DryRun( interpreter::CostInfo InterpreterCore::DryRun(
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors) { const std::vector<framework::LoDTensor>& feed_tensors) {
DryRunPrepare(feed_tensors); Prepare(feed_names, feed_tensors, true);
// DryRun may be called many times. interpreter::CostInfo cost_info;
dry_run_profiler_.Reset(); {
dry_run_profiler_.Start(); interpreter::ProfilerGuard(place_, &cost_info);
ExecuteInstructionList(vec_instruction_); ExecuteInstructionList(vec_instruction_);
platform::DeviceContextPool::Instance().Get(place_)->Wait(); platform::DeviceContextPool::Instance().Get(place_)->Wait();
}
dry_run_profiler_.Pause();
dry_run_profiler_.TotalCUDAAllocatedMemorySize(place_); return cost_info;
return dry_run_profiler_.GetCostInfo();
} }
} // namespace framework } // namespace framework
......
...@@ -40,23 +40,23 @@ using AtomicVectorSizeT = std::vector<std::unique_ptr<std::atomic<size_t>>>; ...@@ -40,23 +40,23 @@ using AtomicVectorSizeT = std::vector<std::unique_ptr<std::atomic<size_t>>>;
class InterpreterCore { class InterpreterCore {
public: public:
InterpreterCore(const platform::Place& place, BlockDesc* block, InterpreterCore(const platform::Place& place, const BlockDesc& block,
VariableScope* global_scope, VariableScope* global_scope);
const std::vector<std::string>& feed_names);
~InterpreterCore(); ~InterpreterCore();
paddle::framework::FetchList Run( paddle::framework::FetchList Run(
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors); const std::vector<framework::LoDTensor>& feed_tensors);
const CostInfo& DryRun(const std::vector<framework::LoDTensor>& feed_tensors); interpreter::CostInfo DryRun(
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors);
private: private:
void Convert(); void Convert();
void BuildAndCacheInstructionCtx(Instruction* instr_node, void BuildAndCacheInstructionCtx(Instruction* instr_node);
const VariableScope& var_scope,
const platform::Place& place);
void BuildInplace(); void BuildInplace();
...@@ -66,7 +66,9 @@ class InterpreterCore { ...@@ -66,7 +66,9 @@ class InterpreterCore {
void ExecuteInstructionList(const std::vector<Instruction>& vec_instr); void ExecuteInstructionList(const std::vector<Instruction>& vec_instr);
void DryRunPrepare(const std::vector<framework::LoDTensor>& feed_tensors); void Prepare(const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors,
bool prepare_feed);
void CheckGC(const Instruction& instr); void CheckGC(const Instruction& instr);
...@@ -79,22 +81,17 @@ class InterpreterCore { ...@@ -79,22 +81,17 @@ class InterpreterCore {
bool is_build_; bool is_build_;
const platform::Place& place_; const platform::Place& place_;
BlockDesc* block_; // not owned const BlockDesc& block_; // not owned
VariableScope* global_scope_; // not owned VariableScope* global_scope_; // not owned
std::vector<paddle::framework::OpFuncNode> vec_func_list_; std::vector<paddle::framework::OpFuncNode> vec_func_list_;
std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode
InstructionInfo instruction_info_;
std::vector<size_t> dependecy_count_; std::vector<size_t> dependecy_count_;
std::atomic<size_t> op_run_number_{0};
std::vector<std::vector<size_t>> input_var2op_info_; std::vector<std::vector<size_t>> input_var2op_info_;
std::vector<VariableMetaInfo> vec_meta_info_;
std::vector<std::string> feed_names_;
InterpreterProfiler dry_run_profiler_;
StreamAnalyzer stream_analyzer_; StreamAnalyzer stream_analyzer_;
EventManager event_manager_;
EventsWaiter main_thread_blocker_; EventsWaiter main_thread_blocker_;
std::unique_ptr<interpreter::AsyncWorkQueue> async_work_queue_; std::unique_ptr<interpreter::AsyncWorkQueue> async_work_queue_;
details::ExceptionHolder exception_holder_; details::ExceptionHolder exception_holder_;
...@@ -102,7 +99,6 @@ class InterpreterCore { ...@@ -102,7 +99,6 @@ class InterpreterCore {
std::unique_ptr<InterpreterCoreGarbageCollector> gc_; std::unique_ptr<InterpreterCoreGarbageCollector> gc_;
std::vector<paddle::platform::DeviceEvent> gc_event_; std::vector<paddle::platform::DeviceEvent> gc_event_;
std::atomic<size_t> op_run_number_{0};
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -607,6 +607,12 @@ class VariableScope : public ScopeBase { ...@@ -607,6 +607,12 @@ class VariableScope : public ScopeBase {
platform::errors::NotFound("%s not in VariableScope.", name)); platform::errors::NotFound("%s not in VariableScope.", name));
} }
std::vector<VariableMetaInfo>& MutableVecMetaInfo() { return vec_meta_info_; }
const std::vector<VariableMetaInfo>& VecMetaInfo() const {
return vec_meta_info_;
}
private: private:
std::vector<Variable*> var_list_; std::vector<Variable*> var_list_;
std::map<std::string, int> name2id_; std::map<std::string, int> name2id_;
......
...@@ -20,84 +20,41 @@ ...@@ -20,84 +20,41 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace interpreter {
static void GetTensors(Variable* var, std::unordered_set<Tensor*>* tensor_set) {
if (var->IsType<LoDTensor>() && var->Get<LoDTensor>().IsInitialized()) {
tensor_set->insert(var->GetMutable<LoDTensor>());
} else if (var->IsType<SelectedRows>() &&
var->Get<SelectedRows>().value().IsInitialized()) {
tensor_set->insert(var->GetMutable<SelectedRows>()->mutable_value());
} else if (var->IsType<LoDTensorArray>()) {
auto* tensor_arr = var->GetMutable<LoDTensorArray>();
for (auto& t : *tensor_arr) {
if (t.IsInitialized()) {
tensor_set->insert(&t);
}
}
}
}
static std::pair<size_t, size_t> GetTensorMemorySize(
const std::vector<Variable*>& var_list) {
std::unordered_set<Tensor*> tensor_set;
for (auto* var : var_list) {
GetTensors(var, &tensor_set);
}
size_t host_memory_bytes = 0;
size_t device_memory_bytes = 0;
std::unordered_set<memory::Allocation*> allocation_set;
for (auto* tensor : tensor_set) {
auto allocation = tensor->Holder().get();
if (!allocation_set.count(allocation)) {
allocation_set.insert(allocation);
if (platform::is_cuda_pinned_place(tensor->place()) ||
platform::is_cpu_place(tensor->place())) {
VLOG(3) << "found host memory : " << allocation->size();
host_memory_bytes += allocation->size();
} else {
VLOG(3) << "found device memory : " << allocation->size();
device_memory_bytes += allocation->size();
}
}
}
return {host_memory_bytes, device_memory_bytes};
}
struct CostInfo { struct CostInfo {
double total_time{0.}; // ms double total_time{0.}; // ms
size_t device_memory_bytes{0}; // total allocated memory size size_t device_memory_bytes{0}; // total allocated memory size
}; };
class InterpreterProfiler { class ProfilerGuard {
public: public:
void Start() { timer_.Start(); } ProfilerGuard(const platform::Place& place, CostInfo* cost_info)
: place_(place), cost_info_(cost_info) {
void Pause() { timer_.Start();
timer_.Pause();
cost_info_.total_time += timer_.ElapsedMS();
} }
void Reset() { ~ProfilerGuard() {
timer_.Reset(); timer_.Pause();
cost_info_.total_time = 0.; cost_info_->total_time += timer_.ElapsedMS();
cost_info_.device_memory_bytes = 0; TotalCUDAAllocatedMemorySize(place_);
} }
private:
void TotalCUDAAllocatedMemorySize(const platform::Place& place) { void TotalCUDAAllocatedMemorySize(const platform::Place& place) {
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto cuda_place = BOOST_GET_CONST(platform::CUDAPlace, place); auto cuda_place = BOOST_GET_CONST(platform::CUDAPlace, place);
cost_info_.device_memory_bytes = cost_info_->device_memory_bytes =
platform::RecordedCudaMallocSize(cuda_place.device); platform::RecordedCudaMallocSize(cuda_place.device);
#endif #endif
} }
} }
const CostInfo& GetCostInfo() const { return cost_info_; } const platform::Place& place_;
CostInfo* cost_info_;
private:
platform::Timer timer_; platform::Timer timer_;
CostInfo cost_info_;
}; };
} // namespace interpreter
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -51,16 +51,15 @@ paddle::framework::FetchList StandaloneExecutor::Run( ...@@ -51,16 +51,15 @@ paddle::framework::FetchList StandaloneExecutor::Run(
const std::vector<std::string>& fetch_names) { const std::vector<std::string>& fetch_names) {
auto core = GetInterpreterCore(feed_names, fetch_names); auto core = GetInterpreterCore(feed_names, fetch_names);
return core->Run(feed_tensors); return core->Run(feed_names, feed_tensors);
} }
const CostInfo& StandaloneExecutor::DryRun( framework::interpreter::CostInfo StandaloneExecutor::DryRun(
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors) { const std::vector<framework::LoDTensor>& feed_tensors) {
auto core = GetInterpreterCore(feed_names, {}); auto core = GetInterpreterCore(feed_names, {});
auto& cost_info = core->DryRun(feed_tensors); return core->DryRun(feed_names, feed_tensors);
return cost_info;
} }
void StandaloneExecutor::BuildVariableOuterScope( void StandaloneExecutor::BuildVariableOuterScope(
...@@ -102,8 +101,8 @@ std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore( ...@@ -102,8 +101,8 @@ std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore(
auto* block = new_prog->MutableBlock(0); auto* block = new_prog->MutableBlock(0);
interpreter::add_fetch(fetch_names, block); interpreter::add_fetch(fetch_names, block);
auto core = std::make_shared<InterpreterCore>(place_, block, &global_scope_, auto core =
feed_names); std::make_shared<InterpreterCore>(place_, *block, &global_scope_);
programs_.emplace(oss.str(), new_prog); programs_.emplace(oss.str(), new_prog);
interpretercores_.emplace(oss.str(), core); interpretercores_.emplace(oss.str(), core);
return core; return core;
......
...@@ -45,8 +45,9 @@ class StandaloneExecutor : public ExecutorBase { ...@@ -45,8 +45,9 @@ class StandaloneExecutor : public ExecutorBase {
const std::vector<framework::LoDTensor>& feed_tensors, const std::vector<framework::LoDTensor>& feed_tensors,
const std::vector<std::string>& fetch_names); const std::vector<std::string>& fetch_names);
const CostInfo& DryRun(const std::vector<std::string>& feed_names, framework::interpreter::CostInfo DryRun(
const std::vector<framework::LoDTensor>& feed_tensors); const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors);
private: private:
void BuildVariableOuterScope(const framework::ProgramDesc& pdesc, void BuildVariableOuterScope(const framework::ProgramDesc& pdesc,
......
...@@ -2069,11 +2069,13 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -2069,11 +2069,13 @@ All parameter, weight, gradient are variables in Paddle.
fetch_vars); fetch_vars);
}); });
py::class_<framework::CostInfo>(m, "CostInfo") py::class_<framework::interpreter::CostInfo>(m, "CostInfo")
.def(py::init<>()) .def(py::init<>())
.def("total_time", [](CostInfo &self) { return self.total_time; }) .def("total_time",
.def("device_memory_bytes", [](interpreter::CostInfo &self) { return self.total_time; })
[](CostInfo &self) { return self.device_memory_bytes; }); .def("device_memory_bytes", [](interpreter::CostInfo &self) {
return self.device_memory_bytes;
});
py::class_<framework::StandaloneExecutor>(m, "StandaloneExecutor") py::class_<framework::StandaloneExecutor>(m, "StandaloneExecutor")
.def(py::init<const platform::Place &, const ProgramDesc &, .def(py::init<const platform::Place &, const ProgramDesc &,
...@@ -2134,7 +2136,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -2134,7 +2136,7 @@ All parameter, weight, gradient are variables in Paddle.
feed_tensors.push_back(t); feed_tensors.push_back(t);
} }
CostInfo cost_info; framework::interpreter::CostInfo cost_info;
{ {
pybind11::gil_scoped_release release; pybind11::gil_scoped_release release;
cost_info = self.DryRun(feed_names, feed_tensors); cost_info = self.DryRun(feed_names, feed_tensors);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册