From bae5930ba134d0569ddf0043d59825c5b397d095 Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Tue, 24 Mar 2020 21:08:58 -0500 Subject: [PATCH] fix graph attr copy issues, test=develop (#23191) --- .../framework/details/multi_devices_helper.cc | 8 ++++++++ .../framework/details/multi_devices_helper.h | 9 +++++++++ paddle/fluid/framework/ir/graph_test.cc | 18 ++++++++++++++++++ 3 files changed, 35 insertions(+) diff --git a/paddle/fluid/framework/details/multi_devices_helper.cc b/paddle/fluid/framework/details/multi_devices_helper.cc index a3cdf4a0e3f..79279f1b143 100644 --- a/paddle/fluid/framework/details/multi_devices_helper.cc +++ b/paddle/fluid/framework/details/multi_devices_helper.cc @@ -171,6 +171,10 @@ std::vector> TrySeparateToMultipleSingleDeviceGraphs( "issue at https://github.com/PaddlePaddle/Paddle/issues/new. And " "we will resolve it with high priority.")); + if (place_num == 1) { + return {}; + } + std::vector> graphs(place_num); for (auto &g : graphs) { g.reset(new ir::Graph(ProgramDesc())); @@ -208,6 +212,10 @@ std::vector> TrySeparateToMultipleSingleDeviceGraphs( graph->Erase(kGraphVars); graph->Erase(kGraphDepVars); + for (auto &g : graphs) { + CopyGraphAttrIfExists(*graph, g.get(), kProgramDescs); + CopyGraphAttrIfExists(*graph, g.get(), kFusedVars); + } return graphs; } diff --git a/paddle/fluid/framework/details/multi_devices_helper.h b/paddle/fluid/framework/details/multi_devices_helper.h index 797bc8ec48b..ab68cf53280 100644 --- a/paddle/fluid/framework/details/multi_devices_helper.h +++ b/paddle/fluid/framework/details/multi_devices_helper.h @@ -108,6 +108,15 @@ bool HasDropLastReadOp(const ir::Graph &graph); bool HasKeepLastReadOp(const ir::Graph &graph); +template +void CopyGraphAttrIfExists(const ir::Graph &src, ir::Graph *dst, + const std::string &name) { + if (src.Has(name)) { + auto &attr = src.Get(name); + dst->Set(name, new T(attr)); + } +} + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_test.cc b/paddle/fluid/framework/ir/graph_test.cc index 1317e8771f5..37d22ec566c 100644 --- a/paddle/fluid/framework/ir/graph_test.cc +++ b/paddle/fluid/framework/ir/graph_test.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph.h" #include "gtest/gtest.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" @@ -252,5 +253,22 @@ TEST(GraphTest, TestException) { } ASSERT_TRUE(not_met_exception); } + +TEST(GraphTest, TestAttrCopy) { + ProgramDesc prog; + ir::Graph src_g(prog); + ir::Graph dst_g(prog); + const std::string kIntValue = "int_value"; + const std::string kFloatValue = "float_value"; + const int INT_VALUE = 3; + src_g.Set(kIntValue, new int(INT_VALUE)); + details::CopyGraphAttrIfExists(src_g, &dst_g, kIntValue); + details::CopyGraphAttrIfExists(src_g, &dst_g, kFloatValue); + + ASSERT_TRUE(dst_g.Has(kIntValue)); + ASSERT_EQ(dst_g.Get(kIntValue), INT_VALUE); + ASSERT_FALSE(dst_g.Has(kFloatValue)); +} + } // namespace framework } // namespace paddle -- GitLab