未验证 提交 a2240190 编写于 作者: C chenjian 提交者: GitHub

Record op shape data for profiler [cherry-pick PR43405 43578 43822] (#44384)

* add serialization for new field in event node (#43405)

* add serialization for new field in event node

* fix a bug

* add more field to memory record (#43578)

* Add infer shape in dygraph (#43822)

* record memory and op supplement info

* update

* update

* fix a bug

* fix memory recording

* fix a bug

* update

* update

* fix a bug

* update

* fix a bug

* fix a bug

* fix a bug

* update dygraph record

* add infer shape record

* fix

* fix

* fix

* add comments

* fix a bug

* fix

* fix

* add record op info

* fix file mode

* add op input shape info

* fix dependency
上级 94271bc2
......@@ -22,14 +22,17 @@
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/os_info.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/fluid/platform/profiler/supplement_tracing.h"
#include "paddle/phi/core/kernel_context.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
PADDLE_DEFINE_EXPORTED_bool(new_executor_use_inplace, true,
PADDLE_DEFINE_EXPORTED_bool(new_executor_use_inplace,
true,
"Use inplace in new executor");
PADDLE_DEFINE_EXPORTED_bool(new_executor_use_local_scope, true,
PADDLE_DEFINE_EXPORTED_bool(new_executor_use_local_scope,
true,
"Use local_scope in new executor(especially used "
"in UT), can turn off for better performance");
......@@ -167,8 +170,8 @@ paddle::framework::FetchList InterpreterCore::Run(
// scope?
}
global_scope_->SetLocalScope(local_scope_);
paddle::framework::interpreter::build_variable_scope(block_, global_scope_,
create_local_scope_);
paddle::framework::interpreter::build_variable_scope(
block_, global_scope_, create_local_scope_);
std::vector<paddle::framework::OpFuncNode> op_func_nodes;
paddle::framework::interpreter::build_op_func_list(
place_, block_, &op_func_nodes, global_scope_, create_local_scope_);
......@@ -490,7 +493,9 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
// If it is OperatorBase, InferShape do nothing.
if (op_with_kernel != nullptr) {
platform::RecordEvent infershape_event(
"infer_shape", platform::TracerEventType::OperatorInner, 1,
"infer_shape",
platform::TracerEventType::OperatorInner,
1,
platform::EventRole::kInnerOp);
// see OperatorWithKernel::RunImpl in operator.cc for why
......@@ -499,6 +504,11 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
op_with_kernel->Info().infer_shape_(
instr_node.InnerInferShapeContext().get());
}
infershape_event.End();
platform::RecordOpInfoSupplement(op->Type(),
op->Attrs(),
*(instr_node.InnerInferShapeContext()),
*(instr_node.InnerRuntimeContext()));
}
}
......@@ -516,7 +526,9 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
{
platform::RecordEvent compute_event(
"compute", platform::TracerEventType::OperatorInner, 1,
"compute",
platform::TracerEventType::OperatorInner,
1,
platform::EventRole::kInnerOp);
if (op_with_kernel == nullptr) {
instr_node.OpBase()->Run(*local_scope, place_);
......@@ -571,7 +583,8 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
if (op_with_kernel != nullptr && FLAGS_check_nan_inf) {
VLOG(4) << "Check nan/inf";
framework::details::CheckOpHasNanOrInf(
*op, *global_scope_,
*op,
*global_scope_,
place); // TODO(xiongkun03) change it to inner scope.
}
}
......@@ -596,10 +609,14 @@ void InterpreterCore::ExecuteInstructionList(
for (size_t i = 0; i < dependecy_count_.size(); ++i) {
if (dependecy_count_[i] == 0) {
async_work_queue_->AddTask(vec_instr.at(i).KernelType(), [
this, i, atomic_deps = atomic_deps.get(),
atomic_var_ref = atomic_var_ref.get()
] { RunInstructionAsync(i, atomic_deps, atomic_var_ref); });
async_work_queue_->AddTask(vec_instr.at(i).KernelType(),
[this,
i,
atomic_deps = atomic_deps.get(),
atomic_var_ref = atomic_var_ref.get()] {
RunInstructionAsync(
i, atomic_deps, atomic_var_ref);
});
}
}
......@@ -615,7 +632,8 @@ void InterpreterCore::ExecuteInstructionList(
}
VLOG(4) << "Cancel ok";
PADDLE_ENFORCE_EQ(
main_thread_blocker_.Clear(), 0,
main_thread_blocker_.Clear(),
0,
platform::errors::PreconditionNotMet(
"main_thread_blocker_.Clear() return -1, clear failed"));
VLOG(4) << "clear ok";
......@@ -624,7 +642,8 @@ void InterpreterCore::ExecuteInstructionList(
}
void InterpreterCore::RunNextInstructions(
const Instruction& instr, std::queue<size_t>* reserved_next_ops,
const Instruction& instr,
std::queue<size_t>* reserved_next_ops,
std::vector<std::atomic<size_t>>* atomic_deps,
std::vector<std::atomic<size_t>>* atomic_var_ref) {
auto& next_instr = instr.NextInstructions();
......@@ -691,7 +710,8 @@ void InterpreterCore::RunNextInstructions(
}
void InterpreterCore::RunInstructionAsync(
size_t instr_id, std::vector<std::atomic<size_t>>* atomic_deps,
size_t instr_id,
std::vector<std::atomic<size_t>>* atomic_deps,
std::vector<std::atomic<size_t>>* atomic_var_ref) {
std::queue<size_t> ready_ops;
ready_ops.push(instr_id);
......@@ -700,10 +720,10 @@ void InterpreterCore::RunInstructionAsync(
ready_ops.pop();
auto& instr_node = vec_instruction_.at(instr_id);
VLOG(5) << __func__ << " OP id:" << instr_node.Id()
<< " name:" << instr_node.OpBase()->Type()
<< " type:" << (instr_node.KernelType() == OpFuncType::kQueueSync
? "kQueueSync"
: "kQueueAsync")
<< " name:" << instr_node.OpBase()->Type() << " type:"
<< (instr_node.KernelType() == OpFuncType::kQueueSync
? "kQueueSync"
: "kQueueAsync")
<< " runs on " << platform::GetCurrentThreadName();
auto* op = instr_node.OpBase();
......@@ -877,12 +897,14 @@ void InterpreterCore::CheckGC(
} else {
static_cast<InterpreterCoreEventGarbageCollector*>(gc_.get())->Add(
var_scope.Var(var_id), &gc_event_.at(instr_id),
var_scope.Var(var_id),
&gc_event_.at(instr_id),
&instr.DeviceContext());
}
#else
static_cast<InterpreterCoreEventGarbageCollector*>(gc_.get())->Add(
var_scope.Var(var_id), &gc_event_.at(instr_id),
var_scope.Var(var_id),
&gc_event_.at(instr_id),
&instr.DeviceContext());
#endif
}
......@@ -891,20 +913,24 @@ void InterpreterCore::CheckGC(
void InterpreterCore::Prepare(
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors, bool prepare_feed) {
PADDLE_ENFORCE_EQ(feed_names.size(), feed_tensors.size(),
const std::vector<framework::LoDTensor>& feed_tensors,
bool prepare_feed) {
PADDLE_ENFORCE_EQ(feed_names.size(),
feed_tensors.size(),
platform::errors::PreconditionNotMet(
"Required feed_names.size() == feed_tensors.size(), "
"but received %d != %d",
feed_names.size(), feed_tensors.size()));
feed_names.size(),
feed_tensors.size()));
auto FeedInput = [&] {
VLOG(4) << "Feed inputs";
for (size_t i = 0; i < feed_names.size(); ++i) {
auto* feed_var = global_scope_->FindVar(feed_names[i]);
PADDLE_ENFORCE_NOT_NULL(
feed_var, platform::errors::NotFound(
"Variable %s should not be nullptr.", feed_names[i]));
feed_var,
platform::errors::NotFound("Variable %s should not be nullptr.",
feed_names[i]));
auto feed_tensor = feed_var->GetMutable<framework::LoDTensor>();
feed_tensor->ShareDataWith(feed_tensors[i]);
......@@ -913,8 +939,8 @@ void InterpreterCore::Prepare(
};
if (!is_build_) {
paddle::framework::interpreter::build_variable_scope(block_, global_scope_,
create_local_scope_);
paddle::framework::interpreter::build_variable_scope(
block_, global_scope_, create_local_scope_);
FeedInput();
std::vector<paddle::framework::OpFuncNode> op_func_nodes;
paddle::framework::interpreter::build_op_func_list(
......
此差异已折叠。
cc_library(imperative_flag SRCS flags.cc DEPS gflags flags)
cc_library(var_helper SRCS var_helper.cc DEPS tensor phi_api)
IF(WITH_XPU)
cc_library(prepared_operator SRCS prepared_operator.cc DEPS xpu_op_list proto_desc operator device_context lod_tensor selected_rows_utils var_type_traits op_kernel_type data_transform nan_inf_utils phi_api phi phi_utils var_helper)
ELSE()
cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows_utils var_type_traits op_kernel_type data_transform nan_inf_utils phi_api phi phi_utils var_helper)
ENDIF()
cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_flag variable_helper op_registry var_helper phi_api)
cc_library(
imperative_flag
SRCS flags.cc
DEPS gflags flags)
cc_library(
var_helper
SRCS var_helper.cc
DEPS tensor phi_api)
if(WITH_XPU)
cc_library(
prepared_operator
SRCS prepared_operator.cc
DEPS xpu_op_list
proto_desc
operator
device_context
lod_tensor
selected_rows_utils
var_type_traits
op_kernel_type
data_transform
nan_inf_utils
phi_api
phi_utils
var_helper
profiler)
else()
cc_library(
prepared_operator
SRCS prepared_operator.cc
DEPS proto_desc
operator
device_context
lod_tensor
selected_rows_utils
var_type_traits
op_kernel_type
data_transform
nan_inf_utils
phi_api
phi_utils
var_helper
profiler)
endif()
cc_library(
layer
SRCS layer.cc
DEPS prepared_operator
math_function
imperative_flag
variable_helper
op_registry
var_helper
phi_api)
add_subdirectory(jit)
if (WITH_GPU)
cc_library(layout_autotune SRCS layout_autotune.cc DEPS op_info phi_gpu_info)
if(WITH_GPU)
cc_library(
layout_autotune
SRCS layout_autotune.cc
DEPS op_info phi_gpu_info)
else()
cc_library(layout_autotune SRCS layout_autotune.cc DEPS op_info)
cc_library(
layout_autotune
SRCS layout_autotune.cc
DEPS op_info)
endif()
cc_library(amp SRCS amp_auto_cast.cc DEPS layer var_helper)
cc_library(tracer SRCS tracer.cc DEPS layer engine program_desc_tracer amp denormal garbage_collector var_helper layout_autotune)
cc_library(basic_engine SRCS basic_engine.cc DEPS layer gradient_accumulator switch_autotune)
cc_library(engine SRCS basic_engine.cc partial_grad_engine.cc DEPS layer gradient_accumulator switch_autotune)
cc_library(imperative_profiler SRCS profiler.cc DEPS flags)
cc_library(
amp
SRCS amp_auto_cast.cc
DEPS layer var_helper)
cc_library(
tracer
SRCS tracer.cc
DEPS layer
engine
program_desc_tracer
amp
denormal
garbage_collector
var_helper
layout_autotune)
cc_library(
basic_engine
SRCS basic_engine.cc
DEPS layer gradient_accumulator switch_autotune)
cc_library(
engine
SRCS basic_engine.cc partial_grad_engine.cc
DEPS layer gradient_accumulator switch_autotune)
cc_library(
imperative_profiler
SRCS profiler.cc
DEPS flags)
if(NOT WIN32)
if(WITH_NCCL OR WITH_RCCL)
cc_library(imperative_all_reduce SRCS all_reduce.cc DEPS collective_helper device_context selected_rows_utils tensor)
cc_library(nccl_context SRCS nccl_context.cc DEPS collective_helper device_context imperative_all_reduce var_type_traits)
if(WITH_NCCL)
nv_library(reducer SRCS reducer.cc reducer.cu DEPS layer imperative_all_reduce)
endif()
if(WITH_RCCL)
hip_library(reducer SRCS reducer.cc reducer.cu DEPS layer imperative_all_reduce)
endif()
endif()
if(WITH_XPU_BKCL)
cc_library(bkcl_context SRCS bkcl_context.cc DEPS collective_helper device_context tensor var_type_traits)
cc_library(reducer SRCS reducer.cc DEPS layer)
if(WITH_NCCL OR WITH_RCCL)
cc_library(
imperative_all_reduce
SRCS all_reduce.cc
DEPS collective_helper device_context selected_rows_utils tensor)
cc_library(
nccl_context
SRCS nccl_context.cc
DEPS collective_helper device_context imperative_all_reduce
var_type_traits)
if(WITH_NCCL)
nv_library(
reducer
SRCS reducer.cc reducer.cu
DEPS layer imperative_all_reduce)
endif()
if(WITH_ASCEND_CL)
cc_library(hccl_context SRCS hccl_context.cc DEPS collective_helper device_context tensor var_type_traits)
cc_library(reducer SRCS reducer.cc DEPS layer)
if(WITH_RCCL)
hip_library(
reducer
SRCS reducer.cc reducer.cu
DEPS layer imperative_all_reduce)
endif()
if(WITH_CNCL)
cc_library(cncl_context SRCS cncl_context.cc DEPS collective_helper device_context tensor var_type_traits)
cc_library(reducer SRCS reducer.cc DEPS layer)
endif()
if(WITH_NCCL OR WITH_RCCL OR WITH_XPU_BKCL OR WITH_ASCEND_CL)
cc_library(heter_ccl_context SRCS heter_ccl_context.cc DEPS collective_helper device_context tensor var_type_traits)
endif()
cc_library(data_loader SRCS data_loader.cc DEPS enforce)
endif()
if(WITH_XPU_BKCL)
cc_library(
bkcl_context
SRCS bkcl_context.cc
DEPS collective_helper device_context tensor var_type_traits)
cc_library(
reducer
SRCS reducer.cc
DEPS layer)
endif()
if(WITH_ASCEND_CL)
cc_library(
hccl_context
SRCS hccl_context.cc
DEPS collective_helper device_context tensor var_type_traits)
cc_library(
reducer
SRCS reducer.cc
DEPS layer)
endif()
if(WITH_CNCL)
cc_library(
cncl_context
SRCS cncl_context.cc
DEPS collective_helper device_context tensor var_type_traits)
cc_library(
reducer
SRCS reducer.cc
DEPS layer)
endif()
if(WITH_NCCL
OR WITH_RCCL
OR WITH_XPU_BKCL
OR WITH_ASCEND_CL)
cc_library(
heter_ccl_context
SRCS heter_ccl_context.cc
DEPS collective_helper device_context tensor var_type_traits)
endif()
cc_library(
data_loader
SRCS data_loader.cc
DEPS enforce)
endif(NOT WIN32)
if(WITH_GLOO)
cc_library(imperative_gloo_context SRCS gloo_context.cc DEPS collective_helper device_context tensor var_type_traits)
if ( WIN32 OR (NOT (WITH_NCCL OR WITH_RCCL OR WITH_XPU_BKCL OR WITH_ASCEND_CL OR WITH_CNCL) ))
cc_library(reducer SRCS reducer.cc DEPS layer)
endif()
cc_library(
imperative_gloo_context
SRCS gloo_context.cc
DEPS collective_helper device_context tensor var_type_traits)
if(WIN32
OR (NOT
(WITH_NCCL
OR WITH_RCCL
OR WITH_XPU_BKCL
OR WITH_ASCEND_CL
OR WITH_CNCL)
))
cc_library(
reducer
SRCS reducer.cc
DEPS layer)
endif()
endif()
if(WITH_MLU)
SET(MLU_DEPS mlu_baseop)
set(MLU_DEPS mlu_baseop)
endif()
if(NOT WITH_ASCEND_CL)
cc_library(gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows_utils selected_rows_functor var_type_traits layer math_function phi_tensor ${MLU_DEPS})
cc_library(
gradient_accumulator
SRCS gradient_accumulator.cc
DEPS blas
operator
lod_tensor
selected_rows_utils
selected_rows_functor
var_type_traits
layer
math_function
phi_tensor
${MLU_DEPS})
else()
cc_library(gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows_utils selected_rows_functor var_type_traits layer math_function npu_op_runner phi_tensor)
cc_library(
gradient_accumulator
SRCS gradient_accumulator.cc
DEPS blas
operator
lod_tensor
selected_rows_utils
selected_rows_functor
var_type_traits
layer
math_function
npu_op_runner
phi_tensor)
endif()
add_subdirectory(tests)
......@@ -28,6 +28,7 @@
#include "paddle/fluid/framework/library_type.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/fluid/platform/profiler/supplement_tracing.h"
DECLARE_bool(check_nan_inf);
DECLARE_bool(benchmark);
......@@ -91,8 +92,8 @@ void HandleComplexGradToRealGrad(const NameVarMap<VarType>& outs) {
<< framework::DataTypeToString(var->ForwardDataType())
<< " real var in dynamic graph.";
framework::Tensor out;
framework::TransComplexToReal(var->ForwardDataType(), var->DataType(),
*tensor, &out);
framework::TransComplexToReal(
var->ForwardDataType(), var->DataType(), *tensor, &out);
SetTensorToVariable(var->Var(), out, var->MutableVar());
}
}
......@@ -147,8 +148,10 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op,
template <typename VarType>
PreparedOp PrepareImpl(
const NameVarMap<VarType>& ins, const NameVarMap<VarType>& outs,
const framework::OperatorWithKernel& op, const platform::Place& place,
const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs,
const framework::OperatorWithKernel& op,
const platform::Place& place,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const phi::KernelFactory& phi_kernel_factory,
......@@ -254,7 +257,7 @@ PreparedOp PrepareImpl(
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
&& !is_xpu_unsupport
#endif
) {
) {
VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << pt_kernel_name
<< " | kernel key: " << pt_kernel_key
<< " | kernel: " << phi_kernel;
......@@ -263,9 +266,14 @@ PreparedOp PrepareImpl(
dev_ctx = pool.Get(expected_kernel_key.place_);
}
return PreparedOp(op, empty_ctx, expected_kernel_key, arg_map_fn,
default_kernel_signature, std::move(kernel_signature),
phi_kernel, dev_ctx);
return PreparedOp(op,
empty_ctx,
expected_kernel_key,
arg_map_fn,
default_kernel_signature,
std::move(kernel_signature),
phi_kernel,
dev_ctx);
} else {
VLOG(6) << "Dynamic mode ChoosePhiKernel - kernel `" << pt_kernel_name
<< "` not found.";
......@@ -302,7 +310,7 @@ PreparedOp PrepareImpl(
#if defined(PADDLE_WITH_XPU_KP)
|| (is_xpu_unsupport && !is_xpu_kp_support)
#endif
) {
) {
if (has_phi_kernel) {
auto pt_cpu_kernel_key =
FallBackToCpu(expected_kernel_key, pt_kernel_key, op);
......@@ -313,15 +321,21 @@ PreparedOp PrepareImpl(
<< " | kernel key: " << pt_cpu_kernel_key
<< " | kernel: " << pt_cpu_kernel;
auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace());
return PreparedOp(op, empty_ctx, expected_kernel_key, arg_map_fn,
default_kernel_signature, std::move(kernel_signature),
pt_cpu_kernel, cpu_ctx);
return PreparedOp(op,
empty_ctx,
expected_kernel_key,
arg_map_fn,
default_kernel_signature,
std::move(kernel_signature),
pt_cpu_kernel,
cpu_ctx);
}
}
}
PADDLE_ENFORCE_NE(
kernels_iter, all_op_kernels.end(),
kernels_iter,
all_op_kernels.end(),
platform::errors::NotFound(
"There are no kernels which are registered in the %s operator.",
op.Type()));
......@@ -397,17 +411,24 @@ PreparedOp PrepareImpl(
#endif
// TODO(jiabin): Add operator.cc's line 1000 part back when we need that
// case
PADDLE_ENFORCE_NE(kernel_iter, kernels.end(),
platform::errors::NotFound(
"Operator %s does not have kernel for %s.", op.Type(),
KernelTypeToString(expected_kernel_key)));
PADDLE_ENFORCE_NE(
kernel_iter,
kernels.end(),
platform::errors::NotFound("Operator %s does not have kernel for %s.",
op.Type(),
KernelTypeToString(expected_kernel_key)));
if (!(expected_kernel_key.place_ == place)) {
dev_ctx = pool.Get(expected_kernel_key.place_);
}
return PreparedOp(op, empty_ctx, expected_kernel_key, kernel_iter->second,
arg_map_fn, default_kernel_signature, dev_ctx);
return PreparedOp(op,
empty_ctx,
expected_kernel_key,
kernel_iter->second,
arg_map_fn,
default_kernel_signature,
dev_ctx);
}
PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins,
......@@ -416,8 +437,14 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins,
const platform::Place& place,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
return PrepareImpl<VarBase>(ins, outs, op, place, attrs, default_attrs,
phi_kernel_factory, phi_op_utils_map,
return PrepareImpl<VarBase>(ins,
outs,
op,
place,
attrs,
default_attrs,
phi_kernel_factory,
phi_op_utils_map,
default_phi_kernel_sig_map);
}
......@@ -427,9 +454,15 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins,
const platform::Place& place,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
return PrepareImpl<VariableWrapper>(
ins, outs, op, place, attrs, default_attrs, phi_kernel_factory,
phi_op_utils_map, default_phi_kernel_sig_map);
return PrepareImpl<VariableWrapper>(ins,
outs,
op,
place,
attrs,
default_attrs,
phi_kernel_factory,
phi_op_utils_map,
default_phi_kernel_sig_map);
}
PreparedOp PreparedOp::Prepare(const NameVarMap<egr::EagerVariable>& ins,
......@@ -438,39 +471,58 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<egr::EagerVariable>& ins,
const platform::Place& place,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
return PrepareImpl<egr::EagerVariable>(
ins, outs, op, place, attrs, default_attrs, phi_kernel_factory,
phi_op_utils_map, default_phi_kernel_sig_map);
return PrepareImpl<egr::EagerVariable>(ins,
outs,
op,
place,
attrs,
default_attrs,
phi_kernel_factory,
phi_op_utils_map,
default_phi_kernel_sig_map);
}
template <typename VarType>
static void PreparedOpRunImpl(
const framework::OperatorBase& op, const framework::RuntimeContext& ctx,
const framework::OperatorBase& op,
const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type,
const framework::OperatorWithKernel::OpKernelFunc& func,
const phi::ArgumentMappingFn* arg_map_fn,
const phi::KernelSignature* default_kernel_signature,
platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs,
platform::DeviceContext* dev_ctx,
const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
// TODO(zjl): remove scope in dygraph
{
platform::RecordEvent record_event("infer_shape",
platform::TracerEventType::OperatorInner,
1, platform::EventRole::kInnerOp);
DygraphInferShapeContext<VarType> infer_shape_ctx(
&ins, &outs, &attrs, &default_attrs, op.Type(), &kernel_type,
arg_map_fn, default_kernel_signature);
1,
platform::EventRole::kInnerOp);
DygraphInferShapeContext<VarType> infer_shape_ctx(&ins,
&outs,
&attrs,
&default_attrs,
op.Type(),
&kernel_type,
arg_map_fn,
default_kernel_signature);
op.Info().infer_shape_(&infer_shape_ctx);
record_event.End();
platform::RecordOpInfoSupplement(
op.Type(), op.Attrs(), infer_shape_ctx, ctx);
}
{
platform::RecordEvent record_event("compute",
platform::TracerEventType::OperatorInner,
1, platform::EventRole::kInnerOp);
1,
platform::EventRole::kInnerOp);
func(DygraphExecutionContext<VarType>(op, empty_scope, *dev_ctx, ctx, ins,
outs, attrs, default_attrs));
func(DygraphExecutionContext<VarType>(
op, empty_scope, *dev_ctx, ctx, ins, outs, attrs, default_attrs));
}
if (FLAGS_check_nan_inf) {
......@@ -509,30 +561,48 @@ static void PreparedOpRunPtImpl(
const framework::OpKernelType& kernel_type,
const phi::ArgumentMappingFn* arg_map_fn,
const phi::KernelSignature* default_kernel_signature,
const phi::KernelSignature& kernel_signature, const phi::Kernel& phi_kernel,
platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs,
const phi::KernelSignature& kernel_signature,
const phi::Kernel& phi_kernel,
platform::DeviceContext* dev_ctx,
const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
{
platform::RecordEvent record_event("infer_shape",
platform::TracerEventType::OperatorInner,
1, platform::EventRole::kInnerOp);
DygraphInferShapeContext<VarType> infer_shape_ctx(
&ins, &outs, &attrs, &default_attrs, op.Type(), &kernel_type,
arg_map_fn, default_kernel_signature);
1,
platform::EventRole::kInnerOp);
DygraphInferShapeContext<VarType> infer_shape_ctx(&ins,
&outs,
&attrs,
&default_attrs,
op.Type(),
&kernel_type,
arg_map_fn,
default_kernel_signature);
op.Info().infer_shape_(&infer_shape_ctx);
record_event.End();
platform::RecordOpInfoSupplement(
op.Type(), op.Attrs(), infer_shape_ctx, kernel_signature);
}
{
platform::RecordEvent record_event("compute",
platform::TracerEventType::OperatorInner,
1, platform::EventRole::kInnerOp);
1,
platform::EventRole::kInnerOp);
PreparePhiData<VarType>(phi_kernel, kernel_signature, ins);
phi::KernelContext pt_kernel_context;
BuildDygraphPhiKernelContext<VarType>(kernel_signature, phi_kernel, ins,
outs, attrs, default_attrs, dev_ctx,
BuildDygraphPhiKernelContext<VarType>(kernel_signature,
phi_kernel,
ins,
outs,
attrs,
default_attrs,
dev_ctx,
&pt_kernel_context);
phi_kernel(&pt_kernel_context);
......@@ -561,14 +631,29 @@ void PreparedOp::Run(const NameVarMap<VarBase>& ins,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
if (run_phi_kernel_) {
PreparedOpRunPtImpl<VarBase>(op_, kernel_type_, arg_map_fn_,
default_kernel_signature_, kernel_signature_,
phi_kernel_, dev_ctx_, ins, outs, attrs,
PreparedOpRunPtImpl<VarBase>(op_,
kernel_type_,
arg_map_fn_,
default_kernel_signature_,
kernel_signature_,
phi_kernel_,
dev_ctx_,
ins,
outs,
attrs,
default_attrs);
} else {
PreparedOpRunImpl<VarBase>(op_, ctx_, kernel_type_, func_, arg_map_fn_,
default_kernel_signature_, dev_ctx_, ins, outs,
attrs, default_attrs);
PreparedOpRunImpl<VarBase>(op_,
ctx_,
kernel_type_,
func_,
arg_map_fn_,
default_kernel_signature_,
dev_ctx_,
ins,
outs,
attrs,
default_attrs);
}
}
......@@ -577,14 +662,29 @@ void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
if (run_phi_kernel_) {
PreparedOpRunPtImpl<VariableWrapper>(
op_, kernel_type_, arg_map_fn_, default_kernel_signature_,
kernel_signature_, phi_kernel_, dev_ctx_, ins, outs, attrs,
default_attrs);
PreparedOpRunPtImpl<VariableWrapper>(op_,
kernel_type_,
arg_map_fn_,
default_kernel_signature_,
kernel_signature_,
phi_kernel_,
dev_ctx_,
ins,
outs,
attrs,
default_attrs);
} else {
PreparedOpRunImpl<VariableWrapper>(
op_, ctx_, kernel_type_, func_, arg_map_fn_, default_kernel_signature_,
dev_ctx_, ins, outs, attrs, default_attrs);
PreparedOpRunImpl<VariableWrapper>(op_,
ctx_,
kernel_type_,
func_,
arg_map_fn_,
default_kernel_signature_,
dev_ctx_,
ins,
outs,
attrs,
default_attrs);
}
}
......@@ -593,14 +693,29 @@ void PreparedOp::Run(const NameVarMap<egr::EagerVariable>& ins,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
if (run_phi_kernel_) {
PreparedOpRunPtImpl<egr::EagerVariable>(
op_, kernel_type_, arg_map_fn_, default_kernel_signature_,
kernel_signature_, phi_kernel_, dev_ctx_, ins, outs, attrs,
default_attrs);
PreparedOpRunPtImpl<egr::EagerVariable>(op_,
kernel_type_,
arg_map_fn_,
default_kernel_signature_,
kernel_signature_,
phi_kernel_,
dev_ctx_,
ins,
outs,
attrs,
default_attrs);
} else {
PreparedOpRunImpl<egr::EagerVariable>(
op_, ctx_, kernel_type_, func_, arg_map_fn_, default_kernel_signature_,
dev_ctx_, ins, outs, attrs, default_attrs);
PreparedOpRunImpl<egr::EagerVariable>(op_,
ctx_,
kernel_type_,
func_,
arg_map_fn_,
default_kernel_signature_,
dev_ctx_,
ins,
outs,
attrs,
default_attrs);
}
}
......
此差异已折叠。
......@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/profiler.h"
#include <mutex> // NOLINT
#include <random>
#include <sstream>
......@@ -20,7 +22,6 @@ limitations under the License. */
#include "paddle/fluid/platform/device_tracer.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler/common_event.h"
#include "paddle/fluid/platform/profiler/host_event_recorder.h"
#include "paddle/fluid/platform/profiler/host_tracer.h"
......@@ -29,12 +30,16 @@ limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/dynload/nvtx.h"
#endif
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/os_info.h"
PADDLE_DEFINE_EXPORTED_bool(enable_rpc_profiler, false,
PADDLE_DEFINE_EXPORTED_bool(enable_rpc_profiler,
false,
"Enable rpc profiler or not.");
DEFINE_bool(enable_host_event_recorder_hook, false,
DEFINE_bool(enable_host_event_recorder_hook,
false,
"enable HostEventRecorder, hook Profiler");
namespace paddle {
......@@ -42,8 +47,11 @@ namespace platform {
MemEvenRecorder MemEvenRecorder::recorder;
Event::Event(EventType type, std::string name, uint32_t thread_id,
EventRole role, std::string attr)
Event::Event(EventType type,
std::string name,
uint32_t thread_id,
EventRole role,
std::string attr)
: type_(type),
name_(name),
thread_id_(thread_id),
......@@ -67,8 +75,10 @@ double Event::CudaElapsedMs(const Event &e) const {
#endif
}
RecordEvent::RecordEvent(const char *name, const TracerEventType type,
uint32_t level, const EventRole role) {
RecordEvent::RecordEvent(const char *name,
const TracerEventType type,
uint32_t level,
const EventRole role) {
#ifndef _WIN32
#ifdef PADDLE_WITH_CUDA
if (g_enable_nvprof_hook) {
......@@ -99,8 +109,10 @@ RecordEvent::RecordEvent(const char *name, const TracerEventType type,
start_ns_ = PosixInNsec();
}
RecordEvent::RecordEvent(const std::string &name, const TracerEventType type,
uint32_t level, const EventRole role) {
RecordEvent::RecordEvent(const std::string &name,
const TracerEventType type,
uint32_t level,
const EventRole role) {
#ifndef _WIN32
#ifdef PADDLE_WITH_CUDA
if (g_enable_nvprof_hook) {
......@@ -129,8 +141,10 @@ RecordEvent::RecordEvent(const std::string &name, const TracerEventType type,
start_ns_ = PosixInNsec();
}
RecordEvent::RecordEvent(const std::string &name, const std::string &attr,
const TracerEventType type, uint32_t level,
RecordEvent::RecordEvent(const std::string &name,
const std::string &attr,
const TracerEventType type,
uint32_t level,
const EventRole role) {
#ifndef _WIN32
#ifdef PADDLE_WITH_CUDA
......@@ -191,15 +205,15 @@ void RecordEvent::End() {
if (LIKELY(FLAGS_enable_host_event_recorder_hook && is_enabled_)) {
uint64_t end_ns = PosixInNsec();
if (LIKELY(shallow_copy_name_ != nullptr)) {
HostEventRecorder::GetInstance().RecordEvent(
HostEventRecorder<CommonEvent>::GetInstance().RecordEvent(
shallow_copy_name_, start_ns_, end_ns, role_, type_);
} else if (name_ != nullptr) {
if (attr_ == nullptr) {
HostEventRecorder::GetInstance().RecordEvent(*name_, start_ns_, end_ns,
role_, type_);
HostEventRecorder<CommonEvent>::GetInstance().RecordEvent(
*name_, start_ns_, end_ns, role_, type_);
} else {
HostEventRecorder::GetInstance().RecordEvent(*name_, start_ns_, end_ns,
role_, type_, *attr_);
HostEventRecorder<CommonEvent>::GetInstance().RecordEvent(
*name_, start_ns_, end_ns, role_, type_, *attr_);
delete attr_;
}
delete name_;
......@@ -214,8 +228,8 @@ void RecordEvent::End() {
DeviceTracer *tracer = GetDeviceTracer();
if (tracer) {
uint64_t end_ns = PosixInNsec();
tracer->AddCPURecords(CurAnnotationName(), start_ns_, end_ns, BlockDepth(),
g_thread_id);
tracer->AddCPURecords(
CurAnnotationName(), start_ns_, end_ns, BlockDepth(), g_thread_id);
}
ClearCurAnnotation();
PopEvent(*name_, role_);
......@@ -225,30 +239,96 @@ void RecordEvent::End() {
is_enabled_ = false;
}
RecordInstantEvent::RecordInstantEvent(const char *name, TracerEventType type,
RecordInstantEvent::RecordInstantEvent(const char *name,
TracerEventType type,
uint32_t level) {
if (UNLIKELY(HostTraceLevel::GetInstance().NeedTrace(level) == false)) {
return;
}
auto start_end_ns = PosixInNsec();
HostEventRecorder::GetInstance().RecordEvent(name, start_end_ns, start_end_ns,
EventRole::kOrdinary, type);
HostEventRecorder<CommonEvent>::GetInstance().RecordEvent(
name, start_end_ns, start_end_ns, EventRole::kOrdinary, type);
}
RecordOpInfoSupplement::RecordOpInfoSupplement(
const std::string &type,
const framework::AttributeMap &attrs,
const framework::InferShapeContext &shape_ctx,
const framework::RuntimeContext &ctx) {
if (FLAGS_enable_host_event_recorder_hook == false) {
return;
}
std::map<std::string, std::vector<framework::DDim>> input_shapes;
std::map<std::string, std::vector<framework::proto::VarType::Type>> dtypes;
for (auto it = ctx.inputs.begin(); it != ctx.inputs.end(); it++) {
input_shapes[it->first] = shape_ctx.GetInputsDim(it->first);
dtypes[it->first] = shape_ctx.GetInputsVarType(it->first);
}
const std::vector<std::string> *callstack_ptr = nullptr;
std::vector<std::string> callstack;
auto iter = attrs.find(
framework::OpProtoAndCheckerMaker::OpCreationCallstackAttrName());
if (iter != attrs.end()) {
callstack_ptr = &BOOST_GET_CONST(std::vector<std::string>, iter->second);
callstack = *callstack_ptr;
}
HostEventRecorder<OperatorSupplementOriginEvent>::GetInstance().RecordEvent(
PosixInNsec(), type, input_shapes, dtypes, callstack);
}
RecordOpInfoSupplement::RecordOpInfoSupplement(
const std::string &type,
const framework::AttributeMap &attrs,
const framework::InferShapeContext &shape_ctx,
const phi::KernelSignature &kernel_signature) {
if (FLAGS_enable_host_event_recorder_hook == false) {
return;
}
std::map<std::string, std::vector<framework::DDim>> input_shapes;
std::map<std::string, std::vector<framework::proto::VarType::Type>> dtypes;
for (auto it = kernel_signature.input_names.begin();
it != kernel_signature.input_names.end();
it++) {
std::string input_name(*it);
if (shape_ctx.HasInputs(input_name)) {
input_shapes[input_name] = shape_ctx.GetInputsDim(input_name);
dtypes[input_name] = shape_ctx.GetInputsVarType(input_name);
}
}
const std::vector<std::string> *callstack_ptr = nullptr;
std::vector<std::string> callstack;
auto iter = attrs.find(
framework::OpProtoAndCheckerMaker::OpCreationCallstackAttrName());
if (iter != attrs.end()) {
callstack_ptr = &BOOST_GET_CONST(std::vector<std::string>, iter->second);
callstack = *callstack_ptr;
}
HostEventRecorder<OperatorSupplementOriginEvent>::GetInstance().RecordEvent(
PosixInNsec(), type, input_shapes, dtypes, callstack);
}
void MemEvenRecorder::PushMemRecord(const void *ptr, const Place &place,
void MemEvenRecorder::PushMemRecord(const void *ptr,
const Place &place,
size_t size) {
if (g_state == ProfilerState::kDisabled) return;
if (g_state == ProfilerState::kDisabled) {
return;
}
std::lock_guard<std::mutex> guard(mtx_);
auto &events = address_memevent_[place];
PADDLE_ENFORCE_EQ(events.count(ptr), 0,
PADDLE_ENFORCE_EQ(events.count(ptr),
0,
platform::errors::InvalidArgument(
"The Place can't exist in the stage of PushMemRecord"));
events.emplace(ptr, std::unique_ptr<RecordMemEvent>(
new MemEvenRecorder::RecordMemEvent(place, size)));
events.emplace(ptr,
std::unique_ptr<RecordMemEvent>(
new MemEvenRecorder::RecordMemEvent(place, size)));
}
void MemEvenRecorder::PopMemRecord(const void *ptr, const Place &place) {
if (g_state == ProfilerState::kDisabled) return;
if (g_state == ProfilerState::kDisabled) {
return;
}
std::lock_guard<std::mutex> guard(mtx_);
auto &events = address_memevent_[place];
auto iter = events.find(ptr);
......@@ -278,8 +358,13 @@ MemEvenRecorder::RecordMemEvent::~RecordMemEvent() {
auto annotation_free = CurAnnotationName();
if (tracer) {
tracer->AddMemInfoRecord(start_ns_, end_ns_, bytes_, place_, alloc_in_,
annotation_free, g_mem_thread_id);
tracer->AddMemInfoRecord(start_ns_,
end_ns_,
bytes_,
place_,
alloc_in_,
annotation_free,
g_mem_thread_id);
}
PopMemEvent(start_ns_, end_ns_, bytes_, place_, annotation_free);
}
......@@ -306,44 +391,62 @@ RecordBlock::~RecordBlock() {
if (tracer) {
// We try to put all blocks at the same nested depth in the
// same timeline lane. and distinguish the using thread_id.
tracer->AddCPURecords(name_, start_ns_, PosixInNsec(), BlockDepth(),
g_thread_id);
tracer->AddCPURecords(
name_, start_ns_, PosixInNsec(), BlockDepth(), g_thread_id);
}
ClearCurBlock();
}
void PushMemEvent(uint64_t start_ns, uint64_t end_ns, size_t bytes,
const Place &place, const std::string &annotation) {
GetMemEventList().Record(EventType::kPushRange, start_ns, end_ns, bytes,
place, g_mem_thread_id, annotation);
}
void PopMemEvent(uint64_t start_ns, uint64_t end_ns, size_t bytes,
const Place &place, const std::string &annotation) {
GetMemEventList().Record(EventType::kPopRange, start_ns, end_ns, bytes, place,
g_mem_thread_id, annotation);
void PushMemEvent(uint64_t start_ns,
uint64_t end_ns,
size_t bytes,
const Place &place,
const std::string &annotation) {
GetMemEventList().Record(EventType::kPushRange,
start_ns,
end_ns,
bytes,
place,
g_mem_thread_id,
annotation);
}
void PopMemEvent(uint64_t start_ns,
uint64_t end_ns,
size_t bytes,
const Place &place,
const std::string &annotation) {
GetMemEventList().Record(EventType::kPopRange,
start_ns,
end_ns,
bytes,
place,
g_mem_thread_id,
annotation);
}
void Mark(const std::string &name) {
if (FLAGS_enable_host_event_recorder_hook) {
HostEventRecorder::GetInstance().RecordEvent(
HostEventRecorder<CommonEvent>::GetInstance().RecordEvent(
name, 0, 0, EventRole::kOrdinary, TracerEventType::UserDefined);
return;
}
GetEventList().Record(EventType::kMark, name, g_thread_id);
}
Event *PushEvent(const std::string &name, const EventRole role,
Event *PushEvent(const std::string &name,
const EventRole role,
std::string attr) {
return GetEventList().Record(EventType::kPushRange, name, g_thread_id, role,
attr);
return GetEventList().Record(
EventType::kPushRange, name, g_thread_id, role, attr);
}
void PopEvent(const std::string &name, const EventRole role, std::string attr) {
GetEventList().Record(EventType::kPopRange, name, g_thread_id, role, attr);
}
void EnableProfiler(ProfilerState state) {
PADDLE_ENFORCE_NE(state, ProfilerState::kDisabled,
PADDLE_ENFORCE_NE(state,
ProfilerState::kDisabled,
platform::errors::InvalidArgument(
"Can't enable profiling, since the input state is"
"ProfilerState::kDisabled"));
......@@ -379,7 +482,8 @@ void ResetProfiler() {
(*it)->Clear();
}
for (auto it = g_all_mem_event_lists.begin();
it != g_all_mem_event_lists.end(); ++it) {
it != g_all_mem_event_lists.end();
++it) {
(*it)->Clear();
}
}
......@@ -521,7 +625,8 @@ void DisableHostEventRecorder() {
std::string PrintHostEvents() {
std::ostringstream oss;
auto host_evt_sec = HostEventRecorder::GetInstance().GatherEvents();
auto host_evt_sec =
HostEventRecorder<CommonEvent>::GetInstance().GatherEvents();
for (const auto &thr_evt_sec : host_evt_sec.thr_sections) {
oss << thr_evt_sec.thread_id << std::endl;
for (const auto &evt : thr_evt_sec.events) {
......@@ -533,8 +638,9 @@ std::string PrintHostEvents() {
return oss.str();
}
static void EmulateEventPushAndPop(const HostEventSection &host_sec,
std::map<uint64_t, ThreadEvents> *out) {
static void EmulateEventPushAndPop(
const HostEventSection<CommonEvent> &host_sec,
std::map<uint64_t, ThreadEvents> *out) {
for (const auto &thr_sec : host_sec.thr_sections) {
uint64_t tid = thr_sec.thread_id;
auto cur_thr_list = std::make_shared<EventList<Event>>();
......@@ -573,15 +679,16 @@ static void EmulateEventPushAndPop(const HostEventSection &host_sec,
std::string name =
prefix_stk.empty() ? evt.name : prefix_stk.top() + "/" + evt.name;
const char *attr = (evt.attr == nullptr ? "none" : evt.attr);
Event *orig_evt = cur_thr_list->Record(EventType::kPushRange, name, tid,
evt.role, attr);
Event *orig_evt = cur_thr_list->Record(
EventType::kPushRange, name, tid, evt.role, attr);
(*out)[tid][evt.end_ns] = std::make_pair(orig_evt, evt.start_ns);
cur_thr_list->Record(EventType::kPopRange, name, tid, evt.role, attr);
}
}
}
static void EmulateCPURecordsAdd(const HostEventSection &host_sec) {
static void EmulateCPURecordsAdd(
const HostEventSection<CommonEvent> &host_sec) {
DeviceTracer *tracer = GetDeviceTracer();
if (tracer == nullptr) {
return;
......@@ -589,8 +696,8 @@ static void EmulateCPURecordsAdd(const HostEventSection &host_sec) {
for (const auto &thr_sec : host_sec.thr_sections) {
uint64_t tid = thr_sec.thread_id;
for (const auto &evt : thr_sec.events) {
tracer->AddCPURecords(evt.name, evt.start_ns, evt.end_ns, BlockDepth(),
tid);
tracer->AddCPURecords(
evt.name, evt.start_ns, evt.end_ns, BlockDepth(), tid);
}
}
}
......@@ -609,10 +716,11 @@ static std::map<uint64_t, ThreadEvents> DockHostEventRecorderHostPart() {
if (FLAGS_enable_host_event_recorder_hook == false) {
return thr_events;
}
auto host_evt_sec = HostEventRecorder::GetInstance().GatherEvents();
auto host_evt_sec =
HostEventRecorder<CommonEvent>::GetInstance().GatherEvents();
EmulateEventPushAndPop(host_evt_sec, &thr_events);
EmulateCPURecordsAdd(host_evt_sec);
return std::move(thr_events);
return thr_events;
}
static void DockHostEventRecorderDevicePart(
......
......@@ -30,6 +30,7 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.pb.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/fluid/platform/profiler/supplement_tracing.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#endif
......@@ -160,7 +161,8 @@ struct EventList {
std::vector<T> Reduce() {
std::vector<T> result;
for (auto& block : event_blocks) {
result.insert(result.begin(), std::make_move_iterator(block.begin()),
result.insert(result.begin(),
std::make_move_iterator(block.begin()),
std::make_move_iterator(block.end()));
}
event_blocks.clear();
......@@ -173,13 +175,21 @@ struct EventList {
};
void Mark(const std::string& name);
void PushMemEvent(uint64_t start_ns, uint64_t end_ns, size_t bytes,
const Place& place, const std::string& annotation);
void PopMemEvent(uint64_t start_ns, uint64_t end_ns, size_t bytes,
const Place& place, const std::string& annotation);
Event* PushEvent(const std::string& name, const EventRole role,
void PushMemEvent(uint64_t start_ns,
uint64_t end_ns,
size_t bytes,
const Place& place,
const std::string& annotation);
void PopMemEvent(uint64_t start_ns,
uint64_t end_ns,
size_t bytes,
const Place& place,
const std::string& annotation);
Event* PushEvent(const std::string& name,
const EventRole role,
const std::string attr = "none");
void PopEvent(const std::string& name, const EventRole role,
void PopEvent(const std::string& name,
const EventRole role,
const std::string attr = "none");
// Return the event list of all threads. Assumed the returned value calls
// event_lists, event_lists[i][j] represents the j-th Event of i-th thread.
......
cc_library(host_tracer SRCS host_tracer.cc DEPS enforce)
cc_library(cuda_tracer SRCS cuda_tracer.cc cupti_data_process.cc DEPS workqueue_utils enforce glog)
cc_library(
host_tracer
SRCS host_tracer.cc
DEPS enforce ddim var_type_traits)
cc_library(
cuda_tracer
SRCS cuda_tracer.cc cupti_data_process.cc
DEPS workqueue_utils enforce glog)
add_subdirectory(mlu)
cc_library(event_node SRCS event_node.cc DEPS enforce)
cc_library(profiler_utils SRCS utils.cc DEPS enforce glog)
cc_library(
event_node
SRCS event_node.cc
DEPS enforce place)
cc_library(
profiler_utils
SRCS utils.cc
DEPS enforce glog)
add_subdirectory(dump)
cc_library(profiler_logger SRCS chrometracing_logger.cc dump/serialization_logger.cc dump/deserialization_reader.cc DEPS nodetreeproto event_node profiler_utils)
cc_library(event_bind SRCS event_python.cc DEPS profiler_logger)
cc_library(cpu_utilization SRCS cpu_utilization.cc DEPS cpu_info os_info enforce glog)
cc_library(new_profiler SRCS profiler.cc DEPS host_tracer cuda_tracer profiler_utils cpu_utilization event_bind mlu_tracer)
cc_test(test_event_node SRCS test_event_node.cc DEPS event_node profiler_logger)
cc_test(test_extra_info SRCS test_extra_info.cc DEPS profiler_utils)
cc_test(test_serialization_logger SRCS dump/test_serialization_logger.cc DEPS event_bind)
cc_test(new_profiler_test SRCS profiler_test.cc DEPS new_profiler)
cc_library(
profiler_logger
SRCS chrometracing_logger.cc dump/serialization_logger.cc
dump/deserialization_reader.cc
DEPS nodetreeproto event_node profiler_utils)
cc_library(
event_bind
SRCS event_python.cc
DEPS profiler_logger)
cc_library(
cpu_utilization
SRCS cpu_utilization.cc
DEPS cpu_info os_info enforce glog)
cc_library(
new_profiler
SRCS profiler.cc
DEPS host_tracer cuda_tracer profiler_utils cpu_utilization event_bind
mlu_tracer)
cc_test(
test_event_node
SRCS test_event_node.cc
DEPS event_node profiler_logger)
cc_test(
test_extra_info
SRCS test_extra_info.cc
DEPS profiler_utils)
cc_test(
test_serialization_logger
SRCS dump/test_serialization_logger.cc
DEPS event_bind)
cc_test(
new_profiler_test
SRCS profiler_test.cc
DEPS new_profiler)
......@@ -15,6 +15,7 @@ limitations under the License. */
#include <cstdio>
#include <ctime>
#include <limits>
#include <regex>
#include "glog/logging.h"
......@@ -128,27 +129,32 @@ void ChromeTracingLogger::LogMemTraceEventNode(
std::string(
R"JSON(
{
"name": "[memory]", "pid": %lld, "tid": "%lld",
"name": "[memory]", "pid": %lld, "tid": "%lld(C++)",
"ts": %lld,
"ph": "i", "cat": "%s",
"args": {
"place": "%s",
"addr": "%llu",
"increase_bytes": %lld,
"current_allocated": %llu,
"current_reserved": %llu,
"increase_bytes": %lld
"peak_allocated": %llu,
"peak_reserved": %llu
}
},
)JSON"),
mem_node.ProcessId(),
mem_node.ThreadId(),
mem_node.TimeStampNs(),
nsToUs(mem_node.TimeStampNs()),
StringTracerMemEventType(mem_node.Type()),
mem_node.Place().c_str(),
mem_node.Addr(),
mem_node.IncreaseBytes(),
mem_node.CurrentAllocated(),
mem_node.CurrentReserved(),
mem_node.IncreaseBytes());
mem_node.PeakAllocated(),
mem_node.PeakReserved());
pid_tid_set_.insert({mem_node.ProcessId(), mem_node.ThreadId()});
}
void ChromeTracingLogger::LogHostTraceEventNode(
......@@ -172,6 +178,8 @@ void ChromeTracingLogger::LogHostTraceEventNode(
input_shapes = op_supplement_node->InputShapes();
input_dtypes = op_supplement_node->Dtypes();
callstack = op_supplement_node->CallStack();
callstack = std::regex_replace(callstack, std::regex("\""), "\'");
callstack = std::regex_replace(callstack, std::regex("\n"), "\\n");
}
switch (host_node.Type()) {
case TracerEventType::ProfileStep:
......
......@@ -17,16 +17,22 @@
#include <cstring>
#include <functional>
#include <string>
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/event.h" // import EventRole, TODO(TIEXING): remove later
#include "paddle/fluid/platform/profiler/trace_event.h"
#include "paddle/phi/core/ddim.h"
namespace paddle {
namespace platform {
struct CommonEvent {
public:
CommonEvent(const char *name, uint64_t start_ns, uint64_t end_ns,
EventRole role, TracerEventType type)
CommonEvent(const char *name,
uint64_t start_ns,
uint64_t end_ns,
EventRole role,
TracerEventType type)
: name(name),
start_ns(start_ns),
end_ns(end_ns),
......@@ -34,8 +40,12 @@ struct CommonEvent {
type(type) {}
CommonEvent(std::function<void *(size_t)> arena_allocator,
const std::string &name_str, uint64_t start_ns, uint64_t end_ns,
EventRole role, TracerEventType type, const std::string &attr_str)
const std::string &name_str,
uint64_t start_ns,
uint64_t end_ns,
EventRole role,
TracerEventType type,
const std::string &attr_str)
: start_ns(start_ns), end_ns(end_ns), role(role), type(type) {
auto buf = static_cast<char *>(arena_allocator(name_str.length() + 1));
strncpy(buf, name_str.c_str(), name_str.length() + 1);
......@@ -46,8 +56,11 @@ struct CommonEvent {
}
CommonEvent(std::function<void *(size_t)> arena_allocator,
const std::string &name_str, uint64_t start_ns, uint64_t end_ns,
EventRole role, TracerEventType type)
const std::string &name_str,
uint64_t start_ns,
uint64_t end_ns,
EventRole role,
TracerEventType type)
: start_ns(start_ns), end_ns(end_ns), role(role), type(type) {
auto buf = static_cast<char *>(arena_allocator(name_str.length() + 1));
strncpy(buf, name_str.c_str(), name_str.length() + 1);
......@@ -62,5 +75,32 @@ struct CommonEvent {
const char *attr = nullptr; // not owned, designed for performance
};
struct OperatorSupplementOriginEvent {
public:
OperatorSupplementOriginEvent(
std::function<void *(size_t)> arena_allocator,
uint64_t timestamp_ns,
const std::string &type_name,
const std::map<std::string, std::vector<framework::DDim>> &input_shapes,
const std::map<std::string, std::vector<framework::proto::VarType::Type>>
&dtypes,
const std::vector<std::string> callstack)
: timestamp_ns(timestamp_ns),
input_shapes(input_shapes),
dtypes(dtypes),
callstack(callstack) {
auto buf = static_cast<char *>(arena_allocator(type_name.length() + 1));
strncpy(buf, type_name.c_str(), type_name.length() + 1);
op_type = buf;
}
uint64_t timestamp_ns;
const char *op_type = nullptr; // not owned, designed for performance
// input shapes
std::map<std::string, std::vector<framework::DDim>> input_shapes;
std::map<std::string, std::vector<framework::proto::VarType::Type>> dtypes;
// call stack
const std::vector<std::string> callstack;
};
} // namespace platform
} // namespace paddle
......@@ -45,7 +45,8 @@ std::unique_ptr<ProfilerResult> DeserializationReader::Parse() {
ExtraInfo extrainfo;
for (auto indx = 0; indx < node_trees_proto_->extra_info_size(); indx++) {
ExtraInfoMap extra_info_map = node_trees_proto_->extra_info(indx);
extrainfo.AddExtraInfo(extra_info_map.key(), std::string("%s"),
extrainfo.AddExtraInfo(extra_info_map.key(),
std::string("%s"),
extra_info_map.value().c_str());
}
// restore NodeTrees
......@@ -90,6 +91,26 @@ std::unique_ptr<ProfilerResult> DeserializationReader::Parse() {
device_node); // insert into runtime_node
}
}
// handle mem node
for (int mem_node_index = 0;
mem_node_index < host_node_proto.mem_nodes_size();
mem_node_index++) {
const MemTraceEventNodeProto& mem_node_proto =
host_node_proto.mem_nodes(mem_node_index);
MemTraceEventNode* mem_node = RestoreMemTraceEventNode(mem_node_proto);
host_node->AddMemNode(mem_node);
}
// handle op supplement node
for (int op_supplement_node_index = 0;
op_supplement_node_index <
host_node_proto.op_supplement_nodes_size();
op_supplement_node_index++) {
const OperatorSupplementEventNodeProto& op_supplement_node_proto =
host_node_proto.op_supplement_nodes(op_supplement_node_index);
OperatorSupplementEventNode* op_supplement_node =
RestoreOperatorSupplementEventNode(op_supplement_node_proto);
host_node->SetOperatorSupplementNode(op_supplement_node);
}
}
// restore parent-child relationship
for (auto it = child_parent_map.begin(); it != child_parent_map.end();
......@@ -174,6 +195,64 @@ HostTraceEventNode* DeserializationReader::RestoreHostTraceEventNode(
return new HostTraceEventNode(host_event);
}
MemTraceEventNode* DeserializationReader::RestoreMemTraceEventNode(
const MemTraceEventNodeProto& mem_node_proto) {
const MemTraceEventProto& mem_event_proto = mem_node_proto.mem_event();
MemTraceEvent mem_event;
mem_event.timestamp_ns = mem_event_proto.timestamp_ns();
mem_event.addr = mem_event_proto.addr();
mem_event.type = static_cast<TracerMemEventType>(mem_event_proto.type());
mem_event.process_id = mem_event_proto.process_id();
mem_event.thread_id = mem_event_proto.thread_id();
mem_event.increase_bytes = mem_event_proto.increase_bytes();
mem_event.place = mem_event_proto.place();
mem_event.current_allocated = mem_event_proto.current_allocated();
mem_event.current_reserved = mem_event_proto.current_reserved();
mem_event.peak_allocated = mem_event_proto.peak_allocated();
mem_event.peak_reserved = mem_event_proto.peak_reserved();
return new MemTraceEventNode(mem_event);
}
OperatorSupplementEventNode*
DeserializationReader::RestoreOperatorSupplementEventNode(
const OperatorSupplementEventNodeProto& op_supplement_node_proto) {
const OperatorSupplementEventProto& op_supplement_event_proto =
op_supplement_node_proto.op_supplement_event();
OperatorSupplementEvent op_supplement_event;
op_supplement_event.timestamp_ns = op_supplement_event_proto.timestamp_ns();
op_supplement_event.op_type = op_supplement_event_proto.op_type();
op_supplement_event.callstack = op_supplement_event_proto.callstack();
op_supplement_event.process_id = op_supplement_event_proto.process_id();
op_supplement_event.thread_id = op_supplement_event_proto.thread_id();
std::map<std::string, std::vector<std::vector<int64_t>>> input_shapes;
std::map<std::string, std::vector<std::string>> dtypes;
auto input_shape_proto = op_supplement_event_proto.input_shapes();
for (int i = 0; i < input_shape_proto.key_size(); i++) {
auto input_shape_vec = input_shapes[input_shape_proto.key(i)];
auto shape_vectors_proto = input_shape_proto.shape_vecs(i);
for (int j = 0; j < shape_vectors_proto.shapes_size(); j++) {
auto shape_vector_proto = shape_vectors_proto.shapes(j);
std::vector<int64_t> shape;
for (int k = 0; k < shape_vector_proto.size_size(); k++) {
shape.push_back(shape_vector_proto.size(k));
}
input_shape_vec.push_back(shape);
}
}
op_supplement_event.input_shapes = input_shapes;
auto dtype_proto = op_supplement_event_proto.dtypes();
for (int i = 0; i < dtype_proto.key_size(); i++) {
auto dtype_vec = dtypes[dtype_proto.key(i)];
auto dtype_vec_proto = dtype_proto.dtype_vecs(i);
for (int j = 0; j < dtype_vec_proto.dtype_size(); j++) {
auto dtype_string = dtype_vec_proto.dtype(j);
dtype_vec.push_back(dtype_string);
}
}
op_supplement_event.dtypes = dtypes;
return new OperatorSupplementEventNode(op_supplement_event);
}
KernelEventInfo DeserializationReader::HandleKernelEventInfoProto(
const DeviceTraceEventProto& device_event_proto) {
const KernelEventInfoProto& kernel_info_proto =
......@@ -203,11 +282,14 @@ MemcpyEventInfo DeserializationReader::HandleMemcpyEventInfoProto(
device_event_proto.memcpy_info();
MemcpyEventInfo memcpy_info;
memcpy_info.num_bytes = memcpy_info_proto.num_bytes();
std::strncpy(memcpy_info.copy_kind, memcpy_info_proto.copy_kind().c_str(),
std::strncpy(memcpy_info.copy_kind,
memcpy_info_proto.copy_kind().c_str(),
kMemKindMaxLen - 1);
std::strncpy(memcpy_info.src_kind, memcpy_info_proto.src_kind().c_str(),
std::strncpy(memcpy_info.src_kind,
memcpy_info_proto.src_kind().c_str(),
kMemKindMaxLen - 1);
std::strncpy(memcpy_info.dst_kind, memcpy_info_proto.dst_kind().c_str(),
std::strncpy(memcpy_info.dst_kind,
memcpy_info_proto.dst_kind().c_str(),
kMemKindMaxLen - 1);
return memcpy_info;
}
......@@ -218,7 +300,8 @@ MemsetEventInfo DeserializationReader::HandleMemsetEventInfoProto(
device_event_proto.memset_info();
MemsetEventInfo memset_info;
memset_info.num_bytes = memset_info_proto.num_bytes();
std::strncpy(memset_info.memory_kind, memset_info_proto.memory_kind().c_str(),
std::strncpy(memset_info.memory_kind,
memset_info_proto.memory_kind().c_str(),
kMemKindMaxLen - 1);
memset_info.value = memset_info_proto.value();
return memset_info;
......
......@@ -36,6 +36,9 @@ class DeserializationReader {
KernelEventInfo HandleKernelEventInfoProto(const DeviceTraceEventProto&);
MemcpyEventInfo HandleMemcpyEventInfoProto(const DeviceTraceEventProto&);
MemsetEventInfo HandleMemsetEventInfoProto(const DeviceTraceEventProto&);
MemTraceEventNode* RestoreMemTraceEventNode(const MemTraceEventNodeProto&);
OperatorSupplementEventNode* RestoreOperatorSupplementEventNode(
const OperatorSupplementEventNodeProto&);
std::string filename_;
std::ifstream input_file_stream_;
NodeTreesProto* node_trees_proto_;
......
......@@ -46,6 +46,19 @@ enum TracerEventTypeProto {
PythonOp = 13;
// Used to mark python level userdefined
PythonUserDefined = 14;
// Used to mark mlu runtime record returned by cnpapi
MluRuntime = 15;
};
enum TracerMemEventTypeProto {
// Used to mark memory allocation which is managed by paddle
Allocate = 0;
// Used to mark memory free which is managed by paddle
Free = 1;
// Used to mark reserved memory allocation which is applied from device.
ReservedAllocate = 2;
// Used to mark reserved memory free which is released to device.
ReservedFree = 3;
};
message KernelEventInfoProto {
......@@ -121,6 +134,62 @@ message HostTraceEventProto {
required uint64 thread_id = 6;
}
message MemTraceEventProto {
// timestamp of the record
required uint64 timestamp_ns = 1;
// memory manipulation type
required TracerMemEventTypeProto type = 2;
// memory addr of allocation or free
required uint64 addr = 3;
// process id of the record
required uint64 process_id = 4;
// thread id of the record
required uint64 thread_id = 5;
// increase bytes after this manipulation, allocation for sign +, free for
// sign -
required int64 increase_bytes = 6;
// place
required string place = 7;
// current total allocated memory
required uint64 current_allocated = 8;
// current total reserved memory
required uint64 current_reserved = 9;
// current peak allocated memory
required uint64 peak_allocated = 10;
// current peak reserved memory
required uint64 peak_reserved = 11;
}
message OperatorSupplementEventProto {
// timestamp of the record
required uint64 timestamp_ns = 1;
// op type name
required string op_type = 2;
// process id of the record
required uint64 process_id = 3;
// thread id of the record
required uint64 thread_id = 4;
// input shapes
message input_shape_proto {
repeated string key = 1;
message shape_vector {
message shape { repeated uint64 size = 1; }
repeated shape shapes = 1;
}
repeated shape_vector shape_vecs = 2;
}
required input_shape_proto input_shapes = 5;
// dtypes
message dtype_proto {
repeated string key = 1;
message dtype_vector { repeated string dtype = 1; }
repeated dtype_vector dtype_vecs = 2;
}
required dtype_proto dtypes = 6;
// call stack
required string callstack = 7;
}
message CudaRuntimeTraceEventProto {
// record name
required string name = 1;
......@@ -166,6 +235,12 @@ message DeviceTraceEventProto {
}
}
message OperatorSupplementEventNodeProto {
required OperatorSupplementEventProto op_supplement_event = 1;
}
message MemTraceEventNodeProto { required MemTraceEventProto mem_event = 1; }
message DeviceTraceEventNodeProto {
required DeviceTraceEventProto device_event = 1;
}
......@@ -180,6 +255,9 @@ message HostTraceEventNodeProto {
required int64 parentid = 2;
required HostTraceEventProto host_trace_event = 3;
repeated CudaRuntimeTraceEventNodeProto runtime_nodes = 4;
// below is added in version 1.0.1
repeated MemTraceEventNodeProto mem_nodes = 5;
repeated OperatorSupplementEventNodeProto op_supplement_nodes = 6;
}
message ThreadNodeTreeProto {
......
......@@ -20,19 +20,19 @@ namespace paddle {
namespace platform {
static const char* kDefaultFilename = "pid_%s_time_%s.paddle_trace.pb";
static const char* version = "1.0.0";
static const char* version = "1.0.1";
static uint32_t span_indx = 0;
static std::string DefaultFileName() {
auto pid = GetProcessId();
return string_format(std::string(kDefaultFilename), pid,
GetStringFormatLocalTime().c_str());
return string_format(
std::string(kDefaultFilename), pid, GetStringFormatLocalTime().c_str());
}
void SerializationLogger::OpenFile() {
output_file_stream_.open(filename_, std::ofstream::out |
std::ofstream::trunc |
std::ofstream::binary);
output_file_stream_.open(
filename_,
std::ofstream::out | std::ofstream::trunc | std::ofstream::binary);
if (!output_file_stream_) {
LOG(WARNING) << "Unable to open file for writing profiling data."
<< std::endl;
......@@ -50,7 +50,8 @@ void SerializationLogger::LogNodeTrees(const NodeTrees& node_trees) {
thread2host_event_nodes = node_trees.Traverse(true);
for (auto it = thread2host_event_nodes.begin();
it != thread2host_event_nodes.end(); ++it) {
it != thread2host_event_nodes.end();
++it) {
// 1. order every node an index, every node a parent
std::map<HostTraceEventNode*, int64_t> node_index_map;
std::map<HostTraceEventNode*, int64_t> node_parent_map;
......@@ -64,7 +65,8 @@ void SerializationLogger::LogNodeTrees(const NodeTrees& node_trees) {
for (auto hostnode = it->second.begin(); hostnode != it->second.end();
++hostnode) {
for (auto childnode = (*hostnode)->GetChildren().begin();
childnode != (*hostnode)->GetChildren().end(); ++childnode) {
childnode != (*hostnode)->GetChildren().end();
++childnode) {
node_parent_map[(*childnode)] =
node_index_map[(*hostnode)]; // mark each node's parent
}
......@@ -106,10 +108,36 @@ void SerializationLogger::LogNodeTrees(const NodeTrees& node_trees) {
(*devicenode)->LogMe(this); // fill detail information
}
}
for (auto memnode = (*hostnode)->GetMemTraceEventNodes().begin();
memnode != (*hostnode)->GetMemTraceEventNodes().end();
++memnode) {
MemTraceEventNodeProto* mem_node_proto =
current_host_trace_event_node_proto_->add_mem_nodes();
current_mem_trace_event_node_proto_ = mem_node_proto;
(*memnode)->LogMe(this);
}
}
}
}
void SerializationLogger::LogMemTraceEventNode(
const MemTraceEventNode& mem_node) {
MemTraceEventProto* mem_trace_event = new MemTraceEventProto();
mem_trace_event->set_timestamp_ns(mem_node.TimeStampNs());
mem_trace_event->set_type(
static_cast<TracerMemEventTypeProto>(mem_node.Type()));
mem_trace_event->set_addr(mem_node.Addr());
mem_trace_event->set_process_id(mem_node.ProcessId());
mem_trace_event->set_thread_id(mem_node.ThreadId());
mem_trace_event->set_increase_bytes(mem_node.IncreaseBytes());
mem_trace_event->set_place(mem_node.Place());
mem_trace_event->set_current_allocated(mem_node.CurrentAllocated());
mem_trace_event->set_current_reserved(mem_node.CurrentReserved());
mem_trace_event->set_peak_allocated(mem_node.PeakAllocated());
mem_trace_event->set_peak_reserved(mem_node.PeakReserved());
current_mem_trace_event_node_proto_->set_allocated_mem_event(mem_trace_event);
}
void SerializationLogger::LogHostTraceEventNode(
const HostTraceEventNode& host_node) {
HostTraceEventProto* host_trace_event = new HostTraceEventProto();
......@@ -122,6 +150,63 @@ void SerializationLogger::LogHostTraceEventNode(
host_trace_event->set_thread_id(host_node.ThreadId());
current_host_trace_event_node_proto_->set_allocated_host_trace_event(
host_trace_event);
OperatorSupplementEventNode* op_supplement_event_node =
host_node.GetOperatorSupplementEventNode();
if (op_supplement_event_node != nullptr) {
current_op_supplement_event_node_proto_ =
current_host_trace_event_node_proto_->add_op_supplement_nodes();
OperatorSupplementEventProto* op_supplement_event_proto =
new OperatorSupplementEventProto();
op_supplement_event_proto->set_op_type(op_supplement_event_node->Name());
op_supplement_event_proto->set_timestamp_ns(
op_supplement_event_node->TimeStampNs());
op_supplement_event_proto->set_process_id(
op_supplement_event_node->ProcessId());
op_supplement_event_proto->set_thread_id(
op_supplement_event_node->ThreadId());
op_supplement_event_proto->set_callstack(
op_supplement_event_node->CallStack());
OperatorSupplementEventProto::input_shape_proto* input_shape_proto =
op_supplement_event_proto->mutable_input_shapes();
for (auto it = op_supplement_event_node->InputShapes().begin();
it != op_supplement_event_node->InputShapes().end();
it++) {
input_shape_proto->add_key(it->first);
OperatorSupplementEventProto::input_shape_proto::shape_vector*
shape_vectors_proto = input_shape_proto->add_shape_vecs();
auto shape_vectors = it->second;
for (auto shape_vecs_it = shape_vectors.begin();
shape_vecs_it != shape_vectors.end();
shape_vecs_it++) {
auto shape_vector = *shape_vecs_it;
OperatorSupplementEventProto::input_shape_proto::shape_vector::shape*
shape_proto = shape_vectors_proto->add_shapes();
for (auto shape_it = shape_vector.begin();
shape_it != shape_vector.end();
shape_it++) {
shape_proto->add_size(*shape_it);
}
}
}
OperatorSupplementEventProto::dtype_proto* dtype_proto =
op_supplement_event_proto->mutable_dtypes();
for (auto it = op_supplement_event_node->Dtypes().begin();
it != op_supplement_event_node->Dtypes().end();
it++) {
dtype_proto->add_key(it->first);
OperatorSupplementEventProto::dtype_proto::dtype_vector*
dtype_vector_proto = dtype_proto->add_dtype_vecs();
auto dtype_vector = it->second;
for (auto dtype_it = dtype_vector.begin(); dtype_it != dtype_vector.end();
dtype_it++) {
dtype_vector_proto->add_dtype(*dtype_it);
}
}
current_op_supplement_event_node_proto_->set_allocated_op_supplement_event(
op_supplement_event_proto);
}
}
void SerializationLogger::LogRuntimeTraceEventNode(
......
......@@ -34,6 +34,7 @@ class SerializationLogger : public BaseLogger {
void LogRuntimeTraceEventNode(const CudaRuntimeTraceEventNode&) override;
void LogNodeTrees(const NodeTrees&) override;
void LogMetaInfo(const std::unordered_map<std::string, std::string>);
void LogMemTraceEventNode(const MemTraceEventNode&) override;
private:
void OpenFile();
......@@ -48,6 +49,8 @@ class SerializationLogger : public BaseLogger {
HostTraceEventNodeProto* current_host_trace_event_node_proto_;
CudaRuntimeTraceEventNodeProto* current_runtime_trace_event_node_proto_;
DeviceTraceEventNodeProto* current_device_trace_event_node_proto_;
MemTraceEventNodeProto* current_mem_trace_event_node_proto_;
OperatorSupplementEventNodeProto* current_op_supplement_event_node_proto_;
};
} // namespace platform
......
......@@ -35,6 +35,7 @@ using paddle::platform::ProfilerResult;
using paddle::platform::RuntimeTraceEvent;
using paddle::platform::SerializationLogger;
using paddle::platform::TracerEventType;
using paddle::platform::TracerMemEventType;
TEST(SerializationLoggerTest, dump_case0) {
std::list<HostTraceEvent> host_events;
......@@ -54,6 +55,36 @@ TEST(SerializationLoggerTest, dump_case0) {
std::string("op2"), TracerEventType::Operator, 21000, 30000, 10, 10));
host_events.push_back(HostTraceEvent(
std::string("op3"), TracerEventType::Operator, 31000, 40000, 10, 11));
mem_events.push_back(MemTraceEvent(11500,
0x1000,
TracerMemEventType::Allocate,
10,
10,
50,
"GPU:0",
50,
50,
100,
100));
mem_events.push_back(MemTraceEvent(11900,
0x1000,
TracerMemEventType::Free,
10,
10,
-50,
"GPU:0",
0,
50,
100,
100));
std::map<std::string, std::vector<std::vector<int64_t>>> input_shapes;
std::map<std::string, std::vector<std::string>> dtypes;
input_shapes[std::string("X")].push_back(std::vector<int64_t>{1, 2, 3});
input_shapes[std::string("X")].push_back(std::vector<int64_t>{4, 5, 6, 7});
dtypes[std::string("X")].push_back(std::string("int8"));
dtypes[std::string("X")].push_back(std::string("float32"));
op_supplement_events.push_back(OperatorSupplementEvent(
11600, "op1", input_shapes, dtypes, "op1()", 10, 10));
runtime_events.push_back(RuntimeTraceEvent(
std::string("cudalaunch1"), 15000, 17000, 10, 10, 1, 0));
runtime_events.push_back(RuntimeTraceEvent(
......@@ -128,6 +159,8 @@ TEST(SerializationLoggerTest, dump_case0) {
if ((*it)->Name() == "op1") {
EXPECT_EQ((*it)->GetChildren().size(), 0u);
EXPECT_EQ((*it)->GetRuntimeTraceEventNodes().size(), 2u);
EXPECT_EQ((*it)->GetMemTraceEventNodes().size(), 2u);
EXPECT_NE((*it)->GetOperatorSupplementEventNode(), nullptr);
}
}
for (auto it = thread2_nodes.begin(); it != thread2_nodes.end(); it++) {
......@@ -137,6 +170,7 @@ TEST(SerializationLoggerTest, dump_case0) {
}
}
tree.LogMe(&logger);
logger.LogMetaInfo(std::unordered_map<std::string, std::string>());
}
TEST(SerializationLoggerTest, dump_case1) {
......@@ -224,6 +258,7 @@ TEST(SerializationLoggerTest, dump_case1) {
}
}
tree.LogMe(&logger);
logger.LogMetaInfo(std::unordered_map<std::string, std::string>());
}
TEST(DeserializationReaderTest, restore_case0) {
......@@ -243,6 +278,8 @@ TEST(DeserializationReaderTest, restore_case0) {
if ((*it)->Name() == "op1") {
EXPECT_EQ((*it)->GetChildren().size(), 0u);
EXPECT_EQ((*it)->GetRuntimeTraceEventNodes().size(), 2u);
EXPECT_EQ((*it)->GetMemTraceEventNodes().size(), 2u);
EXPECT_NE((*it)->GetOperatorSupplementEventNode(), nullptr);
}
}
for (auto it = thread2_nodes.begin(); it != thread2_nodes.end(); it++) {
......
......@@ -92,11 +92,9 @@ void NodeTrees::BuildTrees(
++it) {
auto dst_iter =
correlation_id2runtime_event_node.find((*it)->CorrelationId());
PADDLE_ENFORCE_NE(
dst_iter,
correlation_id2runtime_event_node.end(),
platform::errors::NotFound("Unknown device events, "
"no corresponding cuda runtime events"));
if (dst_iter == correlation_id2runtime_event_node.end()) {
continue;
}
dst_iter->second->AddDeviceTraceEventNode(*it);
}
// construct thread2mem_event_nodes
......@@ -375,22 +373,9 @@ HostTraceEventNode* NodeTrees::BuildTreeRelationship(
hasenter = true;
}
(*it)->SetOperatorSupplementNode(*op_supplement_it);
PADDLE_ENFORCE_EQ((*it)->Type(),
TracerEventType::Operator,
platform::errors::PreconditionNotMet(
"Operator supplement events should be embraced "
"by event of type TracerEventType::Operator, "
"but got type TracerEventType::%s",
StringTracerEventType((*it)->Type())));
op_supplement_count += 1;
} else {
if ((*op_supplement_it)->TimeStampNs() > (*it)->EndNs()) {
PADDLE_ENFORCE_LE(op_supplement_count,
1,
platform::errors::PreconditionNotMet(
"One event of TracerEventType::Operator has no "
"more than 1 op supplement event, but got %d.",
op_supplement_count));
lastposition = op_supplement_it;
break;
}
......
......@@ -47,6 +47,8 @@ class MemTraceEventNode {
std::string Place() const { return mem_event_.place; }
uint64_t CurrentAllocated() const { return mem_event_.current_allocated; }
uint64_t CurrentReserved() const { return mem_event_.current_reserved; }
uint64_t PeakAllocated() const { return mem_event_.peak_allocated; }
uint64_t PeakReserved() const { return mem_event_.peak_reserved; }
// member function
void LogMe(BaseLogger* logger) { logger->LogMemTraceEventNode(*this); }
......
......@@ -31,6 +31,9 @@ HostPythonNode::~HostPythonNode() {
for (auto it = device_node_ptrs.begin(); it != device_node_ptrs.end(); ++it) {
delete *it;
}
for (auto it = mem_node_ptrs.begin(); it != mem_node_ptrs.end(); ++it) {
delete *it;
}
}
HostPythonNode* ProfilerResult::CopyTree(HostTraceEventNode* root) {
......@@ -52,7 +55,8 @@ HostPythonNode* ProfilerResult::CopyTree(HostTraceEventNode* root) {
}
// copy its CudaRuntimeTraceEventNode
for (auto runtimenode = root->GetRuntimeTraceEventNodes().begin();
runtimenode != root->GetRuntimeTraceEventNodes().end(); ++runtimenode) {
runtimenode != root->GetRuntimeTraceEventNodes().end();
++runtimenode) {
HostPythonNode* runtime_python_node = new HostPythonNode();
runtime_python_node->name = (*runtimenode)->Name();
runtime_python_node->type = (*runtimenode)->Type();
......@@ -76,6 +80,32 @@ HostPythonNode* ProfilerResult::CopyTree(HostTraceEventNode* root) {
runtime_python_node->device_node_ptrs.push_back(device_python_node);
}
}
// copy MemTraceEventNode
for (auto memnode = root->GetMemTraceEventNodes().begin();
memnode != root->GetMemTraceEventNodes().end();
memnode++) {
MemPythonNode* mem_python_node = new MemPythonNode();
mem_python_node->timestamp_ns = (*memnode)->TimeStampNs();
mem_python_node->addr = (*memnode)->Addr();
mem_python_node->type = (*memnode)->Type();
mem_python_node->process_id = (*memnode)->ProcessId();
mem_python_node->thread_id = (*memnode)->ThreadId();
mem_python_node->increase_bytes = (*memnode)->IncreaseBytes();
mem_python_node->place = (*memnode)->Place();
mem_python_node->current_allocated = (*memnode)->CurrentAllocated();
mem_python_node->current_reserved = (*memnode)->CurrentReserved();
mem_python_node->peak_allocated = (*memnode)->PeakAllocated();
mem_python_node->peak_reserved = (*memnode)->PeakReserved();
host_python_node->mem_node_ptrs.push_back(mem_python_node);
}
// copy OperatorSupplementEventNode's information if exists
OperatorSupplementEventNode* op_supplement_node =
root->GetOperatorSupplementEventNode();
if (op_supplement_node != nullptr) {
host_python_node->input_shapes = op_supplement_node->InputShapes();
host_python_node->dtypes = op_supplement_node->Dtypes();
host_python_node->callstack = op_supplement_node->CallStack();
}
return host_python_node;
}
......@@ -93,7 +123,8 @@ ProfilerResult::ProfilerResult(std::unique_ptr<NodeTrees> tree,
ProfilerResult::~ProfilerResult() {
// delete all root nodes
for (auto it = thread_event_trees_map_.begin();
it != thread_event_trees_map_.end(); ++it) {
it != thread_event_trees_map_.end();
++it) {
delete it->second;
}
}
......
......@@ -43,6 +43,35 @@ struct DevicePythonNode {
uint64_t stream_id;
};
struct MemPythonNode {
MemPythonNode() = default;
~MemPythonNode() {}
// timestamp of the record
uint64_t timestamp_ns;
// memory addr of allocation or free
uint64_t addr;
// memory manipulation type
TracerMemEventType type;
// process id of the record
uint64_t process_id;
// thread id of the record
uint64_t thread_id;
// increase bytes after this manipulation, allocation for sign +, free for
// sign -
int64_t increase_bytes;
// place
std::string place;
// current total allocated memory
uint64_t current_allocated;
// current total reserved memory
uint64_t current_reserved;
// peak allocated memory
uint64_t peak_allocated;
// peak reserved memory
uint64_t peak_reserved;
};
struct HostPythonNode {
HostPythonNode() = default;
~HostPythonNode();
......@@ -58,12 +87,19 @@ struct HostPythonNode {
uint64_t process_id;
// thread id of the record
uint64_t thread_id;
// input shapes
std::map<std::string, std::vector<std::vector<int64_t>>> input_shapes;
std::map<std::string, std::vector<std::string>> dtypes;
// call stack
std::string callstack;
// children node
std::vector<HostPythonNode*> children_node_ptrs;
// runtime node
std::vector<HostPythonNode*> runtime_node_ptrs;
// device node
std::vector<DevicePythonNode*> device_node_ptrs;
// mem node
std::vector<MemPythonNode*> mem_node_ptrs;
};
class ProfilerResult {
......
......@@ -17,10 +17,10 @@
#include <string>
#include <type_traits>
#include <vector>
#include "paddle/fluid/framework/new_executor/workqueue/thread_data_registry.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/os_info.h"
#include "paddle/fluid/platform/profiler/common_event.h"
namespace paddle {
namespace platform {
......@@ -28,9 +28,11 @@ namespace platform {
template <typename HeadType, typename... RestTypes>
struct ContainsStdString
: std::conditional_t<
std::is_same<std::string, std::remove_cv_t<std::remove_reference_t<
HeadType>>>::value,
std::true_type, ContainsStdString<RestTypes...>> {};
std::is_same<
std::string,
std::remove_cv_t<std::remove_reference_t<HeadType>>>::value,
std::true_type,
ContainsStdString<RestTypes...>> {};
template <typename TailType>
struct ContainsStdString<TailType>
......@@ -58,7 +60,7 @@ class EventContainer {
public:
// Record an event
template <typename... Args>
void Record(Args &&... args) {
void Record(Args &&...args) {
DoRecord(ContainsStdString<Args...>(), std::forward<Args>(args)...);
}
......@@ -112,7 +114,7 @@ class EventContainer {
// Record an event with string arguments
template <typename... Args>
void DoRecord(std::true_type, Args &&... args) {
void DoRecord(std::true_type, Args &&...args) {
auto *storage = GetEventStorage();
std::function<void *(size_t)> allocator = [this](size_t size) {
return GetStrBufFromArena(size);
......@@ -122,7 +124,7 @@ class EventContainer {
// Record an event without any string argument
template <typename... Args>
void DoRecord(std::false_type, Args &&... args) {
void DoRecord(std::false_type, Args &&...args) {
auto *storage = GetEventStorage();
new (storage) EventType(std::forward<Args>(args)...);
}
......@@ -181,12 +183,14 @@ char *EventContainer<EventType>::GetStringStorage(size_t sz) {
return storage;
}
template <typename EventType>
struct ThreadEventSection {
std::string thread_name;
uint64_t thread_id;
std::vector<CommonEvent> events;
std::vector<EventType> events;
};
template <typename EventType>
class ThreadEventRecorder {
public:
ThreadEventRecorder() {
......@@ -199,12 +203,12 @@ class ThreadEventRecorder {
public:
// Forward call to EventContainer::Record
template <typename... Args>
void RecordEvent(Args &&... args) {
void RecordEvent(Args &&...args) {
base_evt_cntr_.Record(std::forward<Args>(args)...);
}
ThreadEventSection GatherEvents() {
ThreadEventSection thr_sec;
ThreadEventSection<EventType> GatherEvents() {
ThreadEventSection<EventType> thr_sec;
thr_sec.thread_name = thread_name_;
thr_sec.thread_id = thread_id_;
thr_sec.events = std::move(base_evt_cntr_.Reduce());
......@@ -214,15 +218,17 @@ class ThreadEventRecorder {
private:
uint64_t thread_id_;
std::string thread_name_;
EventContainer<CommonEvent> base_evt_cntr_;
EventContainer<EventType> base_evt_cntr_;
};
template <typename EventType>
struct HostEventSection {
std::string process_name;
uint64_t process_id;
std::vector<ThreadEventSection> thr_sections;
std::vector<ThreadEventSection<EventType>> thr_sections;
};
template <typename EventType>
class HostEventRecorder {
public:
// singleton
......@@ -237,37 +243,51 @@ class HostEventRecorder {
// Do your best to avoid using 'std::string' as the argument type.
// It will cause deep-copy to harm performance.
template <typename... Args>
void RecordEvent(Args &&... args) {
GetThreadLocalRecorder()->RecordEvent(std::forward<Args>(args)...);
void RecordEvent(Args &&...args) {
// Get thread local ThreadEventRecorder
// If not exists, we create a new one.
// Both HostEventRecorder and thread-local varibale in
// ThreadEventRecorderRegistry keep the shared pointer. We add this to
// prevent ThreadEventRecorder being destroyed by thread-local variable in
// ThreadEventRecorderRegistry and lose data.
if (GetThreadLocalRecorder()->get() == nullptr) {
std::shared_ptr<ThreadEventRecorder<EventType>>
thread_event_recorder_ptr =
std::make_shared<ThreadEventRecorder<EventType>>();
*(GetThreadLocalRecorder()) = thread_event_recorder_ptr;
thr_recorders_.push_back(thread_event_recorder_ptr);
}
(*GetThreadLocalRecorder())->RecordEvent(std::forward<Args>(args)...);
}
// thread-unsafe, make sure make sure there is no running tracing.
// Poor performance, call it at the ending
HostEventSection GatherEvents() {
auto thr_recorders =
ThreadEventRecorderRegistry::GetInstance().GetAllThreadDataByRef();
HostEventSection host_sec;
HostEventSection<EventType> GatherEvents() {
HostEventSection<EventType> host_sec;
host_sec.process_id = GetProcessId();
host_sec.thr_sections.reserve(thr_recorders.size());
for (auto &kv : thr_recorders) {
auto &thr_recorder = kv.second.get();
host_sec.thr_sections.emplace_back(
std::move(thr_recorder.GatherEvents()));
host_sec.thr_sections.reserve(thr_recorders_.size());
for (auto &v : thr_recorders_) {
host_sec.thr_sections.emplace_back(std::move(v->GatherEvents()));
}
return host_sec;
}
private:
using ThreadEventRecorderRegistry =
framework::ThreadDataRegistry<ThreadEventRecorder>;
using ThreadEventRecorderRegistry = framework::ThreadDataRegistry<
std::shared_ptr<ThreadEventRecorder<EventType>>>;
HostEventRecorder() = default;
DISABLE_COPY_AND_ASSIGN(HostEventRecorder);
ThreadEventRecorder *GetThreadLocalRecorder() {
std::shared_ptr<ThreadEventRecorder<EventType>> *GetThreadLocalRecorder() {
return ThreadEventRecorderRegistry::GetInstance()
.GetMutableCurrentThreadData();
}
// Hold all thread-local ThreadEventRecorders
// ThreadEventRecorderRegistry and HostEventRecorder both take care of this
// shared pointer. We add this to prevent ThreadEventRecorder being destroyed
// by thread-local variable in ThreadEventRecorderRegistry and lose data.
std::vector<std::shared_ptr<ThreadEventRecorder<EventType>>> thr_recorders_;
};
} // namespace platform
......
......@@ -11,8 +11,10 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/profiler/host_tracer.h"
#include <sstream>
#include "glog/logging.h"
#include "paddle/fluid/platform/flags.h"
#include "paddle/fluid/platform/profiler/common_event.h"
......@@ -20,7 +22,8 @@
// Used to filter events, works like glog VLOG(level).
// RecordEvent will works if host_trace_level >= level.
PADDLE_DEFINE_EXPORTED_int64(host_trace_level, 1,
PADDLE_DEFINE_EXPORTED_int64(host_trace_level,
1,
"RecordEvent will works "
"if host_trace_level >= level.");
......@@ -29,7 +32,7 @@ namespace platform {
namespace {
void ProcessHostEvents(const HostEventSection& host_events,
void ProcessHostEvents(const HostEventSection<CommonEvent>& host_events,
TraceEventCollector* collector) {
for (const auto& thr_sec : host_events.thr_sections) {
uint64_t tid = thr_sec.thread_id;
......@@ -49,6 +52,53 @@ void ProcessHostEvents(const HostEventSection& host_events,
}
}
void ProcessOperatorSupplementEvents(
const HostEventSection<OperatorSupplementOriginEvent>& op_supplement_events,
TraceEventCollector* collector) {
for (const auto& thr_sec : op_supplement_events.thr_sections) {
uint64_t tid = thr_sec.thread_id;
if (thr_sec.thread_name != kDefaultThreadName) {
collector->AddThreadName(tid, thr_sec.thread_name);
}
for (const auto& evt : thr_sec.events) {
OperatorSupplementEvent event;
event.timestamp_ns = evt.timestamp_ns;
event.op_type = evt.op_type;
std::map<std::string, std::vector<std::vector<int64_t>>> input_shapes;
std::map<std::string, std::vector<std::string>> dtypes;
std::string callstack;
for (auto it = evt.input_shapes.begin(); it != evt.input_shapes.end();
it++) {
for (auto idx = 0lu; idx < it->second.size(); idx++) {
input_shapes[it->first].push_back(std::vector<int64_t>());
for (auto dim_idx = 0; dim_idx < it->second.at(idx).size();
dim_idx++) {
input_shapes[it->first][idx].push_back(
it->second.at(idx).at(dim_idx));
}
}
}
for (auto it = evt.dtypes.begin(); it != evt.dtypes.end(); it++) {
for (auto idx = 0lu; idx < it->second.size(); idx++) {
dtypes[it->first].push_back(
framework::proto::VarType::Type_Name(it->second.at(idx)));
}
}
std::ostringstream result_string;
for (auto it = evt.callstack.begin(); it != evt.callstack.end(); it++) {
result_string << (*it) << std::endl;
}
event.input_shapes = input_shapes;
event.dtypes = dtypes;
event.callstack = result_string.str();
event.process_id = op_supplement_events.process_id;
event.thread_id = tid;
collector->AddOperatorSupplementEvent(std::move(event));
}
}
}
} // namespace
void HostTracer::PrepareTracing() {
......@@ -59,16 +109,20 @@ void HostTracer::PrepareTracing() {
void HostTracer::StartTracing() {
PADDLE_ENFORCE_EQ(
state_ == TracerState::READY || state_ == TracerState::STOPED, true,
state_ == TracerState::READY || state_ == TracerState::STOPED,
true,
platform::errors::PreconditionNotMet("TracerState must be READY"));
HostEventRecorder::GetInstance().GatherEvents();
HostEventRecorder<CommonEvent>::GetInstance().GatherEvents();
HostEventRecorder<OperatorSupplementOriginEvent>::GetInstance()
.GatherEvents();
HostTraceLevel::GetInstance().SetLevel(options_.trace_level);
state_ = TracerState::STARTED;
}
void HostTracer::StopTracing() {
PADDLE_ENFORCE_EQ(
state_, TracerState::STARTED,
state_,
TracerState::STARTED,
platform::errors::PreconditionNotMet("TracerState must be STARTED"));
HostTraceLevel::GetInstance().SetLevel(HostTraceLevel::kDisabled);
state_ = TracerState::STOPED;
......@@ -76,11 +130,16 @@ void HostTracer::StopTracing() {
void HostTracer::CollectTraceData(TraceEventCollector* collector) {
PADDLE_ENFORCE_EQ(
state_, TracerState::STOPED,
state_,
TracerState::STOPED,
platform::errors::PreconditionNotMet("TracerState must be STOPED"));
HostEventSection host_events =
HostEventRecorder::GetInstance().GatherEvents();
HostEventSection<CommonEvent> host_events =
HostEventRecorder<CommonEvent>::GetInstance().GatherEvents();
ProcessHostEvents(host_events, collector);
HostEventSection<OperatorSupplementOriginEvent> op_supplement_events =
HostEventRecorder<OperatorSupplementOriginEvent>::GetInstance()
.GatherEvents();
ProcessOperatorSupplementEvents(op_supplement_events, collector);
}
} // namespace platform
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/platform/profiler/trace_event.h"
#include "paddle/phi/core/compat/arg_map_context.h"
namespace paddle {
namespace framework {
class RuntimeContext;
}
namespace platform {
class RecordOpInfoSupplement {
public:
/**
* @param type: Operator type name.
* @param attrs: Attribute map of op.
* @param shape_ctx: Infershape context object.
* @param ctx: Runtime context object.
*/
explicit RecordOpInfoSupplement(const std::string& type,
const framework::AttributeMap& attrs,
const framework::InferShapeContext& shape_ctx,
const framework::RuntimeContext& ctx);
/**
* @param type: Operator type name.
* @param attrs: Attribute map of op.
* @param shape_ctx: Infershape context object.
* @param kernel_signature: KernelSignature object, used in dygraph.
*/
explicit RecordOpInfoSupplement(const std::string& type,
const framework::AttributeMap& attrs,
const framework::InferShapeContext& shape_ctx,
const phi::KernelSignature& kernel_signature);
};
} // namespace platform
} // namespace paddle
......@@ -60,9 +60,20 @@ TEST(NodeTreesTest, LogMe_case0) {
50,
"GPU:0",
50,
50));
mem_events.push_back(MemTraceEvent(
11900, 0x1000, TracerMemEventType::Free, 10, 10, -50, "GPU:0", 0, 50));
50,
100,
100));
mem_events.push_back(MemTraceEvent(11900,
0x1000,
TracerMemEventType::Free,
10,
10,
-50,
"GPU:0",
0,
50,
100,
100));
std::map<std::string, std::vector<std::vector<int64_t>>> input_shapes;
std::map<std::string, std::vector<std::string>> dtypes;
input_shapes[std::string("X")].push_back(std::vector<int64_t>{1, 2, 3});
......@@ -267,9 +278,20 @@ TEST(NodeTreesTest, HandleTrees_case0) {
50,
"GPU:0",
50,
50));
mem_events.push_back(MemTraceEvent(
11900, 0x1000, TracerMemEventType::Free, 10, 10, -50, "GPU:0", 0, 50));
50,
100,
100));
mem_events.push_back(MemTraceEvent(11900,
0x1000,
TracerMemEventType::Free,
10,
10,
-50,
"GPU:0",
0,
50,
100,
100));
op_supplement_events.push_back(OperatorSupplementEvent(
11600,
"op1",
......
......@@ -59,10 +59,14 @@ enum class TracerEventType {
};
enum class TracerMemEventType {
// Used to mark memory allocation
// Used to mark memory allocation which is managed by paddle
Allocate = 0,
// Used to mark memory free
// Used to mark memory free which is managed by paddle
Free = 1,
// Used to mark reserved memory allocation which is applied from device.
ReservedAllocate = 2,
// Used to mark reserved memory free which is released to device.
ReservedFree = 3,
// A flag to denote the number of current types
NumTypes
};
......@@ -318,7 +322,9 @@ struct MemTraceEvent {
int64_t increase_bytes,
const std::string& place,
uint64_t current_allocated,
uint64_t current_reserved)
uint64_t current_reserved,
uint64_t peak_allocated,
uint64_t peak_reserved)
: timestamp_ns(timestamp_ns),
addr(addr),
type(type),
......@@ -327,7 +333,9 @@ struct MemTraceEvent {
increase_bytes(increase_bytes),
place(place),
current_allocated(current_allocated),
current_reserved(current_reserved) {}
current_reserved(current_reserved),
peak_allocated(peak_allocated),
peak_reserved(peak_reserved) {}
// timestamp of the record
uint64_t timestamp_ns;
......@@ -348,6 +356,10 @@ struct MemTraceEvent {
uint64_t current_allocated;
// current total reserved memory
uint64_t current_reserved;
// current peak allocated memory
uint64_t peak_allocated;
// current peak reserved memory
uint64_t peak_reserved;
};
} // namespace platform
......
......@@ -91,7 +91,8 @@ float CalculateEstOccupancy(uint32_t DeviceId,
#endif
const char* StringTracerMemEventType(TracerMemEventType type) {
static const char* categary_name_[] = {"Allocate", "Free"};
static const char* categary_name_[] = {
"Allocate", "Free", "ReservedAllocate", "ReservedFree"};
return categary_name_[static_cast<int>(type)];
}
......
......@@ -3515,6 +3515,10 @@ All parameter, weight, gradient are variables in Paddle.
.def_readwrite("process_id",
&paddle::platform::HostPythonNode::process_id)
.def_readwrite("thread_id", &paddle::platform::HostPythonNode::thread_id)
.def_readwrite("input_shapes",
&paddle::platform::HostPythonNode::input_shapes)
.def_readwrite("dtypes", &paddle::platform::HostPythonNode::dtypes)
.def_readwrite("callstack", &paddle::platform::HostPythonNode::callstack)
.def_readwrite("children_node",
&paddle::platform::HostPythonNode::children_node_ptrs)
.def_readwrite("runtime_node",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册