diff --git a/paddle/fluid/framework/details/multi_devices_helper.cc b/paddle/fluid/framework/details/multi_devices_helper.cc index a3cdf4a0e3f2844d8a7464b8447bbc96c98c05fd..79279f1b1435bd5e89ecf7af68aab25eb8ab5baf 100644 --- a/paddle/fluid/framework/details/multi_devices_helper.cc +++ b/paddle/fluid/framework/details/multi_devices_helper.cc @@ -171,6 +171,10 @@ std::vector> TrySeparateToMultipleSingleDeviceGraphs( "issue at https://github.com/PaddlePaddle/Paddle/issues/new. And " "we will resolve it with high priority.")); + if (place_num == 1) { + return {}; + } + std::vector> graphs(place_num); for (auto &g : graphs) { g.reset(new ir::Graph(ProgramDesc())); @@ -208,6 +212,10 @@ std::vector> TrySeparateToMultipleSingleDeviceGraphs( graph->Erase(kGraphVars); graph->Erase(kGraphDepVars); + for (auto &g : graphs) { + CopyGraphAttrIfExists(*graph, g.get(), kProgramDescs); + CopyGraphAttrIfExists(*graph, g.get(), kFusedVars); + } return graphs; } diff --git a/paddle/fluid/framework/details/multi_devices_helper.h b/paddle/fluid/framework/details/multi_devices_helper.h index 797bc8ec48bc90f661cce6dfa079da7610487dc9..ab68cf53280c2dd7b7996b9c0839b3bc809860cc 100644 --- a/paddle/fluid/framework/details/multi_devices_helper.h +++ b/paddle/fluid/framework/details/multi_devices_helper.h @@ -108,6 +108,15 @@ bool HasDropLastReadOp(const ir::Graph &graph); bool HasKeepLastReadOp(const ir::Graph &graph); +template +void CopyGraphAttrIfExists(const ir::Graph &src, ir::Graph *dst, + const std::string &name) { + if (src.Has(name)) { + auto &attr = src.Get(name); + dst->Set(name, new T(attr)); + } +} + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_test.cc b/paddle/fluid/framework/ir/graph_test.cc index 1317e8771f5ed4e8f1d95d9648d78c2410a849b7..37d22ec566c1927983b7c6b19fca8a965c433213 100644 --- a/paddle/fluid/framework/ir/graph_test.cc +++ b/paddle/fluid/framework/ir/graph_test.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph.h" #include "gtest/gtest.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" @@ -252,5 +253,22 @@ TEST(GraphTest, TestException) { } ASSERT_TRUE(not_met_exception); } + +TEST(GraphTest, TestAttrCopy) { + ProgramDesc prog; + ir::Graph src_g(prog); + ir::Graph dst_g(prog); + const std::string kIntValue = "int_value"; + const std::string kFloatValue = "float_value"; + const int INT_VALUE = 3; + src_g.Set(kIntValue, new int(INT_VALUE)); + details::CopyGraphAttrIfExists(src_g, &dst_g, kIntValue); + details::CopyGraphAttrIfExists(src_g, &dst_g, kFloatValue); + + ASSERT_TRUE(dst_g.Has(kIntValue)); + ASSERT_EQ(dst_g.Get(kIntValue), INT_VALUE); + ASSERT_FALSE(dst_g.Has(kFloatValue)); +} + } // namespace framework } // namespace paddle