未验证 提交 ff7da117 编写于 作者: C chenjian 提交者: GitHub

Add infer shape in dygraph (#43822)

* record memory and op supplement info

* update

* update

* fix a bug

* fix memory recording

* fix a bug

* update

* update

* fix a bug

* update

* fix a bug

* fix a bug

* fix a bug

* update dygraph record

* add infer shape record

* fix

* fix

* fix

* add comments

* fix a bug

* fix

* fix
上级 9a459efb
...@@ -22,7 +22,8 @@ if(WITH_XPU) ...@@ -22,7 +22,8 @@ if(WITH_XPU)
nan_inf_utils nan_inf_utils
phi_api phi_api
phi_utils phi_utils
var_helper) var_helper
profiler)
else() else()
cc_library( cc_library(
prepared_operator prepared_operator
...@@ -38,7 +39,8 @@ else() ...@@ -38,7 +39,8 @@ else()
nan_inf_utils nan_inf_utils
phi_api phi_api
phi_utils phi_utils
var_helper) var_helper
profiler)
endif() endif()
cc_library( cc_library(
layer layer
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "paddle/fluid/framework/library_type.h" #include "paddle/fluid/framework/library_type.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/fluid/platform/profiler/supplement_tracing.h"
DECLARE_bool(check_nan_inf); DECLARE_bool(check_nan_inf);
DECLARE_bool(benchmark); DECLARE_bool(benchmark);
...@@ -514,6 +515,9 @@ static void PreparedOpRunImpl( ...@@ -514,6 +515,9 @@ static void PreparedOpRunImpl(
arg_map_fn, arg_map_fn,
default_kernel_signature); default_kernel_signature);
op.Info().infer_shape_(&infer_shape_ctx); op.Info().infer_shape_(&infer_shape_ctx);
record_event.End();
platform::RecordOpInfoSupplement(
op.Type(), op.Attrs(), infer_shape_ctx, ctx);
} }
{ {
...@@ -583,6 +587,9 @@ static void PreparedOpRunPtImpl( ...@@ -583,6 +587,9 @@ static void PreparedOpRunPtImpl(
arg_map_fn, arg_map_fn,
default_kernel_signature); default_kernel_signature);
op.Info().infer_shape_(&infer_shape_ctx); op.Info().infer_shape_(&infer_shape_ctx);
record_event.End();
platform::RecordOpInfoSupplement(
op.Type(), op.Attrs(), infer_shape_ctx, kernel_signature);
} }
{ {
......
...@@ -277,6 +277,37 @@ RecordOpInfoSupplement::RecordOpInfoSupplement( ...@@ -277,6 +277,37 @@ RecordOpInfoSupplement::RecordOpInfoSupplement(
PosixInNsec(), type, input_shapes, dtypes, callstack); PosixInNsec(), type, input_shapes, dtypes, callstack);
} }
RecordOpInfoSupplement::RecordOpInfoSupplement(
const std::string &type,
const framework::AttributeMap &attrs,
const framework::InferShapeContext &shape_ctx,
const phi::KernelSignature &kernel_signature) {
if (FLAGS_enable_host_event_recorder_hook == false) {
return;
}
std::map<std::string, std::vector<framework::DDim>> input_shapes;
std::map<std::string, std::vector<framework::proto::VarType::Type>> dtypes;
for (auto it = kernel_signature.input_names.begin();
it != kernel_signature.input_names.end();
it++) {
std::string input_name(*it);
if (shape_ctx.HasInputs(input_name)) {
input_shapes[input_name] = shape_ctx.GetInputsDim(input_name);
dtypes[input_name] = shape_ctx.GetInputsVarType(input_name);
}
}
const std::vector<std::string> *callstack_ptr = nullptr;
std::vector<std::string> callstack;
auto iter = attrs.find(
framework::OpProtoAndCheckerMaker::OpCreationCallstackAttrName());
if (iter != attrs.end()) {
callstack_ptr = &BOOST_GET_CONST(std::vector<std::string>, iter->second);
callstack = *callstack_ptr;
}
HostEventRecorder<OperatorSupplementOriginEvent>::GetInstance().RecordEvent(
PosixInNsec(), type, input_shapes, dtypes, callstack);
}
RecordMemEvent::RecordMemEvent(const void *ptr, RecordMemEvent::RecordMemEvent(const void *ptr,
const phi::Place &place, const phi::Place &place,
size_t size, size_t size,
......
...@@ -91,6 +91,8 @@ struct CommonMemEvent { ...@@ -91,6 +91,8 @@ struct CommonMemEvent {
type(type), type(type),
increase_bytes(increase_bytes), increase_bytes(increase_bytes),
place(place), place(place),
current_allocated(current_allocated),
current_reserved(current_reserved),
peak_allocated(peak_allocated), peak_allocated(peak_allocated),
peak_reserved(peak_reserved) {} peak_reserved(peak_reserved) {}
uint64_t timestamp_ns; uint64_t timestamp_ns;
......
...@@ -93,11 +93,9 @@ void NodeTrees::BuildTrees( ...@@ -93,11 +93,9 @@ void NodeTrees::BuildTrees(
++it) { ++it) {
auto dst_iter = auto dst_iter =
correlation_id2runtime_event_node.find((*it)->CorrelationId()); correlation_id2runtime_event_node.find((*it)->CorrelationId());
PADDLE_ENFORCE_NE( if (dst_iter == correlation_id2runtime_event_node.end()) {
dst_iter, continue;
correlation_id2runtime_event_node.end(), }
platform::errors::NotFound("Unknown device events, "
"no corresponding cuda runtime events"));
dst_iter->second->AddDeviceTraceEventNode(*it); dst_iter->second->AddDeviceTraceEventNode(*it);
} }
// construct thread2mem_event_nodes // construct thread2mem_event_nodes
...@@ -376,22 +374,9 @@ HostTraceEventNode* NodeTrees::BuildTreeRelationship( ...@@ -376,22 +374,9 @@ HostTraceEventNode* NodeTrees::BuildTreeRelationship(
hasenter = true; hasenter = true;
} }
(*it)->SetOperatorSupplementNode(*op_supplement_it); (*it)->SetOperatorSupplementNode(*op_supplement_it);
PADDLE_ENFORCE_EQ((*it)->Type(),
TracerEventType::Operator,
platform::errors::PreconditionNotMet(
"Operator supplement events should be embraced "
"by event of type TracerEventType::Operator, "
"but got type TracerEventType::%s",
StringTracerEventType((*it)->Type())));
op_supplement_count += 1; op_supplement_count += 1;
} else { } else {
if ((*op_supplement_it)->TimeStampNs() > (*it)->EndNs()) { if ((*op_supplement_it)->TimeStampNs() > (*it)->EndNs()) {
PADDLE_ENFORCE_LE(op_supplement_count,
1,
platform::errors::PreconditionNotMet(
"One event of TracerEventType::Operator has no "
"more than 1 op supplement event, but got %d.",
op_supplement_count));
lastposition = op_supplement_it; lastposition = op_supplement_it;
break; break;
} }
......
...@@ -244,36 +244,50 @@ class HostEventRecorder { ...@@ -244,36 +244,50 @@ class HostEventRecorder {
// It will cause deep-copy to harm performance. // It will cause deep-copy to harm performance.
template <typename... Args> template <typename... Args>
void RecordEvent(Args &&...args) { void RecordEvent(Args &&...args) {
GetThreadLocalRecorder()->RecordEvent(std::forward<Args>(args)...); // Get thread local ThreadEventRecorder
// If not exists, we create a new one.
// Both HostEventRecorder and thread-local varibale in
// ThreadEventRecorderRegistry keep the shared pointer. We add this to
// prevent ThreadEventRecorder being destroyed by thread-local variable in
// ThreadEventRecorderRegistry and lose data.
if (GetThreadLocalRecorder()->get() == nullptr) {
std::shared_ptr<ThreadEventRecorder<EventType>>
thread_event_recorder_ptr =
std::make_shared<ThreadEventRecorder<EventType>>();
*(GetThreadLocalRecorder()) = thread_event_recorder_ptr;
thr_recorders_.push_back(thread_event_recorder_ptr);
}
(*GetThreadLocalRecorder())->RecordEvent(std::forward<Args>(args)...);
} }
// thread-unsafe, make sure make sure there is no running tracing. // thread-unsafe, make sure make sure there is no running tracing.
// Poor performance, call it at the ending // Poor performance, call it at the ending
HostEventSection<EventType> GatherEvents() { HostEventSection<EventType> GatherEvents() {
auto thr_recorders =
ThreadEventRecorderRegistry::GetInstance().GetAllThreadDataByRef();
HostEventSection<EventType> host_sec; HostEventSection<EventType> host_sec;
host_sec.process_id = GetProcessId(); host_sec.process_id = GetProcessId();
host_sec.thr_sections.reserve(thr_recorders.size()); host_sec.thr_sections.reserve(thr_recorders_.size());
for (auto &kv : thr_recorders) { for (auto &v : thr_recorders_) {
auto &thr_recorder = kv.second.get(); host_sec.thr_sections.emplace_back(std::move(v->GatherEvents()));
host_sec.thr_sections.emplace_back(
std::move(thr_recorder.GatherEvents()));
} }
return host_sec; return host_sec;
} }
private: private:
using ThreadEventRecorderRegistry = using ThreadEventRecorderRegistry = framework::ThreadDataRegistry<
framework::ThreadDataRegistry<ThreadEventRecorder<EventType>>; std::shared_ptr<ThreadEventRecorder<EventType>>>;
HostEventRecorder() = default; HostEventRecorder() = default;
DISABLE_COPY_AND_ASSIGN(HostEventRecorder); DISABLE_COPY_AND_ASSIGN(HostEventRecorder);
ThreadEventRecorder<EventType> *GetThreadLocalRecorder() { std::shared_ptr<ThreadEventRecorder<EventType>> *GetThreadLocalRecorder() {
return ThreadEventRecorderRegistry::GetInstance() return ThreadEventRecorderRegistry::GetInstance()
.GetMutableCurrentThreadData(); .GetMutableCurrentThreadData();
} }
// Hold all thread-local ThreadEventRecorders
// ThreadEventRecorderRegistry and HostEventRecorder both take care of this
// shared pointer. We add this to prevent ThreadEventRecorder being destroyed
// by thread-local variable in ThreadEventRecorderRegistry and lose data.
std::vector<std::shared_ptr<ThreadEventRecorder<EventType>>> thr_recorders_;
}; };
} // namespace platform } // namespace platform
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/platform/profiler/trace_event.h" #include "paddle/fluid/platform/profiler/trace_event.h"
#include "paddle/phi/core/compat/arg_map_context.h"
namespace paddle { namespace paddle {
...@@ -39,6 +40,16 @@ class RecordOpInfoSupplement { ...@@ -39,6 +40,16 @@ class RecordOpInfoSupplement {
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::InferShapeContext& shape_ctx, const framework::InferShapeContext& shape_ctx,
const framework::RuntimeContext& ctx); const framework::RuntimeContext& ctx);
/**
* @param type: Operator type name.
* @param attrs: Attribute map of op.
* @param shape_ctx: Infershape context object.
* @param kernel_signature: KernelSignature object, used in dygraph.
*/
explicit RecordOpInfoSupplement(const std::string& type,
const framework::AttributeMap& attrs,
const framework::InferShapeContext& shape_ctx,
const phi::KernelSignature& kernel_signature);
}; };
} // namespace platform } // namespace platform
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册