diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 2b90bb5abdfa5532f83dd4bcff968ee13ee90a33..f1d19efa97de433b1416785d966b5939e5241030 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -95,7 +95,7 @@ else() endif() cc_library(parallel_executor SRCS parallel_executor.cc DEPS op_registry device_context scope backward glog lod_rank_table simple_threadpool scale_loss_grad_op_handle - fetch_op_handle computation_op_handle ${parallel_executor_cuda_deps}) + fetch_op_handle computation_op_handle ssa_graph ${parallel_executor_cuda_deps}) cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 7565bc4c9c4203741e431cc2bb17f1573bd85056..9ed41ab94c3c4c81c140442133e65e1330dbeca1 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -5,3 +5,5 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory dynload_cuda) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) + +cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) diff --git a/paddle/fluid/framework/details/ssa_graph.cc b/paddle/fluid/framework/details/ssa_graph.cc new file mode 100644 index 0000000000000000000000000000000000000000..1b8c889449059c563ea39f86250075ac2537cdbe --- /dev/null +++ b/paddle/fluid/framework/details/ssa_graph.cc @@ -0,0 +1,15 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/ssa_graph.h" diff --git a/paddle/fluid/framework/details/ssa_graph.h b/paddle/fluid/framework/details/ssa_graph.h new file mode 100644 index 0000000000000000000000000000000000000000..c1e041b8c0b4a894c152ce1ff6e8efe6b7056311 --- /dev/null +++ b/paddle/fluid/framework/details/ssa_graph.h @@ -0,0 +1,34 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/details/var_handle.h" + +namespace paddle { +namespace framework { +namespace details { + +struct SSAGraph { + std::vector>> vars_; + std::unordered_set> dep_vars_; + std::vector> ops_; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index b2be3d13055c9137b567118d53b1b6e6b08c3a8e..5c10595db9c72b4b65045335bd85e1511f034d66 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -15,15 +15,12 @@ limitations under the License. */ #include "paddle/fluid/framework/parallel_executor.h" #include "ThreadPool.h" #include "lod_tensor.h" -#include "lod_tensor_array.h" #include "op_registry.h" #include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h" -#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" -#include "paddle/fluid/framework/details/var_handle.h" -#include "paddle/fluid/platform/nccl_helper.h" +#include "paddle/fluid/framework/details/ssa_graph.h" namespace paddle { namespace framework { @@ -34,15 +31,10 @@ using details::FetchOpHandle; using details::NCCLAllReduceOpHandle; using details::OpHandleBase; using details::ScaleLossGradOpHandle; +using details::SSAGraph; using details::VarHandle; using details::VarHandleBase; -struct SSAGraph { - std::vector>> vars_; - std::unordered_set> dep_vars_; - std::vector> ops_; -}; - class SSAGraphBuilder { public: virtual ~SSAGraphBuilder() {}