未验证 提交 ace512a3 编写于 作者: Y Yibing Liu 提交者: GitHub

Merge pull request #8596 from kuke/profiler_multi_gpu

Fix the profiler's bug in multi-gpu mode
......@@ -56,7 +56,7 @@ cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog
shape_inference data_transform lod_tensor)
shape_inference data_transform lod_tensor profiler)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry init)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog)
......@@ -80,7 +80,7 @@ cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor)
cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto backward glog lod_rank_table profiler feed_fetch_method)
framework_proto backward glog lod_rank_table feed_fetch_method)
cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
......@@ -25,7 +25,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
DECLARE_bool(benchmark);
DEFINE_bool(check_nan_inf, false,
......@@ -126,11 +125,6 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
for (auto& op_desc : block.AllOps()) {
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
// TODO(panyx0718): Need a program id to distinguish programs.
platform::RecordEvent record_event(op->Type(), pool.Get(place_),
op_desc->Block()->ID());
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
op->Run(*local_scope, place_);
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/profiler.h"
DECLARE_bool(benchmark);
......@@ -497,7 +498,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
this->InferShape(&infer_shape_ctx);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto dev_ctx = pool.Get(place);
// profile
platform::RecordEvent record_event(Type(), dev_ctx);
// check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_);
......
......@@ -132,21 +132,19 @@ void PopEvent(const std::string& name, const DeviceContext* dev_ctx) {
GetEventList().Record(EventKind::kPopRange, name, g_thread_id, dev_ctx);
}
RecordEvent::RecordEvent(const std::string& name, const DeviceContext* dev_ctx,
int32_t block_id) {
RecordEvent::RecordEvent(const std::string& name,
const DeviceContext* dev_ctx) {
if (g_state == ProfilerState::kDisabled) return;
dev_ctx_ = dev_ctx;
name_ = name;
PushEvent(name_, dev_ctx_);
full_name_ = string::Sprintf("%s_b%d", name, block_id);
// Maybe need the same push/pop behavior.
SetCurAnnotation(full_name_.c_str());
SetCurAnnotation(name_.c_str());
}
RecordEvent::~RecordEvent() {
ClearCurAnnotation();
if (g_state == ProfilerState::kDisabled) return;
ClearCurAnnotation();
PopEvent(name_, dev_ctx_);
}
......
......@@ -104,8 +104,7 @@ void PushEvent(const std::string& name, const DeviceContext* dev_ctx);
void PopEvent(const std::string& name, const DeviceContext* dev_ctx);
struct RecordEvent {
RecordEvent(const std::string& name, const DeviceContext* dev_ctx,
int32_t block_id);
RecordEvent(const std::string& name, const DeviceContext* dev_ctx);
~RecordEvent();
......
......@@ -95,7 +95,7 @@ TEST(RecordEvent, RecordEvent) {
*/
for (int i = 1; i < 5; ++i) {
std::string name = "evs_op_" + std::to_string(i);
RecordEvent record_event(name, dev_ctx, 0);
RecordEvent record_event(name, dev_ctx);
int counter = 1;
while (counter != i * 1000) counter++;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册