未验证 提交 ace512a3 编写于 作者: Y Yibing Liu 提交者: GitHub

Merge pull request #8596 from kuke/profiler_multi_gpu

Fix the profiler's bug in multi-gpu mode
...@@ -56,7 +56,7 @@ cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) ...@@ -56,7 +56,7 @@ cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context) cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog
shape_inference data_transform lod_tensor) shape_inference data_transform lod_tensor profiler)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry init) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry init)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog)
...@@ -80,7 +80,7 @@ cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor) ...@@ -80,7 +80,7 @@ cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor)
cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog) cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto backward glog lod_rank_table profiler feed_fetch_method) framework_proto backward glog lod_rank_table feed_fetch_method)
cc_library(prune SRCS prune.cc DEPS framework_proto) cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
...@@ -25,7 +25,6 @@ limitations under the License. */ ...@@ -25,7 +25,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
DECLARE_bool(benchmark); DECLARE_bool(benchmark);
DEFINE_bool(check_nan_inf, false, DEFINE_bool(check_nan_inf, false,
...@@ -126,11 +125,6 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, ...@@ -126,11 +125,6 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
for (auto& op_desc : block.AllOps()) { for (auto& op_desc : block.AllOps()) {
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc); auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
// TODO(panyx0718): Need a program id to distinguish programs.
platform::RecordEvent record_event(op->Type(), pool.Get(place_),
op_desc->Block()->ID());
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope); VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
op->Run(*local_scope, place_); op->Run(*local_scope, place_);
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/profiler.h"
DECLARE_bool(benchmark); DECLARE_bool(benchmark);
...@@ -497,7 +498,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -497,7 +498,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto dev_ctx = pool.Get(place); auto dev_ctx = pool.Get(place);
// profile
platform::RecordEvent record_event(Type(), dev_ctx);
// check if op[type] has kernel registered. // check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels(); auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_); auto kernels_iter = all_op_kernels.find(type_);
......
...@@ -132,21 +132,19 @@ void PopEvent(const std::string& name, const DeviceContext* dev_ctx) { ...@@ -132,21 +132,19 @@ void PopEvent(const std::string& name, const DeviceContext* dev_ctx) {
GetEventList().Record(EventKind::kPopRange, name, g_thread_id, dev_ctx); GetEventList().Record(EventKind::kPopRange, name, g_thread_id, dev_ctx);
} }
RecordEvent::RecordEvent(const std::string& name, const DeviceContext* dev_ctx, RecordEvent::RecordEvent(const std::string& name,
int32_t block_id) { const DeviceContext* dev_ctx) {
if (g_state == ProfilerState::kDisabled) return; if (g_state == ProfilerState::kDisabled) return;
dev_ctx_ = dev_ctx; dev_ctx_ = dev_ctx;
name_ = name; name_ = name;
PushEvent(name_, dev_ctx_); PushEvent(name_, dev_ctx_);
full_name_ = string::Sprintf("%s_b%d", name, block_id);
// Maybe need the same push/pop behavior. // Maybe need the same push/pop behavior.
SetCurAnnotation(full_name_.c_str()); SetCurAnnotation(name_.c_str());
} }
RecordEvent::~RecordEvent() { RecordEvent::~RecordEvent() {
ClearCurAnnotation();
if (g_state == ProfilerState::kDisabled) return; if (g_state == ProfilerState::kDisabled) return;
ClearCurAnnotation();
PopEvent(name_, dev_ctx_); PopEvent(name_, dev_ctx_);
} }
......
...@@ -104,8 +104,7 @@ void PushEvent(const std::string& name, const DeviceContext* dev_ctx); ...@@ -104,8 +104,7 @@ void PushEvent(const std::string& name, const DeviceContext* dev_ctx);
void PopEvent(const std::string& name, const DeviceContext* dev_ctx); void PopEvent(const std::string& name, const DeviceContext* dev_ctx);
struct RecordEvent { struct RecordEvent {
RecordEvent(const std::string& name, const DeviceContext* dev_ctx, RecordEvent(const std::string& name, const DeviceContext* dev_ctx);
int32_t block_id);
~RecordEvent(); ~RecordEvent();
......
...@@ -95,7 +95,7 @@ TEST(RecordEvent, RecordEvent) { ...@@ -95,7 +95,7 @@ TEST(RecordEvent, RecordEvent) {
*/ */
for (int i = 1; i < 5; ++i) { for (int i = 1; i < 5; ++i) {
std::string name = "evs_op_" + std::to_string(i); std::string name = "evs_op_" + std::to_string(i);
RecordEvent record_event(name, dev_ctx, 0); RecordEvent record_event(name, dev_ctx);
int counter = 1; int counter = 1;
while (counter != i * 1000) counter++; while (counter != i * 1000) counter++;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册