提交 f6741df4 编写于 作者: S sneaxiy

merge develop

fix bug
test=develop
......@@ -199,6 +199,7 @@ paddle.fluid.layers.merge_selected_rows ArgSpec(args=['x', 'name'], varargs=None
paddle.fluid.layers.get_tensor_from_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.lstm ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1))
paddle.fluid.layers.py_func ArgSpec(args=['func', 'x', 'out', 'backward_func', 'skip_vars_in_backward_input'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.psroi_pool ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None)
......
......@@ -72,6 +72,8 @@ cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
cc_library(garbage_collector SRCS garbage_collector.cc DEPS device_context memory)
cc_library(reader SRCS reader.cc DEPS lod_tensor ddim)
cc_test(reader_test SRCS reader_test.cc DEPS reader)
......@@ -183,6 +185,8 @@ else()
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif()
target_link_libraries(executor garbage_collector)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor
graph build_strategy
......
......@@ -45,10 +45,10 @@ cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base s
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
if (WITH_GPU)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass)
endif()
cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle)
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper)
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass)
......@@ -56,10 +56,7 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass)
if (WITH_GPU)
list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass)
endif()
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
......
......@@ -20,11 +20,13 @@ namespace paddle {
namespace framework {
namespace details {
ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
platform::Place place)
platform::Place place,
size_t scope_idx)
: OpHandleBase(node),
op_(framework::OpRegistry::CreateOp(*node->Op())),
scope_(scope),
place_(place) {}
place_(place),
scope_idx_(scope_idx) {}
void ComputationOpHandle::RunImpl() {
WaitInputVarGenerated(place_);
......
......@@ -28,7 +28,8 @@ namespace framework {
namespace details {
struct ComputationOpHandle : public OpHandleBase {
public:
ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place);
ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place,
size_t scope_idx);
std::string Name() const override;
......@@ -38,6 +39,8 @@ struct ComputationOpHandle : public OpHandleBase {
void SetLockAndRecordEventFree(bool b) { is_lock_and_record_event_free_ = b; }
size_t GetScopeIdx() const { return scope_idx_; }
protected:
void RunImpl() override;
......@@ -47,6 +50,7 @@ struct ComputationOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_;
Scope *scope_;
platform::Place place_;
size_t scope_idx_;
bool is_lock_and_record_event_free_{false};
};
} // namespace details
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
namespace paddle {
namespace framework {
namespace details {
EagerDeletionOpHandle::EagerDeletionOpHandle(
ir::Node *node, const Scope *scope, const platform::Place &place,
const std::unordered_set<std::string> &var_names, GarbageCollector *gc,
AtomicReferenceCountMap *ref_cnts)
: OpHandleBase(node),
scope_(scope),
var_names_(var_names),
gc_(gc),
ref_cnts_(ref_cnts) {
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place)) {
dev_ctx_ = reinterpret_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place));
if (dynamic_cast<StreamGarbageCollector *>(gc_)) {
platform::CUDADeviceGuard guard(
boost::get<platform::CUDAPlace>(place).device);
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
PADDLE_ENFORCE_NOT_NULL(event_);
}
}
#endif
}
EagerDeletionOpHandle::~EagerDeletionOpHandle() {
#ifdef PADDLE_WITH_CUDA
if (event_) {
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace());
platform::CUDADeviceGuard guard(gpu_place.device);
PADDLE_ENFORCE(cudaEventDestroy(event_));
}
#endif
}
std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; }
void EagerDeletionOpHandle::RunImpl() {
auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
std::deque<std::shared_ptr<memory::Allocation>> garbages;
for (auto &name : var_names_) {
auto it = ref_cnts_->find(name);
// Var not found, not reference count has not decreased to 0
if (it == ref_cnts_->end() || it->second.fetch_sub(1) != 1) {
continue;
}
auto *var = exec_scope->FindVar(name);
if (var == nullptr) {
continue;
}
VLOG(2) << "Erase variable " << name;
if (var->IsType<LoDTensor>()) {
garbages.emplace_back(var->GetMutable<LoDTensor>()->MoveMemoryHolder());
} else if (var->IsType<SelectedRows>()) {
garbages.emplace_back(
var->GetMutable<SelectedRows>()->mutable_value()->MoveMemoryHolder());
} else if (var->IsType<LoDTensorArray>()) {
auto *tensor_arr = var->GetMutable<LoDTensorArray>();
for (auto &t : *tensor_arr) {
garbages.emplace_back(t.MoveMemoryHolder());
}
} else {
PADDLE_THROW("Type %s of %s is not supported eager deletion",
var->Type().name(), name);
}
}
if (!garbages.empty()) {
ClearGarbages(&garbages);
}
}
void EagerDeletionOpHandle::ClearGarbages(
std::deque<std::shared_ptr<memory::Allocation>> *garbages) {
#ifdef PADDLE_WITH_CUDA
if (event_) {
auto compute_stream = dev_ctx_->stream();
auto callback_stream =
reinterpret_cast<StreamGarbageCollector *>(gc_)->stream();
auto callback_func = [=]() {
PADDLE_ENFORCE(cudaEventRecord(event_, compute_stream));
PADDLE_ENFORCE(cudaStreamWaitEvent(callback_stream, event_, 0));
};
gc_->Add(std::move(*garbages), callback_func);
} else {
#endif
gc_->Add(std::move(*garbages));
#ifdef PADDLE_WITH_CUDA
}
#endif
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <deque>
#include <string>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
namespace paddle {
namespace framework {
class Scope;
namespace details {
class EagerDeletionOpHandle : public OpHandleBase {
public:
EagerDeletionOpHandle(ir::Node *node, const Scope *scope,
const platform::Place &place,
const std::unordered_set<std::string> &var_names,
GarbageCollector *gc,
AtomicReferenceCountMap *ref_cnts);
~EagerDeletionOpHandle();
std::string Name() const override;
protected:
void RunImpl() override;
private:
void ClearGarbages(std::deque<std::shared_ptr<memory::Allocation>> *garbages);
const Scope *scope_;
std::unordered_set<std::string> var_names_;
GarbageCollector *gc_; // not own
AtomicReferenceCountMap *ref_cnts_; // not own
#ifdef PADDLE_WITH_CUDA
platform::CUDADeviceContext *dev_ctx_{nullptr};
cudaEvent_t event_{nullptr};
#endif
};
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <queue>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
namespace details {
std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
auto &ref_cnts =
Get<std::vector<AtomicReferenceCountMap>>(kRuntimeReferenceCount);
PADDLE_ENFORCE(ref_cnts.empty(),
"kRuntimeReferenceCount should be initialized here!");
const auto &vars = graph->Get<GraphVars>(kGraphVars);
ref_cnts.resize(vars.size());
const auto &last_live_ops =
Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
const auto &gcs = Get<GarbageCollectorMap>(kGarbageCollector);
const auto &places = Get<std::vector<platform::Place>>(kAllPlaces);
// a reverse map of last_live_ops
// i.e., last op --> variable names which can be deleted.
std::unordered_map<ComputationOpHandle *, std::unordered_set<std::string>>
op_vars_map;
for (auto &var_ops_map : last_live_ops) {
for (auto &var_ops_pair : var_ops_map) {
const std::string &var_name = var_ops_pair.first;
for (auto *op : var_ops_pair.second) {
op_vars_map[op].insert(var_name);
}
}
}
for (auto &pair : op_vars_map) {
auto *op = pair.first;
auto &var_names = pair.second;
auto *eager_deletion_node =
graph->CreateEmptyNode("eager_deletion", ir::Node::Type::kOperation);
auto *eager_deletion_op = new EagerDeletionOpHandle(
eager_deletion_node, op->GetScope(), op->GetPlace(), var_names,
gcs.at(places[op->GetScopeIdx()]).get(),
&(ref_cnts[op->GetScopeIdx()]));
auto it = std::find_if(
op->Outputs().begin(), op->Outputs().end(), [](VarHandleBase *var) {
return dynamic_cast<DummyVarHandle *>(var) != nullptr;
});
if (it != op->Outputs().end()) {
eager_deletion_op->AddInput(*it);
} else {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
op->AddOutput(dep_var);
eager_deletion_op->AddInput(dep_var);
}
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dummy_leaf);
eager_deletion_op->AddOutput(dummy_leaf);
}
VLOG(10) << "Create " << op_vars_map.size() << " EagerDeletionOpHandle(s)";
return graph;
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(eager_deletion_pass,
paddle::framework::details::EagerDeletionPass)
.RequirePassAttr(paddle::framework::details::kRuntimeReferenceCount)
.RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars)
.RequirePassAttr(paddle::framework::details::kAllPlaces)
.RequirePassAttr(paddle::framework::details::kGarbageCollector);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
class EagerDeletionPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -565,7 +565,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
int dev_id) const {
result->Get<GraphOps>(kGraphOps).emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()),
local_scopes_[dev_id], places_[dev_id]));
local_scopes_[dev_id], places_[dev_id], dev_id));
CreateOpHandleIOs(result, node, dev_id);
}
......@@ -688,8 +688,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx];
result->Get<GraphOps>(kGraphOps).emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p));
result->Get<GraphOps>(kGraphOps).emplace_back(new ComputationOpHandle(
result->CreateOpNode(node->Op()), s, p, scope_idx));
CreateOpHandleIOs(result, node, scope_idx);
}
}
......
......@@ -23,6 +23,8 @@ namespace details {
OpGraphView::OpGraphView(const std::vector<OpHandleBase *> &ops) { Build(ops); }
void OpGraphView::Build(const std::vector<OpHandleBase *> &ops) {
preceding_ops_.clear();
pending_ops_.clear();
for (auto &op : ops) {
preceding_ops_[op];
pending_ops_[op];
......@@ -40,6 +42,7 @@ void OpGraphView::Build(const std::vector<OpHandleBase *> &ops) {
std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const {
std::unordered_set<OpHandleBase *> ret;
ret.reserve(preceding_ops_.size());
for (auto &pair : preceding_ops_) {
ret.insert(pair.first);
}
......
......@@ -14,7 +14,7 @@
#pragma once
#include <memory>
#include <queue>
#include <unordered_map>
#include <unordered_set>
#include <vector>
......@@ -34,6 +34,11 @@ class OpGraphView {
bool HasOp(OpHandleBase *op) const;
// Use a visitor to visit all pending ops of op
// Stop when callback returns false
template <typename Callback>
bool VisitAllPendingOps(OpHandleBase *op, Callback &&callback) const;
private:
void Build(const std::vector<OpHandleBase *> &ops);
void EnforceHasOp(OpHandleBase *op) const;
......@@ -44,6 +49,28 @@ class OpGraphView {
pending_ops_;
};
template <typename Callback>
bool OpGraphView::VisitAllPendingOps(OpHandleBase *op,
Callback &&callback) const {
EnforceHasOp(op);
std::unordered_set<OpHandleBase *> visited;
std::queue<OpHandleBase *> q;
q.push(op);
do {
op = q.front();
q.pop();
for (auto &pending_op : pending_ops_.at(op)) {
if (visited.count(pending_op) == 0) {
visited.insert(pending_op);
if (!callback(pending_op)) {
return false;
}
}
}
} while (!q.empty());
return true;
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
namespace paddle {
namespace framework {
namespace details {
using ReferenceCountMap = std::unordered_map<std::string, int>;
using AtomicReferenceCountMap =
std::unordered_map<std::string, std::atomic<int>>;
using DeviceReferenceCountMap =
std::unordered_map<int, std::unique_ptr<ReferenceCountMap>>;
using AtomicDeviceReferenceCountMap =
std::unordered_map<int, std::unique_ptr<AtomicReferenceCountMap>>;
using DeviceGarbageCollectorMap =
std::unordered_map<int,
std::unique_ptr<GarbageCollector<framework::Tensor>>>;
class ReferenceCountOpHandle : public OpHandleBase {
public:
ReferenceCountOpHandle(ir::Node *node, const Scope *scope,
const platform::CUDAPlace &place,
const std::vector<std::string> &var_names,
GarbageCollector<Tensor> *gc,
AtomicReferenceCountMap *ref_cnts)
: OpHandleBase(node), scope_(scope), gc_(gc), ref_cnts_(ref_cnts) {
dev_ctx_ = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place));
if (IsStreamGarabageCollector()) {
platform::SetDeviceId(place.device);
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
}
for (auto &name : var_names) AddVar(name);
}
~ReferenceCountOpHandle() {
if (IsStreamGarabageCollector()) {
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace());
platform::SetDeviceId(gpu_place.device);
PADDLE_ENFORCE(cudaEventDestroy(event_));
}
}
std::string Name() const override { return "reference_count"; }
void AddVar(const std::string &name) {
auto it = var_names_.find(name);
if (it != var_names_.end())
++(it->second);
else
var_names_[name] = 1;
}
protected:
void RunImpl() override {
auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
std::vector<Tensor *> tensors;
for (auto &pair : var_names_) {
auto &name = pair.first;
auto it = ref_cnts_->find(name);
if (it == ref_cnts_->end()) continue;
auto *var = exec_scope->FindVar(name);
if (var == nullptr) continue;
if (var->IsType<LoDTensor>()) {
if (it->second.fetch_sub(pair.second) <= pair.second) {
tensors.emplace_back(var->GetMutable<LoDTensor>());
}
} else if (var->IsType<SelectedRows>()) {
if (it->second.fetch_sub(pair.second) <= pair.second) {
tensors.emplace_back(
var->GetMutable<SelectedRows>()->mutable_value());
}
}
}
if (!tensors.empty()) {
ClearTensors(tensors);
}
}
private:
void ClearTensors(const std::vector<Tensor *> &tensors) {
auto *gc = dynamic_cast<StreamGarbageCollector<Tensor> *>(gc_);
if (gc != nullptr) {
auto compute_stream = dev_ctx_->stream();
auto callback_stream = gc->stream();
auto callback_func = [=]() {
PADDLE_ENFORCE(cudaEventRecord(event_, compute_stream));
PADDLE_ENFORCE(cudaStreamWaitEvent(callback_stream, event_, 0));
};
gc_->Add(tensors, callback_func);
} else {
gc_->Add(tensors);
}
}
bool IsStreamGarabageCollector() const {
return dynamic_cast<const StreamGarbageCollector<Tensor> *>(gc_) != nullptr;
}
const Scope *scope_;
platform::CUDADeviceContext *dev_ctx_;
std::unordered_map<std::string, int> var_names_;
GarbageCollector<Tensor> *gc_; // not own
AtomicReferenceCountMap *ref_cnts_; // not own
cudaEvent_t event_;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -14,187 +14,240 @@
#include <queue>
#include <string>
#include <type_traits>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/details/reference_count_pass.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
namespace details {
static ComputationOpHandle *FindNextComputationOpHandle(VarHandle *var_in) {
std::queue<VarHandleBase *> queue;
queue.push(var_in);
do {
auto *var = queue.front();
queue.pop();
for (auto *op : var->PendingOps()) {
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op != nullptr && compute_op->GetPlace() == var_in->place_) {
return compute_op;
// A functor to shrink/remove operators who depend on other operators in a set
class ShrinkDepsOpFunctor {
private:
enum RelationShip { kSame = 0, kNoDeps = 1, kBefore = 2, kAfter = 3 };
public:
explicit ShrinkDepsOpFunctor(const std::vector<OpHandleBase *> &all_ops)
: graph_(all_ops) {}
template <typename OpSet>
OpSet operator()(const OpSet &op_set) const {
using KeyType = typename OpSet::key_type;
static_assert(
std::is_base_of<OpHandleBase,
typename std::remove_pointer<KeyType>::type>::value,
"Key type of OpSet must be OpHandleBase, or derived of OpHandleBase");
if (op_set.size() <= 1) return op_set;
std::vector<OpHandleBase *> ops(op_set.begin(), op_set.end());
OpSet ret;
auto rels = GetRelations(ops);
auto not_before = [](RelationShip r) { return r != kBefore; };
for (size_t i = 0; i < rels.size(); ++i) {
if (std::all_of(rels[i].begin(), rels[i].end(), not_before)) {
ret.emplace(static_cast<KeyType>(ops[i]));
}
for (auto *out_var : op->Outputs()) {
queue.push(out_var);
}
return ret;
}
private:
std::vector<std::vector<RelationShip>> GetRelations(
const std::vector<OpHandleBase *> &ops) const {
std::unordered_map<OpHandleBase *, size_t> op_to_idx;
for (size_t i = 0; i < ops.size(); ++i) {
PADDLE_ENFORCE(graph_.HasOp(ops[i]), "Op does not exist in graph");
op_to_idx[ops[i]] = i;
}
PADDLE_ENFORCE(op_to_idx.size() == ops.size(), "Duplicate ops");
std::vector<std::vector<RelationShip>> ret(ops.size());
for (auto &e : ret) {
e.assign(ops.size(), kSame);
}
size_t found_num = ops.size();
size_t total_num = ops.size() * ops.size();
auto visitor = [&](OpHandleBase *op, size_t i) {
auto it = op_to_idx.find(op);
if (it != op_to_idx.end()) {
size_t j = it->second;
if (i != j && ret[i][j] == kSame) {
ret[i][j] = kBefore;
ret[j][i] = kAfter;
found_num += 2;
if (found_num == total_num) {
return false;
}
}
}
return true;
};
for (size_t i = 0; i < ops.size(); ++i) {
auto sub_visitor = [&, i](OpHandleBase *op) { return visitor(op, i); };
if (!graph_.VisitAllPendingOps(ops[i], sub_visitor)) {
break;
}
}
for (size_t i = 0; i < ops.size(); ++i) {
for (size_t j = i + 1; j < ops.size(); ++j) {
if (ret[i][j] != kSame) continue;
ret[i][j] = kNoDeps;
ret[j][i] = kNoDeps;
}
}
return ret;
}
const OpGraphView graph_;
};
/**
* Find the nearest downstream computation op handle. If the op is a
* computation op, just return itself.
*/
static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
OpHandleBase *op, size_t scope_idx) {
std::queue<OpHandleBase *> q;
std::unordered_set<OpHandleBase *> visited;
q.push(op);
do {
auto *op = q.front();
q.pop();
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op != nullptr && compute_op->GetScopeIdx() == scope_idx) {
return compute_op;
}
for (auto *out_var : op->Outputs()) {
for (auto *pending_op : out_var->PendingOps()) {
if (visited.count(pending_op)) continue;
visited.insert(pending_op);
}
}
} while (!queue.empty());
} while (!q.empty());
return nullptr;
}
static void AddDependencyBetween(OpHandleBase *in, OpHandleBase *out,
ir::Graph *graph) {
auto it = std::find_if(
in->Outputs().begin(), in->Outputs().end(), [](VarHandleBase *var) {
return dynamic_cast<DummyVarHandle *>(var) != nullptr;
});
if (it != in->Outputs().end()) {
out->AddInput(*it);
} else {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
in->AddOutput(dep_var);
out->AddInput(dep_var);
static std::unordered_set<ComputationOpHandle *>
ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
const ShrinkDepsOpFunctor &shrink_func,
bool *ok) {
// stage one. Get last op for variable.
std::unordered_set<OpHandleBase *> candidates;
{
if (var->PendingOps().empty() && var->GeneratedOp()) {
// No operator depends on this variable. So the last operator is the op
// who generates this variable.
candidates.emplace(var->GeneratedOp());
} else {
candidates = var->PendingOps();
}
// No pending ops or generated op is nullptr
if (candidates.empty()) {
*ok = false;
return {};
}
}
// stage two. Try to cast them to computation op.
// return (*ok=false) when failed.
//
// The reason why we cannot make any types of op handle to be the last lived
// op is:
// some op handle may operate on many DeviceContext, however, our garbage
// collector can only wait one DeviceContext for now. So currently, we wait
// the nearest compute op.
std::unordered_set<ComputationOpHandle *> computation_op;
{
for (auto *op : candidates) {
auto *compute_op =
FindNextComputationOpHandleOrReturnItself(op, scope_idx);
if (compute_op == nullptr) {
*ok = false;
return {};
}
computation_op.emplace(compute_op);
}
}
// stage three. Try to shrink computation op if they depend on each other.
// Get the smallest set of the most ops.
*ok = true;
return shrink_func(computation_op);
}
static VarDesc *TryGetLatestVarDesc(const std::vector<VarHandle *> &vars) {
VarDesc *var_desc = nullptr;
std::find_if(vars.rbegin(), vars.rend(), [&](VarHandle *var_handle) -> bool {
var_desc = var_handle->Node()->Var();
return var_desc != nullptr;
});
return var_desc;
}
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
auto &ref_cnts = Get<DeviceReferenceCountMap>(kGlobalReferenceCount);
auto &cur_ref_cnts = Get<AtomicDeviceReferenceCountMap>(kCurReferenceCount);
auto &gcs = Get<DeviceGarbageCollectorMap>(kGarbageCollector);
// It is not easy to find the right reference counts of varaibles in graph
// Step 1: Find all variables in computation ops
// Step 2: Find all variables in non-computation ops which refers to variables
// in computation ops
std::unordered_set<std::string> names;
std::unordered_map<OpHandleBase *, ReferenceCountOpHandle *>
compute_ref_cnt_map;
auto get_ref_cnts_from_compute_op = [&](
OpHandleBase *op, const std::vector<VarHandleBase *> &vars) {
std::vector<std::string> var_names_in_op;
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op == nullptr ||
!platform::is_gpu_place(compute_op->GetPlace()))
return var_names_in_op;
auto place = boost::get<platform::CUDAPlace>(compute_op->GetPlace());
for (VarHandleBase *var_handle_base : vars) {
auto *var_handle = dynamic_cast<VarHandle *>(var_handle_base);
if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue;
if (!platform::is_gpu_place(var_handle->place_) ||
boost::get<platform::CUDAPlace>(var_handle->place_) != place)
continue;
auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount);
auto &last_live_ops_of_vars =
Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
PADDLE_ENFORCE(last_live_ops_of_vars.empty() && ref_cnts.empty(),
"Last Live Ops and Reference Counts of vars should be "
"initialized at here.");
VarDesc *var_desc = var_handle->Node()->Var();
auto var_name = var_handle->Node()->Name();
const auto &vars = graph->Get<GraphVars>(kGraphVars);
// This is weird but there is really some variables without var_desc
// in computation_op
if (var_desc == nullptr) {
var_desc = compute_op->Node()->Op()->Block()->FindVar(var_name);
if (var_desc == nullptr) continue;
last_live_ops_of_vars.resize(vars.size());
ref_cnts.resize(vars.size());
ShrinkDepsOpFunctor shrink_func(
ir::FilterByNodeWrapper<OpHandleBase>(*graph));
for (size_t i = 0; i < vars.size(); ++i) {
for (auto &name_var_pair : vars[i]) {
// Whether this variable can be reused or deleted? If not, we do not
// compute reference counts and dependencies.
VarDesc *var_desc = TryGetLatestVarDesc(name_var_pair.second);
if (var_desc == nullptr || var_desc->Persistable()) {
continue;
}
if (var_desc->Persistable()) continue;
auto var_type = var_desc->Proto()->type().type();
if (var_type != proto::VarType::LOD_TENSOR &&
var_type != proto::VarType::SELECTED_ROWS) {
var_type != proto::VarType::SELECTED_ROWS &&
var_type != proto::VarType::LOD_TENSOR_ARRAY) {
// Var type cannot be deleted
continue;
}
// compute op only runs in one device
if (ref_cnts[place.device]->count(var_name))
++(*ref_cnts[place.device])[var_name];
else
(*ref_cnts[place.device])[var_name] = 1;
bool ok;
auto result = ExtractComputationOpFromLastLivedVar(
name_var_pair.second.back(), i, shrink_func, &ok);
names.insert(var_name);
var_names_in_op.push_back(var_name);
}
return var_names_in_op;
};
auto update_ref_cnts_from_non_compute_op = [&](
OpHandleBase *op, const std::vector<VarHandleBase *> &vars) {
if (dynamic_cast<ComputationOpHandle *>(op) != nullptr) return;
for (VarHandleBase *var_handle_base : vars) {
auto *var_handle = dynamic_cast<VarHandle *>(var_handle_base);
if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue;
auto var_name = var_handle->Node()->Name();
auto var_place = var_handle->place_;
if (!platform::is_gpu_place(var_place)) continue;
auto place = boost::get<platform::CUDAPlace>(var_place);
if (names.count(var_name) == 0) continue;
if (ref_cnts.count(place.device) &&
ref_cnts[place.device]->count(var_name)) {
++(*ref_cnts[place.device])[var_name];
auto *next_compute_op = FindNextComputationOpHandle(var_handle);
if (next_compute_op != nullptr) {
if (compute_ref_cnt_map.count(next_compute_op)) {
compute_ref_cnt_map[next_compute_op]->AddVar(var_name);
VLOG(5) << "Add reference count of " << var_name << " to Operator "
<< next_compute_op->Name();
} else {
// Create new reference_count_op_handle
ir::Node *ref_cnt_node = graph->CreateEmptyNode(
"reference_count", ir::Node::Type::kOperation);
auto *ref_cnt_handle = new ReferenceCountOpHandle(
ref_cnt_node, next_compute_op->GetScope(), place, {var_name},
gcs[place.device].get(), cur_ref_cnts[place.device].get());
AddDependencyBetween(next_compute_op, ref_cnt_handle, graph.get());
compute_ref_cnt_map[next_compute_op] = ref_cnt_handle;
}
}
if (ok) {
auto &var_name = name_var_pair.first;
PADDLE_ENFORCE(!result.empty(), "Last living ops of %s cannot be empty",
var_name);
ref_cnts[i].emplace(var_name, result.size());
last_live_ops_of_vars[i].emplace(var_name, std::move(result));
}
}
};
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
for (auto &op : all_ops) {
auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs());
auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs());
if (in_var_names.empty() && out_var_names.empty()) continue;
in_var_names.insert(in_var_names.end(), out_var_names.begin(),
out_var_names.end());
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
auto place = boost::get<platform::CUDAPlace>(compute_op->GetPlace());
ir::Node *ref_cnt_node =
graph->CreateEmptyNode("reference_count", ir::Node::Type::kOperation);
auto *ref_cnt_handle = new ReferenceCountOpHandle(
ref_cnt_node, compute_op->GetScope(), place, in_var_names,
gcs[place.device].get(), cur_ref_cnts[place.device].get());
AddDependencyBetween(compute_op, ref_cnt_handle, graph.get());
compute_ref_cnt_map[compute_op] = ref_cnt_handle;
}
for (auto &op : all_ops) {
update_ref_cnts_from_non_compute_op(op, op->Inputs());
update_ref_cnts_from_non_compute_op(op, op->Outputs());
}
std::vector<OpHandleBase *> new_all_ops;
new_all_ops.reserve(compute_ref_cnt_map.size() + all_ops.size());
for (auto &op : all_ops) {
new_all_ops.emplace_back(std::move(op));
auto it = compute_ref_cnt_map.find(new_all_ops.back());
if (it != compute_ref_cnt_map.end()) {
// Add LeafNode to ReferenceCountOpHandle
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dummy_leaf);
it->second->AddOutput(dummy_leaf);
new_all_ops.emplace_back(std::move(it->second));
}
}
all_ops.swap(new_all_ops);
return graph;
}
......@@ -205,5 +258,4 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
REGISTER_PASS(reference_count_pass,
paddle::framework::details::ReferenceCountPass)
.RequirePassAttr(paddle::framework::details::kGlobalReferenceCount)
.RequirePassAttr(paddle::framework::details::kCurReferenceCount)
.RequirePassAttr(paddle::framework::details::kGarbageCollector);
.RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars);
......@@ -14,7 +14,6 @@
#pragma once
#include "paddle/fluid/framework/details/reference_count_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
......@@ -22,10 +21,6 @@ namespace paddle {
namespace framework {
namespace details {
constexpr char kGlobalReferenceCount[] = "reference_count";
constexpr char kCurReferenceCount[] = "current_reference_count";
constexpr char kGarbageCollector[] = "garbage_collector";
class ReferenceCountPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
namespace paddle {
namespace framework {
namespace details {} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <map>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/garbage_collector.h"
namespace paddle {
namespace framework {
namespace details {
class ComputationOpHandle;
using ReferenceCountMap = std::unordered_map<std::string, size_t>;
using AtomicReferenceCountMap =
std::unordered_map<std::string, std::atomic<size_t>>;
using GarbageCollectorMap =
std::map<platform::Place, std::unique_ptr<GarbageCollector>>;
const char kGlobalReferenceCount[] = "global_reference_count";
const char kRuntimeReferenceCount[] = "runtime_reference_count";
const char kGarbageCollector[] = "garbage_collector";
const char kAllPlaces[] = "all_places";
using LastLiveOpsOfVars =
std::unordered_map<std::string, std::unordered_set<ComputationOpHandle*>>;
const char kLastLiveOpsOfVars[] = "last_live_ops_of_var";
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -18,9 +18,6 @@
#include <vector>
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/details/reference_count_op_handle.h"
#endif
namespace paddle {
namespace framework {
......@@ -69,27 +66,12 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
platform::RecordEvent e("ScopeBufferedSSAGraphExecutorAfterRun", nullptr);
drop_scope_counter_ += 1;
#ifdef PADDLE_WITH_CUDA
const std::string gc_name = "garbage_collector";
DeviceGarbageCollectorMap *gc =
Graph().Has(gc_name) ? &(Graph().Get<DeviceGarbageCollectorMap>(gc_name))
: nullptr;
#endif
if (!fetch_tensors.empty() ||
drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
drop_scope_counter_ = 0;
// Wait All computational streams
for (auto p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
#ifdef PADDLE_WITH_CUDA
if (gc != nullptr && platform::is_gpu_place(p)) {
auto gpu_place = boost::get<platform::CUDAPlace>(p);
auto &gc_at_place = gc->at(gpu_place.device);
gc_at_place->Wait();
gc_at_place->Reset();
}
#endif
}
for (auto &scope : local_scopes_) {
auto &local_scope =
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/executor.h"
#include <deque>
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/lod_rank_table.h"
......@@ -41,11 +42,43 @@ namespace {
int kProgramId = -1;
} // namespace
static std::unordered_map<std::string, size_t> GetNonPersistableReferenceCounts(
const BlockDesc& block, const std::vector<std::string>& skip_var_list) {
std::unordered_map<std::string, size_t> ref_cnts;
std::unordered_set<std::string> skip_vars(skip_var_list.begin(),
skip_var_list.end());
auto update_ref_cnts = [&](OpDesc* op_desc, const VariableNameMap& name_map) {
for (auto& name_pair : name_map) {
for (auto& name : name_pair.second) {
if (skip_vars.count(name)) continue;
auto* var_desc = block.FindVar(name);
if (var_desc == nullptr || var_desc->Persistable()) continue;
auto type = var_desc->Proto()->type().type();
if (type != proto::VarType::LOD_TENSOR &&
type != proto::VarType::SELECTED_ROWS &&
type != proto::VarType::LOD_TENSOR_ARRAY) {
continue;
}
++ref_cnts[name];
}
}
};
for (auto op_desc : block.AllOps()) {
update_ref_cnts(op_desc, op_desc->Inputs());
update_ref_cnts(op_desc, op_desc->Outputs());
}
return ref_cnts;
}
ExecutorPrepareContext::ExecutorPrepareContext(
const framework::ProgramDesc& prog, size_t block_id)
const framework::ProgramDesc& prog, size_t block_id,
const std::vector<std::string>& skip_ref_cnt_vars)
: prog_(prog), block_id_(block_id) {
if (GetEagerDeletionThreshold() >= 0) {
ref_cnts_ = GetNonPersistableReferenceCount<int>(prog_, block_id_);
global_ref_cnts_ = GetNonPersistableReferenceCounts(prog.Block(block_id),
skip_ref_cnt_vars);
}
}
......@@ -53,28 +86,40 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
VLOG(5) << "destroy ExecutorPrepareContext";
}
template <typename RefCntMap>
static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op,
GarbageCollector<Tensor>* gc,
RefCntMap* ref_cnts) {
std::unordered_set<Tensor*> erase_tensors;
static void DeleteUnusedTensors(
const Scope& scope, const OperatorBase* op, GarbageCollector* gc,
std::unordered_map<std::string, size_t>* ref_cnts) {
std::deque<std::shared_ptr<memory::Allocation>> garbages;
auto handler = [&](const VariableNameMap& name_map) {
for (auto& name_pair : name_map) {
for (auto& name : name_pair.second) {
auto it = ref_cnts->find(name);
if (it == ref_cnts->end()) continue;
if ((it->second)-- == 1) {
auto* var = scope.FindVar(name);
if (var != nullptr) {
VLOG(10) << "Erase tensor \'" << name << "\'";
if (var->IsType<LoDTensor>()) {
erase_tensors.insert(var->GetMutable<LoDTensor>());
} else if (var->IsType<SelectedRows>()) {
erase_tensors.insert(
var->GetMutable<SelectedRows>()->mutable_value());
}
if (--(it->second) != 0) {
continue;
}
auto* var = scope.FindVar(name);
if (var != nullptr) {
continue;
}
VLOG(2) << "Erase variable " << name;
if (var->IsType<LoDTensor>()) {
garbages.emplace_back(
var->GetMutable<LoDTensor>()->MoveMemoryHolder());
} else if (var->IsType<SelectedRows>()) {
garbages.emplace_back(var->GetMutable<SelectedRows>()
->mutable_value()
->MoveMemoryHolder());
} else if (var->IsType<LoDTensorArray>()) {
auto* lod_tensor_arr = var->GetMutable<LoDTensorArray>();
for (auto& t : *lod_tensor_arr) {
garbages.emplace_back(t.MoveMemoryHolder());
}
} else {
PADDLE_THROW("Type %s of %s is not supported eager deletion",
var->Type().name(), name);
}
}
}
......@@ -83,8 +128,8 @@ static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op,
handler(op->Inputs());
handler(op->Outputs());
if (!erase_tensors.empty()) {
gc->Add(erase_tensors);
if (!garbages.empty()) {
gc->Add(std::move(garbages));
}
}
......@@ -325,9 +370,10 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
}
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
const ProgramDesc& program, int block_id) {
const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars) {
std::unique_ptr<ExecutorPrepareContext> ctx(
new ExecutorPrepareContext(program, block_id));
new ExecutorPrepareContext(program, block_id, skip_ref_cnt_vars));
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size());
auto& block = program.Block(block_id);
for (auto& op_desc : block.AllOps()) {
......@@ -338,16 +384,28 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
}
std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
const ProgramDesc& program, const std::vector<int>& block_ids) {
const ProgramDesc& program, const std::vector<int>& block_ids,
const std::vector<std::vector<std::string>>& skip_ref_cnt_vars) {
PADDLE_ENFORCE(
skip_ref_cnt_vars.empty() || skip_ref_cnt_vars.size() == block_ids.size(),
"skip_ref_cnt_vars should be either empty or equals to block number %d",
block_ids.size());
std::vector<std::shared_ptr<ExecutorPrepareContext>> result;
size_t idx = 0;
for (auto& bid : block_ids) {
auto* ctx = new ExecutorPrepareContext(program, bid);
ExecutorPrepareContext* ctx;
if (skip_ref_cnt_vars.empty()) {
ctx = new ExecutorPrepareContext(program, bid);
} else {
ctx = new ExecutorPrepareContext(program, bid, skip_ref_cnt_vars[idx]);
}
PADDLE_ENFORCE_LT(static_cast<size_t>(bid), program.Size());
auto& block = program.Block(bid);
for (auto& op_desc : block.AllOps()) {
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
}
result.push_back(std::shared_ptr<ExecutorPrepareContext>(ctx));
++idx;
}
return result;
}
......@@ -365,22 +423,23 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
}
int64_t max_memory_size = GetEagerDeletionThreshold();
std::unique_ptr<GarbageCollector<Tensor>> gc;
// WhileOp would set keep_kids to true,
// because WhileGradOp needs the scopes created in WhileOp.
// Perhaps, we should not perform eager deletion in WhileOp
// The scopes and variables created by WhileOp would be deleted
// in WhileGradOp.
std::unique_ptr<GarbageCollector> gc;
// skip while_op and while_grad_op temporarily
if (max_memory_size >= 0 && !keep_kids) {
ctx->ResetReferenceCount();
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) {
gc.reset(new DefaultStreamGarbageCollector<Tensor>(
boost::get<platform::CUDAPlace>(place_), max_memory_size));
} else {
if (IsFastEagerDeletionModeEnabled()) {
gc.reset(new UnsafeFastGPUGarbageCollector(
boost::get<platform::CUDAPlace>(place_), max_memory_size));
} else {
gc.reset(new DefaultStreamGarbageCollector(
boost::get<platform::CUDAPlace>(place_), max_memory_size));
}
} else if (platform::is_cpu_place(place_)) {
#endif
gc.reset(new CPUGarbageCollector<Tensor>(
boost::get<platform::CPUPlace>(place_), max_memory_size));
gc.reset(new CPUGarbageCollector(boost::get<platform::CPUPlace>(place_),
max_memory_size));
#ifdef PADDLE_WITH_CUDA
}
#endif
......@@ -389,17 +448,13 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
for (auto& op : ctx->ops_) {
op->Run(*local_scope, place_);
if (gc != nullptr) {
if (gc) {
DeleteUnusedTensors(*local_scope, op.get(), gc.get(),
&(ctx->cur_ref_cnts_));
&(ctx->runtime_ref_cnts_));
}
}
if (gc != nullptr) {
gc->Wait();
} else {
platform::DeviceContextPool::Instance().Get(place_)->Wait();
}
platform::DeviceContextPool::Instance().Get(place_)->Wait();
if (local_scope != scope) {
scope->DeleteScope(local_scope);
......
......@@ -27,52 +27,21 @@ limitations under the License. */
namespace paddle {
namespace framework {
template <typename T>
std::unordered_map<std::string, T> GetNonPersistableReferenceCount(
const ProgramDesc& prog, size_t block_id) {
auto& block = prog.Block(block_id);
std::unordered_map<std::string, T> ref_cnts;
auto update_ref_cnts = [&](OpDesc* op_desc, const VariableNameMap& name_map) {
for (auto& name_pair : name_map) {
for (auto& name : name_pair.second) {
auto* var_desc = block.FindVar(name);
if (var_desc == nullptr || var_desc->Persistable()) continue;
auto type = var_desc->Proto()->type().type();
if (type != proto::VarType::LOD_TENSOR &&
type != proto::VarType::SELECTED_ROWS) {
continue;
}
auto it = ref_cnts.find(name);
if (it != ref_cnts.end()) {
++it->second;
} else {
ref_cnts[name] = 1;
}
}
}
};
for (auto op_desc : block.AllOps()) {
update_ref_cnts(op_desc, op_desc->Inputs());
update_ref_cnts(op_desc, op_desc->Outputs());
}
return ref_cnts;
}
struct ExecutorPrepareContext {
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id);
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id,
const std::vector<std::string>& skip_ref_cnt_vars =
std::vector<std::string>());
~ExecutorPrepareContext();
void ResetReferenceCount() { cur_ref_cnts_ = ref_cnts_; }
void ResetReferenceCount() { runtime_ref_cnts_ = global_ref_cnts_; }
const framework::ProgramDesc& prog_;
size_t block_id_;
std::vector<std::unique_ptr<OperatorBase>> ops_;
std::unordered_map<std::string, int> ref_cnts_;
std::unordered_map<std::string, int> cur_ref_cnts_;
std::unordered_map<std::string, size_t> global_ref_cnts_;
std::unordered_map<std::string, size_t> runtime_ref_cnts_;
};
class Executor {
......@@ -108,10 +77,14 @@ class Executor {
const std::string& fetch_holder_name = "fetch");
static std::unique_ptr<ExecutorPrepareContext> Prepare(
const ProgramDesc& program, int block_id);
const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars =
std::vector<std::string>());
static std::vector<std::shared_ptr<ExecutorPrepareContext>> Prepare(
const ProgramDesc& program, const std::vector<int>& block_ids);
const ProgramDesc& program, const std::vector<int>& block_ids,
const std::vector<std::vector<std::string>>& skip_ref_cnt_vars =
std::vector<std::vector<std::string>>());
void CreateVariables(const ProgramDesc& pdesc, Scope* scope, int block_id);
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
#include "paddle/fluid/framework/garbage_collector.h"
namespace paddle {
namespace framework {
GarbageCollector::GarbageCollector(const platform::Place &place,
size_t max_memory_size)
: max_memory_size_((std::max)(max_memory_size, static_cast<size_t>(1))) {
garbages_.reset(new GarbageQueue());
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place);
}
CPUGarbageCollector::CPUGarbageCollector(const platform::CPUPlace &place,
size_t max_memory_size)
: GarbageCollector(place, max_memory_size) {}
void CPUGarbageCollector::ClearCallback(const std::function<void()> &callback) {
callback();
}
#ifdef PADDLE_WITH_CUDA
UnsafeFastGPUGarbageCollector::UnsafeFastGPUGarbageCollector(
const platform::CUDAPlace &place, size_t max_memory_size)
: GarbageCollector(place, max_memory_size) {}
void UnsafeFastGPUGarbageCollector::ClearCallback(
const std::function<void()> &callback) {
callback();
}
DefaultStreamGarbageCollector::DefaultStreamGarbageCollector(
const platform::CUDAPlace &place, size_t max_memory_size)
: GarbageCollector(place, max_memory_size) {}
void DefaultStreamGarbageCollector::Wait() const {
static_cast<platform::CUDADeviceContext *>(this->dev_ctx_)
->WaitStreamCallback();
}
void DefaultStreamGarbageCollector::ClearCallback(
const std::function<void()> &callback) {
static_cast<platform::CUDADeviceContext *>(this->dev_ctx_)
->AddStreamCallback(callback);
}
StreamGarbageCollector::StreamGarbageCollector(const platform::CUDAPlace &place,
size_t max_memory_size)
: GarbageCollector(place, max_memory_size) {
platform::CUDADeviceGuard guard(place.device);
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
callback_manager_.reset(new platform::StreamCallbackManager(stream_));
}
StreamGarbageCollector::~StreamGarbageCollector() {
auto place = boost::get<platform::CUDAPlace>(this->dev_ctx_->GetPlace());
platform::CUDADeviceGuard guard(place.device);
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
}
cudaStream_t StreamGarbageCollector::stream() const { return stream_; }
void StreamGarbageCollector::Wait() const { callback_manager_->Wait(); }
void StreamGarbageCollector::ClearCallback(
const std::function<void()> &callback) {
callback_manager_->AddCallback(callback);
}
#endif
} // namespace framework
} // namespace paddle
......@@ -14,7 +14,6 @@
#pragma once
#include <algorithm>
#include <deque>
#include <functional>
#include <memory>
......@@ -24,134 +23,74 @@
namespace paddle {
namespace framework {
// T should have memory_size() and clear() method
template <typename T>
class GarbageCollector {
public:
GarbageCollector(const platform::Place &place, size_t max_memory_size)
: max_memory_size_((std::max)(max_memory_size, static_cast<size_t>(1))) {
garbages_.reset(new std::deque<T *>());
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place);
}
using GarbageQueue = std::deque<std::shared_ptr<memory::Allocation>>;
virtual ~GarbageCollector() {}
GarbageCollector(const platform::Place &place, size_t max_memory_size);
void Reset() {
std::lock_guard<std::mutex> guard(mutex_);
garbages_.reset(new std::deque<T *>());
cur_memory_size_ = 0;
}
virtual ~GarbageCollector() = default;
virtual void Wait() const {}
template <typename Container>
void Add(const Container &objs) {
Add(objs, []() {});
}
void Add(Container &&objs);
template <typename Container, typename Callback>
void Add(const Container &objs, Callback &&callback) {
std::shared_ptr<std::deque<T *>> clear_deque;
{
std::lock_guard<std::mutex> guard(mutex_);
for (auto *obj : objs) {
garbages_->push_back(obj);
cur_memory_size_ += obj->memory_size();
}
if (cur_memory_size_ >= max_memory_size_) {
cur_memory_size_ = 0;
clear_deque = garbages_;
garbages_.reset(new std::deque<T *>());
}
}
if (clear_deque != nullptr) {
callback();
ClearCallback([=]() {
for (auto *obj : *clear_deque) obj->clear();
});
}
}
virtual void Wait() const {}
void Add(Container &&objs, Callback &&callback);
protected:
virtual void ClearCallback(const std::function<void()> &callback) = 0;
platform::DeviceContext *dev_ctx_;
std::shared_ptr<std::deque<T *>> garbages_;
std::unique_ptr<GarbageQueue> garbages_;
mutable std::mutex mutex_;
const size_t max_memory_size_;
size_t cur_memory_size_ = 0;
size_t cur_memory_size_{0};
};
template <typename T>
class CPUGarbageCollector : public GarbageCollector<T> {
class CPUGarbageCollector : public GarbageCollector {
public:
CPUGarbageCollector(const platform::CPUPlace &place, size_t max_memory_size)
: GarbageCollector<T>(place, max_memory_size) {}
CPUGarbageCollector(const platform::CPUPlace &place, size_t max_memory_size);
protected:
void ClearCallback(const std::function<void()> &callback) override {
callback();
}
void ClearCallback(const std::function<void()> &callback) override;
};
#ifdef PADDLE_WITH_CUDA
template <typename T>
class DefaultStreamGarbageCollector : public GarbageCollector<T> {
class UnsafeFastGPUGarbageCollector : public GarbageCollector {
public:
DefaultStreamGarbageCollector(const platform::CUDAPlace &place,
size_t max_memory_size)
: GarbageCollector<T>(place, max_memory_size) {}
UnsafeFastGPUGarbageCollector(const platform::CUDAPlace &place,
size_t max_memory_size);
cudaStream_t stream() const {
return static_cast<const platform::CUDADeviceContext *>(this->dev_ctx_)
->stream();
}
protected:
void ClearCallback(const std::function<void()> &callback) override;
};
void Wait() const override {
this->dev_ctx_->Wait();
static_cast<const platform::CUDADeviceContext *>(this->dev_ctx_)
->WaitStreamCallback();
}
class DefaultStreamGarbageCollector : public GarbageCollector {
public:
DefaultStreamGarbageCollector(const platform::CUDAPlace &place,
size_t max_memory_size);
void Wait() const override;
protected:
void ClearCallback(const std::function<void()> &callback) override {
static_cast<platform::CUDADeviceContext *>(this->dev_ctx_)
->AddStreamCallback(callback);
}
void ClearCallback(const std::function<void()> &callback) override;
};
template <typename T>
class StreamGarbageCollector : public GarbageCollector<T> {
class StreamGarbageCollector : public GarbageCollector {
public:
StreamGarbageCollector(const platform::CUDAPlace &place,
size_t max_memory_size)
: GarbageCollector<T>(place, max_memory_size) {
PADDLE_ENFORCE(cudaSetDevice(place.device));
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
callback_manager_.reset(new platform::StreamCallbackManager(stream_));
}
size_t max_memory_size);
~StreamGarbageCollector() {
auto place = boost::get<platform::CUDAPlace>(this->dev_ctx_->GetPlace());
PADDLE_ENFORCE(cudaSetDevice(place.device));
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
}
~StreamGarbageCollector();
void Wait() const override {
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
std::lock_guard<std::mutex> guard(this->mutex_);
callback_manager_->Wait();
}
void Wait() const override;
cudaStream_t stream() const { return stream_; }
cudaStream_t stream() const;
protected:
void ClearCallback(const std::function<void()> &callback) override {
std::lock_guard<std::mutex> guard(this->mutex_);
callback_manager_->AddCallback(callback);
}
void ClearCallback(const std::function<void()> &callback) override;
private:
cudaStream_t stream_;
......@@ -159,5 +98,33 @@ class StreamGarbageCollector : public GarbageCollector<T> {
};
#endif
template <typename Container>
void GarbageCollector::Add(Container &&objs) {
Add(std::forward<Container>(objs), []() {});
}
template <typename Container, typename Callback>
void GarbageCollector::Add(Container &&objs, Callback &&callback) {
GarbageQueue *garbage_queue = nullptr;
{
std::lock_guard<std::mutex> guard(mutex_);
for (auto &obj : objs) {
if (!obj) continue;
cur_memory_size_ += obj->size();
garbages_->push_back(std::move(obj));
}
if (cur_memory_size_ >= max_memory_size_) {
cur_memory_size_ = 0;
garbage_queue = garbages_.release();
garbages_.reset(new GarbageQueue());
}
}
if (garbage_queue) {
callback();
ClearCallback([garbage_queue]() { delete garbage_queue; });
}
}
} // namespace framework
} // namespace paddle
......@@ -73,14 +73,21 @@ class Graph {
}
bool Has(const std::string &attr_name) const {
return attrs_.find(attr_name) != attrs_.end();
return attrs_.count(attr_name) > 0;
}
template <typename AttrType>
AttrType &Get(const std::string &attr_name) const {
PADDLE_ENFORCE(Has(attr_name), "%s attr not registered for graph.",
attr_name);
return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
try {
return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
} catch (boost::bad_any_cast &) {
PADDLE_THROW(
"Invalid attribute type of %s error, expected: %s, actual: %s",
attr_name, typeid(AttrType *).name(),
attrs_.at(attr_name).type().name());
}
}
template <typename AttrType>
......
......@@ -51,11 +51,18 @@ class Pass {
AttrType &Get(const std::string &attr_name) const {
PADDLE_ENFORCE(attrs_.find(attr_name) != attrs_.end(),
"%s attr not registered for pass.", attr_name);
return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
try {
return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
} catch (boost::bad_any_cast &) {
PADDLE_THROW(
"Invalid attribute type of %s error, expected: %s, actual: %s",
attr_name, typeid(AttrType *).name(),
attrs_.at(attr_name).type().name());
}
}
bool Has(const std::string &attr_name) const {
return attrs_.find(attr_name) != attrs_.end();
return attrs_.count(attr_name) > 0;
}
void Erase(const std::string &attr_name) {
......
......@@ -319,7 +319,7 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
#define USE_OP(op_type) \
USE_OP_ITSELF(op_type); \
USE_OP_KERNEL(op_type)
// clang-format off
// clang-format on
} // namespace framework
} // namespace paddle
......@@ -881,9 +881,8 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
t = &(var->Get<SelectedRows>().value());
}
if (t != nullptr) {
PADDLE_ENFORCE(t->IsInitialized(),
"Input %s(%s) does not exist in Operator %s",
input.first, ipt_name, DebugString());
PADDLE_ENFORCE(t->IsInitialized(), "Input %s is not initialized: %s",
ipt_name, DebugString());
int tmp = static_cast<int>(ToDataType(t->type()));
PADDLE_ENFORCE(
tmp == data_type || data_type == -1,
......
......@@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -72,6 +73,26 @@ class ParallelExecutorPrivate {
}
}
}
std::unique_ptr<ir::Graph> PrepareGCAndRefCnts(
std::unique_ptr<ir::Graph> graph, size_t max_memory_size);
inline bool HasGarbageCollectors() const { return !gcs_.empty(); }
void ResetRuntimeReferenceCount(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name) {
for (size_t i = 0; i < runtime_ref_cnts_.size(); ++i) {
for (auto &pair : global_ref_cnts_[i]) {
runtime_ref_cnts_[i][pair.first] = pair.second;
}
for (auto &fetch_name : fetch_tensors) {
runtime_ref_cnts_[i].erase(fetch_name);
}
runtime_ref_cnts_[i].erase(fetched_var_name);
}
}
std::vector<platform::Place> places_;
std::vector<Scope *> local_scopes_;
Scope *global_scope_; // not owned
......@@ -83,8 +104,76 @@ class ParallelExecutorPrivate {
bool own_local_scope_;
bool use_cuda_;
bool use_all_reduce_;
// global_ref_cnts_ is only initialized when ParallelExecutor constructs, and
// then keeps unchanged
// Before each iteration, runtime_ref_cnts_ is reset to global_ref_cnts_
std::vector<details::ReferenceCountMap> global_ref_cnts_;
std::vector<details::AtomicReferenceCountMap> runtime_ref_cnts_;
details::GarbageCollectorMap gcs_;
};
std::unique_ptr<ir::Graph> ParallelExecutorPrivate::PrepareGCAndRefCnts(
std::unique_ptr<ir::Graph> graph, size_t max_memory_size) {
for (size_t i = 0; i < places_.size(); ++i) {
auto &place = places_[i];
if (gcs_.count(place) > 0) {
continue;
}
std::unique_ptr<GarbageCollector> gc;
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place)) {
if (IsFastEagerDeletionModeEnabled()) {
gc.reset(new UnsafeFastGPUGarbageCollector(
boost::get<platform::CUDAPlace>(place), max_memory_size));
} else {
gc.reset(new StreamGarbageCollector(
boost::get<platform::CUDAPlace>(place), max_memory_size));
}
VLOG(10) << "Created " << i << "-th GarbageCollector at " << place;
} else {
#endif
if (platform::is_cpu_place(place)) {
gc.reset(new CPUGarbageCollector(boost::get<platform::CPUPlace>(place),
max_memory_size));
VLOG(10) << "Created GarbageCollector at " << place;
} else {
PADDLE_THROW("Unsupported place for garbage collection");
}
#ifdef PADDLE_WITH_CUDA
}
#endif
gcs_.emplace(place, std::move(gc));
}
if (!gcs_.empty()) {
std::vector<details::LastLiveOpsOfVars> last_live_ops_of_vars;
auto ref_cnt_pass =
ir::PassRegistry::Instance().Get("reference_count_pass");
ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount,
&global_ref_cnts_);
ref_cnt_pass->SetNotOwned(details::kLastLiveOpsOfVars,
&last_live_ops_of_vars);
graph = ref_cnt_pass->Apply(std::move(graph));
VLOG(10) << "ReferenceCountPass Applied";
auto eager_deletion_pass =
ir::PassRegistry::Instance().Get("eager_deletion_pass");
eager_deletion_pass->SetNotOwned(details::kRuntimeReferenceCount,
&runtime_ref_cnts_);
eager_deletion_pass->SetNotOwned(details::kGarbageCollector, &gcs_);
eager_deletion_pass->SetNotOwned(details::kLastLiveOpsOfVars,
&last_live_ops_of_vars);
eager_deletion_pass->SetNotOwned(details::kAllPlaces, &places_);
graph = eager_deletion_pass->Apply(std::move(graph));
VLOG(10) << "EagerDeletionPass Applied";
}
return graph;
}
std::vector<Scope *> &ParallelExecutor::GetLocalScopes() {
return member_->local_scopes_;
}
......@@ -151,36 +240,18 @@ ParallelExecutor::ParallelExecutor(
std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
main_program, member_->places_, loss_var_name, params,
member_->local_scopes_, member_->use_cuda_, member_->nccl_ctxs_.get());
auto max_memory_size = GetEagerDeletionThreshold();
if (max_memory_size >= 0) {
for (auto &place : member_->places_) {
if (!platform::is_gpu_place(place)) continue;
auto gpu_place = boost::get<platform::CUDAPlace>(place);
if (gcs_[gpu_place.device] == nullptr) {
ref_cnts_[gpu_place.device].reset(new details::ReferenceCountMap());
cur_ref_cnts_[gpu_place.device].reset(
new details::AtomicReferenceCountMap());
gcs_[gpu_place.device].reset(
new StreamGarbageCollector<Tensor>(gpu_place, max_memory_size));
}
}
if (!gcs_.empty()) {
auto ref_cnt_pass =
ir::PassRegistry::Instance().Get("reference_count_pass");
ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount, &ref_cnts_);
ref_cnt_pass->SetNotOwned(details::kCurReferenceCount, &cur_ref_cnts_);
ref_cnt_pass->SetNotOwned(details::kGarbageCollector, &gcs_);
graph = ref_cnt_pass->Apply(std::move(graph));
graph->SetNotOwned("garbage_collector", &gcs_);
}
}
#else
std::unique_ptr<ir::Graph> graph =
build_strategy.Apply(main_program, member_->places_, loss_var_name,
params, member_->local_scopes_, member_->use_cuda_);
#endif
auto max_memory_size = GetEagerDeletionThreshold();
if (max_memory_size >= 0) {
graph = member_->PrepareGCAndRefCnts(std::move(graph),
static_cast<size_t>(max_memory_size));
}
// Step 3. Create vars in each scope. Passes may also create new vars.
// skip control vars and empty vars
std::vector<details::VariableInfo> var_infos;
......@@ -300,18 +371,9 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
#endif
platform::RecordBlock b(0);
#ifdef PADDLE_WITH_CUDA
if (!gcs_.empty()) {
ResetReferenceCount();
for (auto &pair : cur_ref_cnts_) {
auto &name_map = *(pair.second);
for (auto &fetch_name : fetch_tensors) {
name_map.erase(fetch_name);
}
name_map.erase(fetched_var_name);
}
if (member_->HasGarbageCollectors()) {
member_->ResetRuntimeReferenceCount(fetch_tensors, fetched_var_name);
}
#endif
auto fetch_data = member_->executor_->Run(fetch_tensors);
*member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
fetch_data;
......@@ -355,13 +417,11 @@ ParallelExecutor::~ParallelExecutor() {
for (auto &p : member_->places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
// member_ must be destructed before gcs_ since the destructor of
// ReferenceCountOpHandle use raw pointers of gcs_ inside.
member_.reset();
delete member_;
}
} // namespace framework
} // namespace paddle
#ifdef PADDLE_WITH_CUDA
USE_PASS(reference_count_pass);
#endif
USE_PASS(eager_deletion_pass);
......@@ -14,7 +14,6 @@ limitations under the License. */
#pragma once
#include <atomic>
#include <string>
#include <unordered_map>
#include <unordered_set>
......@@ -29,10 +28,6 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/details/reference_count_pass.h"
#endif
namespace paddle {
namespace framework {
......@@ -75,24 +70,7 @@ class ParallelExecutor {
private:
void BCastParamsToDevices(const std::unordered_set<std::string> &vars) const;
std::unique_ptr<ParallelExecutorPrivate> member_;
#ifdef PADDLE_WITH_CUDA
// ref_cnts_ is only initialized when ParallelExecutor constructs, and then
// keeps unchanged
// Before each iteration, cur_ref_cnts_ is reset to ref_cnts_
details::DeviceReferenceCountMap ref_cnts_;
details::AtomicDeviceReferenceCountMap cur_ref_cnts_;
details::DeviceGarbageCollectorMap gcs_;
void ResetReferenceCount() {
for (auto &pair1 : ref_cnts_) {
for (auto &pair2 : *(pair1.second)) {
(*(cur_ref_cnts_[pair1.first]))[pair2.first] = pair2.second;
}
}
}
#endif
ParallelExecutorPrivate *member_;
};
} // namespace framework
......
......@@ -38,6 +38,10 @@ DEFINE_double(
"Memory size threshold (GB) when the garbage collector clear tensors."
"Disabled when this value is less than 0");
DEFINE_bool(fast_eager_deletion_mode, false,
"Fast eager deletion mode. If enabled, memory would release "
"immediately without waiting GPU kernel ends.");
// When in inference scenario, the scopes will not be written by two threads in
// a mean time, but a scope may be read by multiple threads concurrently, and
// the mutex will cause serious performance issue.
......@@ -58,6 +62,8 @@ int64_t GetEagerDeletionThreshold() {
(static_cast<int64_t>(1) << 30));
}
bool IsFastEagerDeletionModeEnabled() { return FLAGS_fast_eager_deletion_mode; }
Scope::~Scope() { DropKids(); }
Scope& Scope::NewScope() const {
......
......@@ -27,6 +27,7 @@ namespace paddle {
namespace framework {
int64_t GetEagerDeletionThreshold();
bool IsFastEagerDeletionModeEnabled();
class Scope;
......
......@@ -158,6 +158,10 @@ class Tensor {
const std::shared_ptr<memory::Allocation>& Holder() const { return holder_; }
size_t offset() const { return offset_; }
std::shared_ptr<memory::Allocation> MoveMemoryHolder() {
return std::move(holder_);
}
private:
/*! holds the memory block if allocated. */
std::shared_ptr<memory::Allocation> holder_;
......
......@@ -32,6 +32,20 @@ static constexpr char kStepScopes[] = "StepScopes";
static constexpr char kX[] = "X";
static constexpr char kXGRAD[] = "X@GRAD";
static constexpr char kOutputs[] = "Out";
static constexpr char kSkipEagerDeletionVars[] = "skip_eager_deletion_vars";
namespace { // NOLINT
static std::string GetSkipEagerDeletionVarsDebugString(
const std::vector<std::string> &vars) {
std::string str = "Skip " + std::to_string(vars.size()) +
" var(s) in eager deletion mode: ";
for (auto &var : vars) {
str.append(var);
str.push_back(' ');
}
return str;
}
} // NOLINT
class WhileOp : public framework::OperatorBase {
public:
......@@ -59,7 +73,10 @@ class WhileOp : public framework::OperatorBase {
"Condition of while op must in CPU memory.");
bool is_test = Attr<bool>("is_test");
auto ctx = executor.Prepare(*program, block->ID());
auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars);
VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
auto ctx = executor.Prepare(*program, block->ID(), skip_vars);
while (cond.data<bool>()[0]) {
auto &current_scope = scope.NewScope();
step_scopes->push_back(&current_scope);
......@@ -96,6 +113,10 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddAttr<std::vector<std::string>>(kSkipEagerDeletionVars,
"Vars that would skip eager deletion."
"Users should not set this manually.")
.SetDefault(std::vector<std::string>());
AddComment(R"DOC(
)DOC");
}
......@@ -119,7 +140,10 @@ class WhileGradOp : public framework::OperatorBase {
framework::Executor executor(dev_place);
auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program();
auto ctx = executor.Prepare(*program, block->ID());
auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars);
VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
auto ctx = executor.Prepare(*program, block->ID(), skip_vars);
auto *step_scopes =
scope.FindVar(Input(kStepScopes))->GetMutable<StepScopeVar>();
......@@ -341,6 +365,8 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
// while operator could be renamed.
while_grad->SetAttr("original_output_grad", output_grads_list);
while_grad->SetAttr(kSkipEagerDeletionVars, std::vector<std::string>());
return std::unique_ptr<framework::OpDesc>(while_grad);
}
};
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/psroi_pool_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
class PSROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor), "
"the input of PSROIPoolOp. "
"The format of input tensor is NCHW. Where N is the batch size, "
"C is the number of input channels, "
"H is the height of the input feature map, and "
"W is the width.");
AddInput("ROIs",
"(LoDTensor), "
"ROIs (Regions of Interest) to pool over. "
"should be a 2-D LoDTensor of shape (num_rois, 4) "
"given as [(x1, y1, x2, y2), ...]. "
"where (x1, y1) is the top left coordinates, and "
"(x2, y2) is the bottom right coordinates. "
"The roi batch index can be calculated from LoD.");
AddOutput("Out",
"(Tensor), "
"the output of PSROIPoolOp is a 4-D Tensor with shape "
"(num_rois, output_channels, pooled_h, pooled_w).");
AddAttr<int>(
"output_channels",
"(int), "
"the number of channels of the output feature map. "
"For a task of C classes of objects, output_channels should be "
"(C + 1) for classification only.");
AddAttr<float>("spatial_scale",
"(float, default 1.0), "
"Multiplicative spatial scale factor "
"to translate ROI coords from their input scale "
"to the scale used when pooling.")
.SetDefault(1.0);
AddAttr<int>("pooled_height",
"(int, default 1), "
"the pooled output height.")
.SetDefault(1);
AddAttr<int>("pooled_width",
"(int, default 1), "
"the pooled output width.")
.SetDefault(1);
AddComment(R"Doc(
**PSROIPool Operator**
Position sensitive region of interest pooling (also known as PSROIPooling) is to perform
position-sensitive average pooling on regions of interest specified by input, takes as
input N position-sensitive score maps and a list of num_rois regions of interest.
PSROIPooling for R-FCN. Please refer to https://arxiv.org/abs/1605.06409 for more details.
)Doc");
}
};
class PSROIPoolOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of PSROIPoolOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("ROIs"),
"Input(ROIs) of PSROIPoolOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of PSROIPoolOp should not be null.");
auto input_dims = ctx->GetInputDim("X");
auto rois_dims = ctx->GetInputDim("ROIs");
PADDLE_ENFORCE(input_dims.size() == 4,
"The format of input tensor is NCHW");
PADDLE_ENFORCE(rois_dims.size() == 2,
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4) "
"given as [(x1, y1, x2, y2), ...]");
PADDLE_ENFORCE(rois_dims[1] == 4,
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4) "
"given as [(x1, y1, x2, y2), ...]");
int pooled_height = ctx->Attrs().Get<int>("pooled_height");
int pooled_width = ctx->Attrs().Get<int>("pooled_width");
int output_channels = ctx->Attrs().Get<int>("output_channels");
float spatial_scale = ctx->Attrs().Get<float>("spatial_scale");
PADDLE_ENFORCE(
input_dims[1] == output_channels * pooled_height * pooled_width,
"the channel of X(%d) should be equal to the product of "
"output_channels(%d), pooled_height(%d) and pooled_width(%d)",
input_dims[1], output_channels, pooled_height, pooled_width);
PADDLE_ENFORCE_GT(pooled_height, 0,
"The pooled output height must be greater than 0");
PADDLE_ENFORCE_GT(pooled_width, 0,
"The pooled output width must be greater than 0");
PADDLE_ENFORCE_GT(output_channels, 1,
"The pooled output channels must greater than 1");
PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
"The spatial scale must greater than 0.");
auto out_dims = input_dims;
out_dims[0] = rois_dims[0];
out_dims[1] =
output_channels; // input_dims[1] / (pooled_height * pooled_width);
out_dims[2] = pooled_height;
out_dims[3] = pooled_width;
ctx->SetOutputDim("Out", out_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context());
}
};
class PSROIPoolGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"The gradient of Out should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"The gradient of X should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(psroi_pool, ops::PSROIPoolOp, ops::PSROIPoolOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(psroi_pool_grad, ops::PSROIPoolGradOp);
REGISTER_OP_CPU_KERNEL(
psroi_pool,
ops::CPUPSROIPoolOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUPSROIPoolOpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
psroi_pool_grad,
ops::CPUPSROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUPSROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/psroi_pool_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaximumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaximumNumBlocks);
}
template <typename T>
__global__ void GPUPSROIPoolForward(
const int nthreads, const T* input_data, const T* input_rois,
const float spatial_scale, const int input_channels, const int height,
const int width, const int output_channels, const int pooled_height,
const int pooled_width, const int* rois_batch_id_data, T* output_data) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (size_t i = index; i < nthreads; i += offset) {
// The output is in order (n, c, ph, pw)
int pw = i % pooled_width;
int ph = (i / pooled_width) % pooled_height;
int c = (i / pooled_width / pooled_height) % output_channels;
int n = i / pooled_width / pooled_height / output_channels;
// set roi_batch_id
int roi_batch_id = rois_batch_id_data[n];
// [start, end) interval for spatial sampling
const T* offset_input_rois = input_rois + n * 4;
T roi_start_w = static_cast<T>(round(offset_input_rois[0])) * spatial_scale;
T roi_start_h = static_cast<T>(round(offset_input_rois[1])) * spatial_scale;
T roi_end_w =
static_cast<T>(round(offset_input_rois[2]) + 1.) * spatial_scale;
T roi_end_h =
static_cast<T>(round(offset_input_rois[3]) + 1.) * spatial_scale;
// Force too small ROIs to be 1x1
T roi_height = max(roi_end_h - roi_start_h, (T)0.1); // avoid 0
T roi_width = max(roi_end_w - roi_start_w, (T)0.1);
// Compute w and h at input feature map
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
int hstart = floor(bin_size_h * static_cast<T>(ph) + roi_start_h);
int wstart = floor(bin_size_w * static_cast<T>(pw) + roi_start_w);
int hend = ceil(bin_size_h * static_cast<T>(ph + 1) + roi_start_h);
int wend = ceil(bin_size_w * static_cast<T>(pw + 1) + roi_start_w);
// Add roi offsets and clip to input boundaries
hstart = min(max(hstart, 0), height);
hend = min(max(hend, 0), height);
wstart = min(max(wstart, 0), width);
wend = min(max(wend, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
int input_channel = (c * pooled_height + ph) * pooled_width + pw;
const T* offset_input_data =
input_data +
(roi_batch_id * input_channels + input_channel) * height * width;
T outsum = 0;
for (int ih = hstart; ih < hend; ++ih) {
for (int iw = wstart; iw < wend; ++iw) {
int input_index = ih * width + iw;
outsum += offset_input_data[input_index];
}
}
T bin_area = static_cast<T>((hend - hstart) * (wend - wstart));
output_data[i] = is_empty ? 0. : outsum / bin_area;
}
}
template <typename T>
__global__ void GPUPSROIPoolBackward(
const int nthreads, const T* input_rois, const T* output_grad_data,
const float spatial_scale, const int input_channels, const int height,
const int width, const int output_channels, const int pooled_height,
const int pooled_width, const int* rois_batch_id_data, T* input_grad_data) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
// The output is in order (n, c, ph, pw)
int pw = i % pooled_width;
int ph = (i / pooled_width) % pooled_height;
int c = (i / pooled_width / pooled_height) % output_channels;
int n = i / pooled_width / pooled_height / output_channels;
// set roi_batch_id
int roi_batch_id = rois_batch_id_data[n];
int input_channel = (c * pooled_height + ph) * pooled_width + pw;
int input_offset =
(roi_batch_id * input_channels + input_channel) * height * width;
T* offset_input_grad_data = input_grad_data + input_offset;
// [start, end) interval for spatial sampling
const T* offset_input_rois = input_rois + n * 4;
T roi_start_w = static_cast<T>(round(offset_input_rois[0])) * spatial_scale;
T roi_start_h = static_cast<T>(round(offset_input_rois[1])) * spatial_scale;
T roi_end_w =
static_cast<T>(round(offset_input_rois[2]) + 1.) * spatial_scale;
T roi_end_h =
static_cast<T>(round(offset_input_rois[3]) + 1.) * spatial_scale;
// Force too small ROIs to be 1x1
T roi_height = max(roi_end_h - roi_start_h, (T)0.1); // avoid 0
T roi_width = max(roi_end_w - roi_start_w, (T)0.1);
// Compute w and h at input feature map
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
int hstart = floor(bin_size_h * static_cast<T>(ph) + roi_start_h);
int wstart = floor(bin_size_w * static_cast<T>(pw) + roi_start_w);
int hend = ceil(bin_size_h * static_cast<T>(ph + 1) + roi_start_h);
int wend = ceil(bin_size_w * static_cast<T>(pw + 1) + roi_start_w);
// Add roi offsets and clip to input boundaries
hstart = min(max(hstart, 0), height);
hend = min(max(hend, 0), height);
wstart = min(max(wstart, 0), width);
wend = min(max(wend, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
// Accumulate diff_val into input data
T bin_area = static_cast<T>((hend - hstart) * (wend - wstart));
T diff_val = is_empty ? 0. : output_grad_data[i] / bin_area;
for (int ih = hstart; ih < hend; ++ih) {
for (int iw = wstart; iw < wend; ++iw) {
int input_index = ih * width + iw;
platform::CudaAtomicAdd(offset_input_grad_data + input_index, diff_val);
}
}
}
}
template <typename Place, typename T>
class GPUPSROIPoolOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<Tensor>("X");
auto* rois = ctx.Input<LoDTensor>("ROIs");
auto* out = ctx.Output<Tensor>("Out");
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
auto output_channels = ctx.Attr<int>("output_channels");
auto spatial_scale = ctx.Attr<float>("spatial_scale");
auto in_dims = in->dims();
int batch_size = in_dims[0];
int input_channels = in_dims[1];
int height = in_dims[2];
int width = in_dims[3];
PADDLE_ENFORCE_EQ(input_channels,
output_channels * pooled_height * pooled_width,
"the channels of input X should equal the product of "
"output_channels x pooled_height x pooled_width");
int rois_num = rois->dims()[0];
if (rois_num == 0) return;
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
"The rois_batch_size and input(X) batch_size must be the same.");
int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod,
"The rois_num from input and lod must be the same.");
// set rois batch id
framework::Tensor rois_batch_id_list;
rois_batch_id_list.Resize({rois_num});
int* rois_batch_id_data =
rois_batch_id_list.mutable_data<int>(platform::CPUPlace());
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
rois_batch_id_data[i] = n;
}
}
framework::Tensor rois_batch_id_list_gpu;
framework::TensorCopy(rois_batch_id_list, ctx.GetPlace(),
ctx.device_context(), &rois_batch_id_list_gpu);
int output_size = out->numel();
int blocks = NumBlocks(output_size);
int threads = kNumCUDAThreads;
// call cuda kernel function
GPUPSROIPoolForward<
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
output_size, in->data<T>(), rois->data<T>(), spatial_scale,
input_channels, height, width, output_channels, pooled_height,
pooled_width, rois_batch_id_list_gpu.data<int>(),
out->mutable_data<T>(ctx.GetPlace()));
}
};
template <typename Place, typename T>
class GPUPSROIPoolGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<Tensor>("X");
auto* rois = ctx.Input<LoDTensor>("ROIs");
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
auto output_channels = ctx.Attr<int>("output_channels");
auto spatial_scale = ctx.Attr<float>("spatial_scale");
int rois_num = rois->dims()[0];
int input_channels = in->dims()[1];
int height = in->dims()[2];
int width = in->dims()[3];
if (input_grad) {
// set roi batch id
framework::Tensor rois_batch_id_list;
rois_batch_id_list.Resize({rois_num});
int* rois_batch_id_data =
rois_batch_id_list.mutable_data<int>(platform::CPUPlace());
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
rois_batch_id_data[i] = n;
}
}
framework::Tensor rois_batch_id_list_gpu;
framework::TensorCopy(rois_batch_id_list, ctx.GetPlace(),
ctx.device_context(), &rois_batch_id_list_gpu);
input_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<Place, T> set_zero;
set_zero(ctx.cuda_device_context(), input_grad, static_cast<T>(0));
int output_grad_size = output_grad->numel();
int blocks = NumBlocks(output_grad_size);
int threads = kNumCUDAThreads;
if (output_grad_size > 0) {
GPUPSROIPoolBackward<
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
output_grad_size, rois->data<T>(), output_grad->data<T>(),
spatial_scale, input_channels, height, width, output_channels,
pooled_height, pooled_width, rois_batch_id_list_gpu.data<int>(),
input_grad->mutable_data<T>(ctx.GetPlace()));
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
psroi_pool,
ops::GPUPSROIPoolOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::GPUPSROIPoolOpKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
psroi_pool_grad,
ops::GPUPSROIPoolGradOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::GPUPSROIPoolGradOpKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class CPUPSROIPoolOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* rois = ctx.Input<framework::LoDTensor>("ROIs");
auto* out = ctx.Output<framework::Tensor>("Out");
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
auto spatial_scale = ctx.Attr<float>("spatial_scale");
auto output_channels = ctx.Attr<int>("output_channels");
auto in_dims = in->dims();
int batch_size = in_dims[0];
int input_channels = in_dims[1];
int height = in_dims[2];
int width = in_dims[3];
int rois_num = rois->dims()[0];
auto in_stride = framework::stride(in_dims);
auto roi_stride = framework::stride(rois->dims());
auto out_stride = framework::stride(out->dims());
const T* input_data = in->data<T>();
framework::Tensor rois_batch_id_list;
rois_batch_id_list.Resize({rois_num});
int* rois_batch_id_data =
rois_batch_id_list.mutable_data<int>(ctx.GetPlace());
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
"the rois_batch_size and input(X) batch_size should be the same.");
int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_ENFORCE_EQ(rois_num_with_lod, rois_num,
"the rois_num from input and lod must be the same");
PADDLE_ENFORCE_EQ(input_channels,
output_channels * pooled_height * pooled_width,
"the channels of input X should equal the product of "
"output_channels x pooled_height x pooled_width");
// calculate batch id index for each roi according to LoD
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
rois_batch_id_data[i] = n;
}
}
T* output_data = out->mutable_data<T>(ctx.GetPlace());
const T* input_rois = rois->data<T>();
// calculate psroipooling, parallel processing can be implemented per ROI
for (int n = 0; n < rois_num; ++n) {
// set roi batch id
int roi_batch_id = rois_batch_id_data[n];
// [start, end) interval for spatial sampling
const T* offset_input_rois = input_rois + n * 4;
T roi_start_w =
static_cast<T>(round(offset_input_rois[0])) * spatial_scale;
T roi_start_h =
static_cast<T>(round(offset_input_rois[1])) * spatial_scale;
T roi_end_w =
static_cast<T>(round(offset_input_rois[2]) + 1.) * spatial_scale;
T roi_end_h =
static_cast<T>(round(offset_input_rois[3]) + 1.) * spatial_scale;
// Force too small rois to be 1 x 1
T roi_height = std::max(roi_end_h - roi_start_h, (T)0.1); // avoid 0
T roi_width = std::max(roi_end_w - roi_start_w, (T)0.1);
// Compute bin size w and h at input feature map
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
// calculate each pixel of the output feature map.
int out_roi_offset = n * out_stride[0];
for (int c = 0; c < output_channels; ++c) {
// per category
int out_plane_offset = out_roi_offset + c * out_stride[1];
for (int ph = 0; ph < pooled_height; ++ph) {
int out_row_offset = out_plane_offset + ph * out_stride[2];
for (int pw = 0; pw < pooled_width; ++pw) {
// calculate w and h at input feature map
int hstart = floor(static_cast<T>(ph) * bin_size_h + roi_start_h);
int wstart = floor(static_cast<T>(pw) * bin_size_w + roi_start_w);
int hend = ceil(static_cast<T>(ph + 1) * bin_size_h + roi_start_h);
int wend = ceil(static_cast<T>(pw + 1) * bin_size_w + roi_start_w);
// Add roi offsets and clip to input boundaries
hstart = std::min(std::max(hstart, 0), height);
wstart = std::min(std::max(wstart, 0), width);
hend = std::min(std::max(hend, 0), height);
wend = std::min(std::max(wend, 0), width);
int output_index = out_row_offset + pw;
int input_channel = (c * pooled_height + ph) * pooled_width + pw;
int input_plane_offset =
roi_batch_id * in_stride[0] + input_channel * in_stride[1];
const T* offset_input_data = input_data + input_plane_offset;
T out_sum = 0.;
bool is_empty = (hend <= hstart) || (wend <= wstart);
for (int ih = hstart; ih < hend; ++ih) {
for (int iw = wstart; iw < wend; ++iw) {
int input_index = ih * in_stride[2] + iw;
out_sum += offset_input_data[input_index];
}
}
T bin_area = (hend - hstart) * (wend - wstart);
output_data[output_index] = is_empty ? 0. : out_sum / bin_area;
}
}
}
}
return;
}
};
template <typename DeviceContext, typename T>
class CPUPSROIPoolGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* rois = ctx.Input<framework::LoDTensor>("ROIs");
auto* output_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* input_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
auto output_channels = ctx.Attr<int>("output_channels");
auto spatial_scale = ctx.Attr<float>("spatial_scale");
if (input_grad) {
auto in_dims = in->dims();
int input_channels = in_dims[1];
int height = in_dims[2];
int width = in_dims[3];
int rois_num = rois->dims()[0];
// set roi batch id
framework::Tensor rois_batch_id_list;
rois_batch_id_list.Resize({rois_num});
int* rois_batch_id_data =
rois_batch_id_list.mutable_data<int>(ctx.GetPlace());
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
// calculate batch id index for each roi according to LoD
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
rois_batch_id_data[i] = n;
}
}
const T* input_rois = rois->data<T>();
const T* output_grad_data = output_grad->data<T>();
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
// set gradient of X to be 0. before backpropagate.
math::SetConstant<DeviceContext, T> set_zero;
set_zero(ctx.template device_context<DeviceContext>(), input_grad,
static_cast<T>(0));
// backpropagate gradient per output pixel
int output_grad_size = output_grad->numel();
for (int i = 0; i < output_grad_size; ++i) {
// The output is in order (n, c, ph, pw)
int pw = i % pooled_width;
int ph = (i / pooled_width) % pooled_height;
int c = (i / pooled_width / pooled_height) % output_channels;
int n = i / pooled_width / pooled_height / output_channels;
// set roi_batch_id
int roi_batch_id = rois_batch_id_data[n];
int input_channel = (c * pooled_height + ph) * pooled_width + pw;
int input_offset =
(roi_batch_id * input_channels + input_channel) * height * width;
T* offset_input_grad_data = input_grad_data + input_offset;
// [start, end) interval for spatial sampling
const T* offset_input_rois = input_rois + n * 4;
T roi_start_w =
static_cast<T>(round(offset_input_rois[0])) * spatial_scale;
T roi_start_h =
static_cast<T>(round(offset_input_rois[1])) * spatial_scale;
T roi_end_w =
static_cast<T>(round(offset_input_rois[2]) + 1.) * spatial_scale;
T roi_end_h =
static_cast<T>(round(offset_input_rois[3]) + 1.) * spatial_scale;
// Force too small ROIs to be 1x1
T roi_height = std::max(roi_end_h - roi_start_h, (T)0.1); // avoid 0
T roi_width = std::max(roi_end_w - roi_start_w, (T)0.1);
// Compute w and h at input feature map
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
int hstart = floor(bin_size_h * static_cast<T>(ph) + roi_start_h);
int wstart = floor(bin_size_w * static_cast<T>(pw) + roi_start_w);
int hend = ceil(bin_size_h * static_cast<T>(ph + 1) + roi_start_h);
int wend = ceil(bin_size_w * static_cast<T>(pw + 1) + roi_start_w);
// Add roi offsets and clip to input boundaries
hstart = std::min(std::max(hstart, 0), height);
hend = std::min(std::max(hend, 0), height);
wstart = std::min(std::max(wstart, 0), width);
wend = std::min(std::max(wend, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
// Accumulate diff_val into input data
T bin_area = static_cast<T>((hend - hstart) * (wend - wstart));
T diff_val = is_empty ? 0. : output_grad_data[i] / bin_area;
for (int ih = hstart; ih < hend; ++ih) {
for (int iw = wstart; iw < wend; ++iw) {
int input_index = ih * width + iw;
offset_input_grad_data[input_index] += diff_val;
}
}
}
}
return;
}
};
} // namespace operators
} // namespace paddle
......@@ -40,14 +40,14 @@ static py::object *GetPythonCallableObject(size_t i) {
return &g_py_callables[i];
}
std::string PythonObjectToString(const py::object &py_callable) {
static std::string PythonObjectToString(const py::object &py_callable) {
py::gil_scoped_acquire guard;
return py::str(*py_callable);
}
void CallPythonFunc(py::object *callable,
const std::vector<framework::LoDTensor> &ins,
std::vector<framework::LoDTensor *> *out) {
static void CallPythonFunc(py::object *callable,
const std::vector<framework::LoDTensor> &ins,
std::vector<framework::LoDTensor *> *out) {
py::gil_scoped_acquire guard;
py::tuple in_args(ins.size());
for (size_t i = 0; i < ins.size(); ++i) {
......@@ -56,8 +56,19 @@ void CallPythonFunc(py::object *callable,
auto ret = (*callable)(*in_args);
auto ret_tuple = py::cast<py::tuple>(ret);
PADDLE_ENFORCE_EQ(py::len(ret_tuple), out->size(), "Output number not match");
for (size_t i = 0; i < out->size(); ++i) {
size_t ret_num = py::len(ret_tuple);
size_t out_num = out->size();
if (ret_num != out_num) {
// Python function has no return values or returns None
// In this case, ret_num = 1 && ret[0] == None && out_num should be 0
// Otherwise, ret_num must be equal to out_num
PADDLE_ENFORCE(
ret_num == 1 && out_num == 0 &&
py::cast<framework::LoDTensor *>(ret_tuple[0]) == nullptr,
"Output number not match. Expected %d, actual %d", out_num, ret_num);
}
for (size_t i = 0; i < out_num; ++i) {
if ((*out)[i] == nullptr) {
continue;
}
......
......@@ -16,6 +16,7 @@
#include <sys/time.h>
#include <algorithm>
#include <chrono> // NOLINT
#include <cstdlib>
#include <fstream>
......@@ -55,8 +56,7 @@ class CTRReader : public framework::FileReader {
PADDLE_ENFORCE_GT(thread_num, 0, "thread num should be larger then 0!");
PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null");
PADDLE_ENFORCE_GT(file_list.size(), 0, "file list should not be empty");
thread_num_ =
file_list_.size() > thread_num ? thread_num : file_list_.size();
thread_num_ = std::min<size_t>(file_list_.size(), thread_num);
queue_ = queue;
SplitFiles();
for (size_t i = 0; i < thread_num_; ++i) {
......@@ -95,10 +95,10 @@ class CTRReader : public framework::FileReader {
queue_->ReOpen();
VLOG(3) << "reopen success";
VLOG(3) << "thread_num " << thread_num_;
for (int thread_id = 0; thread_id < thread_num_; thread_id++) {
read_threads_.emplace_back(new std::thread(
std::bind(&ReadThread, file_groups_[thread_id], slots_, batch_size_,
thread_id, &read_thread_status_, queue_)));
for (size_t thread_id = 0; thread_id < thread_num_; thread_id++) {
read_threads_.emplace_back(new std::thread(std::bind(
&ReadThread, file_groups_[thread_id], slots_, batch_size_,
static_cast<int>(thread_id), &read_thread_status_, queue_)));
}
monitor_thread_.reset(new std::thread(
std::bind(&MonitorThread, &read_thread_status_, queue_)));
......
......@@ -56,9 +56,16 @@ ELSE()
set(MKLDNN_CTX_DEPS)
ENDIF()
nv_library(stream_callback_manager SRCS stream_callback_manager.cc DEPS simple_threadpool enforce)
IF(WITH_GPU)
set(STREAM_CALLBACK_DEPS stream_callback_manager)
ELSE()
set(STREAM_CALLBACK_DEPS)
ENDIF()
# memcpy depends on device_context, here add deps individually for
# avoiding cycle dependencies
cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc
cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc ${STREAM_CALLBACK_DEPS}
place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS})
nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info)
......
......@@ -222,14 +222,10 @@ class CUDADeviceContext : public DeviceContext {
template <typename Callback>
void AddStreamCallback(Callback&& callback) const {
std::lock_guard<std::mutex> guard(callback_mtx_);
callback_manager_->AddCallback(callback);
}
void WaitStreamCallback() const {
std::lock_guard<std::mutex> guard(callback_mtx_);
callback_manager_->Wait();
}
void WaitStreamCallback() const { callback_manager_->Wait(); }
#if CUDA_VERSION >= 9000
/*! \brief CublasCall may need to change cublas's config,
......@@ -260,9 +256,7 @@ class CUDADeviceContext : public DeviceContext {
mutable std::mutex mtx_;
// This lock is only used by callback
// If we use mtx_ for StreamCallbackManager, deadlock may occur sometimes
mutable std::mutex callback_mtx_;
// StreamCallbackManager is thread-safe
std::unique_ptr<StreamCallbackManager> callback_manager_;
mutable std::mutex cublas_mtx_;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/stream_callback_manager.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace platform {
#if CUDA_VERSION >= 10000
static void CUDART_CB StreamCallbackFunc(void *user_data);
#else
static void CUDART_CB StreamCallbackFunc(cudaStream_t stream,
cudaError_t status, void *user_data)
#endif
{
std::unique_ptr<std::function<void()>> func(
reinterpret_cast<std::function<void()> *>(user_data));
(*func)();
}
StreamCallbackManager::StreamCallbackManager(const cudaStream_t stream)
: stream_(stream), thread_pool_(1) {}
void StreamCallbackManager::AddCallback(std::function<void()> callback) const {
auto *callback_func = new std::function<void()>(std::move(callback));
auto *func = new std::function<void()>([this, callback_func] {
std::lock_guard<std::mutex> lock(mtx_);
last_future_ = thread_pool_.enqueue([callback_func] {
std::unique_ptr<std::function<void()>> releaser(callback_func);
(*callback_func)();
});
});
#if CUDA_VERSION >= 10000
PADDLE_ENFORCE(cudaLaunchHostFunc(stream_, StreamCallbackFunc, func));
#else
PADDLE_ENFORCE(cudaStreamAddCallback(stream_, StreamCallbackFunc, func, 0));
#endif
}
void StreamCallbackManager::Wait() const {
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
{
std::lock_guard<std::mutex> lock(mtx_);
if (last_future_.valid()) {
last_future_.wait();
}
}
}
} // namespace platform
} // namespace paddle
......@@ -18,67 +18,32 @@
#include <cuda.h>
#include <cuda_runtime.h>
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <mutex> // NOLINT
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace platform {
class StreamCallbackManager;
struct StreamCallbackContext {
template <typename Callback>
inline StreamCallbackContext(const StreamCallbackManager *manager,
Callback &&callback)
: manager_(manager), callback_(callback) {}
const StreamCallbackManager *manager_; // do not own
std::function<void()> callback_;
};
// NOTE(zjl): clean StreamCallbackManager to make compilation faster
// Make StreamCallbackManager thread-safe
class StreamCallbackManager {
public:
explicit inline StreamCallbackManager(cudaStream_t stream = nullptr)
: stream_(stream), thread_pool_(new ThreadPool(1)) {}
explicit StreamCallbackManager(const cudaStream_t stream);
~StreamCallbackManager() = default;
template <typename Callback>
inline void AddCallback(Callback &&callback) const {
auto *stream_callback_context =
new StreamCallbackContext(this, std::forward<Callback>(callback));
#if CUDA_VERSION >= 10000
PADDLE_ENFORCE(cudaLaunchHostFunc(stream_,
StreamCallbackManager::StreamCallbackFunc,
stream_callback_context)); // NOLINT
#else
PADDLE_ENFORCE(cudaStreamAddCallback(
stream_, StreamCallbackManager::StreamCallbackFunc,
stream_callback_context, 0)); // NOLINT
#endif
}
void AddCallback(std::function<void()> callback) const;
void Wait() const { thread_pool_.reset(new ThreadPool(1)); }
void Wait() const;
private:
const cudaStream_t stream_;
mutable std::unique_ptr<ThreadPool> thread_pool_;
// cudaStreamCallback cannot call CUDA API inside, so we have to use
// thread_pool here
#if CUDA_VERSION >= 10000
static void CUDART_CB StreamCallbackFunc(void *user_data)
#else
static void CUDART_CB StreamCallbackFunc(cudaStream_t stream,
cudaError_t status, void *user_data)
#endif
{
auto *callback_context_ptr =
reinterpret_cast<StreamCallbackContext *>(user_data);
callback_context_ptr->manager_->thread_pool_->enqueue([=]() {
std::unique_ptr<StreamCallbackContext> callback_context(
callback_context_ptr);
callback_context->callback_();
});
}
mutable ::ThreadPool thread_pool_;
mutable std::mutex mtx_;
mutable std::future<void> last_future_;
};
} // namespace platform
......
......@@ -162,7 +162,7 @@ void PyCPUTensorSetFromArray(
paddle::platform::CPUPlace place) {
std::vector<int64_t> dims;
dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i) {
for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) {
dims.push_back(static_cast<int>(array.shape()[i]));
}
......@@ -182,7 +182,7 @@ inline void PyCPUTensorSetFromArray(
paddle::platform::CPUPlace place) {
std::vector<int64_t> dims;
dims.reserve(array.ndim());
for (int i = 0; i < array.ndim(); ++i) {
for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) {
dims.push_back(static_cast<int>(array.shape()[i]));
}
......@@ -200,7 +200,7 @@ void PyCUDATensorSetFromArray(
paddle::platform::CUDAPlace place) {
std::vector<int64_t> dims;
dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i) {
for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) {
dims.push_back(static_cast<int>(array.shape()[i]));
}
......@@ -221,7 +221,7 @@ inline void PyCUDATensorSetFromArray(
paddle::platform::CUDAPlace place) {
std::vector<int64_t> dims;
dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i) {
for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) {
dims.push_back(static_cast<int>(array.shape()[i]));
}
......@@ -240,7 +240,7 @@ void PyCUDAPinnedTensorSetFromArray(
const paddle::platform::CUDAPinnedPlace &place) {
std::vector<int64_t> dims;
dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i) {
for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) {
dims.push_back(static_cast<int>(array.shape()[i]));
}
......@@ -260,7 +260,7 @@ inline void PyCUDAPinnedTensorSetFromArray(
const paddle::platform::CUDAPinnedPlace &place) {
std::vector<int64_t> dims;
dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i) {
for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) {
dims.push_back(static_cast<int>(array.shape()[i]));
}
......
......@@ -126,9 +126,9 @@ def __bootstrap__():
'check_nan_inf', 'benchmark', 'eager_delete_scope', 'use_mkldnn',
'use_ngraph', 'initial_cpu_memory_in_mb', 'init_allocated_mem',
'free_idle_memory', 'paddle_num_threads', "dist_threadpool_size",
'eager_delete_tensor_gb', 'allocator_strategy',
'reader_queue_speed_test_mode', 'print_sub_graph_dir',
'pe_profile_fname'
'eager_delete_tensor_gb', 'fast_eager_deletion_mode',
'allocator_strategy', 'reader_queue_speed_test_mode',
'print_sub_graph_dir', 'pe_profile_fname'
]
if 'Darwin' not in sysstr:
read_env_flags.append('use_pinned_memory')
......
......@@ -176,6 +176,7 @@ __all__ = [
'get_tensor_from_selected_rows',
'lstm',
'py_func',
'psroi_pool',
]
kIgnoreIndex = -100
......@@ -9127,11 +9128,11 @@ def get_tensor_from_selected_rows(x, name=None):
return out
class PyFuncWrapper(object):
class PyFuncRegistry(object):
_register_funcs = []
def __init__(self, func):
if func is None or not hasattr(func, '__call__'):
if func is None or not callable(func):
raise TypeError('func must be a Python function')
self._func = func
......@@ -9143,7 +9144,7 @@ class PyFuncWrapper(object):
1. For debug usage. Users can call
:code:`py_func.registered_func(idx)` method
to find the registered function coresponding
to find the registered function corresponding
to :code:`idx`.
2. For increasing reference count of self.
......@@ -9152,7 +9153,7 @@ class PyFuncWrapper(object):
segmentation fault error in C++ side.
May be lack of Python GC in C++ side?
'''
PyFuncWrapper._register_funcs.append(self)
PyFuncRegistry._register_funcs.append(self)
@classmethod
def registered_func(cls, idx):
......@@ -9249,8 +9250,8 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
raise TypeError(
'Output must be Variable/list(Variable)/tuple(Variable)')
fwd_func_id = PyFuncWrapper(func).id
bwd_func_id = PyFuncWrapper(
fwd_func_id = PyFuncRegistry(func).id
bwd_func_id = PyFuncRegistry(
backward_func).id if backward_func is not None else -1
for each_out in out_list:
......@@ -9288,5 +9289,59 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
# For debug usage
py_func.registered_func = PyFuncWrapper.registered_func
py_func.registered_func_num = PyFuncWrapper.registered_func_num
py_func.registered_func = PyFuncRegistry.registered_func
py_func.registered_func_num = PyFuncRegistry.registered_func_num
@templatedoc()
def psroi_pool(input,
rois,
output_channels,
spatial_scale,
pooled_height,
pooled_width,
name=None):
"""
${comment}
Args:
input (Variable): ${x_comment}
rois (Variable): ROIs (Regions of Interest) to pool over.
output_channels (integer): ${output_channels_comment}
spatial_scale (float): ${spatial_scale_comment} Default: 1.0
pooled_height (integer): ${pooled_height_comment} Default: 1
pooled_width (integer): ${pooled_width_comment} Default: 1
name (str, default None): The name of this layer.
Returns:
Variable: ${out_comment}.
Examples:
.. code-block:: python
pool_out = fluid.layers.psroi_pool(input=x, rois=rois, 490, 1.0, 7, 7)
"""
helper = LayerHelper('psroi_pool', **locals())
# check attrs
if not isinstance(output_channels, int):
raise TypeError("output_channels must be int type")
if not isinstance(spatial_scale, float):
raise TypeError("spatial_scale must be float type")
if not isinstance(pooled_height, int):
raise TypeError("pooled_height must be int type")
if not isinstance(pooled_width, int):
raise TypeError("pooled_width must be int type")
dtype = helper.input_dtype()
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='psroi_pool',
inputs={'X': input,
'ROIs': rois},
outputs={'Out': out},
attrs={
'output_channels': output_channels,
'spatial_scale': spatial_scale,
'pooled_height': pooled_height,
'pooled_width': pooled_width
})
return out
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
os.environ['FLAGS_eager_delete_tensor_gb'] = '0.0'
os.environ['CPU_NUM'] = '2'
import six
import unittest
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
def train(network, use_cuda, use_parallel_executor, batch_size=32, pass_num=2):
if use_cuda and not core.is_compiled_with_cuda():
print('Skip use_cuda=True because Paddle is not compiled with cuda')
return
word_dict = paddle.dataset.imdb.word_dict()
train_reader = paddle.batch(
paddle.dataset.imdb.train(word_dict), batch_size=batch_size)
data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1)
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
cost = network(data, label, len(word_dict))
optimizer = fluid.optimizer.Adagrad(learning_rate=0.2)
optimizer.minimize(cost)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
feeder = fluid.DataFeeder(feed_list=[data, label], place=place)
reader = feeder.decorate_reader(
train_reader, multi_devices=use_parallel_executor)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if use_parallel_executor:
train_exe = fluid.ParallelExecutor(
use_cuda=use_cuda, loss_name=cost.name)
fetch_list = [cost.name]
else:
train_exe = exe
fetch_list = [cost]
for pass_id in six.moves.xrange(pass_num):
batch_id = 0
for data in reader():
train_exe.run(feed=data,
fetch_list=fetch_list if batch_id % 4 == 0 else [])
batch_id += 1
if batch_id > 16:
break
class TestBase(unittest.TestCase):
def setUp(self):
self.net = None
def test_network(self):
if self.net is None:
return
for use_cuda in [True, False]:
for use_parallel_executor in [False, True]:
print('network: {}, use_cuda: {}, use_parallel_executor: {}'.
format(self.net.__name__, use_cuda,
use_parallel_executor))
with fluid.program_guard(fluid.Program(), fluid.Program()):
with fluid.scope_guard(core.Scope()):
train(self.net, use_cuda, use_parallel_executor)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from test_eager_deletion_dynamic_rnn_base import TestBase
import paddle.fluid as fluid
def gru_net(data,
label,
dict_dim,
emb_dim=128,
hid_dim=128,
hid_dim2=96,
class_dim=2,
emb_lr=400.0):
emb = fluid.layers.embedding(
input=data,
size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(learning_rate=emb_lr))
fc0 = fluid.layers.fc(input=emb, size=hid_dim * 3)
gru_h = fluid.layers.dynamic_gru(input=fc0, size=hid_dim, is_reverse=False)
gru_max = fluid.layers.sequence_pool(input=gru_h, pool_type='max')
gru_max_tanh = fluid.layers.tanh(gru_max)
fc1 = fluid.layers.fc(input=gru_max_tanh, size=hid_dim2, act='tanh')
prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost)
return avg_cost
class GRUTest(TestBase):
def setUp(self):
self.net = gru_net
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from test_eager_deletion_dynamic_rnn_base import TestBase
import paddle.fluid as fluid
import unittest
def lstm_net(data,
label,
dict_dim,
emb_dim=128,
hid_dim=128,
hid_dim2=96,
class_dim=2,
emb_lr=30.0):
emb = fluid.layers.embedding(
input=data,
size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(learning_rate=emb_lr))
fc0 = fluid.layers.fc(input=emb, size=hid_dim * 4)
lstm_h, c = fluid.layers.dynamic_lstm(
input=fc0, size=hid_dim * 4, is_reverse=False)
lstm_max = fluid.layers.sequence_pool(input=lstm_h, pool_type='max')
lstm_max_tanh = fluid.layers.tanh(lstm_max)
fc1 = fluid.layers.fc(input=lstm_max_tanh, size=hid_dim2, act='tanh')
prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost)
return avg_cost
class LSTMTest(TestBase):
def setUp(self):
self.net = lstm_net
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0"
from test_parallel_executor_mnist import TestMNIST
class EagerDeletionTestMNIST(TestMNIST):
pass
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0"
from test_parallel_executor_transformer import TestTransformer
class EagerDeletionTestTransformer(TestTransformer):
pass
if __name__ == '__main__':
unittest.main()
......@@ -511,6 +511,16 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(output)
print(str(program))
def test_psroi_pool(self):
program = Program()
with program_guard(program):
x = layers.data(name="x", shape=[245, 30, 30], dtype="float32")
rois = layers.data(
name="rois", shape=[4], dtype="float32", lod_level=1)
output = layers.psroi_pool(x, rois, 5, 0.25, 7, 7)
self.assertIsNotNone(output)
print(str(program))
def test_roi_align(self):
program = Program()
with program_guard(program):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import math
import numpy as np
import unittest
from op_test import OpTest
class TestPSROIPoolOp(OpTest):
def set_data(self):
self.init_test_case()
self.make_rois()
self.calc_psroi_pool()
self.inputs = {'X': self.x, 'ROIs': (self.rois[:, 1:5], self.rois_lod)}
self.attrs = {
'output_channels': self.output_channels,
'spatial_scale': self.spatial_scale,
'pooled_height': self.pooled_height,
'pooled_width': self.pooled_width
}
self.outputs = {'Out': self.outs}
def init_test_case(self):
self.batch_size = 3
self.channels = 3 * 2 * 2
self.height = 6
self.width = 4
self.x_dim = [self.batch_size, self.channels, self.height, self.width]
self.spatial_scale = 1.0 / 4.0
self.output_channels = 3
self.pooled_height = 2
self.pooled_width = 2
self.x = np.random.random(self.x_dim).astype('float32')
def make_rois(self):
rois = []
self.rois_lod = [[]]
for bno in range(self.batch_size):
self.rois_lod[0].append(bno + 1)
for i in range(bno + 1):
x1 = np.random.random_integers(
0, self.width // self.spatial_scale - self.pooled_width)
y1 = np.random.random_integers(
0, self.height // self.spatial_scale - self.pooled_height)
x2 = np.random.random_integers(x1 + self.pooled_width,
self.width // self.spatial_scale)
y2 = np.random.random_integers(
y1 + self.pooled_height, self.height // self.spatial_scale)
roi = [bno, x1, y1, x2, y2]
rois.append(roi)
self.rois_num = len(rois)
self.rois = np.array(rois).astype('float32')
def calc_psroi_pool(self):
output_shape = (self.rois_num, self.output_channels, self.pooled_height,
self.pooled_width)
out_data = np.zeros(output_shape)
for i in range(self.rois_num):
roi = self.rois[i]
roi_batch_id = int(roi[0])
roi_start_w = round(roi[1]) * self.spatial_scale
roi_start_h = round(roi[2]) * self.spatial_scale
roi_end_w = (round(roi[3]) + 1.) * self.spatial_scale
roi_end_h = (round(roi[4]) + 1.) * self.spatial_scale
roi_height = max(roi_end_h - roi_start_h, 0.1)
roi_width = max(roi_end_w - roi_start_w, 0.1)
bin_size_h = roi_height / float(self.pooled_height)
bin_size_w = roi_width / float(self.pooled_width)
x_i = self.x[roi_batch_id]
for c in range(self.output_channels):
for ph in range(self.pooled_height):
for pw in range(self.pooled_width):
hstart = int(
math.floor(float(ph) * bin_size_h + roi_start_h))
wstart = int(
math.floor(float(pw) * bin_size_w + roi_start_w))
hend = int(
math.ceil(
float(ph + 1) * bin_size_h + roi_start_h))
wend = int(
math.ceil(
float(pw + 1) * bin_size_w + roi_start_w))
hstart = min(max(hstart, 0), self.height)
hend = min(max(hend, 0), self.height)
wstart = min(max(wstart, 0), self.width)
wend = min(max(wend, 0), self.width)
c_in = (c * self.pooled_height + ph
) * self.pooled_width + pw
is_empty = (hend <= hstart) or (wend <= wstart)
out_sum = 0.
for ih in range(hstart, hend):
for iw in range(wstart, wend):
out_sum += x_i[c_in, ih, iw]
bin_area = (hend - hstart) * (wend - wstart)
out_data[i, c, ph, pw] = 0. if is_empty else (
out_sum / float(bin_area))
self.outs = out_data.astype('float32')
def setUp(self):
self.op_type = 'psroi_pool'
self.set_data()
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
if __name__ == '__main__':
unittest.main()
......@@ -12,12 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle.fluid as fluid
import paddle
import unittest
import six
import numpy as np
dev_cnt = 2
if fluid.core.is_compiled_with_cuda():
dev_cnt = fluid.core.get_cuda_device_count()
os.environ['CPU_NUM'] = str(dev_cnt)
def tanh(x):
return np.tanh(x)
......@@ -86,17 +92,18 @@ def simple_fc_net(img, label, use_py_func_op):
out=loss,
backward_func=cross_entropy_grad,
skip_vars_in_backward_input=loss)
loss = fluid.layers.mean(loss)
return loss
def reader():
for _ in six.moves.range(100):
for _ in six.moves.range(dev_cnt * 100):
yield np.random.random([784]), np.random.random_integers(
size=[1], low=0, high=9)
def test_main(use_cuda, use_py_func_op):
def test_main(use_cuda, use_py_func_op, use_parallel_executor):
if use_cuda and not fluid.core.is_compiled_with_cuda():
return None
......@@ -118,21 +125,32 @@ def test_main(use_cuda, use_py_func_op):
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if use_parallel_executor:
exe = fluid.ParallelExecutor(
use_cuda=use_cuda, loss_name=loss.name)
fetch_list = [loss.name]
else:
fetch_list = [loss]
ret = []
for epoch_id in six.moves.range(2):
for d in r():
L, = exe.run(feed=feeder.feed(d), fetch_list=[loss])
ret.append(L[0])
L, = exe.run(feed=feeder.feed(d), fetch_list=fetch_list)
ret.append(L)
return np.array(ret)
class TestPyFuncOp(unittest.TestCase):
class TestPyFuncOpUseExecutor(unittest.TestCase):
def setUp(self):
self.use_parallel_executor = False
def test_loss_diff(self):
losses = []
for use_cuda in [True, False]:
for use_py_func_op in [True, False]:
L = test_main(use_cuda, use_py_func_op)
L = test_main(use_cuda, use_py_func_op,
self.use_parallel_executor)
if L is not None:
losses.append(L)
......@@ -141,5 +159,10 @@ class TestPyFuncOp(unittest.TestCase):
self.assertAlmostEqual(max_diff, 0, delta=1e-3)
class TestPyFuncOpUseParallelExecutor(unittest.TestCase):
def setUp(self):
self.use_parallel_executor = True
if __name__ == '__main__':
unittest.main()
......@@ -15,7 +15,12 @@
from __future__ import print_function
import unittest
from functools import partial
import contextlib
import numpy as np
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
import paddle.fluid.framework as framework
import paddle.fluid.optimizer as optimizer
import paddle.fluid.regularizer as regularizer
......@@ -97,5 +102,134 @@ class TestL1DecayRegularizer(unittest.TestCase):
self.assertEqual(block.ops[-3].type, 'sign')
def bow_net(data,
label,
dict_dim,
is_sparse=False,
emb_dim=128,
hid_dim=128,
hid_dim2=96,
class_dim=2):
"""
BOW net
This model is from https://github.com/PaddlePaddle/models:
fluid/PaddleNLP/text_classification/nets.py
"""
emb = fluid.layers.embedding(
input=data, is_sparse=is_sparse, size=[dict_dim, emb_dim])
bow = fluid.layers.sequence_pool(input=emb, pool_type='sum')
bow_tanh = fluid.layers.tanh(bow)
fc_1 = fluid.layers.fc(input=bow_tanh, size=hid_dim, act="tanh")
fc_2 = fluid.layers.fc(input=fc_1, size=hid_dim2, act="tanh")
prediction = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax")
cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost)
return avg_cost
class TestRegularizer(unittest.TestCase):
def setUp(self):
self.word_dict = paddle.dataset.imdb.word_dict()
reader = paddle.batch(
paddle.dataset.imdb.train(self.word_dict), batch_size=8)()
self.train_data = [next(reader) for _ in range(5)]
def get_places(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
return places
@contextlib.contextmanager
def scope_prog_guard(self, main_prog, startup_prog):
scope = fluid.core.Scope()
with fluid.unique_name.guard():
with fluid.scope_guard(scope):
with fluid.program_guard(main_prog, startup_prog):
yield
def run_program(self, place, feed_list):
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=feed_list, place=place)
exe.run(fluid.default_startup_program())
main_prog = fluid.default_main_program()
param_list = [var.name for var in main_prog.block(0).all_parameters()]
param_sum = []
for data in self.train_data:
out = exe.run(main_prog,
feed=feeder.feed(data),
fetch_list=param_list)
p_sum = 0
for v in out:
p_sum += np.sum(np.abs(v))
param_sum.append(p_sum)
return param_sum
def check_l2decay_regularizer(self, place, model):
main_prog = fluid.framework.Program()
startup_prog = fluid.framework.Program()
startup_prog.random_seed = 1
with self.scope_prog_guard(
main_prog=main_prog, startup_prog=startup_prog):
data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1)
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
avg_cost = model(data, label, len(self.word_dict))
optimizer = fluid.optimizer.Adagrad(
learning_rate=0.1,
regularization=fluid.regularizer.L2Decay(1.0))
optimizer.minimize(avg_cost)
param_sum = self.run_program(place, [data, label])
return param_sum
def check_l2decay(self, place, model):
main_prog = fluid.framework.Program()
startup_prog = fluid.framework.Program()
startup_prog.random_seed = 1
with self.scope_prog_guard(
main_prog=main_prog, startup_prog=startup_prog):
data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1)
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
avg_cost_l2 = model(data, label, len(self.word_dict))
param_list = fluid.default_main_program().block(0).all_parameters()
para_sum = []
for para in param_list:
para_mul = fluid.layers.square(x=para)
para_sum.append(fluid.layers.reduce_sum(input=para_mul))
avg_cost_l2 += fluid.layers.sums(para_sum) * .5
optimizer = fluid.optimizer.Adagrad(learning_rate=0.1)
optimizer.minimize(avg_cost_l2)
param_sum = self.run_program(place, [data, label])
return param_sum
def test_l2(self):
for place in self.get_places():
dense_sparse_p_sum = []
for sparse in [True, False]:
model = partial(bow_net, is_sparse=sparse)
framework_l2 = self.check_l2decay_regularizer(place, model)
l2 = self.check_l2decay(place, model)
assert len(l2) == len(framework_l2)
for i in range(len(l2)):
assert np.isclose(a=framework_l2[i], b=l2[i], rtol=5e-5)
dense_sparse_p_sum.append(framework_l2)
assert len(dense_sparse_p_sum[0]) == len(dense_sparse_p_sum[1])
for i in range(len(dense_sparse_p_sum[0])):
assert np.isclose(
a=dense_sparse_p_sum[0][i],
b=dense_sparse_p_sum[1][i],
rtol=5e-5)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册