提交 4b193db1 编写于 作者: Y Yancey1989

polish code test=develop

上级 d5090c89
...@@ -36,6 +36,11 @@ namespace framework { ...@@ -36,6 +36,11 @@ namespace framework {
namespace details { namespace details {
namespace { namespace {
// TODO(panyx0718): Clean this up as well.
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<OpHandleBase *> GraphOps;
const char kGraphOps[] = "ops";
bool OpHaveRole(const ir::Node &node, const framework::OpRole &role) { bool OpHaveRole(const ir::Node &node, const framework::OpRole &role) {
return boost::get<int>( return boost::get<int>(
...@@ -221,6 +226,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl( ...@@ -221,6 +226,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
* Only variables should be the leaves of graph. * Only variables should be the leaves of graph.
*/ */
AddOutputToLeafOps(&result); AddOutputToLeafOps(&result);
result.Erase(kGraphOps);
return graph; return graph;
} }
......
...@@ -44,12 +44,6 @@ const char kGraphVars[] = "vars"; ...@@ -44,12 +44,6 @@ const char kGraphVars[] = "vars";
typedef std::unordered_set<VarHandleBase *> GraphDepVars; typedef std::unordered_set<VarHandleBase *> GraphDepVars;
const char kGraphDepVars[] = "dep_vars"; const char kGraphDepVars[] = "dep_vars";
// TODO(panyx0718): Clean this up as well.
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<OpHandleBase *> GraphOps;
const char kGraphOps[] = "ops";
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -30,7 +30,6 @@ ParallelSSAGraphExecutor::SeparateMultiDevicesGraph( ...@@ -30,7 +30,6 @@ ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(
auto &g = graphs.back(); auto &g = graphs.back();
g->Set(kGraphVars, new GraphVars(1UL)); g->Set(kGraphVars, new GraphVars(1UL));
g->Set(kGraphDepVars, new GraphDepVars); g->Set(kGraphDepVars, new GraphDepVars);
g->Set(kGraphOps, new GraphOps);
} }
auto op_handles = ir::FilterByNodeWrapper<OpHandleBase>(*graph); auto op_handles = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
...@@ -38,9 +37,7 @@ ParallelSSAGraphExecutor::SeparateMultiDevicesGraph( ...@@ -38,9 +37,7 @@ ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(
auto &dev_ctx = op->DeviceContext(); auto &dev_ctx = op->DeviceContext();
auto &p = dev_ctx.begin()->first; auto &p = dev_ctx.begin()->first;
int dev_id = boost::get<platform::CUDAPlace>(p).device; int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &dev_ops = graphs[dev_id]->Get<GraphOps>(kGraphOps);
auto &dev_dummys = graphs[dev_id]->Get<GraphDepVars>(kGraphDepVars); auto &dev_dummys = graphs[dev_id]->Get<GraphDepVars>(kGraphDepVars);
dev_ops.emplace_back(op);
graphs[dev_id]->AddNode(graph->RemoveNode(op->Node()).release()); graphs[dev_id]->AddNode(graph->RemoveNode(op->Node()).release());
for (auto &var : op->Inputs()) { for (auto &var : op->Inputs()) {
......
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
#pragma once #pragma once
#include <fstream>
#include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
......
...@@ -28,6 +28,9 @@ namespace paddle { ...@@ -28,6 +28,9 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
// This attr is not recommended, because the graph should not dependence
// the program once it is built.
constexpr char kAllOpDescs[] = "all_op_descs"; constexpr char kAllOpDescs[] = "all_op_descs";
} // namespace details } // namespace details
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册