You need to sign in or sign up before continuing.
提交 5368e50d 编写于 作者: Y Yu Yang

Reorganize code

上级 fe7ed285
...@@ -88,7 +88,7 @@ cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glo ...@@ -88,7 +88,7 @@ cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glo
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto backward glog lod_rank_table feed_fetch_method) framework_proto backward glog lod_rank_table feed_fetch_method)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS op_registry device_context scope cc_library(parallel_executor SRCS parallel_executor.cc DEPS op_registry device_context scope
framework_proto backward glog lod_rank_table simple_threadpool var_handle op_handle_base) framework_proto backward glog lod_rank_table simple_threadpool scale_loss_grad_op_handle)
cc_library(prune SRCS prune.cc DEPS framework_proto) cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
cc_library(var_handle SRCS var_handle.cc DEPS place) cc_library(var_handle SRCS var_handle.cc DEPS place)
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context) cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context)
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
namespace paddle {
namespace framework {
namespace details {
ScaleLossGradOpHandle::ScaleLossGradOpHandle(size_t num_dev, Scope *scope,
platform::Place place)
: coeff_(static_cast<float>(1.0 / num_dev)), scope_(scope), place_(place) {}
ScaleLossGradOpHandle::~ScaleLossGradOpHandle() {}
void ScaleLossGradOpHandle::RunImpl() {
std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name_;
float *tmp =
scope_->FindVar(var_name)->GetMutable<LoDTensor>()->mutable_data<float>(
make_ddim({1}), place_);
if (platform::is_cpu_place(place_)) {
*tmp = coeff_;
} else {
#ifdef PADDLE_WITH_CUDA
auto stream =
static_cast<platform::CUDADeviceContext *>(this->dev_ctx_[place_])
->stream();
memory::Copy(boost::get<platform::CUDAPlace>(place_), tmp,
platform::CPUPlace(), &coeff_, sizeof(float), stream);
#endif
}
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace framework {
namespace details {
struct ScaleLossGradOpHandle : public OpHandleBase {
float coeff_;
Scope *scope_;
platform::Place place_;
ScaleLossGradOpHandle(size_t num_dev, Scope *scope, platform::Place place);
~ScaleLossGradOpHandle() final;
protected:
void RunImpl() override;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include "lod_tensor_array.h" #include "lod_tensor_array.h"
#include "op_registry.h" #include "op_registry.h"
#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/var_handle.h" #include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
...@@ -27,42 +28,10 @@ namespace framework { ...@@ -27,42 +28,10 @@ namespace framework {
using details::DummyVarHandle; using details::DummyVarHandle;
using details::OpHandleBase; using details::OpHandleBase;
using details::ScaleLossGradOpHandle;
using details::VarHandle; using details::VarHandle;
using details::VarHandleBase; using details::VarHandleBase;
struct ScaleLossGradOpHandle : public OpHandleBase {
float coeff_;
Scope *scope_;
platform::Place place_;
explicit ScaleLossGradOpHandle(size_t num_dev, Scope *scope,
platform::Place place)
: coeff_(static_cast<float>(1.0 / num_dev)),
scope_(scope),
place_(place) {}
~ScaleLossGradOpHandle() {}
protected:
void RunImpl() override {
std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name_;
float *tmp = scope_->FindVar(var_name)
->GetMutable<framework::LoDTensor>()
->mutable_data<float>(make_ddim({1}), place_);
if (platform::is_cpu_place(place_)) {
*tmp = coeff_;
} else {
auto stream =
static_cast<platform::CUDADeviceContext *>(this->dev_ctx_[place_])
->stream();
memory::Copy(boost::get<platform::CUDAPlace>(place_), tmp,
platform::CPUPlace(), &coeff_, sizeof(float), stream);
}
}
};
struct FetchOpHandle : public OpHandleBase { struct FetchOpHandle : public OpHandleBase {
FeedFetchList *data_; FeedFetchList *data_;
size_t offset_; size_t offset_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册