未验证 提交 ff7da117 编写于 作者: C chenjian 提交者: GitHub

Add infer shape in dygraph (#43822)

* record memory and op supplement info

* update

* update

* fix a bug

* fix memory recording

* fix a bug

* update

* update

* fix a bug

* update

* fix a bug

* fix a bug

* fix a bug

* update dygraph record

* add infer shape record

* fix

* fix

* fix

* add comments

* fix a bug

* fix

* fix
上级 9a459efb
......@@ -22,7 +22,8 @@ if(WITH_XPU)
nan_inf_utils
phi_api
phi_utils
var_helper)
var_helper
profiler)
else()
cc_library(
prepared_operator
......@@ -38,7 +39,8 @@ else()
nan_inf_utils
phi_api
phi_utils
var_helper)
var_helper
profiler)
endif()
cc_library(
layer
......
......@@ -28,6 +28,7 @@
#include "paddle/fluid/framework/library_type.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/fluid/platform/profiler/supplement_tracing.h"
DECLARE_bool(check_nan_inf);
DECLARE_bool(benchmark);
......@@ -514,6 +515,9 @@ static void PreparedOpRunImpl(
arg_map_fn,
default_kernel_signature);
op.Info().infer_shape_(&infer_shape_ctx);
record_event.End();
platform::RecordOpInfoSupplement(
op.Type(), op.Attrs(), infer_shape_ctx, ctx);
}
{
......@@ -583,6 +587,9 @@ static void PreparedOpRunPtImpl(
arg_map_fn,
default_kernel_signature);
op.Info().infer_shape_(&infer_shape_ctx);
record_event.End();
platform::RecordOpInfoSupplement(
op.Type(), op.Attrs(), infer_shape_ctx, kernel_signature);
}
{
......
......@@ -277,6 +277,37 @@ RecordOpInfoSupplement::RecordOpInfoSupplement(
PosixInNsec(), type, input_shapes, dtypes, callstack);
}
RecordOpInfoSupplement::RecordOpInfoSupplement(
const std::string &type,
const framework::AttributeMap &attrs,
const framework::InferShapeContext &shape_ctx,
const phi::KernelSignature &kernel_signature) {
if (FLAGS_enable_host_event_recorder_hook == false) {
return;
}
std::map<std::string, std::vector<framework::DDim>> input_shapes;
std::map<std::string, std::vector<framework::proto::VarType::Type>> dtypes;
for (auto it = kernel_signature.input_names.begin();
it != kernel_signature.input_names.end();
it++) {
std::string input_name(*it);
if (shape_ctx.HasInputs(input_name)) {
input_shapes[input_name] = shape_ctx.GetInputsDim(input_name);
dtypes[input_name] = shape_ctx.GetInputsVarType(input_name);
}
}
const std::vector<std::string> *callstack_ptr = nullptr;
std::vector<std::string> callstack;
auto iter = attrs.find(
framework::OpProtoAndCheckerMaker::OpCreationCallstackAttrName());
if (iter != attrs.end()) {
callstack_ptr = &BOOST_GET_CONST(std::vector<std::string>, iter->second);
callstack = *callstack_ptr;
}
HostEventRecorder<OperatorSupplementOriginEvent>::GetInstance().RecordEvent(
PosixInNsec(), type, input_shapes, dtypes, callstack);
}
RecordMemEvent::RecordMemEvent(const void *ptr,
const phi::Place &place,
size_t size,
......
......@@ -91,6 +91,8 @@ struct CommonMemEvent {
type(type),
increase_bytes(increase_bytes),
place(place),
current_allocated(current_allocated),
current_reserved(current_reserved),
peak_allocated(peak_allocated),
peak_reserved(peak_reserved) {}
uint64_t timestamp_ns;
......
......@@ -93,11 +93,9 @@ void NodeTrees::BuildTrees(
++it) {
auto dst_iter =
correlation_id2runtime_event_node.find((*it)->CorrelationId());
PADDLE_ENFORCE_NE(
dst_iter,
correlation_id2runtime_event_node.end(),
platform::errors::NotFound("Unknown device events, "
"no corresponding cuda runtime events"));
if (dst_iter == correlation_id2runtime_event_node.end()) {
continue;
}
dst_iter->second->AddDeviceTraceEventNode(*it);
}
// construct thread2mem_event_nodes
......@@ -376,22 +374,9 @@ HostTraceEventNode* NodeTrees::BuildTreeRelationship(
hasenter = true;
}
(*it)->SetOperatorSupplementNode(*op_supplement_it);
PADDLE_ENFORCE_EQ((*it)->Type(),
TracerEventType::Operator,
platform::errors::PreconditionNotMet(
"Operator supplement events should be embraced "
"by event of type TracerEventType::Operator, "
"but got type TracerEventType::%s",
StringTracerEventType((*it)->Type())));
op_supplement_count += 1;
} else {
if ((*op_supplement_it)->TimeStampNs() > (*it)->EndNs()) {
PADDLE_ENFORCE_LE(op_supplement_count,
1,
platform::errors::PreconditionNotMet(
"One event of TracerEventType::Operator has no "
"more than 1 op supplement event, but got %d.",
op_supplement_count));
lastposition = op_supplement_it;
break;
}
......
......@@ -244,36 +244,50 @@ class HostEventRecorder {
// It will cause deep-copy to harm performance.
template <typename... Args>
void RecordEvent(Args &&...args) {
GetThreadLocalRecorder()->RecordEvent(std::forward<Args>(args)...);
// Get thread local ThreadEventRecorder
// If not exists, we create a new one.
// Both HostEventRecorder and thread-local varibale in
// ThreadEventRecorderRegistry keep the shared pointer. We add this to
// prevent ThreadEventRecorder being destroyed by thread-local variable in
// ThreadEventRecorderRegistry and lose data.
if (GetThreadLocalRecorder()->get() == nullptr) {
std::shared_ptr<ThreadEventRecorder<EventType>>
thread_event_recorder_ptr =
std::make_shared<ThreadEventRecorder<EventType>>();
*(GetThreadLocalRecorder()) = thread_event_recorder_ptr;
thr_recorders_.push_back(thread_event_recorder_ptr);
}
(*GetThreadLocalRecorder())->RecordEvent(std::forward<Args>(args)...);
}
// thread-unsafe, make sure make sure there is no running tracing.
// Poor performance, call it at the ending
HostEventSection<EventType> GatherEvents() {
auto thr_recorders =
ThreadEventRecorderRegistry::GetInstance().GetAllThreadDataByRef();
HostEventSection<EventType> host_sec;
host_sec.process_id = GetProcessId();
host_sec.thr_sections.reserve(thr_recorders.size());
for (auto &kv : thr_recorders) {
auto &thr_recorder = kv.second.get();
host_sec.thr_sections.emplace_back(
std::move(thr_recorder.GatherEvents()));
host_sec.thr_sections.reserve(thr_recorders_.size());
for (auto &v : thr_recorders_) {
host_sec.thr_sections.emplace_back(std::move(v->GatherEvents()));
}
return host_sec;
}
private:
using ThreadEventRecorderRegistry =
framework::ThreadDataRegistry<ThreadEventRecorder<EventType>>;
using ThreadEventRecorderRegistry = framework::ThreadDataRegistry<
std::shared_ptr<ThreadEventRecorder<EventType>>>;
HostEventRecorder() = default;
DISABLE_COPY_AND_ASSIGN(HostEventRecorder);
ThreadEventRecorder<EventType> *GetThreadLocalRecorder() {
std::shared_ptr<ThreadEventRecorder<EventType>> *GetThreadLocalRecorder() {
return ThreadEventRecorderRegistry::GetInstance()
.GetMutableCurrentThreadData();
}
// Hold all thread-local ThreadEventRecorders
// ThreadEventRecorderRegistry and HostEventRecorder both take care of this
// shared pointer. We add this to prevent ThreadEventRecorder being destroyed
// by thread-local variable in ThreadEventRecorderRegistry and lose data.
std::vector<std::shared_ptr<ThreadEventRecorder<EventType>>> thr_recorders_;
};
} // namespace platform
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/platform/profiler/trace_event.h"
#include "paddle/phi/core/compat/arg_map_context.h"
namespace paddle {
......@@ -39,6 +40,16 @@ class RecordOpInfoSupplement {
const framework::AttributeMap& attrs,
const framework::InferShapeContext& shape_ctx,
const framework::RuntimeContext& ctx);
/**
* @param type: Operator type name.
* @param attrs: Attribute map of op.
* @param shape_ctx: Infershape context object.
* @param kernel_signature: KernelSignature object, used in dygraph.
*/
explicit RecordOpInfoSupplement(const std::string& type,
const framework::AttributeMap& attrs,
const framework::InferShapeContext& shape_ctx,
const phi::KernelSignature& kernel_signature);
};
} // namespace platform
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册