提交 8291b916 编写于 作者: C chengduoZH

replace graph_builder_factory with ssa_graph_builder_factory

上级 9ac785be
......@@ -87,7 +87,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto glog lod_rank_table feed_fetch_method)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor)
cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
......@@ -30,7 +30,7 @@ cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS
scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle)
cc_library(graph_builder_factory SRCS graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer)
cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
......
......@@ -272,6 +272,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
* Only variables should be the leaves of graph.
*/
AddOutputToLeafOps(&result);
return std::unique_ptr<SSAGraph>(graph);
}
......
......@@ -85,7 +85,7 @@ void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) {
}
std::unique_ptr<SSAGraph> SSAGraphBuilder::BuildAndCheck(
const ProgramDesc &program) final {
const ProgramDesc &program) {
std::unique_ptr<SSAGraph> graph = Build(program);
PADDLE_ENFORCE(IsValidGraph(graph.get()));
return std::move(graph);
......@@ -138,7 +138,7 @@ bool SSAGraphBuilder::IsValidGraph(const SSAGraph *graph) const {
if (ready_vars.empty()) {
return false;
}
for (auto ready_var : ready_vars.) {
for (auto ready_var : ready_vars) {
pending_vars.erase(ready_var);
for (auto *op : ready_var->pending_ops_) {
auto &deps = --pending_ops[op];
......
......@@ -31,7 +31,7 @@ class SSAGraphBuilder {
virtual ~SSAGraphBuilder() {}
virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0;
std::unique_ptr<SSAGraph> BuildAndCheck(const ProgramDesc &program) final;
std::unique_ptr<SSAGraph> BuildAndCheck(const ProgramDesc &program);
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/graph_builder_factory.h"
#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h"
#include <fstream>
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
......
......@@ -22,8 +22,8 @@ limitations under the License. */
#include "paddle/fluid/platform/nccl_helper.h"
#endif
#include "paddle/fluid/framework/details/graph_builder_factory.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -114,7 +114,7 @@ ParallelExecutor::ParallelExecutor(
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places,
builder_factory.Create()->Build(main_program)));
builder_factory.Create()->BuildAndCheck(main_program)));
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, std::move(var_infos),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册