From 5ba9fe6ef8b3b52b7f871a5deb70853ee80c3b37 Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Sat, 18 Sep 2021 18:17:57 +0800 Subject: [PATCH] Basic PR on Cost Model (#35774) Add basic Cost Model, it uses executor to run program and profile it to get op time. This is an early basic version, we will add more functions in the future. --- paddle/fluid/framework/ir/CMakeLists.txt | 2 + paddle/fluid/framework/ir/cost_model.cc | 256 ++++++++++++++++++ paddle/fluid/framework/ir/cost_model.h | 85 ++++++ paddle/fluid/framework/ir/cost_model_test.cc | 209 ++++++++++++++ paddle/fluid/platform/device_tracer.cc | 16 +- paddle/fluid/platform/device_tracer.h | 3 + paddle/fluid/platform/profiler.cc | 37 ++- paddle/fluid/platform/profiler.h | 8 + paddle/fluid/platform/profiler_helper.h | 1 - paddle/fluid/pybind/CMakeLists.txt | 4 +- paddle/fluid/pybind/bind_cost_model.cc | 56 ++++ paddle/fluid/pybind/bind_cost_model.h | 25 ++ paddle/fluid/pybind/pybind.cc | 4 +- python/paddle/cost_model/cost_model.py | 69 +++++ .../fluid/tests/unittests/test_cost_model.py | 56 ++++ 15 files changed, 819 insertions(+), 12 deletions(-) create mode 100644 paddle/fluid/framework/ir/cost_model.cc create mode 100644 paddle/fluid/framework/ir/cost_model.h create mode 100644 paddle/fluid/framework/ir/cost_model_test.cc create mode 100644 paddle/fluid/pybind/bind_cost_model.cc create mode 100644 paddle/fluid/pybind/bind_cost_model.h create mode 100644 python/paddle/cost_model/cost_model.py create mode 100644 python/paddle/fluid/tests/unittests/test_cost_model.py diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 175bd591334..99c691e6cf6 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -43,6 +43,7 @@ cc_library(graph SRCS graph.cc DEPS node pretty_log) cc_library(graph_helper SRCS graph_helper.cc DEPS graph) cc_library(pass SRCS pass.cc DEPS graph node graph_helper) cc_library(graph_traits SRCS graph_traits.cc DEPS graph) +cc_library(cost_model SRCS cost_model.cc DEPS executor graph profiler proto_desc device_tracer) SET(GRAPH_PATTERN_DETECTOR_DEPS graph graph_helper graph_traits) if (WITH_TESTING) @@ -141,6 +142,7 @@ cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry) cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass) +cc_test(cost_model_test SRCS cost_model_test.cc DEPS cost_model op_registry) cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector) cc_test(test_op_compat_sensible_pass SRCS op_compat_sensible_pass_tester.cc DEPS op_compat_sensible_pass) cc_test(test_fc_fuse_pass_cc SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto) diff --git a/paddle/fluid/framework/ir/cost_model.cc b/paddle/fluid/framework/ir/cost_model.cc new file mode 100644 index 00000000000..5027c50103a --- /dev/null +++ b/paddle/fluid/framework/ir/cost_model.cc @@ -0,0 +1,256 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/cost_model.h" + +#include +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/errors.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace framework { + +using ir::Graph; +using platform::Event; +using platform::MemEvent; + +const double CostData::NOT_MEASURED = -1; + +CostData::~CostData() { + // TODO(zhhsplendid): when we save a copy of program/graph, we should delete + // here. +} + +double CostData::GetOpTimeMs(int op_id) const { return op_time_ms_.at(op_id); } +double CostData::GetOpMemoryBytes(int op_id) const { + return op_memory_bytes_.at(op_id); +} +double CostData::GetWholeTimeMs() const { return whole_time_ms_; } +double CostData::GetWholeMemoryBytes() const { return whole_memory_bytes_; } + +const Graph* CostData::GetGraph() const { return graph_; } +const ProgramDesc* CostData::GetProgram() const { return program_; } + +bool CostData::SetCostData(const ProgramDesc& program, + const std::vector>& time_events) { + // TODO(zhhsplendid): Make a copy so that CostData can be available even if + // SWE changes Program, the copy can be saved into pointer program_ + if (program.Size() == 0) { + whole_time_ms_ = 0; + whole_memory_bytes_ = 0; + return true; + } + + if (time_events.empty()) { + LOG(WARNING) << "Input time_events for CostModel is empty"; + return false; + } + + std::vector main_thread_events = time_events[0]; + // Support global block only + // TODO(zhhsplendid): support sub blocks + const BlockDesc& global_block = program.Block(0); + size_t op_size = global_block.OpSize(); + if (op_size == 0) { + whole_time_ms_ = 0; + whole_memory_bytes_ = 0; + return true; + } + + bool event_to_cost_success = true; + size_t event_index = 0; + for (size_t i = 0; i < op_size; ++i) { + const OpDesc* op_desc = global_block.Op(i); + std::string op_type = op_desc->Type(); + + while (event_index < main_thread_events.size()) { + if (main_thread_events[event_index].name() == op_type && + main_thread_events[event_index].type() == + platform::EventType::kPushRange) { + break; + } + ++event_index; + } + if (event_index >= main_thread_events.size()) { + LOG(WARNING) << "Input time_events for Op " << i << ", type '" << op_type + << "' have wrong format, skip this Op."; + event_to_cost_success = false; + continue; + } + size_t op_push_index = event_index; + + while (event_index < main_thread_events.size()) { + // Is it possible to Push a lot of Ops with same type and then Pop? + // ControlFlow Op can be like that, but this version only support global + // block + // TODO(zhhsplendid): make a more strict mapping between push and pop + if (main_thread_events[event_index].name() == op_type && + main_thread_events[event_index].type() == + platform::EventType::kPopRange) { + break; + } + ++event_index; + } + if (event_index >= main_thread_events.size()) { + LOG(WARNING) << "Input time_events for Op " << i << ", type '" << op_type + << "' have wrong format, skip this Op."; + event_to_cost_success = false; + continue; + } + size_t op_pop_index = event_index; + double cpu_time_ms = main_thread_events[op_push_index].CpuElapsedMs( + main_thread_events[op_pop_index]); + double gpu_time_ms = 0; +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + gpu_time_ms = main_thread_events[op_push_index].CudaElapsedMs( + main_thread_events[op_pop_index]); +#endif + double time_ms = gpu_time_ms + cpu_time_ms; + op_time_ms_[i] = time_ms; + } + + event_index = 0; + int start_profiler_idx = -1; + int stop_profiler_idx = -1; + while (event_index < main_thread_events.size()) { + if (main_thread_events[event_index].name() == "_start_profiler_") { + start_profiler_idx = event_index; + } else if (main_thread_events[event_index].name() == "_stop_profiler_") { + stop_profiler_idx = event_index; + break; + } + ++event_index; + } + if (start_profiler_idx != -1 && stop_profiler_idx != -1) { + double cpu_time_ms = main_thread_events[start_profiler_idx].CpuElapsedMs( + main_thread_events[stop_profiler_idx]); + double gpu_time_ms = 0; +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + gpu_time_ms = main_thread_events[start_profiler_idx].CudaElapsedMs( + main_thread_events[stop_profiler_idx]); +#endif + whole_time_ms_ = gpu_time_ms + cpu_time_ms; + } else { + LOG(WARNING) << "Input time_events for whole time have wrong format"; + event_to_cost_success = false; + } + + return event_to_cost_success; +} + +void PrintEvents(const std::vector>* time_events, + const std::vector>* mem_events) { + if (time_events != nullptr) { + for (size_t i = 0; i < time_events->size(); ++i) { + for (size_t j = 0; j < (*time_events)[i].size(); ++j) { + VLOG(4) << "Print time event (" << i << ", " << j << ")" << std::endl; + VLOG(4) << (*time_events)[i][j].name() << " " + << (*time_events)[i][j].attr() << std::endl; + VLOG(4) << "This: " << &(*time_events)[i][j] + << ", Parent: " << (*time_events)[i][j].parent() << std::endl; + if ((*time_events)[i][j].role() == platform::EventRole::kInnerOp) { + VLOG(4) << "role kInnerOp" << std::endl; + } else if ((*time_events)[i][j].role() == + platform::EventRole::kUniqueOp) { + VLOG(4) << "role kUniqueOp" << std::endl; + } else if ((*time_events)[i][j].role() == + platform::EventRole::kOrdinary) { + VLOG(4) << "role kOrdinary" << std::endl; + } else if ((*time_events)[i][j].role() == + platform::EventRole::kSpecial) { + VLOG(4) << "role kSpecial" << std::endl; + } + + if ((*time_events)[i][j].type() == platform::EventType::kPopRange) { + VLOG(4) << "type kPopRange" << std::endl; + } else if ((*time_events)[i][j].type() == + platform::EventType::kPushRange) { + VLOG(4) << "type kPushRange" << std::endl; + } else if ((*time_events)[i][j].type() == platform::EventType::kMark) { + VLOG(4) << "type kMark" << std::endl; + } + VLOG(4) << std::endl; + } + } + } + if (mem_events != nullptr) { + for (size_t i = 0; i < mem_events->size(); ++i) { + for (size_t j = 0; j < (*mem_events)[i].size(); ++j) { + VLOG(4) << "Print mem event (" << i << ", " << j << ")" << std::endl; + VLOG(4) << (*mem_events)[i][j].annotation() << std::endl; + } + } + } +} + +std::string ToLowerCopy(const std::string& in) { + std::string out(in); + std::transform(out.begin(), out.end(), out.begin(), + [](unsigned char c) { return std::tolower(c); }); + return out; +} + +CostData CostModel::ProfileMeasure( + const ProgramDesc& main_program, const ProgramDesc& startup_program, + const std::string& device, + const std::vector& fetch_cost_list) const { + // Currently fetch_cost_list is useless + // TODO(zhhsplendid): support different fetch data + + platform::ProfilerState profiler_state; + platform::Place place; + + std::string device_lower_case = ToLowerCopy(device); + if (device_lower_case == "cpu") { + profiler_state = platform::ProfilerState::kCPU; + place = platform::CPUPlace(); + } else if (device_lower_case == "gpu") { + profiler_state = platform::ProfilerState::kAll; + place = platform::CUDAPlace(); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Not support %s in CostModel now", device)); + } + + Executor executor(place); + Scope scope; + executor.Run(startup_program, &scope, /*block_id = */ 0); + + // TODO(zhhsplendid): handle the case that Profiler is already enabled + SetTracerOption(platform::TracerOption::kAllOpDetail); + EnableProfiler(profiler_state); + executor.Run(main_program, &scope, /*block_id = */ 0); + + std::unique_ptr>> time_events( + new std::vector>()); + std::unique_ptr>> mem_events( + new std::vector>()); + + CompleteProfilerEvents(/*tracer_profile= */ nullptr, time_events.get(), + mem_events.get()); + + // TODO(zhhsplendid): remove debug vlog after this series of work + PrintEvents(time_events.get(), mem_events.get()); + + // Convert events to cost data + CostData cost_data; + cost_data.SetCostData(main_program, *time_events); + + return cost_data; +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/cost_model.h b/paddle/fluid/framework/ir/cost_model.h new file mode 100644 index 00000000000..41567df2cb3 --- /dev/null +++ b/paddle/fluid/framework/ir/cost_model.h @@ -0,0 +1,85 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/platform/profiler.h" +#include "paddle/fluid/platform/variant.h" + +namespace paddle { +namespace framework { + +class CostData { + public: + CostData() {} + + ~CostData(); + + // Support global block only + // TODO(zhhsplendid): add support for sub-block + double GetOpTimeMs(int op_id) const; + double GetOpMemoryBytes(int op_id) const; + double GetWholeTimeMs() const; + double GetWholeMemoryBytes() const; + + const ir::Graph* GetGraph() const; + const ProgramDesc* GetProgram() const; + + // Support Time Event only + // TODO(zhhsplendid): add memory + bool SetCostData( + const ProgramDesc& program, + const std::vector>& time_events); + + static const double NOT_MEASURED; + + private: + ir::Graph* graph_{nullptr}; + ProgramDesc* program_{nullptr}; + std::map op_time_ms_; // from Op Node id to time + std::map + op_memory_bytes_; // from Op Node id to total memory bytes + std::map comm_; // from Op Node id to communicate cost + double whole_time_ms_{ + NOT_MEASURED}; // time cost of the whole program or graph + double whole_memory_bytes_{ + NOT_MEASURED}; // memory cost of the whole program or graph + double whole_comm_{ + NOT_MEASURED}; // communication cost of the whole program or graph +}; + +class CostModel { + public: + CostModel() {} + ~CostModel() {} + + CostData ProfileMeasure( + const ProgramDesc& main_program, const ProgramDesc& startup_program, + const std::string& device, + const std::vector& fetch_cost_list) const; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/cost_model_test.cc b/paddle/fluid/framework/ir/cost_model_test.cc new file mode 100644 index 00000000000..57f3904d845 --- /dev/null +++ b/paddle/fluid/framework/ir/cost_model_test.cc @@ -0,0 +1,209 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/cost_model.h" +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/platform/errors.h" +#include "paddle/fluid/platform/event.h" + +namespace paddle { +namespace framework { + +// Register test op +class FakeTestOpMaker : public OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "").AsDuplicable(); + AddInput("Y", "").AsDuplicable(); + AddOutput("Out", "").AsDuplicable(); + AddComment(""); + } +}; + +class FakeTestOp : public OperatorBase { + public: + FakeTestOp(const std::string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, const AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + private: + void RunImpl(const Scope &scope, + const platform::Place &place) const override { + // Fake RunImpl, for test only + Variable *var = scope.FindVar("X"); + if (var != nullptr) { + LoDTensor *tensor = var->GetMutable(); + tensor->mutable_data(place); + } + int count = 0; + while (count <= 1000) { + ++count; + } + } +}; + +} // namespace framework +} // namespace paddle + +REGISTER_OPERATOR(fake_test_op, paddle::framework::FakeTestOp, + paddle::framework::FakeTestOpMaker); + +namespace paddle { +namespace framework { + +ProgramDesc CreateTestProgram() { + // create a ProgramDesc: + // Z = fake_test_op(X, Y) + // Out = fake_test_op(Z, W) + ProgramDesc program; + auto *global_block = program.MutableBlock(0); + + auto *x = global_block->Var("X"); + x->SetType(proto::VarType::LOD_TENSOR); + x->SetLoDLevel(0); + x->SetDataType(proto::VarType::FP32); + x->SetShape({1000, 784}); + + auto *y = global_block->Var("Y"); + y->SetType(proto::VarType::LOD_TENSOR); + y->SetLoDLevel(0); + y->SetDataType(proto::VarType::FP32); + y->SetShape({784, 100}); + + auto *op0 = global_block->AppendOp(); + op0->SetType("fake_test_op"); + op0->SetInput("X", {x->Name()}); + op0->SetInput("Y", {y->Name()}); + + auto *z = global_block->Var("Z"); + z->SetType(proto::VarType::LOD_TENSOR); + op0->SetOutput("Out", {z->Name()}); + + auto *w = global_block->Var("W"); + w->SetType(proto::VarType::LOD_TENSOR); + w->SetLoDLevel(0); + w->SetDataType(proto::VarType::FP32); + w->SetShape({100, 10}); + + auto *op1 = global_block->AppendOp(); + op1->SetType("fake_test_op"); + op1->SetInput("X", {z->Name()}); + op1->SetInput("Y", {w->Name()}); + + auto *out = global_block->Var("Out"); + out->SetType(proto::VarType::LOD_TENSOR); + op1->SetOutput("Out", {out->Name()}); + return program; +} + +TEST(CostModelTest, TestProfileMeasure_EmptyProgram) { + CostModel cost_model; + ProgramDesc empty_program; + CostData cost_data = + cost_model.ProfileMeasure(empty_program, empty_program, "cpu", {"time"}); + EXPECT_EQ(cost_data.GetWholeTimeMs(), 0); +} + +TEST(CostModelTest, TestProfileMeasure_Program) { + CostModel cost_model; + ProgramDesc program = CreateTestProgram(); + ProgramDesc empty_program; + CostData cost_data = + cost_model.ProfileMeasure(program, empty_program, "cpu", {"time"}); + double op0_time_ms = cost_data.GetOpTimeMs(0); + double op1_time_ms = cost_data.GetOpTimeMs(1); + EXPECT_GT(op0_time_ms, 0); + EXPECT_GT(op1_time_ms, 0); + EXPECT_GT(cost_data.GetWholeTimeMs(), op0_time_ms + op1_time_ms); +} + +TEST(CostModelTest, TestProfileMeasure_UnsupportedDevice) { + CostModel cost_model; + ProgramDesc program = CreateTestProgram(); + ProgramDesc empty_program; + + EXPECT_THROW(cost_model.ProfileMeasure(program, empty_program, "wrong_device", + {"time"}), + paddle::platform::EnforceNotMet); +} + +TEST(CostDataTest, TestGetGraphProgram) { + CostData cost_data; + EXPECT_EQ(cost_data.GetGraph(), nullptr); + EXPECT_EQ(cost_data.GetProgram(), nullptr); +} + +TEST(CostDataTest, TestUninitailzed) { + CostData cost_data; + EXPECT_EQ(cost_data.GetWholeMemoryBytes(), CostData::NOT_MEASURED); + EXPECT_EQ(cost_data.GetWholeTimeMs(), CostData::NOT_MEASURED); +} + +TEST(CostDataTest, TestEmptyProgram) { + CostData cost_data; + ProgramDesc empty_program(""); + EXPECT_EQ(cost_data.SetCostData(empty_program, {}), true); + EXPECT_EQ(cost_data.GetWholeMemoryBytes(), 0); + EXPECT_EQ(cost_data.GetWholeTimeMs(), 0); +} + +TEST(CostDataTest, TestEmptyTimeEvent) { + CostData cost_data; + ProgramDesc program = CreateTestProgram(); + EXPECT_EQ(cost_data.SetCostData(program, {}), false); + EXPECT_EQ(cost_data.GetWholeMemoryBytes(), CostData::NOT_MEASURED); + EXPECT_EQ(cost_data.GetWholeTimeMs(), CostData::NOT_MEASURED); +} + +TEST(CostDataTest, TestNoOpEvent) { + CostData cost_data; + ProgramDesc program = CreateTestProgram(); + std::vector thread_events; + thread_events.push_back( + platform::Event(platform::EventType::kPushRange, "not exist name", 0)); + std::vector> time_events{thread_events}; + EXPECT_EQ(cost_data.SetCostData(program, time_events), false); +} + +TEST(CostDataTest, TestNoOpPopEvent) { + CostData cost_data; + ProgramDesc program = CreateTestProgram(); + std::vector thread_events; + thread_events.push_back( + platform::Event(platform::EventType::kPushRange, "fake_test_op", 0)); + std::vector> time_events{thread_events}; + EXPECT_EQ(cost_data.SetCostData(program, time_events), false); +} + +TEST(CostDataTest, TestNoWholeEvent) { + CostData cost_data; + ProgramDesc program = CreateTestProgram(); + std::vector thread_events; + thread_events.push_back( + platform::Event(platform::EventType::kPushRange, "fake_test_op", 0)); + thread_events.push_back( + platform::Event(platform::EventType::kPopRange, "fake_test_op", 0)); + thread_events.push_back( + platform::Event(platform::EventType::kPushRange, "fake_test_op", 0)); + thread_events.push_back( + platform::Event(platform::EventType::kPopRange, "fake_test_op", 0)); + std::vector> time_events{thread_events}; + EXPECT_EQ(cost_data.SetCostData(program, time_events), false); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/platform/device_tracer.cc b/paddle/fluid/platform/device_tracer.cc index 1bd46c0bfaf..8160a06ddea 100644 --- a/paddle/fluid/platform/device_tracer.cc +++ b/paddle/fluid/platform/device_tracer.cc @@ -494,6 +494,16 @@ class DeviceTracerImpl : public DeviceTracer { } proto::Profile GenProfile(const std::string &profile_path) { + proto::Profile profile_pb = this->GetProfile(); + std::ofstream profile_f; + profile_f.open(profile_path, + std::ios::out | std::ios::trunc | std::ios::binary); + profile_pb.SerializeToOstream(&profile_f); + profile_f.close(); + return profile_pb; + } + + proto::Profile GetProfile() { int miss = 0, find = 0; std::lock_guard l(trace_mu_); proto::Profile profile_pb; @@ -601,12 +611,6 @@ class DeviceTracerImpl : public DeviceTracer { event->set_thread_id(r.thread_id); } } - - std::ofstream profile_f; - profile_f.open(profile_path, - std::ios::out | std::ios::trunc | std::ios::binary); - profile_pb.SerializeToOstream(&profile_f); - profile_f.close(); return profile_pb; } diff --git a/paddle/fluid/platform/device_tracer.h b/paddle/fluid/platform/device_tracer.h index 9bae7a87052..ef06d0d609e 100644 --- a/paddle/fluid/platform/device_tracer.h +++ b/paddle/fluid/platform/device_tracer.h @@ -126,6 +126,9 @@ class DeviceTracer { int64_t device_id, int64_t stream_id, uint32_t correlation_id) = 0; + // Get a proto after done + virtual proto::Profile GetProfile() = 0; + // Generate a proto after done (Disabled). virtual proto::Profile GenProfile(const std::string& profile_path) = 0; diff --git a/paddle/fluid/platform/profiler.cc b/paddle/fluid/platform/profiler.cc index 2c8f918414d..40d9bb99f44 100644 --- a/paddle/fluid/platform/profiler.cc +++ b/paddle/fluid/platform/profiler.cc @@ -263,9 +263,40 @@ void DisableProfiler(EventSortingKey sorted_key, ParseEvents(all_events, true, sorted_key); ParseEvents(all_events, false, sorted_key); - if (VLOG_IS_ON(5)) { - std::vector> all_mem_events = GetMemEvents(); - ParseMemEvents(all_mem_events); + + std::vector> all_mem_events = GetMemEvents(); + ParseMemEvents(all_mem_events); + + ResetProfiler(); + g_state = ProfilerState::kDisabled; + g_tracer_option = TracerOption::kDefault; + should_send_profile_state = true; +} + +void CompleteProfilerEvents(proto::Profile *tracer_profile, + std::vector> *time_events, + std::vector> *mem_events) { + SynchronizeAllDevice(); + MemEvenRecorder::Instance().Flush(); + + std::lock_guard l(profiler_mu); + if (g_state == ProfilerState::kDisabled) return; + + // Mark the profiling stop. + Mark("_stop_profiler_"); + + DeviceTracer *tracer = GetDeviceTracer(); + if (tracer->IsEnabled() && tracer_profile != nullptr) { + tracer->Disable(); + tracer->GenEventKernelCudaElapsedTime(); + *tracer_profile = tracer->GetProfile(); + } + + if (time_events != nullptr) { + *time_events = GetAllEvents(); + } + if (mem_events != nullptr) { + *mem_events = GetMemEvents(); } ResetProfiler(); diff --git a/paddle/fluid/platform/profiler.h b/paddle/fluid/platform/profiler.h index 512bbc195b5..fbae6165e31 100644 --- a/paddle/fluid/platform/profiler.h +++ b/paddle/fluid/platform/profiler.h @@ -28,9 +28,12 @@ limitations under the License. */ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/event.h" #include "paddle/fluid/platform/place.h" +#include "paddle/fluid/platform/profiler.pb.h" + #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/fluid/platform/gpu_info.h" #endif + namespace paddle { namespace platform { @@ -215,6 +218,11 @@ void EnableProfiler(ProfilerState state); void ResetProfiler(); void DisableProfiler(EventSortingKey sorted_key, const std::string& profile_path); +// Disable profiler but return events instead of print it. +void CompleteProfilerEvents(proto::Profile* tracer_profile, + std::vector>* time_events, + std::vector>* mem_events); + // Test if the profiler is currently enabled. bool IsProfileEnabled(); // Whether the trainer should send profiling state to PS. diff --git a/paddle/fluid/platform/profiler_helper.h b/paddle/fluid/platform/profiler_helper.h index ae4d75113cd..a8438263cb9 100644 --- a/paddle/fluid/platform/profiler_helper.h +++ b/paddle/fluid/platform/profiler_helper.h @@ -820,7 +820,6 @@ void ParseEvents(const std::vector> &events, std::multimap child_map; size_t max_name_width = 0; OverHead overhead; - AnalyzeEvent(analyze_events, &events_table, &child_map, sorted_func, sorted_by, &max_name_width, &overhead, merge_thread); diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index b30e6c39f54..22778013f23 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -6,7 +6,8 @@ include_directories(${PADDLE_SOURCE_DIR}/paddle/utils) set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper prune feed_fetch_method pass generate_pass pass_builder parallel_executor profiler layer tracer engine scope_pool analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context - gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper custom_operator) + gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper custom_operator + cost_model) if (WITH_PSCORE) set(PYBIND_DEPS ${PYBIND_DEPS} ps_service) @@ -59,6 +60,7 @@ set(PYBIND_SRCS data_set_py.cc imperative.cc ir.cc + bind_cost_model.cc inference_api.cc compatible.cc io.cc diff --git a/paddle/fluid/pybind/bind_cost_model.cc b/paddle/fluid/pybind/bind_cost_model.cc new file mode 100644 index 00000000000..a4a40f1fd02 --- /dev/null +++ b/paddle/fluid/pybind/bind_cost_model.cc @@ -0,0 +1,56 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pybind/bind_cost_model.h" + +#include +#include "paddle/fluid/framework/ir/cost_model.h" +#include "paddle/fluid/framework/program_desc.h" + +namespace py = pybind11; +using paddle::framework::CostData; +using paddle::framework::CostModel; +using paddle::framework::ProgramDesc; + +namespace paddle { +namespace pybind { + +void BindCostModel(py::module* m) { + py::class_(*m, "CostData") + .def(py::init<>()) + .def("get_whole_time_ms", &CostData::GetWholeTimeMs) + .def("get_op_time_ms", &CostData::GetOpTimeMs); + + py::class_(*m, "CostModel") + .def(py::init<>()) + .def("profile_measure", + [](CostModel& self, py::object py_main_program, + py::object py_startup_program, const std::string& device, + const std::vector& fetch_cost_list) { + py::object py_main_program_desc = py_main_program.attr("desc"); + ProgramDesc* main_program_desc = + py_main_program_desc.cast(); + + py::object py_startup_program_desc = + py_startup_program.attr("desc"); + ProgramDesc* startup_program_desc = + py_startup_program_desc.cast(); + return self.ProfileMeasure(*main_program_desc, + *startup_program_desc, device, + fetch_cost_list); + }); +} + +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/bind_cost_model.h b/paddle/fluid/pybind/bind_cost_model.h new file mode 100644 index 00000000000..2545ab67502 --- /dev/null +++ b/paddle/fluid/pybind/bind_cost_model.h @@ -0,0 +1,25 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace paddle { +namespace pybind { + +void BindCostModel(pybind11::module *m); + +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index e404f27a10d..c00f529f617 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -38,6 +38,7 @@ limitations under the License. */ #include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/io/fs.h" #include "paddle/fluid/framework/ir/coalesce_grad_tensor_pass.h" +#include "paddle/fluid/framework/ir/cost_model.h" #include "paddle/fluid/framework/ir/generate_pass.h" #include "paddle/fluid/framework/ir/pass_builder.h" #include "paddle/fluid/framework/lod_rank_table.h" @@ -78,6 +79,7 @@ limitations under the License. */ #ifdef PADDLE_WITH_ASCEND #include "paddle/fluid/pybind/ascend_wrapper_py.h" #endif +#include "paddle/fluid/pybind/bind_cost_model.h" #include "paddle/fluid/pybind/box_helper_py.h" #include "paddle/fluid/pybind/compatible.h" #include "paddle/fluid/pybind/const_value.h" @@ -2131,6 +2133,7 @@ All parameter, weight, gradient are variables in Paddle. BindBlockDesc(&m); BindVarDsec(&m); BindOpDesc(&m); + BindCostModel(&m); BindConstValue(&m); BindGlobalValueGetterSetter(&m); BindProcessMeshDesc(&m); @@ -2439,7 +2442,6 @@ All parameter, weight, gradient are variables in Paddle. [](ir::PassBuilder &self, size_t idx) { self.RemovePass(idx); }); // -- python binds for parallel executor. - py::class_ pe(m, "ParallelExecutor"); py::class_ exec_strategy(pe, "ExecutionStrategy", R"DOC( ExecutionStrategy allows the user to more preciously control how to run diff --git a/python/paddle/cost_model/cost_model.py b/python/paddle/cost_model/cost_model.py new file mode 100644 index 00000000000..93c89d0c892 --- /dev/null +++ b/python/paddle/cost_model/cost_model.py @@ -0,0 +1,69 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.static as static +import numpy as np +from paddle.fluid import core + + +class CostModel(): + def __init__(self): + pass + + def build_program(self): + paddle.enable_static() + + main_program = static.Program() + startup_program = static.Program() + with static.program_guard( + main_program=main_program, startup_program=startup_program): + data = paddle.static.data( + name='X', shape=[None, 1], dtype='float32') + hidden = paddle.static.nn.fc(data, 10) + loss = paddle.mean(hidden) + paddle.optimizer.SGD(learning_rate=0.01).minimize(loss) + + print("main program is: {}".format(main_program)) + #print("start up program is: {}".format(startup_program)) + + return startup_program, main_program + + def profile_measure(self, + startup_program, + main_program, + device='gpu', + fetch_cost_list=['time', 'memory']): + + place = paddle.set_device('gpu') + x = np.random.random(size=(10, 1)).astype('float32') + exe = paddle.static.Executor(place) + + exe.run(startup_program) + paddle.fluid.profiler.start_profiler("All") + exe.run(main_program, feed={"X": x}, fetch_list=[]) + # core.CostModel.ProfileMeasure(main_program, device) + print("core:<<<<<<<") + + cost_model = core.CostModel() + cost_data = cost_model.ProfileMeasure(device) + # cost_list = self.stop_cost_model() + # return cost_list + + +cost_model = CostModel() + +startup_program, main_program = cost_model.build_program() + +cost_model.profile_measure(startup_program, main_program) diff --git a/python/paddle/fluid/tests/unittests/test_cost_model.py b/python/paddle/fluid/tests/unittests/test_cost_model.py new file mode 100644 index 00000000000..483f665fde7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_cost_model.py @@ -0,0 +1,56 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import paddle +import paddle.fluid.core as core + +paddle.enable_static() + +device = "gpu" if core.is_compiled_with_cuda() else "cpu" + + +class TestCostModel(unittest.TestCase): + def test_profiler_measure_empty_program(self): + cost_model = core.CostModel() + empty_program = paddle.static.Program() + startup_program = paddle.static.Program() + cost_data = cost_model.profile_measure(empty_program, startup_program, + device, ["time"]) + self.assertEqual(cost_data.get_whole_time_ms(), 0) + + def test_profiler_measure_program(self): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + # TODO(zhhsplendid): support paddle.static.data, which is uninitialized data + data = paddle.ones(name='X', shape=[16, 100], dtype='float32') + hidden = paddle.static.nn.fc(data, 10) + loss = paddle.mean(hidden) + cost_model = core.CostModel() + cost_data = cost_model.profile_measure(main_program, startup_program, + device, ["time"]) + fc_op_time = cost_data.get_op_time_ms(0) + mean_op_time = cost_data.get_op_time_ms(1) + self.assertGreater(fc_op_time, 0) + self.assertGreater(mean_op_time, 0) + self.assertGreaterEqual(cost_data.get_whole_time_ms(), + fc_op_time + mean_op_time) + + +if __name__ == '__main__': + unittest.main() -- GitLab