未验证 提交 bae5930b 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix graph attr copy issues, test=develop (#23191)

上级 092a62e2
......@@ -171,6 +171,10 @@ std::vector<std::unique_ptr<ir::Graph>> TrySeparateToMultipleSingleDeviceGraphs(
"issue at https://github.com/PaddlePaddle/Paddle/issues/new. And "
"we will resolve it with high priority."));
if (place_num == 1) {
return {};
}
std::vector<std::unique_ptr<ir::Graph>> graphs(place_num);
for (auto &g : graphs) {
g.reset(new ir::Graph(ProgramDesc()));
......@@ -208,6 +212,10 @@ std::vector<std::unique_ptr<ir::Graph>> TrySeparateToMultipleSingleDeviceGraphs(
graph->Erase(kGraphVars);
graph->Erase(kGraphDepVars);
for (auto &g : graphs) {
CopyGraphAttrIfExists<ProgramDescs>(*graph, g.get(), kProgramDescs);
CopyGraphAttrIfExists<FusedVars>(*graph, g.get(), kFusedVars);
}
return graphs;
}
......
......@@ -108,6 +108,15 @@ bool HasDropLastReadOp(const ir::Graph &graph);
bool HasKeepLastReadOp(const ir::Graph &graph);
template <typename T>
void CopyGraphAttrIfExists(const ir::Graph &src, ir::Graph *dst,
const std::string &name) {
if (src.Has(name)) {
auto &attr = src.Get<T>(name);
dst->Set(name, new T(attr));
}
}
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
......@@ -252,5 +253,22 @@ TEST(GraphTest, TestException) {
}
ASSERT_TRUE(not_met_exception);
}
TEST(GraphTest, TestAttrCopy) {
ProgramDesc prog;
ir::Graph src_g(prog);
ir::Graph dst_g(prog);
const std::string kIntValue = "int_value";
const std::string kFloatValue = "float_value";
const int INT_VALUE = 3;
src_g.Set<int>(kIntValue, new int(INT_VALUE));
details::CopyGraphAttrIfExists<int>(src_g, &dst_g, kIntValue);
details::CopyGraphAttrIfExists<float>(src_g, &dst_g, kFloatValue);
ASSERT_TRUE(dst_g.Has(kIntValue));
ASSERT_EQ(dst_g.Get<int>(kIntValue), INT_VALUE);
ASSERT_FALSE(dst_g.Has(kFloatValue));
}
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册