From 5368e50d845bd70d9c9f38a5a75db6cba949f48a Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 21 Mar 2018 14:58:28 +0800 Subject: [PATCH] Reorganize code --- paddle/fluid/framework/CMakeLists.txt | 2 +- paddle/fluid/framework/details/CMakeLists.txt | 1 + .../details/scale_loss_grad_op_handle.cc | 47 +++++++++++++++++++ .../details/scale_loss_grad_op_handle.h | 39 +++++++++++++++ paddle/fluid/framework/parallel_executor.cc | 35 +------------- 5 files changed, 90 insertions(+), 34 deletions(-) create mode 100644 paddle/fluid/framework/details/scale_loss_grad_op_handle.cc create mode 100644 paddle/fluid/framework/details/scale_loss_grad_op_handle.h diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index afc7ec9d663..123b9cb735b 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -88,7 +88,7 @@ cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glo cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward glog lod_rank_table feed_fetch_method) cc_library(parallel_executor SRCS parallel_executor.cc DEPS op_registry device_context scope - framework_proto backward glog lod_rank_table simple_threadpool var_handle op_handle_base) + framework_proto backward glog lod_rank_table simple_threadpool scale_loss_grad_op_handle) cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index d9bdf0b94d6..427785d5182 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -1,2 +1,3 @@ cc_library(var_handle SRCS var_handle.cc DEPS place) cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context) +cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc new file mode 100644 index 00000000000..df9ca371802 --- /dev/null +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc @@ -0,0 +1,47 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" + +namespace paddle { +namespace framework { +namespace details { +ScaleLossGradOpHandle::ScaleLossGradOpHandle(size_t num_dev, Scope *scope, + platform::Place place) + : coeff_(static_cast(1.0 / num_dev)), scope_(scope), place_(place) {} + +ScaleLossGradOpHandle::~ScaleLossGradOpHandle() {} + +void ScaleLossGradOpHandle::RunImpl() { + std::string var_name = static_cast(this->outputs_[0])->name_; + + float *tmp = + scope_->FindVar(var_name)->GetMutable()->mutable_data( + make_ddim({1}), place_); + + if (platform::is_cpu_place(place_)) { + *tmp = coeff_; + } else { +#ifdef PADDLE_WITH_CUDA + auto stream = + static_cast(this->dev_ctx_[place_]) + ->stream(); + memory::Copy(boost::get(place_), tmp, + platform::CPUPlace(), &coeff_, sizeof(float), stream); +#endif + } +} +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.h b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h new file mode 100644 index 00000000000..44a10e33756 --- /dev/null +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h @@ -0,0 +1,39 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +namespace paddle { +namespace framework { +namespace details { + +struct ScaleLossGradOpHandle : public OpHandleBase { + float coeff_; + Scope *scope_; + platform::Place place_; + + ScaleLossGradOpHandle(size_t num_dev, Scope *scope, platform::Place place); + + ~ScaleLossGradOpHandle() final; + + protected: + void RunImpl() override; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 3c24fa4bdf6..5dba3e94c18 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include "lod_tensor_array.h" #include "op_registry.h" #include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/var_handle.h" #include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/platform/nccl_helper.h" @@ -27,42 +28,10 @@ namespace framework { using details::DummyVarHandle; using details::OpHandleBase; +using details::ScaleLossGradOpHandle; using details::VarHandle; using details::VarHandleBase; -struct ScaleLossGradOpHandle : public OpHandleBase { - float coeff_; - Scope *scope_; - platform::Place place_; - - explicit ScaleLossGradOpHandle(size_t num_dev, Scope *scope, - platform::Place place) - : coeff_(static_cast(1.0 / num_dev)), - scope_(scope), - place_(place) {} - - ~ScaleLossGradOpHandle() {} - - protected: - void RunImpl() override { - std::string var_name = static_cast(this->outputs_[0])->name_; - - float *tmp = scope_->FindVar(var_name) - ->GetMutable() - ->mutable_data(make_ddim({1}), place_); - - if (platform::is_cpu_place(place_)) { - *tmp = coeff_; - } else { - auto stream = - static_cast(this->dev_ctx_[place_]) - ->stream(); - memory::Copy(boost::get(place_), tmp, - platform::CPUPlace(), &coeff_, sizeof(float), stream); - } - } -}; - struct FetchOpHandle : public OpHandleBase { FeedFetchList *data_; size_t offset_; -- GitLab