diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index d773497e6c7305774927b62886abf0fbd00ca86a..03933fdecc9cc99bee387edb3866278fdd654fcc 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -20,9 +20,23 @@ #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" +#ifdef WITH_GPERFTOOLS +#include "gperftools/profiler.h" +#endif + +DEFINE_string( + tracer_profile_fname, "", + "Profiler filename for imperative tracer, which generated by gperftools." + "Only valid when compiled `WITH_PROFILER=ON`. Empty if disable."); + namespace paddle { namespace imperative { +static std::once_flag gTracerProfileOnce; +#ifdef WITH_GPERFTOOLS +static bool gTracerProfilerStarted = false; +#endif + void CreateGradOp(const framework::OpDesc& op_desc, const std::unordered_set& no_grad_set, const std::vector& grad_sub_block, @@ -68,11 +82,31 @@ platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) { return result; } +Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) { + if (!FLAGS_tracer_profile_fname.empty()) { + std::call_once(gTracerProfileOnce, [] { +#ifdef WITH_GPERFTOOLS + ProfilerStart(FLAGS_tracer_profile_fname.c_str()); + gTracerProfilerStarted = true; +#else + LOG(WARNING) << "Paddle is not compiled with gperftools. " + "FLAGS_tracer_profile_fname will be ignored"; +#endif + }); + } +} + std::set Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, const VarBasePtrMap& outputs, framework::BlockDesc* block, const platform::Place expected_place, const bool stop_gradient) { +#ifdef WITH_GPERFTOOLS + if (gTracerProfilerStarted) { + ProfilerFlush(); + } +#endif + std::map vars; framework::OpDesc* op_desc = op->op_desc_; diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index 98909e378f0e4188250fcb6efd9502dcc9740da4..8a0267c37f7c98a172fe0fa573955dc420952c0a 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -40,7 +40,7 @@ platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs); class Tracer { public: - explicit Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {} + explicit Tracer(framework::BlockDesc* root_block); virtual ~Tracer() {}