From 31815010130249033096ea584bc2c89983a7e367 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 21 Mar 2018 17:02:51 +0800 Subject: [PATCH] Rerange code --- paddle/fluid/framework/CMakeLists.txt | 4 +- paddle/fluid/framework/details/CMakeLists.txt | 1 + .../details/computation_op_handle.cc | 40 +++++++++++++++++++ .../framework/details/computation_op_handle.h | 39 ++++++++++++++++++ paddle/fluid/framework/parallel_executor.cc | 28 +------------ 5 files changed, 84 insertions(+), 28 deletions(-) create mode 100644 paddle/fluid/framework/details/computation_op_handle.cc create mode 100644 paddle/fluid/framework/details/computation_op_handle.h diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 12d6541b8fa..2b90bb5abdf 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -94,8 +94,8 @@ else() set(parallel_executor_cuda_deps) endif() cc_library(parallel_executor SRCS parallel_executor.cc DEPS op_registry device_context scope - framework_proto backward glog lod_rank_table simple_threadpool scale_loss_grad_op_handle - fetch_op_handle ${parallel_executor_cuda_deps}) + backward glog lod_rank_table simple_threadpool scale_loss_grad_op_handle + fetch_op_handle computation_op_handle ${parallel_executor_cuda_deps}) cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index fb276ea7038..7565bc4c9c4 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -4,3 +4,4 @@ cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_h cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory dynload_cuda) +cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) diff --git a/paddle/fluid/framework/details/computation_op_handle.cc b/paddle/fluid/framework/details/computation_op_handle.cc new file mode 100644 index 00000000000..5867f8fc554 --- /dev/null +++ b/paddle/fluid/framework/details/computation_op_handle.cc @@ -0,0 +1,40 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/computation_op_handle.h" + +namespace paddle { +namespace framework { +namespace details { +ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope, + platform::Place place) + : op_(framework::OpRegistry::CreateOp(op_desc)), + scope_(scope), + place_(place) {} + +void ComputationOpHandle::RunImpl() { + auto *cur_ctx = dev_ctx_[place_]; + for (auto *in : inputs_) { + bool need_wait = + in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx; + if (need_wait) { + in->generated_op_->Wait(cur_ctx); + } + } + + op_->Run(*scope_, place_); +} +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/computation_op_handle.h b/paddle/fluid/framework/details/computation_op_handle.h new file mode 100644 index 00000000000..1fbfd4eabe0 --- /dev/null +++ b/paddle/fluid/framework/details/computation_op_handle.h @@ -0,0 +1,39 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace framework { +namespace details { +struct ComputationOpHandle : public OpHandleBase { + std::unique_ptr op_; + Scope *scope_; + platform::Place place_; + + ComputationOpHandle(const OpDesc &op_desc, Scope *scope, + platform::Place place); + + protected: + void RunImpl() override; +}; +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 93db5ad3e5c..440040a2ef6 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include "lod_tensor.h" #include "lod_tensor_array.h" #include "op_registry.h" +#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h" #include "paddle/fluid/framework/details/op_handle_base.h" @@ -34,6 +35,7 @@ using details::OpHandleBase; using details::ScaleLossGradOpHandle; using details::VarHandle; using details::VarHandleBase; +using details::ComputationOpHandle; class ParallelExecutorPrivate { public: @@ -127,32 +129,6 @@ class ParallelExecutorPrivate { } }; -struct ComputationOpHandle : public OpHandleBase { - std::unique_ptr op_; - Scope *scope_; - platform::Place place_; - - explicit ComputationOpHandle(const OpDesc &op_desc, Scope *scope, - platform::Place place) - : op_(framework::OpRegistry::CreateOp(op_desc)), - scope_(scope), - place_(place) {} - - protected: - void RunImpl() override { - auto *cur_ctx = dev_ctx_[place_]; - for (auto *in : inputs_) { - bool need_wait = - in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx; - if (need_wait) { - in->generated_op_->Wait(cur_ctx); - } - } - - op_->Run(*scope_, place_); - } -}; - ParallelExecutor::ParallelExecutor( size_t num_threads, const std::vector &places, const std::unordered_set ¶ms, -- GitLab