未验证 提交 5ba9fe6e 编写于 作者: H Huihuang Zheng 提交者: GitHub

Basic PR on Cost Model (#35774)

Add basic Cost Model, it uses executor to run program and profile it to get op time.

This is an early basic version, we will add more functions in the future.
上级 d4cd2590
......@@ -43,6 +43,7 @@ cc_library(graph SRCS graph.cc DEPS node pretty_log)
cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
cc_library(cost_model SRCS cost_model.cc DEPS executor graph profiler proto_desc device_tracer)
SET(GRAPH_PATTERN_DETECTOR_DEPS graph graph_helper graph_traits)
if (WITH_TESTING)
......@@ -141,6 +142,7 @@ cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass)
cc_test(cost_model_test SRCS cost_model_test.cc DEPS cost_model op_registry)
cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
cc_test(test_op_compat_sensible_pass SRCS op_compat_sensible_pass_tester.cc DEPS op_compat_sensible_pass)
cc_test(test_fc_fuse_pass_cc SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/cost_model.h"
#include <memory>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
using ir::Graph;
using platform::Event;
using platform::MemEvent;
const double CostData::NOT_MEASURED = -1;
CostData::~CostData() {
// TODO(zhhsplendid): when we save a copy of program/graph, we should delete
// here.
}
double CostData::GetOpTimeMs(int op_id) const { return op_time_ms_.at(op_id); }
double CostData::GetOpMemoryBytes(int op_id) const {
return op_memory_bytes_.at(op_id);
}
double CostData::GetWholeTimeMs() const { return whole_time_ms_; }
double CostData::GetWholeMemoryBytes() const { return whole_memory_bytes_; }
const Graph* CostData::GetGraph() const { return graph_; }
const ProgramDesc* CostData::GetProgram() const { return program_; }
bool CostData::SetCostData(const ProgramDesc& program,
const std::vector<std::vector<Event>>& time_events) {
// TODO(zhhsplendid): Make a copy so that CostData can be available even if
// SWE changes Program, the copy can be saved into pointer program_
if (program.Size() == 0) {
whole_time_ms_ = 0;
whole_memory_bytes_ = 0;
return true;
}
if (time_events.empty()) {
LOG(WARNING) << "Input time_events for CostModel is empty";
return false;
}
std::vector<Event> main_thread_events = time_events[0];
// Support global block only
// TODO(zhhsplendid): support sub blocks
const BlockDesc& global_block = program.Block(0);
size_t op_size = global_block.OpSize();
if (op_size == 0) {
whole_time_ms_ = 0;
whole_memory_bytes_ = 0;
return true;
}
bool event_to_cost_success = true;
size_t event_index = 0;
for (size_t i = 0; i < op_size; ++i) {
const OpDesc* op_desc = global_block.Op(i);
std::string op_type = op_desc->Type();
while (event_index < main_thread_events.size()) {
if (main_thread_events[event_index].name() == op_type &&
main_thread_events[event_index].type() ==
platform::EventType::kPushRange) {
break;
}
++event_index;
}
if (event_index >= main_thread_events.size()) {
LOG(WARNING) << "Input time_events for Op " << i << ", type '" << op_type
<< "' have wrong format, skip this Op.";
event_to_cost_success = false;
continue;
}
size_t op_push_index = event_index;
while (event_index < main_thread_events.size()) {
// Is it possible to Push a lot of Ops with same type and then Pop?
// ControlFlow Op can be like that, but this version only support global
// block
// TODO(zhhsplendid): make a more strict mapping between push and pop
if (main_thread_events[event_index].name() == op_type &&
main_thread_events[event_index].type() ==
platform::EventType::kPopRange) {
break;
}
++event_index;
}
if (event_index >= main_thread_events.size()) {
LOG(WARNING) << "Input time_events for Op " << i << ", type '" << op_type
<< "' have wrong format, skip this Op.";
event_to_cost_success = false;
continue;
}
size_t op_pop_index = event_index;
double cpu_time_ms = main_thread_events[op_push_index].CpuElapsedMs(
main_thread_events[op_pop_index]);
double gpu_time_ms = 0;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
gpu_time_ms = main_thread_events[op_push_index].CudaElapsedMs(
main_thread_events[op_pop_index]);
#endif
double time_ms = gpu_time_ms + cpu_time_ms;
op_time_ms_[i] = time_ms;
}
event_index = 0;
int start_profiler_idx = -1;
int stop_profiler_idx = -1;
while (event_index < main_thread_events.size()) {
if (main_thread_events[event_index].name() == "_start_profiler_") {
start_profiler_idx = event_index;
} else if (main_thread_events[event_index].name() == "_stop_profiler_") {
stop_profiler_idx = event_index;
break;
}
++event_index;
}
if (start_profiler_idx != -1 && stop_profiler_idx != -1) {
double cpu_time_ms = main_thread_events[start_profiler_idx].CpuElapsedMs(
main_thread_events[stop_profiler_idx]);
double gpu_time_ms = 0;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
gpu_time_ms = main_thread_events[start_profiler_idx].CudaElapsedMs(
main_thread_events[stop_profiler_idx]);
#endif
whole_time_ms_ = gpu_time_ms + cpu_time_ms;
} else {
LOG(WARNING) << "Input time_events for whole time have wrong format";
event_to_cost_success = false;
}
return event_to_cost_success;
}
void PrintEvents(const std::vector<std::vector<Event>>* time_events,
const std::vector<std::vector<MemEvent>>* mem_events) {
if (time_events != nullptr) {
for (size_t i = 0; i < time_events->size(); ++i) {
for (size_t j = 0; j < (*time_events)[i].size(); ++j) {
VLOG(4) << "Print time event (" << i << ", " << j << ")" << std::endl;
VLOG(4) << (*time_events)[i][j].name() << " "
<< (*time_events)[i][j].attr() << std::endl;
VLOG(4) << "This: " << &(*time_events)[i][j]
<< ", Parent: " << (*time_events)[i][j].parent() << std::endl;
if ((*time_events)[i][j].role() == platform::EventRole::kInnerOp) {
VLOG(4) << "role kInnerOp" << std::endl;
} else if ((*time_events)[i][j].role() ==
platform::EventRole::kUniqueOp) {
VLOG(4) << "role kUniqueOp" << std::endl;
} else if ((*time_events)[i][j].role() ==
platform::EventRole::kOrdinary) {
VLOG(4) << "role kOrdinary" << std::endl;
} else if ((*time_events)[i][j].role() ==
platform::EventRole::kSpecial) {
VLOG(4) << "role kSpecial" << std::endl;
}
if ((*time_events)[i][j].type() == platform::EventType::kPopRange) {
VLOG(4) << "type kPopRange" << std::endl;
} else if ((*time_events)[i][j].type() ==
platform::EventType::kPushRange) {
VLOG(4) << "type kPushRange" << std::endl;
} else if ((*time_events)[i][j].type() == platform::EventType::kMark) {
VLOG(4) << "type kMark" << std::endl;
}
VLOG(4) << std::endl;
}
}
}
if (mem_events != nullptr) {
for (size_t i = 0; i < mem_events->size(); ++i) {
for (size_t j = 0; j < (*mem_events)[i].size(); ++j) {
VLOG(4) << "Print mem event (" << i << ", " << j << ")" << std::endl;
VLOG(4) << (*mem_events)[i][j].annotation() << std::endl;
}
}
}
}
std::string ToLowerCopy(const std::string& in) {
std::string out(in);
std::transform(out.begin(), out.end(), out.begin(),
[](unsigned char c) { return std::tolower(c); });
return out;
}
CostData CostModel::ProfileMeasure(
const ProgramDesc& main_program, const ProgramDesc& startup_program,
const std::string& device,
const std::vector<std::string>& fetch_cost_list) const {
// Currently fetch_cost_list is useless
// TODO(zhhsplendid): support different fetch data
platform::ProfilerState profiler_state;
platform::Place place;
std::string device_lower_case = ToLowerCopy(device);
if (device_lower_case == "cpu") {
profiler_state = platform::ProfilerState::kCPU;
place = platform::CPUPlace();
} else if (device_lower_case == "gpu") {
profiler_state = platform::ProfilerState::kAll;
place = platform::CUDAPlace();
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Not support %s in CostModel now", device));
}
Executor executor(place);
Scope scope;
executor.Run(startup_program, &scope, /*block_id = */ 0);
// TODO(zhhsplendid): handle the case that Profiler is already enabled
SetTracerOption(platform::TracerOption::kAllOpDetail);
EnableProfiler(profiler_state);
executor.Run(main_program, &scope, /*block_id = */ 0);
std::unique_ptr<std::vector<std::vector<Event>>> time_events(
new std::vector<std::vector<Event>>());
std::unique_ptr<std::vector<std::vector<MemEvent>>> mem_events(
new std::vector<std::vector<MemEvent>>());
CompleteProfilerEvents(/*tracer_profile= */ nullptr, time_events.get(),
mem_events.get());
// TODO(zhhsplendid): remove debug vlog after this series of work
PrintEvents(time_events.get(), mem_events.get());
// Convert events to cost data
CostData cost_data;
cost_data.SetCostData(main_program, *time_events);
return cost_data;
}
} // namespace framework
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle {
namespace framework {
class CostData {
public:
CostData() {}
~CostData();
// Support global block only
// TODO(zhhsplendid): add support for sub-block
double GetOpTimeMs(int op_id) const;
double GetOpMemoryBytes(int op_id) const;
double GetWholeTimeMs() const;
double GetWholeMemoryBytes() const;
const ir::Graph* GetGraph() const;
const ProgramDesc* GetProgram() const;
// Support Time Event only
// TODO(zhhsplendid): add memory
bool SetCostData(
const ProgramDesc& program,
const std::vector<std::vector<platform::Event>>& time_events);
static const double NOT_MEASURED;
private:
ir::Graph* graph_{nullptr};
ProgramDesc* program_{nullptr};
std::map<int, double> op_time_ms_; // from Op Node id to time
std::map<int, double>
op_memory_bytes_; // from Op Node id to total memory bytes
std::map<int, double> comm_; // from Op Node id to communicate cost
double whole_time_ms_{
NOT_MEASURED}; // time cost of the whole program or graph
double whole_memory_bytes_{
NOT_MEASURED}; // memory cost of the whole program or graph
double whole_comm_{
NOT_MEASURED}; // communication cost of the whole program or graph
};
class CostModel {
public:
CostModel() {}
~CostModel() {}
CostData ProfileMeasure(
const ProgramDesc& main_program, const ProgramDesc& startup_program,
const std::string& device,
const std::vector<std::string>& fetch_cost_list) const;
};
} // namespace framework
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/cost_model.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/event.h"
namespace paddle {
namespace framework {
// Register test op
class FakeTestOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "").AsDuplicable();
AddInput("Y", "").AsDuplicable();
AddOutput("Out", "").AsDuplicable();
AddComment("");
}
};
class FakeTestOp : public OperatorBase {
public:
FakeTestOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const Scope &scope,
const platform::Place &place) const override {
// Fake RunImpl, for test only
Variable *var = scope.FindVar("X");
if (var != nullptr) {
LoDTensor *tensor = var->GetMutable<LoDTensor>();
tensor->mutable_data<float>(place);
}
int count = 0;
while (count <= 1000) {
++count;
}
}
};
} // namespace framework
} // namespace paddle
REGISTER_OPERATOR(fake_test_op, paddle::framework::FakeTestOp,
paddle::framework::FakeTestOpMaker);
namespace paddle {
namespace framework {
ProgramDesc CreateTestProgram() {
// create a ProgramDesc:
// Z = fake_test_op(X, Y)
// Out = fake_test_op(Z, W)
ProgramDesc program;
auto *global_block = program.MutableBlock(0);
auto *x = global_block->Var("X");
x->SetType(proto::VarType::LOD_TENSOR);
x->SetLoDLevel(0);
x->SetDataType(proto::VarType::FP32);
x->SetShape({1000, 784});
auto *y = global_block->Var("Y");
y->SetType(proto::VarType::LOD_TENSOR);
y->SetLoDLevel(0);
y->SetDataType(proto::VarType::FP32);
y->SetShape({784, 100});
auto *op0 = global_block->AppendOp();
op0->SetType("fake_test_op");
op0->SetInput("X", {x->Name()});
op0->SetInput("Y", {y->Name()});
auto *z = global_block->Var("Z");
z->SetType(proto::VarType::LOD_TENSOR);
op0->SetOutput("Out", {z->Name()});
auto *w = global_block->Var("W");
w->SetType(proto::VarType::LOD_TENSOR);
w->SetLoDLevel(0);
w->SetDataType(proto::VarType::FP32);
w->SetShape({100, 10});
auto *op1 = global_block->AppendOp();
op1->SetType("fake_test_op");
op1->SetInput("X", {z->Name()});
op1->SetInput("Y", {w->Name()});
auto *out = global_block->Var("Out");
out->SetType(proto::VarType::LOD_TENSOR);
op1->SetOutput("Out", {out->Name()});
return program;
}
TEST(CostModelTest, TestProfileMeasure_EmptyProgram) {
CostModel cost_model;
ProgramDesc empty_program;
CostData cost_data =
cost_model.ProfileMeasure(empty_program, empty_program, "cpu", {"time"});
EXPECT_EQ(cost_data.GetWholeTimeMs(), 0);
}
TEST(CostModelTest, TestProfileMeasure_Program) {
CostModel cost_model;
ProgramDesc program = CreateTestProgram();
ProgramDesc empty_program;
CostData cost_data =
cost_model.ProfileMeasure(program, empty_program, "cpu", {"time"});
double op0_time_ms = cost_data.GetOpTimeMs(0);
double op1_time_ms = cost_data.GetOpTimeMs(1);
EXPECT_GT(op0_time_ms, 0);
EXPECT_GT(op1_time_ms, 0);
EXPECT_GT(cost_data.GetWholeTimeMs(), op0_time_ms + op1_time_ms);
}
TEST(CostModelTest, TestProfileMeasure_UnsupportedDevice) {
CostModel cost_model;
ProgramDesc program = CreateTestProgram();
ProgramDesc empty_program;
EXPECT_THROW(cost_model.ProfileMeasure(program, empty_program, "wrong_device",
{"time"}),
paddle::platform::EnforceNotMet);
}
TEST(CostDataTest, TestGetGraphProgram) {
CostData cost_data;
EXPECT_EQ(cost_data.GetGraph(), nullptr);
EXPECT_EQ(cost_data.GetProgram(), nullptr);
}
TEST(CostDataTest, TestUninitailzed) {
CostData cost_data;
EXPECT_EQ(cost_data.GetWholeMemoryBytes(), CostData::NOT_MEASURED);
EXPECT_EQ(cost_data.GetWholeTimeMs(), CostData::NOT_MEASURED);
}
TEST(CostDataTest, TestEmptyProgram) {
CostData cost_data;
ProgramDesc empty_program("");
EXPECT_EQ(cost_data.SetCostData(empty_program, {}), true);
EXPECT_EQ(cost_data.GetWholeMemoryBytes(), 0);
EXPECT_EQ(cost_data.GetWholeTimeMs(), 0);
}
TEST(CostDataTest, TestEmptyTimeEvent) {
CostData cost_data;
ProgramDesc program = CreateTestProgram();
EXPECT_EQ(cost_data.SetCostData(program, {}), false);
EXPECT_EQ(cost_data.GetWholeMemoryBytes(), CostData::NOT_MEASURED);
EXPECT_EQ(cost_data.GetWholeTimeMs(), CostData::NOT_MEASURED);
}
TEST(CostDataTest, TestNoOpEvent) {
CostData cost_data;
ProgramDesc program = CreateTestProgram();
std::vector<platform::Event> thread_events;
thread_events.push_back(
platform::Event(platform::EventType::kPushRange, "not exist name", 0));
std::vector<std::vector<platform::Event>> time_events{thread_events};
EXPECT_EQ(cost_data.SetCostData(program, time_events), false);
}
TEST(CostDataTest, TestNoOpPopEvent) {
CostData cost_data;
ProgramDesc program = CreateTestProgram();
std::vector<platform::Event> thread_events;
thread_events.push_back(
platform::Event(platform::EventType::kPushRange, "fake_test_op", 0));
std::vector<std::vector<platform::Event>> time_events{thread_events};
EXPECT_EQ(cost_data.SetCostData(program, time_events), false);
}
TEST(CostDataTest, TestNoWholeEvent) {
CostData cost_data;
ProgramDesc program = CreateTestProgram();
std::vector<platform::Event> thread_events;
thread_events.push_back(
platform::Event(platform::EventType::kPushRange, "fake_test_op", 0));
thread_events.push_back(
platform::Event(platform::EventType::kPopRange, "fake_test_op", 0));
thread_events.push_back(
platform::Event(platform::EventType::kPushRange, "fake_test_op", 0));
thread_events.push_back(
platform::Event(platform::EventType::kPopRange, "fake_test_op", 0));
std::vector<std::vector<platform::Event>> time_events{thread_events};
EXPECT_EQ(cost_data.SetCostData(program, time_events), false);
}
} // namespace framework
} // namespace paddle
......@@ -494,6 +494,16 @@ class DeviceTracerImpl : public DeviceTracer {
}
proto::Profile GenProfile(const std::string &profile_path) {
proto::Profile profile_pb = this->GetProfile();
std::ofstream profile_f;
profile_f.open(profile_path,
std::ios::out | std::ios::trunc | std::ios::binary);
profile_pb.SerializeToOstream(&profile_f);
profile_f.close();
return profile_pb;
}
proto::Profile GetProfile() {
int miss = 0, find = 0;
std::lock_guard<std::mutex> l(trace_mu_);
proto::Profile profile_pb;
......@@ -601,12 +611,6 @@ class DeviceTracerImpl : public DeviceTracer {
event->set_thread_id(r.thread_id);
}
}
std::ofstream profile_f;
profile_f.open(profile_path,
std::ios::out | std::ios::trunc | std::ios::binary);
profile_pb.SerializeToOstream(&profile_f);
profile_f.close();
return profile_pb;
}
......
......@@ -126,6 +126,9 @@ class DeviceTracer {
int64_t device_id, int64_t stream_id,
uint32_t correlation_id) = 0;
// Get a proto after done
virtual proto::Profile GetProfile() = 0;
// Generate a proto after done (Disabled).
virtual proto::Profile GenProfile(const std::string& profile_path) = 0;
......
......@@ -263,9 +263,40 @@ void DisableProfiler(EventSortingKey sorted_key,
ParseEvents(all_events, true, sorted_key);
ParseEvents(all_events, false, sorted_key);
if (VLOG_IS_ON(5)) {
std::vector<std::vector<MemEvent>> all_mem_events = GetMemEvents();
ParseMemEvents(all_mem_events);
std::vector<std::vector<MemEvent>> all_mem_events = GetMemEvents();
ParseMemEvents(all_mem_events);
ResetProfiler();
g_state = ProfilerState::kDisabled;
g_tracer_option = TracerOption::kDefault;
should_send_profile_state = true;
}
void CompleteProfilerEvents(proto::Profile *tracer_profile,
std::vector<std::vector<Event>> *time_events,
std::vector<std::vector<MemEvent>> *mem_events) {
SynchronizeAllDevice();
MemEvenRecorder::Instance().Flush();
std::lock_guard<std::mutex> l(profiler_mu);
if (g_state == ProfilerState::kDisabled) return;
// Mark the profiling stop.
Mark("_stop_profiler_");
DeviceTracer *tracer = GetDeviceTracer();
if (tracer->IsEnabled() && tracer_profile != nullptr) {
tracer->Disable();
tracer->GenEventKernelCudaElapsedTime();
*tracer_profile = tracer->GetProfile();
}
if (time_events != nullptr) {
*time_events = GetAllEvents();
}
if (mem_events != nullptr) {
*mem_events = GetMemEvents();
}
ResetProfiler();
......
......@@ -28,9 +28,12 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/event.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.pb.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/gpu_info.h"
#endif
namespace paddle {
namespace platform {
......@@ -215,6 +218,11 @@ void EnableProfiler(ProfilerState state);
void ResetProfiler();
void DisableProfiler(EventSortingKey sorted_key,
const std::string& profile_path);
// Disable profiler but return events instead of print it.
void CompleteProfilerEvents(proto::Profile* tracer_profile,
std::vector<std::vector<Event>>* time_events,
std::vector<std::vector<MemEvent>>* mem_events);
// Test if the profiler is currently enabled.
bool IsProfileEnabled();
// Whether the trainer should send profiling state to PS.
......
......@@ -820,7 +820,6 @@ void ParseEvents(const std::vector<std::vector<Event>> &events,
std::multimap<std::string, EventItem> child_map;
size_t max_name_width = 0;
OverHead overhead;
AnalyzeEvent(analyze_events, &events_table, &child_map, sorted_func,
sorted_by, &max_name_width, &overhead, merge_thread);
......
......@@ -6,7 +6,8 @@ include_directories(${PADDLE_SOURCE_DIR}/paddle/utils)
set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper prune
feed_fetch_method pass generate_pass pass_builder parallel_executor profiler layer tracer engine scope_pool
analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context
gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper custom_operator)
gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper custom_operator
cost_model)
if (WITH_PSCORE)
set(PYBIND_DEPS ${PYBIND_DEPS} ps_service)
......@@ -59,6 +60,7 @@ set(PYBIND_SRCS
data_set_py.cc
imperative.cc
ir.cc
bind_cost_model.cc
inference_api.cc
compatible.cc
io.cc
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/pybind/bind_cost_model.h"
#include <pybind11/stl.h>
#include "paddle/fluid/framework/ir/cost_model.h"
#include "paddle/fluid/framework/program_desc.h"
namespace py = pybind11;
using paddle::framework::CostData;
using paddle::framework::CostModel;
using paddle::framework::ProgramDesc;
namespace paddle {
namespace pybind {
void BindCostModel(py::module* m) {
py::class_<CostData>(*m, "CostData")
.def(py::init<>())
.def("get_whole_time_ms", &CostData::GetWholeTimeMs)
.def("get_op_time_ms", &CostData::GetOpTimeMs);
py::class_<CostModel>(*m, "CostModel")
.def(py::init<>())
.def("profile_measure",
[](CostModel& self, py::object py_main_program,
py::object py_startup_program, const std::string& device,
const std::vector<std::string>& fetch_cost_list) {
py::object py_main_program_desc = py_main_program.attr("desc");
ProgramDesc* main_program_desc =
py_main_program_desc.cast<ProgramDesc*>();
py::object py_startup_program_desc =
py_startup_program.attr("desc");
ProgramDesc* startup_program_desc =
py_startup_program_desc.cast<ProgramDesc*>();
return self.ProfileMeasure(*main_program_desc,
*startup_program_desc, device,
fetch_cost_list);
});
}
} // namespace pybind
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <pybind11/pybind11.h>
namespace paddle {
namespace pybind {
void BindCostModel(pybind11::module *m);
} // namespace pybind
} // namespace paddle
......@@ -38,6 +38,7 @@ limitations under the License. */
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/framework/ir/coalesce_grad_tensor_pass.h"
#include "paddle/fluid/framework/ir/cost_model.h"
#include "paddle/fluid/framework/ir/generate_pass.h"
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/lod_rank_table.h"
......@@ -78,6 +79,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_ASCEND
#include "paddle/fluid/pybind/ascend_wrapper_py.h"
#endif
#include "paddle/fluid/pybind/bind_cost_model.h"
#include "paddle/fluid/pybind/box_helper_py.h"
#include "paddle/fluid/pybind/compatible.h"
#include "paddle/fluid/pybind/const_value.h"
......@@ -2131,6 +2133,7 @@ All parameter, weight, gradient are variables in Paddle.
BindBlockDesc(&m);
BindVarDsec(&m);
BindOpDesc(&m);
BindCostModel(&m);
BindConstValue(&m);
BindGlobalValueGetterSetter(&m);
BindProcessMeshDesc(&m);
......@@ -2439,7 +2442,6 @@ All parameter, weight, gradient are variables in Paddle.
[](ir::PassBuilder &self, size_t idx) { self.RemovePass(idx); });
// -- python binds for parallel executor.
py::class_<ParallelExecutor> pe(m, "ParallelExecutor");
py::class_<ExecutionStrategy> exec_strategy(pe, "ExecutionStrategy", R"DOC(
ExecutionStrategy allows the user to more preciously control how to run
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.static as static
import numpy as np
from paddle.fluid import core
class CostModel():
def __init__(self):
pass
def build_program(self):
paddle.enable_static()
main_program = static.Program()
startup_program = static.Program()
with static.program_guard(
main_program=main_program, startup_program=startup_program):
data = paddle.static.data(
name='X', shape=[None, 1], dtype='float32')
hidden = paddle.static.nn.fc(data, 10)
loss = paddle.mean(hidden)
paddle.optimizer.SGD(learning_rate=0.01).minimize(loss)
print("main program is: {}".format(main_program))
#print("start up program is: {}".format(startup_program))
return startup_program, main_program
def profile_measure(self,
startup_program,
main_program,
device='gpu',
fetch_cost_list=['time', 'memory']):
place = paddle.set_device('gpu')
x = np.random.random(size=(10, 1)).astype('float32')
exe = paddle.static.Executor(place)
exe.run(startup_program)
paddle.fluid.profiler.start_profiler("All")
exe.run(main_program, feed={"X": x}, fetch_list=[])
# core.CostModel.ProfileMeasure(main_program, device)
print("core:<<<<<<<")
cost_model = core.CostModel()
cost_data = cost_model.ProfileMeasure(device)
# cost_list = self.stop_cost_model()
# return cost_list
cost_model = CostModel()
startup_program, main_program = cost_model.build_program()
cost_model.profile_measure(startup_program, main_program)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
import paddle.fluid.core as core
paddle.enable_static()
device = "gpu" if core.is_compiled_with_cuda() else "cpu"
class TestCostModel(unittest.TestCase):
def test_profiler_measure_empty_program(self):
cost_model = core.CostModel()
empty_program = paddle.static.Program()
startup_program = paddle.static.Program()
cost_data = cost_model.profile_measure(empty_program, startup_program,
device, ["time"])
self.assertEqual(cost_data.get_whole_time_ms(), 0)
def test_profiler_measure_program(self):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
# TODO(zhhsplendid): support paddle.static.data, which is uninitialized data
data = paddle.ones(name='X', shape=[16, 100], dtype='float32')
hidden = paddle.static.nn.fc(data, 10)
loss = paddle.mean(hidden)
cost_model = core.CostModel()
cost_data = cost_model.profile_measure(main_program, startup_program,
device, ["time"])
fc_op_time = cost_data.get_op_time_ms(0)
mean_op_time = cost_data.get_op_time_ms(1)
self.assertGreater(fc_op_time, 0)
self.assertGreater(mean_op_time, 0)
self.assertGreaterEqual(cost_data.get_whole_time_ms(),
fc_op_time + mean_op_time)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册