From 5c79dbb2d5333152abf40ace9d260d2426625b28 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Wed, 26 May 2021 16:50:16 +0800 Subject: [PATCH] Marker op for profiling (#33034) --- paddle/fluid/operators/marker_op.cc | 76 +++++++++++++++++++ paddle/fluid/operators/marker_op.cu | 61 +++++++++++++++ paddle/fluid/platform/device_tracer.cc | 2 +- paddle/fluid/platform/event.h | 5 +- paddle/fluid/platform/profiler.cc | 23 ++++-- paddle/fluid/platform/profiler.h | 9 ++- .../fluid/tests/unittests/test_marker_op.py | 36 +++++++++ tools/static_mode_white_list.py | 1 + 8 files changed, 199 insertions(+), 14 deletions(-) create mode 100644 paddle/fluid/operators/marker_op.cc create mode 100644 paddle/fluid/operators/marker_op.cu create mode 100644 python/paddle/fluid/tests/unittests/test_marker_op.py diff --git a/paddle/fluid/operators/marker_op.cc b/paddle/fluid/operators/marker_op.cc new file mode 100644 index 0000000000..397e3bfc6a --- /dev/null +++ b/paddle/fluid/operators/marker_op.cc @@ -0,0 +1,76 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/profiler.h" + +namespace paddle { +namespace operators { + +class MarkerOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + std::string marker_role = ctx->Attrs().Get("marker_role"); + std::string marker_pos = ctx->Attrs().Get("marker_pos"); + + VLOG(3) << "The role is:" << marker_role << ";" + << "The position is:" << marker_pos << "."; + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(framework::proto::VarType::FP32, + ctx.GetPlace()); + } +}; + +class MarkerOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddAttr("marker_role", + "(string, default forward)forward or backward," + " mark different stages of porcess.") + .SetDefault("forward"); + AddAttr( + "marker_pos", + "(string, default B)the posititon where the marker is placed, " + "B stands for begin of duration," + " E stands for end of duration.") + .SetDefault("B"); + AddComment( + R"DOC(Marker Operator - Add marker at the beginning/end of a forward/backward process.)DOC"); + } +}; + +template +class MarkerOpCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto marker_role = ctx.Attr("marker_role"); + auto marker_pos = ctx.Attr("marker_pos"); + + platform::RecordEvent record_event( + "MarkerCPU", platform::EventRole::kInnerOp, + "marker_" + marker_role + "_" + marker_pos); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_WITHOUT_GRADIENT(marker, ops::MarkerOp, ops::MarkerOpMaker); +REGISTER_OP_CPU_KERNEL(marker, ops::MarkerOpCPUKernel); diff --git a/paddle/fluid/operators/marker_op.cu b/paddle/fluid/operators/marker_op.cu new file mode 100644 index 0000000000..b918210389 --- /dev/null +++ b/paddle/fluid/operators/marker_op.cu @@ -0,0 +1,61 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/platform/profiler.h" + +namespace paddle { +namespace operators { + +template +__global__ void SimpleMarkerKernel(T* in, T* out, int ndim) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + for (; idx < ndim; idx += blockDim.x * gridDim.x) { + out[idx] = in[idx]; + } +} + +template +class MarkerOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = ctx.template device_context(); + + auto marker_role = ctx.Attr("marker_role"); + auto marker_pos = ctx.Attr("marker_pos"); + VLOG(3) << "marker role: " << marker_role + << " marker position: " << marker_pos; + + framework::Tensor A; + framework::Tensor B; + auto* in_temp = A.mutable_data({32, 1}, ctx.GetPlace()); + auto* out_temp = B.mutable_data({32, 1}, ctx.GetPlace()); + platform::RecordEvent record_event( + "MarkerCUDA", platform::EventRole::kInnerOp, + "marker_" + marker_role + "_" + marker_pos); + SimpleMarkerKernel<<<1, 32, 0, dev_ctx.stream()>>>(in_temp, out_temp, + 32); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(marker, ops::MarkerOpCUDAKernel); diff --git a/paddle/fluid/platform/device_tracer.cc b/paddle/fluid/platform/device_tracer.cc index 724a9b8483..1bd46c0bfa 100644 --- a/paddle/fluid/platform/device_tracer.cc +++ b/paddle/fluid/platform/device_tracer.cc @@ -511,7 +511,7 @@ class DeviceTracerImpl : public DeviceTracer { auto c = correlations_.find(r.correlation_id); if (c != correlations_.end() && c->second != nullptr) { event->set_name(c->second->name()); - event->set_detail_info(r.name); + event->set_detail_info(c->second->attr()); find++; } else { VLOG(10) << "Missing Kernel Event: " + r.name; diff --git a/paddle/fluid/platform/event.h b/paddle/fluid/platform/event.h index 0985b884d1..3a81cfab86 100644 --- a/paddle/fluid/platform/event.h +++ b/paddle/fluid/platform/event.h @@ -40,7 +40,7 @@ class Event { // The DeviceContext is used to get the cuda stream. // If CPU profiling mode, can pass nullptr. Event(EventType type, std::string name, uint32_t thread_id, - EventRole role = EventRole::kOrdinary); + EventRole role = EventRole::kOrdinary, std::string attr = "none"); const EventType& type() const; Event* parent() const { return parent_; } @@ -50,7 +50,7 @@ class Event { uint32_t thread_id() const { return thread_id_; } void set_name(std::string name) { name_ = name; } void set_role(EventRole role) { role_ = role; } - + std::string attr() const { return attr_; } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #ifndef PADDLE_WITH_CUPTI gpuEvent_t event() const { return event_; } @@ -69,6 +69,7 @@ class Event { EventRole role_{}; int64_t cpu_ns_; bool visited_status_{false}; + std::string attr_; #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #ifdef PADDLE_WITH_CUPTI int64_t gpu_ns_ = 0; diff --git a/paddle/fluid/platform/profiler.cc b/paddle/fluid/platform/profiler.cc index aef7f8648f..9c33233e1f 100644 --- a/paddle/fluid/platform/profiler.cc +++ b/paddle/fluid/platform/profiler.cc @@ -32,8 +32,12 @@ namespace platform { MemEvenRecorder MemEvenRecorder::recorder; Event::Event(EventType type, std::string name, uint32_t thread_id, - EventRole role) - : type_(type), name_(name), thread_id_(thread_id), role_(role) { + EventRole role, std::string attr) + : type_(type), + name_(name), + thread_id_(thread_id), + role_(role), + attr_(attr) { cpu_ns_ = GetTimeInNsec(); } @@ -52,7 +56,8 @@ double Event::CudaElapsedMs(const Event &e) const { #endif } -RecordEvent::RecordEvent(const std::string &name, const EventRole role) { +RecordEvent::RecordEvent(const std::string &name, const EventRole role, + const std::string attr) { #ifndef _WIN32 #ifdef PADDLE_WITH_CUDA if (g_enable_nvprof_hook) { @@ -69,7 +74,7 @@ RecordEvent::RecordEvent(const std::string &name, const EventRole role) { is_enabled_ = true; // lock is not needed, the code below is thread-safe // Maybe need the same push/pop behavior. - Event *e = PushEvent(name, role); + Event *e = PushEvent(name, role, attr); SetCurAnnotation(e); name_ = e->name(); } @@ -186,12 +191,14 @@ void Mark(const std::string &name) { GetEventList().Record(EventType::kMark, name, g_thread_id); } -Event *PushEvent(const std::string &name, const EventRole role) { - return GetEventList().Record(EventType::kPushRange, name, g_thread_id, role); +Event *PushEvent(const std::string &name, const EventRole role, + std::string attr) { + return GetEventList().Record(EventType::kPushRange, name, g_thread_id, role, + attr); } -void PopEvent(const std::string &name, const EventRole role) { - GetEventList().Record(EventType::kPopRange, name, g_thread_id, role); +void PopEvent(const std::string &name, const EventRole role, std::string attr) { + GetEventList().Record(EventType::kPopRange, name, g_thread_id, role, attr); } void EnableProfiler(ProfilerState state) { PADDLE_ENFORCE_NE(state, ProfilerState::kDisabled, diff --git a/paddle/fluid/platform/profiler.h b/paddle/fluid/platform/profiler.h index 2e802bf5ea..512bbc195b 100644 --- a/paddle/fluid/platform/profiler.h +++ b/paddle/fluid/platform/profiler.h @@ -126,7 +126,8 @@ struct MemEvenRecorder { struct RecordEvent { RecordEvent(const std::string& name, - const EventRole role = EventRole::kOrdinary); + const EventRole role = EventRole::kOrdinary, + const std::string attr = "none"); ~RecordEvent(); @@ -200,8 +201,10 @@ void PushMemEvent(uint64_t start_ns, uint64_t end_ns, size_t bytes, const Place& place, const std::string& annotation); void PopMemEvent(uint64_t start_ns, uint64_t end_ns, size_t bytes, const Place& place, const std::string& annotation); -Event* PushEvent(const std::string& name, const EventRole role); -void PopEvent(const std::string& name, const EventRole role); +Event* PushEvent(const std::string& name, const EventRole role, + const std::string attr = "none"); +void PopEvent(const std::string& name, const EventRole role, + const std::string attr = "none"); // Return the event list of all threads. Assumed the returned value calls // event_lists, event_lists[i][j] represents the j-th Event of i-th thread. std::vector> GetAllEvents(); diff --git a/python/paddle/fluid/tests/unittests/test_marker_op.py b/python/paddle/fluid/tests/unittests/test_marker_op.py new file mode 100644 index 0000000000..3f9f8c7d6b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_marker_op.py @@ -0,0 +1,36 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +import numpy as np +from op_test import OpTest +from paddle.distributed.fleet.meta_optimizers.common import OpRole + + +class TestMarkerOp(OpTest): + def setUp(self): + self.op_type = "marker" + self.inputs = {} + self.attrs = { + 'marker_role': 'forward', + 'marker_pos': 'B', + 'op_role': OpRole.Forward + } + self.outputs = {} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 2c50c4bf9f..bc0c5af4d7 100644 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -710,4 +710,5 @@ STATIC_MODE_TESTING_LIST = [ 'test_lamb_op_xpu', 'test_model_cast_to_bf16', 'test_sgd_op_bf16', + 'test_marker_op', ] -- GitLab