未验证 提交 3d1981ad 编写于 作者: C Chitsing KUI 提交者: GitHub

[PROFILER] add flops for Profiler (#47766)

* attr ready

* op ip ready

* start dynamic

* end2end ok

* input shape to map, stat by op

* layer wip

* first version ready

* fix proto depds

* fix profiler deps

* fix flops typo, rm tuple shape
上级 889318d8
...@@ -742,7 +742,8 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -742,7 +742,8 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
platform::RecordOpInfoSupplement(op->Type(), platform::RecordOpInfoSupplement(op->Type(),
op->Attrs(), op->Attrs(),
*(instr_node.InnerInferShapeContext()), *(instr_node.InnerInferShapeContext()),
*(instr_node.InnerRuntimeContext())); *(instr_node.InnerRuntimeContext()),
op->Id());
} }
} }
if (op_with_kernel != nullptr && FLAGS_new_executor_use_inplace) { if (op_with_kernel != nullptr && FLAGS_new_executor_use_inplace) {
......
...@@ -125,11 +125,13 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp( ...@@ -125,11 +125,13 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
} }
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) { std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
return CreateOp(op_desc.Type(), auto op = CreateOp(op_desc.Type(),
op_desc.Inputs(), op_desc.Inputs(),
op_desc.Outputs(), op_desc.Outputs(),
op_desc.GetAttrMap(), op_desc.GetAttrMap(),
op_desc.GetRuntimeAttrMap()); op_desc.GetRuntimeAttrMap());
op->SetId(op_desc.Id());
return op;
} }
} // namespace framework } // namespace framework
......
...@@ -1802,7 +1802,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1802,7 +1802,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
this->Info().infer_shape_(&infer_shape_ctx); this->Info().infer_shape_(&infer_shape_ctx);
record_event.End(); record_event.End();
platform::RecordOpInfoSupplement( platform::RecordOpInfoSupplement(
Type(), Attrs(), infer_shape_ctx, *runtime_ctx); Type(), Attrs(), infer_shape_ctx, *runtime_ctx, Id());
} }
if (FLAGS_enable_unused_var_check) { if (FLAGS_enable_unused_var_check) {
......
...@@ -251,6 +251,10 @@ class OperatorBase { ...@@ -251,6 +251,10 @@ class OperatorBase {
return place; return place;
} }
uint64_t Id() const { return id_; }
void SetId(uint64_t id) { id_ = id; }
protected: protected:
std::string type_; std::string type_;
// NOTE: in case of OpGrad, inputs_ contains: // NOTE: in case of OpGrad, inputs_ contains:
...@@ -273,6 +277,9 @@ class OperatorBase { ...@@ -273,6 +277,9 @@ class OperatorBase {
// OpInfo // OpInfo
const OpInfo* info_; const OpInfo* info_;
// OpDesc Id
uint64_t id_ = UINT64_MAX;
// Whether this operator executes in an Executor. // Whether this operator executes in an Executor.
bool run_by_executor_{true}; bool run_by_executor_{true};
......
...@@ -600,7 +600,7 @@ static void PreparedOpRunImpl( ...@@ -600,7 +600,7 @@ static void PreparedOpRunImpl(
op.Info().infer_shape_(&infer_shape_ctx); op.Info().infer_shape_(&infer_shape_ctx);
record_event.End(); record_event.End();
platform::RecordOpInfoSupplement( platform::RecordOpInfoSupplement(
op.Type(), op.Attrs(), infer_shape_ctx, ctx); op.Type(), op.Attrs(), infer_shape_ctx, ctx, op.Id());
} }
{ {
......
...@@ -42,7 +42,9 @@ DEFINE_bool(enable_host_event_recorder_hook, ...@@ -42,7 +42,9 @@ DEFINE_bool(enable_host_event_recorder_hook,
false, false,
"enable HostEventRecorder, hook Profiler"); "enable HostEventRecorder, hook Profiler");
DEFINE_bool(enable_record_input_shape, false, "enable input shape recorder"); DEFINE_bool(enable_record_op_info,
false,
"enable operator supplement info recorder");
DEFINE_bool(enable_record_memory, false, "enable memory recorder"); DEFINE_bool(enable_record_memory, false, "enable memory recorder");
...@@ -258,7 +260,8 @@ RecordOpInfoSupplement::RecordOpInfoSupplement( ...@@ -258,7 +260,8 @@ RecordOpInfoSupplement::RecordOpInfoSupplement(
const std::string &type, const std::string &type,
const framework::AttributeMap &attrs, const framework::AttributeMap &attrs,
const framework::InferShapeContext &shape_ctx, const framework::InferShapeContext &shape_ctx,
const framework::RuntimeContext &ctx) { const framework::RuntimeContext &ctx,
uint64_t op_id) {
if (FLAGS_enable_host_event_recorder_hook == false) { if (FLAGS_enable_host_event_recorder_hook == false) {
return; return;
} }
...@@ -272,16 +275,8 @@ RecordOpInfoSupplement::RecordOpInfoSupplement( ...@@ -272,16 +275,8 @@ RecordOpInfoSupplement::RecordOpInfoSupplement(
dtypes[it->first] = shape_ctx.GetInputsVarType(it->first); dtypes[it->first] = shape_ctx.GetInputsVarType(it->first);
} }
const std::vector<std::string> *callstack_ptr = nullptr;
std::vector<std::string> callstack;
auto iter = attrs.find(
framework::OpProtoAndCheckerMaker::OpCreationCallstackAttrName());
if (iter != attrs.end()) {
callstack_ptr = &PADDLE_GET_CONST(std::vector<std::string>, iter->second);
callstack = *callstack_ptr;
}
HostEventRecorder<OperatorSupplementOriginEvent>::GetInstance().RecordEvent( HostEventRecorder<OperatorSupplementOriginEvent>::GetInstance().RecordEvent(
PosixInNsec(), type, input_shapes, dtypes, callstack); PosixInNsec(), type, input_shapes, dtypes, attrs, op_id);
} }
RecordOpInfoSupplement::RecordOpInfoSupplement( RecordOpInfoSupplement::RecordOpInfoSupplement(
...@@ -306,22 +301,16 @@ RecordOpInfoSupplement::RecordOpInfoSupplement( ...@@ -306,22 +301,16 @@ RecordOpInfoSupplement::RecordOpInfoSupplement(
dtypes[input_name] = shape_ctx.GetInputsVarType(input_name); dtypes[input_name] = shape_ctx.GetInputsVarType(input_name);
} }
} }
const std::vector<std::string> *callstack_ptr = nullptr; uint64_t op_id = 0;
std::vector<std::string> callstack;
auto iter = attrs.find(
framework::OpProtoAndCheckerMaker::OpCreationCallstackAttrName());
if (iter != attrs.end()) {
callstack_ptr = &PADDLE_GET_CONST(std::vector<std::string>, iter->second);
callstack = *callstack_ptr;
}
HostEventRecorder<OperatorSupplementOriginEvent>::GetInstance().RecordEvent( HostEventRecorder<OperatorSupplementOriginEvent>::GetInstance().RecordEvent(
PosixInNsec(), type, input_shapes, dtypes, callstack); PosixInNsec(), type, input_shapes, dtypes, attrs, op_id);
} }
RecordOpInfoSupplement::RecordOpInfoSupplement( RecordOpInfoSupplement::RecordOpInfoSupplement(
const std::string &type, const std::string &type,
const std::vector<std::pair<const char *, std::vector<framework::DDim>>> const std::vector<std::pair<const char *, std::vector<framework::DDim>>>
&input_shapes) { &input_shapes,
const framework::AttributeMap &attrs) {
if (FLAGS_enable_host_event_recorder_hook == false) { if (FLAGS_enable_host_event_recorder_hook == false) {
return; return;
} }
...@@ -329,9 +318,9 @@ RecordOpInfoSupplement::RecordOpInfoSupplement( ...@@ -329,9 +318,9 @@ RecordOpInfoSupplement::RecordOpInfoSupplement(
return; return;
} }
std::map<std::string, std::vector<framework::proto::VarType::Type>> dtypes; std::map<std::string, std::vector<framework::proto::VarType::Type>> dtypes;
std::vector<std::string> callstack; uint64_t op_id = 0;
HostEventRecorder<OperatorSupplementOriginEvent>::GetInstance().RecordEvent( HostEventRecorder<OperatorSupplementOriginEvent>::GetInstance().RecordEvent(
PosixInNsec(), type, input_shapes, dtypes, callstack); PosixInNsec(), type, input_shapes, dtypes, attrs, op_id);
} }
bool RecordEvent::IsEnabled() { bool RecordEvent::IsEnabled() {
...@@ -339,9 +328,7 @@ bool RecordEvent::IsEnabled() { ...@@ -339,9 +328,7 @@ bool RecordEvent::IsEnabled() {
g_state != ProfilerState::kDisabled; g_state != ProfilerState::kDisabled;
} }
bool RecordOpInfoSupplement::IsEnabled() { bool RecordOpInfoSupplement::IsEnabled() { return FLAGS_enable_record_op_info; }
return FLAGS_enable_record_input_shape;
}
bool RecordMemEvent::IsEnabled() { return FLAGS_enable_record_memory; } bool RecordMemEvent::IsEnabled() { return FLAGS_enable_record_memory; }
...@@ -1087,9 +1074,9 @@ void DisableHostEventRecorder() { ...@@ -1087,9 +1074,9 @@ void DisableHostEventRecorder() {
FLAGS_enable_host_event_recorder_hook = false; FLAGS_enable_host_event_recorder_hook = false;
} }
void EnableInputShapeRecorder() { FLAGS_enable_record_input_shape = true; } void EnableOpInfoRecorder() { FLAGS_enable_record_op_info = true; }
void DisableInputShapeRecorder() { FLAGS_enable_record_input_shape = false; } void DisableOpInfoRecorder() { FLAGS_enable_record_op_info = false; }
void EnableMemoryRecorder() { FLAGS_enable_record_memory = true; } void EnableMemoryRecorder() { FLAGS_enable_record_memory = true; }
......
...@@ -251,8 +251,8 @@ void DisableHostEventRecorder(); ...@@ -251,8 +251,8 @@ void DisableHostEventRecorder();
void EnableMemoryRecorder(); void EnableMemoryRecorder();
void DisableMemoryRecorder(); void DisableMemoryRecorder();
void EnableInputShapeRecorder(); void EnableOpInfoRecorder();
void DisableInputShapeRecorder(); void DisableOpInfoRecorder();
// Defined for UT // Defined for UT
std::string PrintHostEvents(); std::string PrintHostEvents();
......
cc_library( cc_library(
host_tracer host_tracer
SRCS host_tracer.cc SRCS host_tracer.cc
DEPS enforce ddim var_type_traits) DEPS framework_proto enforce ddim var_type_traits)
cc_library( cc_library(
cuda_tracer cuda_tracer
SRCS cuda_tracer.cc cupti_data_process.cc SRCS cuda_tracer.cc cupti_data_process.cc
......
...@@ -115,11 +115,13 @@ struct OperatorSupplementOriginEvent { ...@@ -115,11 +115,13 @@ struct OperatorSupplementOriginEvent {
const std::map<std::string, std::vector<framework::DDim>> &input_shapes, const std::map<std::string, std::vector<framework::DDim>> &input_shapes,
const std::map<std::string, std::vector<framework::proto::VarType::Type>> const std::map<std::string, std::vector<framework::proto::VarType::Type>>
&dtypes, &dtypes,
const std::vector<std::string> callstack) const framework::AttributeMap &attributes,
uint64_t op_id)
: timestamp_ns(timestamp_ns), : timestamp_ns(timestamp_ns),
input_shapes(input_shapes), input_shapes(input_shapes),
dtypes(dtypes), dtypes(dtypes),
callstack(callstack) { attributes(attributes),
op_id(op_id) {
auto buf = static_cast<char *>(arena_allocator(type_name.length() + 1)); auto buf = static_cast<char *>(arena_allocator(type_name.length() + 1));
strncpy(buf, type_name.c_str(), type_name.length() + 1); strncpy(buf, type_name.c_str(), type_name.length() + 1);
op_type = buf; op_type = buf;
...@@ -132,8 +134,12 @@ struct OperatorSupplementOriginEvent { ...@@ -132,8 +134,12 @@ struct OperatorSupplementOriginEvent {
&shapes, &shapes,
const std::map<std::string, std::vector<framework::proto::VarType::Type>> const std::map<std::string, std::vector<framework::proto::VarType::Type>>
&dtypes, &dtypes,
const std::vector<std::string> callstack) const framework::AttributeMap &attributes,
: timestamp_ns(timestamp_ns), dtypes(dtypes), callstack(callstack) { uint64_t op_id)
: timestamp_ns(timestamp_ns),
dtypes(dtypes),
attributes(attributes),
op_id(op_id) {
auto buf = static_cast<char *>(arena_allocator(type_name.length() + 1)); auto buf = static_cast<char *>(arena_allocator(type_name.length() + 1));
strncpy(buf, type_name.c_str(), type_name.length() + 1); strncpy(buf, type_name.c_str(), type_name.length() + 1);
op_type = buf; op_type = buf;
...@@ -146,8 +152,10 @@ struct OperatorSupplementOriginEvent { ...@@ -146,8 +152,10 @@ struct OperatorSupplementOriginEvent {
// input shapes // input shapes
std::map<std::string, std::vector<framework::DDim>> input_shapes; std::map<std::string, std::vector<framework::DDim>> input_shapes;
std::map<std::string, std::vector<framework::proto::VarType::Type>> dtypes; std::map<std::string, std::vector<framework::proto::VarType::Type>> dtypes;
// call stack // op attributes
const std::vector<std::string> callstack; framework::AttributeMap attributes;
// op id
uint64_t op_id;
}; };
} // namespace platform } // namespace platform
......
...@@ -274,6 +274,7 @@ DeserializationReader::RestoreOperatorSupplementEventNode( ...@@ -274,6 +274,7 @@ DeserializationReader::RestoreOperatorSupplementEventNode(
op_supplement_event.timestamp_ns = op_supplement_event_proto.timestamp_ns(); op_supplement_event.timestamp_ns = op_supplement_event_proto.timestamp_ns();
op_supplement_event.op_type = op_supplement_event_proto.op_type(); op_supplement_event.op_type = op_supplement_event_proto.op_type();
op_supplement_event.callstack = op_supplement_event_proto.callstack(); op_supplement_event.callstack = op_supplement_event_proto.callstack();
op_supplement_event.op_id = op_supplement_event_proto.op_id();
op_supplement_event.process_id = op_supplement_event_proto.process_id(); op_supplement_event.process_id = op_supplement_event_proto.process_id();
op_supplement_event.thread_id = op_supplement_event_proto.thread_id(); op_supplement_event.thread_id = op_supplement_event_proto.thread_id();
std::map<std::string, std::vector<std::vector<int64_t>>> input_shapes; std::map<std::string, std::vector<std::vector<int64_t>>> input_shapes;
......
...@@ -194,6 +194,8 @@ message OperatorSupplementEventProto { ...@@ -194,6 +194,8 @@ message OperatorSupplementEventProto {
required dtype_proto dtypes = 6; required dtype_proto dtypes = 6;
// call stack // call stack
required string callstack = 7; required string callstack = 7;
required uint64 op_id = 8;
} }
message CudaRuntimeTraceEventProto { message CudaRuntimeTraceEventProto {
......
...@@ -197,6 +197,7 @@ void SerializationLogger::LogHostTraceEventNode( ...@@ -197,6 +197,7 @@ void SerializationLogger::LogHostTraceEventNode(
op_supplement_event_node->ThreadId()); op_supplement_event_node->ThreadId());
op_supplement_event_proto->set_callstack( op_supplement_event_proto->set_callstack(
op_supplement_event_node->CallStack()); op_supplement_event_node->CallStack());
op_supplement_event_proto->set_op_id(op_supplement_event_node->OpId());
OperatorSupplementEventProto::input_shape_proto* input_shape_proto = OperatorSupplementEventProto::input_shape_proto* input_shape_proto =
op_supplement_event_proto->mutable_input_shapes(); op_supplement_event_proto->mutable_input_shapes();
......
...@@ -13,11 +13,13 @@ ...@@ -13,11 +13,13 @@
// limitations under the License. // limitations under the License.
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/platform/profiler/dump/deserialization_reader.h" #include "paddle/fluid/platform/profiler/dump/deserialization_reader.h"
#include "paddle/fluid/platform/profiler/dump/serialization_logger.h" #include "paddle/fluid/platform/profiler/dump/serialization_logger.h"
#include "paddle/fluid/platform/profiler/event_node.h" #include "paddle/fluid/platform/profiler/event_node.h"
#include "paddle/fluid/platform/profiler/event_python.h" #include "paddle/fluid/platform/profiler/event_python.h"
using paddle::framework::AttributeMap;
using paddle::platform::CudaRuntimeTraceEventNode; using paddle::platform::CudaRuntimeTraceEventNode;
using paddle::platform::DeserializationReader; using paddle::platform::DeserializationReader;
using paddle::platform::DeviceTraceEvent; using paddle::platform::DeviceTraceEvent;
...@@ -82,8 +84,9 @@ TEST(SerializationLoggerTest, dump_case0) { ...@@ -82,8 +84,9 @@ TEST(SerializationLoggerTest, dump_case0) {
input_shapes[std::string("X")].push_back(std::vector<int64_t>{4, 5, 6, 7}); input_shapes[std::string("X")].push_back(std::vector<int64_t>{4, 5, 6, 7});
dtypes[std::string("X")].push_back(std::string("int8")); dtypes[std::string("X")].push_back(std::string("int8"));
dtypes[std::string("X")].push_back(std::string("float32")); dtypes[std::string("X")].push_back(std::string("float32"));
AttributeMap attrs;
op_supplement_events.push_back(OperatorSupplementEvent( op_supplement_events.push_back(OperatorSupplementEvent(
11600, "op1", input_shapes, dtypes, "op1()", 10, 10)); 11600, "op1", input_shapes, dtypes, "op1()", attrs, 0, 10, 10));
runtime_events.push_back(RuntimeTraceEvent( runtime_events.push_back(RuntimeTraceEvent(
std::string("cudalaunch1"), 15000, 17000, 10, 10, 1, 0)); std::string("cudalaunch1"), 15000, 17000, 10, 10, 1, 0));
runtime_events.push_back(RuntimeTraceEvent( runtime_events.push_back(RuntimeTraceEvent(
......
...@@ -76,6 +76,10 @@ class OperatorSupplementEventNode { ...@@ -76,6 +76,10 @@ class OperatorSupplementEventNode {
return op_supplement_event_.dtypes; return op_supplement_event_.dtypes;
} }
std::string CallStack() { return op_supplement_event_.callstack; } std::string CallStack() { return op_supplement_event_.callstack; }
framework::AttributeMap Attributes() {
return op_supplement_event_.attributes;
}
uint64_t OpId() const { return op_supplement_event_.op_id; }
uint64_t ProcessId() const { return op_supplement_event_.process_id; } uint64_t ProcessId() const { return op_supplement_event_.process_id; }
uint64_t ThreadId() const { return op_supplement_event_.thread_id; } uint64_t ThreadId() const { return op_supplement_event_.thread_id; }
......
...@@ -131,6 +131,8 @@ HostPythonNode* ProfilerResult::CopyTree(HostTraceEventNode* root) { ...@@ -131,6 +131,8 @@ HostPythonNode* ProfilerResult::CopyTree(HostTraceEventNode* root) {
host_python_node->input_shapes = op_supplement_node->InputShapes(); host_python_node->input_shapes = op_supplement_node->InputShapes();
host_python_node->dtypes = op_supplement_node->Dtypes(); host_python_node->dtypes = op_supplement_node->Dtypes();
host_python_node->callstack = op_supplement_node->CallStack(); host_python_node->callstack = op_supplement_node->CallStack();
host_python_node->attributes = op_supplement_node->Attributes();
host_python_node->op_id = op_supplement_node->OpId();
} }
return host_python_node; return host_python_node;
} }
......
...@@ -121,6 +121,10 @@ struct HostPythonNode { ...@@ -121,6 +121,10 @@ struct HostPythonNode {
std::map<std::string, std::vector<std::string>> dtypes; std::map<std::string, std::vector<std::string>> dtypes;
// call stack // call stack
std::string callstack; std::string callstack;
// op attributes
framework::AttributeMap attributes;
// op id
uint64_t op_id;
// children node // children node
std::vector<HostPythonNode*> children_node_ptrs; std::vector<HostPythonNode*> children_node_ptrs;
// runtime node // runtime node
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <sstream> #include <sstream>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/platform/flags.h" #include "paddle/fluid/platform/flags.h"
#include "paddle/fluid/platform/profiler/common_event.h" #include "paddle/fluid/platform/profiler/common_event.h"
#include "paddle/fluid/platform/profiler/host_event_recorder.h" #include "paddle/fluid/platform/profiler/host_event_recorder.h"
...@@ -87,6 +88,21 @@ void ProcessOperatorSupplementEvents( ...@@ -87,6 +88,21 @@ void ProcessOperatorSupplementEvents(
collector->AddThreadName(tid, thr_sec.thread_name); collector->AddThreadName(tid, thr_sec.thread_name);
} }
for (const auto& evt : thr_sec.events) { for (const auto& evt : thr_sec.events) {
// get callstack from event
std::vector<std::string> callstacks;
const std::vector<std::string>* callstack_ptr = nullptr;
auto iter = evt.attributes.find(
framework::OpProtoAndCheckerMaker::OpCreationCallstackAttrName());
if (iter != evt.attributes.end()) {
callstack_ptr =
&PADDLE_GET_CONST(std::vector<std::string>, iter->second);
callstacks = *callstack_ptr;
}
std::ostringstream result_string;
for (auto it = callstacks.begin(); it != callstacks.end(); it++) {
result_string << (*it) << std::endl;
}
OperatorSupplementEvent event; OperatorSupplementEvent event;
event.timestamp_ns = evt.timestamp_ns; event.timestamp_ns = evt.timestamp_ns;
event.op_type = evt.op_type; event.op_type = evt.op_type;
...@@ -111,13 +127,11 @@ void ProcessOperatorSupplementEvents( ...@@ -111,13 +127,11 @@ void ProcessOperatorSupplementEvents(
} }
} }
std::ostringstream result_string;
for (auto it = evt.callstack.begin(); it != evt.callstack.end(); it++) {
result_string << (*it) << std::endl;
}
event.input_shapes = input_shapes; event.input_shapes = input_shapes;
event.dtypes = dtypes; event.dtypes = dtypes;
event.callstack = result_string.str(); event.callstack = result_string.str();
event.attributes = evt.attributes;
event.op_id = evt.op_id;
event.process_id = op_supplement_events.process_id; event.process_id = op_supplement_events.process_id;
event.thread_id = tid; event.thread_id = tid;
collector->AddOperatorSupplementEvent(std::move(event)); collector->AddOperatorSupplementEvent(std::move(event));
......
...@@ -43,7 +43,8 @@ class RecordOpInfoSupplement { ...@@ -43,7 +43,8 @@ class RecordOpInfoSupplement {
explicit RecordOpInfoSupplement(const std::string& type, explicit RecordOpInfoSupplement(const std::string& type,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::InferShapeContext& shape_ctx, const framework::InferShapeContext& shape_ctx,
const framework::RuntimeContext& ctx); const framework::RuntimeContext& ctx,
uint64_t op_id);
/** /**
* @param type: Operator type name. * @param type: Operator type name.
* @param attrs: Attribute map of op. * @param attrs: Attribute map of op.
...@@ -61,7 +62,8 @@ class RecordOpInfoSupplement { ...@@ -61,7 +62,8 @@ class RecordOpInfoSupplement {
explicit RecordOpInfoSupplement( explicit RecordOpInfoSupplement(
const std::string& type, const std::string& type,
const std::vector<std::pair<const char*, std::vector<framework::DDim>>>& const std::vector<std::pair<const char*, std::vector<framework::DDim>>>&
input_shapes); input_shapes,
const framework::AttributeMap& attrs);
}; };
} // namespace platform } // namespace platform
......
...@@ -13,9 +13,11 @@ ...@@ -13,9 +13,11 @@
// limitations under the License. // limitations under the License.
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/platform/profiler/chrometracing_logger.h" #include "paddle/fluid/platform/profiler/chrometracing_logger.h"
#include "paddle/fluid/platform/profiler/event_node.h" #include "paddle/fluid/platform/profiler/event_node.h"
using paddle::framework::AttributeMap;
using paddle::platform::ChromeTracingLogger; using paddle::platform::ChromeTracingLogger;
using paddle::platform::CudaRuntimeTraceEventNode; using paddle::platform::CudaRuntimeTraceEventNode;
using paddle::platform::DeviceTraceEvent; using paddle::platform::DeviceTraceEvent;
...@@ -33,6 +35,7 @@ using paddle::platform::OperatorSupplementEventNode; ...@@ -33,6 +35,7 @@ using paddle::platform::OperatorSupplementEventNode;
using paddle::platform::RuntimeTraceEvent; using paddle::platform::RuntimeTraceEvent;
using paddle::platform::TracerEventType; using paddle::platform::TracerEventType;
using paddle::platform::TracerMemEventType; using paddle::platform::TracerMemEventType;
TEST(NodeTreesTest, LogMe_case0) { TEST(NodeTreesTest, LogMe_case0) {
std::list<HostTraceEvent> host_events; std::list<HostTraceEvent> host_events;
std::list<RuntimeTraceEvent> runtime_events; std::list<RuntimeTraceEvent> runtime_events;
...@@ -79,8 +82,9 @@ TEST(NodeTreesTest, LogMe_case0) { ...@@ -79,8 +82,9 @@ TEST(NodeTreesTest, LogMe_case0) {
input_shapes[std::string("X")].push_back(std::vector<int64_t>{4, 5, 6, 7}); input_shapes[std::string("X")].push_back(std::vector<int64_t>{4, 5, 6, 7});
dtypes[std::string("X")].push_back(std::string("int8")); dtypes[std::string("X")].push_back(std::string("int8"));
dtypes[std::string("X")].push_back(std::string("float32")); dtypes[std::string("X")].push_back(std::string("float32"));
AttributeMap attrs;
op_supplement_events.push_back(OperatorSupplementEvent( op_supplement_events.push_back(OperatorSupplementEvent(
11600, "op1", input_shapes, dtypes, "op1()", 10, 10)); 11600, "op1", input_shapes, dtypes, "op1()", attrs, 0, 10, 10));
runtime_events.push_back(RuntimeTraceEvent( runtime_events.push_back(RuntimeTraceEvent(
std::string("cudalaunch1"), 15000, 17000, 10, 10, 1, 0)); std::string("cudalaunch1"), 15000, 17000, 10, 10, 1, 0));
runtime_events.push_back(RuntimeTraceEvent( runtime_events.push_back(RuntimeTraceEvent(
...@@ -293,12 +297,15 @@ TEST(NodeTreesTest, HandleTrees_case0) { ...@@ -293,12 +297,15 @@ TEST(NodeTreesTest, HandleTrees_case0) {
50, 50,
100, 100,
100)); 100));
AttributeMap attrs;
op_supplement_events.push_back(OperatorSupplementEvent( op_supplement_events.push_back(OperatorSupplementEvent(
11600, 11600,
"op1", "op1",
std::map<std::string, std::vector<std::vector<int64_t>>>(), std::map<std::string, std::vector<std::vector<int64_t>>>(),
std::map<std::string, std::vector<std::string>>(), std::map<std::string, std::vector<std::string>>(),
"op1()", "op1()",
attrs,
0,
10, 10,
10)); 10));
runtime_events.push_back(RuntimeTraceEvent( runtime_events.push_back(RuntimeTraceEvent(
......
...@@ -18,6 +18,8 @@ limitations under the License. */ ...@@ -18,6 +18,8 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/type_defs.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -147,6 +149,8 @@ struct OperatorSupplementEvent { ...@@ -147,6 +149,8 @@ struct OperatorSupplementEvent {
input_shapes, input_shapes,
const std::map<std::string, std::vector<std::string>>& dtypes, const std::map<std::string, std::vector<std::string>>& dtypes,
const std::string& callstack, const std::string& callstack,
const framework::AttributeMap& attributes,
uint64_t op_id,
uint64_t process_id, uint64_t process_id,
uint64_t thread_id) uint64_t thread_id)
: timestamp_ns(timestamp_ns), : timestamp_ns(timestamp_ns),
...@@ -154,6 +158,8 @@ struct OperatorSupplementEvent { ...@@ -154,6 +158,8 @@ struct OperatorSupplementEvent {
input_shapes(input_shapes), input_shapes(input_shapes),
dtypes(dtypes), dtypes(dtypes),
callstack(callstack), callstack(callstack),
attributes(attributes),
op_id(op_id),
process_id(process_id), process_id(process_id),
thread_id(thread_id) {} thread_id(thread_id) {}
// timestamp of the record // timestamp of the record
...@@ -165,6 +171,10 @@ struct OperatorSupplementEvent { ...@@ -165,6 +171,10 @@ struct OperatorSupplementEvent {
std::map<std::string, std::vector<std::string>> dtypes; std::map<std::string, std::vector<std::string>> dtypes;
// call stack // call stack
std::string callstack; std::string callstack;
// op attributes
framework::AttributeMap attributes;
// op id
uint64_t op_id;
// process id of the record // process id of the record
uint64_t process_id; uint64_t process_id;
// thread id of the record // thread id of the record
......
...@@ -2261,6 +2261,9 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -2261,6 +2261,9 @@ All parameter, weight, gradient are variables in Paddle.
&paddle::platform::HostPythonNode::input_shapes) &paddle::platform::HostPythonNode::input_shapes)
.def_readwrite("dtypes", &paddle::platform::HostPythonNode::dtypes) .def_readwrite("dtypes", &paddle::platform::HostPythonNode::dtypes)
.def_readwrite("callstack", &paddle::platform::HostPythonNode::callstack) .def_readwrite("callstack", &paddle::platform::HostPythonNode::callstack)
.def_readwrite("attributes",
&paddle::platform::HostPythonNode::attributes)
.def_readwrite("op_id", &paddle::platform::HostPythonNode::op_id)
.def_readwrite("children_node", .def_readwrite("children_node",
&paddle::platform::HostPythonNode::children_node_ptrs) &paddle::platform::HostPythonNode::children_node_ptrs)
.def_readwrite("runtime_node", .def_readwrite("runtime_node",
...@@ -2334,10 +2337,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -2334,10 +2337,8 @@ All parameter, weight, gradient are variables in Paddle.
m.def("load_profiler_result", &paddle::platform::LoadProfilerResult); m.def("load_profiler_result", &paddle::platform::LoadProfilerResult);
m.def("enable_memory_recorder", &paddle::platform::EnableMemoryRecorder); m.def("enable_memory_recorder", &paddle::platform::EnableMemoryRecorder);
m.def("disable_memory_recorder", &paddle::platform::DisableMemoryRecorder); m.def("disable_memory_recorder", &paddle::platform::DisableMemoryRecorder);
m.def("enable_input_shape_recorder", m.def("enable_op_info_recorder", &paddle::platform::EnableOpInfoRecorder);
&paddle::platform::EnableInputShapeRecorder); m.def("disable_op_info_recorder", &paddle::platform::DisableOpInfoRecorder);
m.def("disable_input_shape_recorder",
&paddle::platform::DisableInputShapeRecorder);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
m.def("set_cublas_switch", platform::SetAllowTF32Cublas); m.def("set_cublas_switch", platform::SetAllowTF32Cublas);
......
...@@ -1017,10 +1017,74 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d ...@@ -1017,10 +1017,74 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
{code_indent} input_shapes.emplace_back("{input_name}", ddims_vec);""" {code_indent} input_shapes.emplace_back("{input_name}", ddims_vec);"""
) )
input_tensor_code += f"""
{code_indent} framework::AttributeMap attrs;"""
for attr_name in self.attrs['names']:
if 'IntArray' in self.attrs['attr_info'][attr_name][0]:
input_tensor_code += f"""
{code_indent} attrs["{attr_name}"] = {attr_name}.GetData();"""
elif 'vector<phi::Scalar>' in self.attrs['attr_info'][attr_name][0]:
input_tensor_code += f"""
{code_indent} attrs["{attr_name}"] = "";""" # TODO(kuizhiqing)
elif 'Scalar' in self.attrs['attr_info'][attr_name][0]:
input_tensor_code += f"""
{code_indent} switch ({attr_name}.dtype()) {{
{code_indent} case DataType::FLOAT32:
{code_indent} attrs["{attr_name}"] = static_cast<float>({attr_name}.to<float>());
{code_indent} break;
{code_indent} case DataType::FLOAT64:
{code_indent} attrs["{attr_name}"] = static_cast<double>({attr_name}.to<double>());
{code_indent} break;
{code_indent} case DataType::FLOAT16:
{code_indent} attrs["{attr_name}"] = static_cast<float>({attr_name}.to<float16>());
{code_indent} break;
{code_indent} case DataType::BFLOAT16:
{code_indent} attrs["{attr_name}"] = static_cast<float>({attr_name}.to<bfloat16>());
{code_indent} break;
{code_indent} case DataType::INT32:
{code_indent} attrs["{attr_name}"] = static_cast<int32_t>({attr_name}.to<int32_t>());
{code_indent} break;
{code_indent} case DataType::INT64:
{code_indent} attrs["{attr_name}"] = static_cast<int64_t>({attr_name}.to<int64_t>());
{code_indent} break;
{code_indent} case DataType::INT16:
{code_indent} attrs["{attr_name}"] = static_cast<int16_t>({attr_name}.to<int16_t>());
{code_indent} break;
{code_indent} case DataType::INT8:
{code_indent} attrs["{attr_name}"] = static_cast<int8_t>({attr_name}.to<int8_t>());
{code_indent} break;
{code_indent} case DataType::UINT16:
{code_indent} attrs["{attr_name}"] = static_cast<uint16_t>({attr_name}.to<uint16_t>());
{code_indent} break;
{code_indent} case DataType::UINT8:
{code_indent} attrs["{attr_name}"] = static_cast<uint8_t>({attr_name}.to<uint8_t>());
{code_indent} break;
{code_indent} case DataType::BOOL:
{code_indent} attrs["{attr_name}"] = static_cast<bool>({attr_name}.to<bool>());
{code_indent} break;
{code_indent} case DataType::COMPLEX64:
{code_indent} attrs["{attr_name}"] = static_cast<float>({attr_name}.to<complex64>());
{code_indent} break;
{code_indent} case DataType::COMPLEX128:
{code_indent} attrs["{attr_name}"] = static_cast<double>({attr_name}.to<complex128>());
{code_indent} break;
{code_indent} default:
{code_indent} attrs["{attr_name}"] = "";
{code_indent} break;
{code_indent} }}"""
elif 'DataType' in self.attrs['attr_info'][attr_name][0]:
pass # no need
elif 'Place' in self.attrs['attr_info'][attr_name][0]:
pass # no need
else:
input_tensor_code += f"""
{code_indent} attrs["{attr_name}"] = {attr_name};"""
input_tensor_code = ( input_tensor_code = (
input_tensor_code input_tensor_code
+ f""" + f"""
{code_indent} platform::RecordOpInfoSupplement("{self.api}", input_shapes); {code_indent} platform::RecordOpInfoSupplement("{self.api}", input_shapes, attrs);
{code_indent} }}""" {code_indent} }}"""
) )
kernel_args = ["*dev_ctx"] kernel_args = ["*dev_ctx"]
......
...@@ -225,8 +225,8 @@ class TestProfilerAPIError(unittest.TestCase): ...@@ -225,8 +225,8 @@ class TestProfilerAPIError(unittest.TestCase):
class TestFLOPSAPI(unittest.TestCase): class TestFLOPSAPI(unittest.TestCase):
def test_flops(self): def test_flops(self):
self.assertTrue(flops('relu', ([12, 12],), output=4) == 144) self.assertTrue(flops('relu', {'X': [[12, 12]]}, {'output': 4}) == 144)
self.assertTrue(flops('dropout', ([12, 12],), **{'output': 4}) == 0) self.assertTrue(flops('dropout', {}, {'output': 4}) == 0)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -27,9 +27,9 @@ from paddle.fluid.core import ( ...@@ -27,9 +27,9 @@ from paddle.fluid.core import (
ProfilerOptions, ProfilerOptions,
TracerEventType, TracerEventType,
enable_memory_recorder, enable_memory_recorder,
enable_input_shape_recorder, enable_op_info_recorder,
disable_memory_recorder, disable_memory_recorder,
disable_input_shape_recorder, disable_op_info_recorder,
) )
from .utils import RecordEvent, wrap_optimizers from .utils import RecordEvent, wrap_optimizers
...@@ -115,7 +115,7 @@ def make_scheduler( ...@@ -115,7 +115,7 @@ def make_scheduler(
ready: int, ready: int,
record: int, record: int,
repeat: int = 0, repeat: int = 0,
skip_first: int = 0 skip_first: int = 0,
) -> Callable: ) -> Callable:
r""" r"""
Return a scheduler function, which scheduler the :ref:`state <api_paddle_profiler_ProfilerState>` according to the setting. Return a scheduler function, which scheduler the :ref:`state <api_paddle_profiler_ProfilerState>` according to the setting.
...@@ -351,6 +351,7 @@ class Profiler: ...@@ -351,6 +351,7 @@ class Profiler:
be timed and profiled. Default: False. be timed and profiled. Default: False.
record_shapes (bool, optional): If it is True, collect op's input shape information. Default: False. record_shapes (bool, optional): If it is True, collect op's input shape information. Default: False.
profile_memory (bool, optional): If it is True, collect tensor memory allocation and release information. Default: False. profile_memory (bool, optional): If it is True, collect tensor memory allocation and release information. Default: False.
with_flops (bool, optional): If it is True, the flops of the op will be calculated. Default: False.
Examples: Examples:
1. profiling range [2, 5). 1. profiling range [2, 5).
...@@ -468,10 +469,11 @@ class Profiler: ...@@ -468,10 +469,11 @@ class Profiler:
scheduler: Union[Callable[[int], ProfilerState], tuple, None] = None, scheduler: Union[Callable[[int], ProfilerState], tuple, None] = None,
on_trace_ready: Optional[Callable[..., Any]] = None, on_trace_ready: Optional[Callable[..., Any]] = None,
record_shapes: Optional[bool] = False, record_shapes: Optional[bool] = False,
profile_memory=False, profile_memory: Optional[bool] = False,
timer_only: Optional[bool] = False, timer_only: Optional[bool] = False,
emit_nvtx: Optional[bool] = False, emit_nvtx: Optional[bool] = False,
custom_device_types: Optional[list] = [] custom_device_types: Optional[list] = [],
with_flops: Optional[bool] = False,
): ):
supported_targets = _get_supported_targets() supported_targets = _get_supported_targets()
if targets: if targets:
...@@ -534,6 +536,7 @@ class Profiler: ...@@ -534,6 +536,7 @@ class Profiler:
self.timer_only = timer_only self.timer_only = timer_only
self.record_shapes = record_shapes self.record_shapes = record_shapes
self.profile_memory = profile_memory self.profile_memory = profile_memory
self.with_flops = with_flops
self.emit_nvtx = emit_nvtx self.emit_nvtx = emit_nvtx
def __enter__(self): def __enter__(self):
...@@ -571,8 +574,8 @@ class Profiler: ...@@ -571,8 +574,8 @@ class Profiler:
utils._is_profiler_used = True utils._is_profiler_used = True
if self.timer_only: if self.timer_only:
return return
if self.record_shapes: if self.record_shapes or self.with_flops:
enable_input_shape_recorder() enable_op_info_recorder()
if self.profile_memory: if self.profile_memory:
enable_memory_recorder() enable_memory_recorder()
# CLOSED -> self.current_state # CLOSED -> self.current_state
...@@ -614,8 +617,8 @@ class Profiler: ...@@ -614,8 +617,8 @@ class Profiler:
benchmark().end() benchmark().end()
if self.timer_only: if self.timer_only:
return return
if self.record_shapes: if self.record_shapes or self.with_flops:
disable_input_shape_recorder() disable_op_info_recorder()
if self.profile_memory: if self.profile_memory:
disable_memory_recorder() disable_memory_recorder()
# self.current_state -> CLOSED # self.current_state -> CLOSED
......
...@@ -17,6 +17,8 @@ import re ...@@ -17,6 +17,8 @@ import re
from paddle.fluid.core import TracerEventType, TracerMemEventType from paddle.fluid.core import TracerEventType, TracerMemEventType
from paddle.utils.flops import flops
from .statistic_helper import ( from .statistic_helper import (
intersection_ranges, intersection_ranges,
merge_ranges, merge_ranges,
...@@ -92,24 +94,40 @@ class HostStatisticNode: ...@@ -92,24 +94,40 @@ class HostStatisticNode:
self.self_gpu_time = 0 self.self_gpu_time = 0
self.general_gpu_time = 0 # besides kernel, include time of gpu events like memcpy and memset self.general_gpu_time = 0 # besides kernel, include time of gpu events like memcpy and memset
self.self_general_gpu_time = 0 self.self_general_gpu_time = 0
self.flops = 0
def cal_flops(self):
if self.hostnode.type == TracerEventType.Operator:
if hasattr(self.hostnode, 'input_shapes'):
op_name = self.hostnode.name
op_name = op_name.replace(' compute', '')
op_name = op_name.replace(' dygraph', '')
op_name = op_name.replace(' pybind_imperative_func', '')
self.flops = flops(
op_name,
self.hostnode.input_shapes,
self.hostnode.attributes,
)
def cal_statistic(self): def cal_statistic(self):
for child in self.children_node:
child.cal_statistic()
for rt in self.runtime_node:
rt.cal_statistic()
self.cpu_time = self.hostnode.end_ns - self.hostnode.start_ns self.cpu_time = self.hostnode.end_ns - self.hostnode.start_ns
self.self_cpu_time = self.cpu_time self.self_cpu_time = self.cpu_time
for child in self.children_node: for child in self.children_node:
child.cal_flops()
child.cal_statistic()
self.gpu_time += child.gpu_time self.gpu_time += child.gpu_time
self.general_gpu_time += child.general_gpu_time self.general_gpu_time += child.general_gpu_time
self.self_cpu_time -= child.end_ns - child.start_ns self.self_cpu_time -= child.end_ns - child.start_ns
self.flops += child.flops
for rt in self.runtime_node: for rt in self.runtime_node:
rt.cal_statistic()
self.self_cpu_time -= rt.end_ns - rt.start_ns self.self_cpu_time -= rt.end_ns - rt.start_ns
self.gpu_time += rt.gpu_time self.gpu_time += rt.gpu_time
self.self_gpu_time += rt.gpu_time self.self_gpu_time += rt.gpu_time
self.general_gpu_time += rt.general_gpu_time self.general_gpu_time += rt.general_gpu_time
self.self_general_gpu_time += rt.general_gpu_time self.self_general_gpu_time += rt.general_gpu_time
for device in self.hostnode.device_node: for device in self.hostnode.device_node:
if device.type == TracerEventType.Kernel: if device.type == TracerEventType.Kernel:
self.gpu_time += device.end_ns - device.start_ns self.gpu_time += device.end_ns - device.start_ns
...@@ -229,6 +247,7 @@ class TimeRangeSummary: ...@@ -229,6 +247,7 @@ class TimeRangeSummary:
) )
) # device_id/type/stream_id ) # device_id/type/stream_id
for hostnode in hostnodes[1:]: # skip root node for hostnode in hostnodes[1:]: # skip root node
CPUTimeRange[hostnode.type].append( CPUTimeRange[hostnode.type].append(
(hostnode.start_ns, hostnode.end_ns) (hostnode.start_ns, hostnode.end_ns)
) )
...@@ -407,6 +426,11 @@ class EventSummary: ...@@ -407,6 +426,11 @@ class EventSummary:
self.general_gpu_time = 0 self.general_gpu_time = 0
self.min_general_gpu_time = float('inf') self.min_general_gpu_time = float('inf')
self.max_general_gpu_time = 0 self.max_general_gpu_time = 0
self._flops = 0
@property
def flops(self):
return self._flops
@property @property
def avg_cpu_time(self): def avg_cpu_time(self):
...@@ -444,11 +468,15 @@ class EventSummary: ...@@ -444,11 +468,15 @@ class EventSummary:
def add_call(self): def add_call(self):
self.call += 1 self.call += 1
def add_flops(self, flops):
self._flops += flops
def add_item(self, node): def add_item(self, node):
self.add_call() self.add_call()
self.add_cpu_time(node.cpu_time) self.add_cpu_time(node.cpu_time)
self.add_gpu_time(node.gpu_time) self.add_gpu_time(node.gpu_time)
self.add_general_gpu_time(node.general_gpu_time) self.add_general_gpu_time(node.general_gpu_time)
self.add_flops(node.flops)
for child in node.children_node: for child in node.children_node:
if child.type != TracerEventType.Operator: if child.type != TracerEventType.Operator:
if child.name not in self.operator_inners: if child.name not in self.operator_inners:
...@@ -1328,6 +1356,7 @@ def _build_table( ...@@ -1328,6 +1356,7 @@ def _build_table(
), ),
format_ratio(gpu_ratio), format_ratio(gpu_ratio),
), ),
item.flops,
] ]
all_row_values.append(row_values) all_row_values.append(row_values)
if op_detail: if op_detail:
...@@ -1393,6 +1422,7 @@ def _build_table( ...@@ -1393,6 +1422,7 @@ def _build_table(
), ),
format_ratio(gpu_ratio), format_ratio(gpu_ratio),
), ),
'-',
] ]
all_row_values.append(row_values) all_row_values.append(row_values)
for ( for (
...@@ -1436,6 +1466,7 @@ def _build_table( ...@@ -1436,6 +1466,7 @@ def _build_table(
), ),
format_ratio(gpu_ratio), format_ratio(gpu_ratio),
), ),
'-',
] ]
all_row_values.append(row_values) all_row_values.append(row_values)
for ( for (
...@@ -1473,12 +1504,14 @@ def _build_table( ...@@ -1473,12 +1504,14 @@ def _build_table(
), ),
format_ratio(gpu_ratio), format_ratio(gpu_ratio),
), ),
'-',
] ]
all_row_values.append(row_values) all_row_values.append(row_values)
# Calculate the column width # Calculate the column width
calltime_width = 6 calltime_width = 6
cpu_data_description_width = 40 cpu_data_description_width = 40
gpu_data_description_width = 40 gpu_data_description_width = 40
flops_width = 10
for row_values in all_row_values: for row_values in all_row_values:
if isinstance(row_values, str): if isinstance(row_values, str):
continue continue
...@@ -1496,6 +1529,7 @@ def _build_table( ...@@ -1496,6 +1529,7 @@ def _build_table(
'Calls', 'Calls',
'CPU Total / Avg / Max / Min / Ratio(%)', 'CPU Total / Avg / Max / Min / Ratio(%)',
'GPU Total / Avg / Max / Min / Ratio(%)', 'GPU Total / Avg / Max / Min / Ratio(%)',
'FLOPs',
] ]
row_format_list = [""] row_format_list = [""]
header_sep_list = [""] header_sep_list = [""]
...@@ -1504,6 +1538,7 @@ def _build_table( ...@@ -1504,6 +1538,7 @@ def _build_table(
add_column(calltime_width) add_column(calltime_width)
add_column(cpu_data_description_width) add_column(cpu_data_description_width)
add_column(gpu_data_description_width) add_column(gpu_data_description_width)
add_column(flops_width)
row_format = row_format_list[0] row_format = row_format_list[0]
header_sep = header_sep_list[0] header_sep = header_sep_list[0]
......
...@@ -12,29 +12,35 @@ ...@@ -12,29 +12,35 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from numpy import prod
_FLOPS_COMPUTE_FUNC_MAP = {} _FLOPS_COMPUTE_FUNC_MAP = {}
def flops(op_type: str, input_shapes: tuple, **attrs) -> int: def prod(s):
p = 1
for v in s:
p *= v
return p
def flops(op_type: str, input_shapes: dict, attrs: dict) -> int:
""" """
count flops for operation. count FLOPs for operation.
Args: Args:
op_type (str): the type of operation. op_type (str): the type of operation.
input_shapes (tuple): the shapes of inputs. input_shapes (dict): the shapes of inputs.
attrs (dict): the attributes of the operation. attrs (dict): the attributes of the operation.
Returns: Returns:
the total flops of the operation. the total FLOPs of the operation.
""" """
if op_type not in _FLOPS_COMPUTE_FUNC_MAP: if op_type not in _FLOPS_COMPUTE_FUNC_MAP:
return 0 return 0
else: else:
func = _FLOPS_COMPUTE_FUNC_MAP[op_type] func = _FLOPS_COMPUTE_FUNC_MAP[op_type]
return func(input_shapes, **attrs) return func(input_shapes, attrs)
def register_flops(op_type): def register_flops(op_type):
...@@ -51,10 +57,10 @@ def register_flops(op_type): ...@@ -51,10 +57,10 @@ def register_flops(op_type):
@register_flops("dropout") @register_flops("dropout")
def _dropout_flops(input_shapes, **attrs): def _dropout_flops(input_shapes, attrs):
return 0 return 0
@register_flops("relu") @register_flops("relu")
def _relu_flops(input_shapes, **attrs): def _relu_flops(input_shapes, attrs):
return prod(input_shapes[0]) return prod(input_shapes.get('X')[0])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册