未验证 提交 c3135426 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] refactor code of interpretercore gc (#39617)


* relocate code of interpretercore gc
上级 f29da150
set(INTERPRETERCORE_DEPS op_registry device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor nan_inf_utils interpretercore_event_garbage_collector)
graph_to_program_pass variable_helper timer monitor nan_inf_utils)
if(WITH_GPU)
list(APPEND INTERPRETERCORE_DEPS interpretercore_fast_garbage_collector)
endif()
add_subdirectory(workqueue)
add_subdirectory(garbage_collector)
cc_library(data_transfer SRCS data_transfer.cc DEPS enforce scope glog)
cc_library(new_executor_defs SRCS new_executor_defs.cc DEPS enforce glog scope)
cc_library(interpretercore_garbage_collector SRCS interpretercore_garbage_collector.cc DEPS garbage_collector)
cc_library(interpretercore_event_garbage_collector SRCS interpretercore_event_garbage_collector.cc DEPS interpretercore_garbage_collector)
cc_library(interpretercore_util SRCS interpretercore_util.cc DEPS ${INTERPRETERCORE_DEPS} workqueue new_executor_defs data_transfer)
cc_library(event_manager SRCS event_manager.cc DEPS ${DEVICE_EVENT_LIBS} glog new_executor_defs)
cc_library(stream_analyzer SRCS stream_analyzer.cc DEPS ${DEVICE_EVENT_LIBS} glog device_context new_executor_defs)
cc_library(interpretercore SRCS interpretercore.cc DEPS workqueue ${DEVICE_EVENT_LIBS} interpretercore_util interpretercore_event_garbage_collector stream_analyzer event_manager)
cc_library(standalone_executor SRCS standalone_executor.cc DEPS interpretercore)
if(WITH_GPU OR WITH_ROCM)
if(WITH_GPU)
nv_library(interpretercore_fast_garbage_collector SRCS interpretercore_fast_garbage_collector.cc DEPS interpretercore_garbage_collector)
elseif(WITH_ROCM)
hip_library(interpretercore_fast_garbage_collector SRCS interpretercore_fast_garbage_collector.cc DEPS interpretercore_garbage_collector)
endif()
target_link_libraries(interpretercore interpretercore_fast_garbage_collector)
cc_library(interpretercore SRCS interpretercore.cc DEPS workqueue ${DEVICE_EVENT_LIBS} interpretercore_util interpretercore_event_garbage_collector interpretercore_fast_garbage_collector stream_analyzer event_manager)
else()
cc_library(interpretercore SRCS interpretercore.cc DEPS workqueue ${DEVICE_EVENT_LIBS} interpretercore_util interpretercore_event_garbage_collector stream_analyzer event_manager)
endif()
cc_library(standalone_executor SRCS standalone_executor.cc DEPS interpretercore)
# cc_binary(standalone_executor_test SRCS standalone_executor_test.cc DEPS interpretercore standalone_executor operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler)
# skip win32 since wget is not installed by default on windows machine.
# skip COVERAGE_CI since the test runs slowly because of instrumentation.
......
cc_library(interpretercore_garbage_collector SRCS garbage_collector.cc DEPS garbage_collector)
cc_library(interpretercore_event_garbage_collector SRCS event_garbage_collector.cc DEPS interpretercore_garbage_collector)
if(WITH_GPU OR WITH_ROCM)
if(WITH_GPU)
nv_library(interpretercore_fast_garbage_collector SRCS fast_garbage_collector.cc DEPS interpretercore_garbage_collector)
elseif(WITH_ROCM)
hip_library(interpretercore_fast_garbage_collector SRCS fast_garbage_collector.cc DEPS interpretercore_garbage_collector)
endif()
endif()
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/new_executor/interpretercore_event_garbage_collector.h"
#include "paddle/fluid/framework/new_executor/garbage_collector/event_garbage_collector.h"
#if !defined(_WIN32)
#include <sched.h>
......@@ -36,7 +36,7 @@ InterpreterCoreEventGarbageCollector::~InterpreterCoreEventGarbageCollector() {
}
void InterpreterCoreEventGarbageCollector::Add(
Garbage garbage, platform::DeviceEvent& event,
Garbage garbage, platform::DeviceEvent* event,
const platform::DeviceContext* ctx) {
if (!garbage) {
return;
......@@ -60,8 +60,14 @@ void InterpreterCoreEventGarbageCollector::Add(
}
}
void InterpreterCoreEventGarbageCollector::Add(Variable* var) {
PADDLE_THROW(platform::errors::Unimplemented(
"Add(Variable* var) is not implemented for "
"InterpreterCoreEventGarbageCollector."));
}
void InterpreterCoreEventGarbageCollector::Add(
Variable* var, platform::DeviceEvent& event,
Variable* var, platform::DeviceEvent* event,
const platform::DeviceContext* ctx) {
if (UNLIKELY(max_memory_size_ < 0) || var == nullptr) {
return;
......@@ -100,11 +106,11 @@ void InterpreterCoreEventGarbageCollector::Add(
}
void InterpreterCoreEventGarbageCollector::Free(
GarbageQueue* garbages, platform::DeviceEvent& event,
GarbageQueue* garbages, platform::DeviceEvent* event,
const platform::DeviceContext* ctx) {
event.Record(ctx);
event.SetFininshed(); // Only for CPU Event
queue_->AddTask([ container = garbages, event = &event ]() {
event->Record(ctx);
event->SetFininshed(); // Only for CPU Event
queue_->AddTask([ container = garbages, event = event ]() {
while (!event->Query()) {
#if defined(_WIN32)
SleepEx(50, FALSE);
......@@ -118,11 +124,11 @@ void InterpreterCoreEventGarbageCollector::Free(
}
void InterpreterCoreEventGarbageCollector::Free(
Garbage& garbage, platform::DeviceEvent& event,
const Garbage& garbage, platform::DeviceEvent* event,
const platform::DeviceContext* ctx) {
event.Record(ctx);
event.SetFininshed(); // Only for CPU Event
queue_->AddTask([ container = garbage, event = &event ]() {
event->Record(ctx);
event->SetFininshed(); // Only for CPU Event
queue_->AddTask([ container = garbage, event = event ]() {
while (!event->Query()) {
#if defined(_WIN32)
SleepEx(50, FALSE);
......
......@@ -14,7 +14,7 @@
#pragma once
#include <queue>
#include "paddle/fluid/framework/new_executor/interpretercore_garbage_collector.h"
#include "paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.h"
#include "paddle/fluid/framework/new_executor/workqueue/workqueue.h"
namespace paddle {
......@@ -26,15 +26,17 @@ class InterpreterCoreEventGarbageCollector
InterpreterCoreEventGarbageCollector();
~InterpreterCoreEventGarbageCollector();
virtual void Add(Variable* var, platform::DeviceEvent& event,
const platform::DeviceContext* ctx) override;
void Add(Variable* var) override;
virtual void Add(Variable* var, platform::DeviceEvent* event,
const platform::DeviceContext* ctx);
private:
void Add(Garbage garbage, platform::DeviceEvent& event,
void Add(Garbage garbage, platform::DeviceEvent* event,
const platform::DeviceContext* ctx);
void Free(GarbageQueue* garbages, platform::DeviceEvent& event,
void Free(GarbageQueue* garbages, platform::DeviceEvent* event,
const platform::DeviceContext* ctx);
void Free(Garbage& garbage, platform::DeviceEvent& event,
void Free(const Garbage& garbage, platform::DeviceEvent* event,
const platform::DeviceContext* ctx);
std::unique_ptr<WorkQueue> queue_;
......
......@@ -12,11 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/new_executor/interpretercore_fast_garbage_collector.h"
#include "paddle/fluid/framework/new_executor/garbage_collector/fast_garbage_collector.h"
namespace paddle {
namespace framework {
void InterpreterCoreFastGarbageCollector::Add(
Variable* var, platform::DeviceEvent* event,
const platform::DeviceContext* ctx) {
PADDLE_THROW(platform::errors::Unimplemented(
"Not implemented for InterpreterCoreFastGarbageCollector."));
}
void InterpreterCoreFastGarbageCollector::Add(Variable* var) {
if (UNLIKELY(max_memory_size_ < 0) || var == nullptr) {
return;
......
......@@ -13,7 +13,9 @@
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/new_executor/interpretercore_garbage_collector.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.h"
namespace paddle {
namespace framework {
......@@ -21,10 +23,14 @@ namespace framework {
class InterpreterCoreFastGarbageCollector
: public InterpreterCoreGarbageCollector {
public:
virtual void Add(Variable* var) override;
void Add(Variable* var) override;
void Add(Variable* var, platform::DeviceEvent* event,
const platform::DeviceContext* ctx) override;
private:
void Add(Garbage garbage);
};
} // namespace framework
} // namespace paddle
#endif
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/new_executor/interpretercore_garbage_collector.h"
#include "paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.h"
#include "paddle/fluid/framework/garbage_collector.h"
namespace paddle {
......@@ -24,19 +24,5 @@ InterpreterCoreGarbageCollector::InterpreterCoreGarbageCollector() {
cur_memory_size_ = 0;
}
void InterpreterCoreGarbageCollector::Add(Variable* var) {
PADDLE_THROW(
platform::errors::Unimplemented("Not allowed to call the member function "
"of InterpreterCoreGarbageCollector"));
}
void InterpreterCoreGarbageCollector::Add(Variable* var,
platform::DeviceEvent& event,
const platform::DeviceContext* ctx) {
PADDLE_THROW(
platform::errors::Unimplemented("Not allowed to call the member function "
"of InterpreterCoreGarbageCollector"));
}
} // namespace framework
} // namespace paddle
......@@ -16,6 +16,8 @@
#include <queue>
#include "paddle/fluid/memory/allocation/spin_lock.h"
#include "paddle/fluid/platform/device_event.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle {
namespace framework {
......@@ -26,10 +28,10 @@ using GarbageQueue = std::deque<Garbage>;
class InterpreterCoreGarbageCollector {
public:
InterpreterCoreGarbageCollector();
virtual ~InterpreterCoreGarbageCollector(){};
virtual void Add(Variable* var);
virtual void Add(Variable* var, platform::DeviceEvent& event,
const platform::DeviceContext* ctx);
virtual ~InterpreterCoreGarbageCollector() {}
virtual void Add(Variable* var) = 0;
virtual void Add(Variable* var, platform::DeviceEvent* event,
const platform::DeviceContext* ctx) = 0;
DISABLE_COPY_AND_ASSIGN(InterpreterCoreGarbageCollector);
protected:
......
......@@ -16,16 +16,13 @@
#include <unordered_set>
#include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/details/share_tensor_buffer_functor.h"
#include "paddle/fluid/framework/new_executor/interpretercore_event_garbage_collector.h"
#include "paddle/fluid/framework/new_executor/garbage_collector/event_garbage_collector.h"
#include "paddle/fluid/framework/new_executor/garbage_collector/fast_garbage_collector.h"
#include "paddle/fluid/framework/new_executor/interpretercore_util.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/os_info.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/framework/new_executor/interpretercore_fast_garbage_collector.h"
#endif
PADDLE_DEFINE_EXPORTED_bool(new_executor_use_inplace, true,
"Use inplace in new executor");
PADDLE_DEFINE_EXPORTED_bool(new_executor_use_local_scope, true,
......@@ -726,12 +723,12 @@ void InterpreterCore::CheckGC(const Instruction& instr) {
} else {
static_cast<InterpreterCoreEventGarbageCollector*>(gc_.get())->Add(
var_scope.Var(var_id), gc_event_.at(instr_id),
var_scope.Var(var_id), &gc_event_.at(instr_id),
&instr.DeviceContext());
}
#else
static_cast<InterpreterCoreEventGarbageCollector*>(gc_.get())->Add(
var_scope.Var(var_id), gc_event_.at(instr_id),
var_scope.Var(var_id), &gc_event_.at(instr_id),
&instr.DeviceContext());
#endif
}
......
......@@ -21,7 +21,7 @@
#include "paddle/fluid/framework/details/exception_holder.h"
#include "paddle/fluid/framework/new_executor/event_manager.h"
#include "paddle/fluid/framework/new_executor/interpretercore_garbage_collector.h"
#include "paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.h"
#include "paddle/fluid/framework/new_executor/interpretercore_util.h"
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
#include "paddle/fluid/framework/new_executor/profiler.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册