From dd73d18bb7b7cb521cab2f3547633fd6736e8c12 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 22 Mar 2018 10:49:51 +0800 Subject: [PATCH] Extract SSAGraph --- paddle/fluid/framework/CMakeLists.txt | 2 +- paddle/fluid/framework/details/CMakeLists.txt | 2 ++ paddle/fluid/framework/details/ssa_graph.cc | 15 ++++++++ paddle/fluid/framework/details/ssa_graph.h | 34 +++++++++++++++++++ paddle/fluid/framework/parallel_executor.cc | 12 ++----- 5 files changed, 54 insertions(+), 11 deletions(-) create mode 100644 paddle/fluid/framework/details/ssa_graph.cc create mode 100644 paddle/fluid/framework/details/ssa_graph.h diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 2b90bb5abdf..f1d19efa97d 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -95,7 +95,7 @@ else() endif() cc_library(parallel_executor SRCS parallel_executor.cc DEPS op_registry device_context scope backward glog lod_rank_table simple_threadpool scale_loss_grad_op_handle - fetch_op_handle computation_op_handle ${parallel_executor_cuda_deps}) + fetch_op_handle computation_op_handle ssa_graph ${parallel_executor_cuda_deps}) cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 7565bc4c9c4..9ed41ab94c3 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -5,3 +5,5 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory dynload_cuda) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) + +cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) diff --git a/paddle/fluid/framework/details/ssa_graph.cc b/paddle/fluid/framework/details/ssa_graph.cc new file mode 100644 index 00000000000..1b8c8894490 --- /dev/null +++ b/paddle/fluid/framework/details/ssa_graph.cc @@ -0,0 +1,15 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/ssa_graph.h" diff --git a/paddle/fluid/framework/details/ssa_graph.h b/paddle/fluid/framework/details/ssa_graph.h new file mode 100644 index 00000000000..c1e041b8c0b --- /dev/null +++ b/paddle/fluid/framework/details/ssa_graph.h @@ -0,0 +1,34 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/details/var_handle.h" + +namespace paddle { +namespace framework { +namespace details { + +struct SSAGraph { + std::vector>> vars_; + std::unordered_set> dep_vars_; + std::vector> ops_; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index b2be3d13055..5c10595db9c 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -15,15 +15,12 @@ limitations under the License. */ #include "paddle/fluid/framework/parallel_executor.h" #include "ThreadPool.h" #include "lod_tensor.h" -#include "lod_tensor_array.h" #include "op_registry.h" #include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h" -#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" -#include "paddle/fluid/framework/details/var_handle.h" -#include "paddle/fluid/platform/nccl_helper.h" +#include "paddle/fluid/framework/details/ssa_graph.h" namespace paddle { namespace framework { @@ -34,15 +31,10 @@ using details::FetchOpHandle; using details::NCCLAllReduceOpHandle; using details::OpHandleBase; using details::ScaleLossGradOpHandle; +using details::SSAGraph; using details::VarHandle; using details::VarHandleBase; -struct SSAGraph { - std::vector>> vars_; - std::unordered_set> dep_vars_; - std::vector> ops_; -}; - class SSAGraphBuilder { public: virtual ~SSAGraphBuilder() {} -- GitLab