未验证 提交 814f7211 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] async prepare deps (#40713)

* async prepare deps

* fix bug that std::future is not set

* add ut

* refine code

* fix standalone ut

* disable prof
上级 97a20d75
......@@ -23,7 +23,7 @@ cc_library(standalone_executor SRCS standalone_executor.cc DEPS interpretercore)
# cc_binary(standalone_executor_test SRCS standalone_executor_test.cc DEPS interpretercore standalone_executor operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler)
# skip win32 since wget is not installed by default on windows machine.
# skip COVERAGE_CI since the test runs slowly because of instrumentation.
if (WITH_CUDA AND WITH_TESTING AND NOT WIN32 AND NOT WITH_COVERAGE AND NOT "$ENV{CI_SKIP_CPP_TEST}" STREQUAL "ON")
if (WITH_GPU AND WITH_TESTING AND NOT WIN32 AND NOT WITH_COVERAGE AND NOT "$ENV{CI_SKIP_CPP_TEST}" STREQUAL "ON")
add_custom_target(
download_program
COMMAND wget -nc https://paddle-ci.gz.bcebos.com/new_exec/lm_main_program
......
......@@ -41,6 +41,7 @@ namespace paddle {
namespace framework {
// NOTE(Aurelius84): Need a better strategy to determine it.
static constexpr size_t kHostNumThreads = 4;
static constexpr size_t kDeviceNumThreads = 1;
bool IsInterpretercoreFastGCEnabled() {
return FLAGS_fast_eager_deletion_mode && FLAGS_use_stream_safe_cuda_allocator;
......@@ -54,8 +55,8 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
global_scope_(global_scope),
stream_analyzer_(place) {
is_build_ = false;
async_work_queue_.reset(
new interpreter::AsyncWorkQueue(kHostNumThreads, &main_thread_blocker_));
async_work_queue_.reset(new interpreter::AsyncWorkQueue(
kHostNumThreads, kDeviceNumThreads, &main_thread_blocker_));
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (IsInterpretercoreFastGCEnabled()) {
......@@ -271,6 +272,10 @@ void InterpreterCore::Convert(
if (FLAGS_new_executor_use_inplace) {
BuildInplace();
}
// prepare for the first time.
async_work_queue_->PrepareAtomicDeps(dependecy_count_);
async_work_queue_->PrepareAtomicVarRef(vec_meta_info);
}
bool InterpreterCore::BuildInplaceCheckVarIsOnlyInput(size_t var_index) {
......@@ -388,18 +393,18 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
: global_scope_->GetMutableScope();
auto op_with_kernel = dynamic_cast<const framework::OperatorWithKernel*>(op);
{
platform::RecordEvent infershape_event(
"infer_shape", platform::TracerEventType::OperatorInner, 1,
platform::EventRole::kInnerOp);
// If it is OperatorBase, InferShape do nothing.
if (op_with_kernel != nullptr)
if (op_with_kernel != nullptr) {
platform::RecordEvent infershape_event(
"infer_shape", platform::TracerEventType::OperatorInner, 1,
platform::EventRole::kInnerOp);
// If it is OperatorBase, InferShape do nothing.
op_with_kernel->Info().infer_shape_(
instr_node.InnerInferShapeContext().get());
}
}
if (op_with_kernel != nullptr &&
FLAGS_new_executor_use_inplace) { // TODO(xiongkun03) Does operator
// base support inplace ?
if (op_with_kernel != nullptr && FLAGS_new_executor_use_inplace) {
// TODO(xiongkun03) Does operator base support inplace ?
for (auto& pair : instr_node.InplaceInfo()) {
const auto& in = paddle::framework::details::GetTensorFromVar(pair.first);
auto* out =
......@@ -409,6 +414,7 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
}
}
}
{
platform::RecordEvent compute_event(
"compute", platform::TracerEventType::OperatorInner, 1,
......@@ -458,16 +464,24 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
void InterpreterCore::ExecuteInstructionList(
const std::vector<Instruction>& vec_instr) {
// NOTE(zhiqiu): get the prepared deps from std::future, and async prepare
// those for the next step
auto atomic_deps = async_work_queue_->AtomicDeps();
auto atomic_var_ref = async_work_queue_->AtomicVarRef();
async_work_queue_->PrepareAtomicDeps(dependecy_count_);
async_work_queue_->PrepareAtomicVarRef(global_scope_->VecMetaInfo());
unfinished_op_numer_ = vec_instr.size();
exception_holder_.Clear();
for (size_t i = 0; i < dependecy_count_.size(); ++i) {
if (dependecy_count_[i] == 0) {
async_work_queue_->AddTask(vec_instr.at(i).KernelType(),
[&, i] { RunInstructionAsync(i); });
async_work_queue_->AddTask(vec_instr.at(i).KernelType(), [
this, i, atomic_deps = atomic_deps.get(),
atomic_var_ref = atomic_var_ref.get()
] { RunInstructionAsync(i, atomic_deps, atomic_var_ref); });
}
}
......@@ -490,11 +504,16 @@ void InterpreterCore::ExecuteInstructionList(
}
void InterpreterCore::RunNextInstructions(
const Instruction& instr, std::queue<size_t>* reserved_next_ops) {
const Instruction& instr, std::queue<size_t>* reserved_next_ops,
std::vector<std::atomic<size_t>>* atomic_deps,
std::vector<std::atomic<size_t>>* atomic_var_ref) {
VLOG(4) << "atomic 1:" << atomic_deps;
auto& next_instr = instr.NextInstructions();
auto& atomic_deps = async_work_queue_->AtomicDeps();
auto IsReady = [&](size_t next_id) {
return atomic_deps[next_id]->fetch_sub(1, std::memory_order_relaxed) == 1;
auto IsReady = [atomic_deps](size_t next_id) {
VLOG(4) << "atomic:" << atomic_deps << " " << &(*atomic_deps)[next_id]
<< " " << next_id;
return (*atomic_deps)[next_id].fetch_sub(1, std::memory_order_relaxed) == 1;
};
if (instr.KernelType() == OpFuncType::kQueueAsync) {
......@@ -503,7 +522,9 @@ void InterpreterCore::RunNextInstructions(
if (IsReady(next_id)) {
async_work_queue_->AddTask(
vec_instruction_[next_id].KernelType(),
[&, next_id] { RunInstructionAsync(next_id); });
[this, next_id, atomic_deps, atomic_var_ref]() {
RunInstructionAsync(next_id, atomic_deps, atomic_var_ref);
});
}
}
// keep all async_ops running in current thread
......@@ -523,7 +544,9 @@ void InterpreterCore::RunNextInstructions(
if (IsReady(next_id)) {
async_work_queue_->AddTask(
vec_instruction_[next_id].KernelType(),
[&, next_id] { RunInstructionAsync(next_id); });
[this, next_id, atomic_deps, atomic_var_ref] {
RunInstructionAsync(next_id, atomic_deps, atomic_var_ref);
});
}
}
auto direct_run_ops = interpreter::merge_vector(next_instr.SyncRunIds(),
......@@ -539,14 +562,18 @@ void InterpreterCore::RunNextInstructions(
// move rest ops into other threads
async_work_queue_->AddTask(
vec_instruction_[next_id].KernelType(),
[&, next_id] { RunInstructionAsync(next_id); });
[this, next_id, atomic_deps, atomic_var_ref] {
RunInstructionAsync(next_id, atomic_deps, atomic_var_ref);
});
}
}
if (first_op != 0) reserved_next_ops->push(first_op);
}
}
void InterpreterCore::RunInstructionAsync(size_t instr_id) {
void InterpreterCore::RunInstructionAsync(
size_t instr_id, std::vector<std::atomic<size_t>>* atomic_deps,
std::vector<std::atomic<size_t>>* atomic_var_ref) {
std::queue<size_t> ready_ops;
ready_ops.push(instr_id);
while (!ready_ops.empty()) {
......@@ -571,7 +598,7 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
RecordStreamForGC(instr_node);
#endif
CheckGC(instr_node);
CheckGC(instr_node, atomic_var_ref);
} catch (platform::EnforceNotMet& ex) {
framework::InsertCallStackInfo(op->Type(), op->Attrs(), &ex);
exception_holder_.Catch(std::make_exception_ptr(std::move(ex)));
......@@ -605,7 +632,7 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) {
interpreter::RecordEvent(instr_node, place_);
RunNextInstructions(instr_node, &ready_ops);
RunNextInstructions(instr_node, &ready_ops, atomic_deps, atomic_var_ref);
}
}
......@@ -703,17 +730,19 @@ void InterpreterCore::RecordStreamForGC(const Instruction& instr) {
}
#endif
void InterpreterCore::CheckGC(const Instruction& instr) {
void InterpreterCore::CheckGC(
const Instruction& instr,
std::vector<std::atomic<size_t>>* atomic_var_ref) {
size_t instr_id = instr.Id();
auto& var_scope = *global_scope_;
auto& atomic_var_ref = async_work_queue_->AtomicVarRef();
for (auto var_id : instr.GCCheckVars()) {
VLOG(4) << "GC " << global_scope_->GetNameById(var_id) << " "
<< var_scope.VarDesc(var_id);
VLOG(4) << "atomic:" << atomic_var_ref << " " << &(*atomic_var_ref)[var_id]
<< " " << var_id;
bool is_ready =
atomic_var_ref[var_id]->fetch_sub(1, std::memory_order_relaxed) == 1;
(*atomic_var_ref)[var_id].fetch_sub(1, std::memory_order_relaxed) == 1;
// ignore all persistable var while GC
if (var_scope.VarDesc(var_id) && var_scope.VarDesc(var_id)->Persistable()) {
continue;
......
......@@ -76,11 +76,16 @@ class InterpreterCore {
void RecordStreamForGC(const Instruction& instr);
#endif
void CheckGC(const Instruction& instr);
void CheckGC(const Instruction& instr,
std::vector<std::atomic<size_t>>* atomic_var_ref);
void RunInstructionAsync(size_t instr_id);
void RunInstructionAsync(size_t instr_id,
std::vector<std::atomic<size_t>>* atomic_deps,
std::vector<std::atomic<size_t>>* atomic_var_ref);
void RunNextInstructions(const Instruction& instr_id,
std::queue<size_t>* reserved_next_ops);
std::queue<size_t>* reserved_next_ops,
std::vector<std::atomic<size_t>>* atomic_deps,
std::vector<std::atomic<size_t>>* atomic_var_ref);
void BuildSkipShareLoDInfo();
......
......@@ -44,32 +44,37 @@ void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type,
using VariableIdMap = std::map<std::string, std::vector<int>>;
AtomicVectorSizeT& AsyncWorkQueue::PrepareAtomicDeps(
void AsyncWorkQueue::PrepareAtomicDeps(
const std::vector<size_t>& dependecy_count) {
if (atomic_deps_.size() != dependecy_count.size()) {
atomic_deps_.clear();
std::generate_n(std::back_inserter(atomic_deps_), dependecy_count.size(),
[] { return std::make_unique<std::atomic<size_t>>(0); });
}
for (size_t i = 0; i < dependecy_count.size(); ++i) {
atomic_deps_[i]->store(dependecy_count[i]);
}
return atomic_deps_;
VLOG(4) << "PrepareAtomicDeps";
auto p = std::make_shared<
std::promise<std::unique_ptr<std::vector<std::atomic<size_t>>>>>();
atomic_deps_ = p->get_future();
queue_group_->AddTask(2, [&dependecy_count, p] {
auto* op_deps =
new std::vector<std::atomic<size_t>>(dependecy_count.size());
for (size_t i = 0; i < dependecy_count.size(); ++i) {
(*op_deps)[i] = dependecy_count[i];
}
VLOG(4) << "AtomicDeps:" << op_deps << " " << (*op_deps).size();
p->set_value(std::unique_ptr<std::vector<std::atomic<size_t>>>(op_deps));
});
}
AtomicVectorSizeT& AsyncWorkQueue::PrepareAtomicVarRef(
void AsyncWorkQueue::PrepareAtomicVarRef(
const std::vector<VariableMetaInfo>& vec_meta_info) {
if (atomic_var_ref_.size() != vec_meta_info.size()) {
atomic_var_ref_.clear();
std::generate_n(std::back_inserter(atomic_var_ref_), vec_meta_info.size(),
[] { return std::make_unique<std::atomic<size_t>>(0); });
}
for (size_t i = 0; i < vec_meta_info.size(); ++i) {
atomic_var_ref_[i]->store(vec_meta_info[i].var_ref_count_);
}
return atomic_var_ref_;
VLOG(4) << "PrepareAtomicVarRef";
auto p = std::make_shared<
std::promise<std::unique_ptr<std::vector<std::atomic<size_t>>>>>();
atomic_var_ref_ = p->get_future();
queue_group_->AddTask(2, [&vec_meta_info, p] {
auto* var_ref = new std::vector<std::atomic<size_t>>(vec_meta_info.size());
for (size_t i = 0; i < vec_meta_info.size(); ++i) {
(*var_ref)[i] = vec_meta_info[i].var_ref_count_;
}
VLOG(4) << "AtomicVarRef:" << var_ref << " " << (*var_ref).size();
p->set_value(std::unique_ptr<std::vector<std::atomic<size_t>>>(var_ref));
});
}
bool var_can_be_deleted(const std::string& name, const BlockDesc& block) {
......
......@@ -50,11 +50,13 @@ namespace framework {
namespace interpreter {
using AtomicVectorSizeT = std::vector<std::unique_ptr<std::atomic<size_t>>>;
using AtomicVectorSizeT =
std::future<std::unique_ptr<std::vector<std::atomic<size_t>>>>;
class AsyncWorkQueue {
public:
AsyncWorkQueue(size_t host_num_threads, EventsWaiter* waiter)
AsyncWorkQueue(size_t host_num_threads, size_t deivce_num_threads,
EventsWaiter* waiter)
: host_num_thread_(host_num_threads) {
std::vector<WorkQueueOptions> group_options;
// for execute host Kernel
......@@ -66,6 +68,13 @@ class AsyncWorkQueue {
/*events_waiter*/ waiter);
// for launch device Kernel
group_options.emplace_back(/*name*/ "DeviceKernelLaunch",
/*num_threads*/ deivce_num_threads,
/*allow_spinning*/ true,
/*track_task*/ false,
/*detached*/ true,
/*events_waiter*/ waiter);
// for prepare deps and others
group_options.emplace_back(/*name*/ "Prepare",
/*num_threads*/ 1,
/*allow_spinning*/ true,
/*track_task*/ false,
......@@ -74,10 +83,8 @@ class AsyncWorkQueue {
queue_group_ = CreateWorkQueueGroup(group_options);
}
AtomicVectorSizeT& PrepareAtomicDeps(
const std::vector<size_t>& dependecy_count);
AtomicVectorSizeT& PrepareAtomicVarRef(
const std::vector<VariableMetaInfo>& vec_meta_info);
void PrepareAtomicDeps(const std::vector<size_t>& dependecy_count);
void PrepareAtomicVarRef(const std::vector<VariableMetaInfo>& vec_meta_info);
// void WaitEmpty() { queue_group_->WaitQueueGroupEmpty(); }
......@@ -85,8 +92,12 @@ class AsyncWorkQueue {
void Cancel() { queue_group_->Cancel(); }
AtomicVectorSizeT& AtomicDeps() { return atomic_deps_; }
AtomicVectorSizeT& AtomicVarRef() { return atomic_var_ref_; }
std::unique_ptr<std::vector<std::atomic<size_t>>> AtomicDeps() {
return atomic_deps_.get();
}
std::unique_ptr<std::vector<std::atomic<size_t>>> AtomicVarRef() {
return atomic_var_ref_.get();
}
private:
size_t host_num_thread_;
......
......@@ -20,45 +20,65 @@
// #include "gperftools/profiler.h"
#include "paddle/fluid/framework/new_executor/standalone_executor.h"
#include "paddle/phi/core/kernel_registry.h"
USE_OP_ITSELF(fill_constant);
USE_OP(uniform_random);
USE_OP_ITSELF(uniform_random);
USE_OP(lookup_table);
USE_OP(transpose2);
USE_OP_ITSELF(transpose2);
USE_OP_ITSELF(reshape2);
USE_OP(split);
USE_OP(slice);
USE_OP(concat);
USE_OP(matmul);
USE_OP_ITSELF(split);
USE_OP_ITSELF(slice);
USE_OP_ITSELF(concat);
USE_OP_ITSELF(matmul);
USE_OP_ITSELF(elementwise_add);
USE_OP_ITSELF(sigmoid);
USE_OP_ITSELF(tanh);
USE_OP(elementwise_mul);
USE_OP_ITSELF(elementwise_mul);
USE_OP(softmax_with_cross_entropy);
USE_OP_ITSELF(reduce_mean);
USE_OP_ITSELF(reduce_sum);
USE_OP_ITSELF(reduce_sum_grad);
USE_OP_ITSELF(reduce_mean_grad);
USE_OP_ITSELF(reshape2_grad);
USE_OP(softmax_with_cross_entropy_grad);
USE_OP_ITSELF(softmax_with_cross_entropy_grad);
USE_OP_ITSELF(elementwise_add_grad);
USE_OP(matmul_grad);
USE_OP(square);
USE_OP(transpose2_grad);
USE_OP_ITSELF(matmul_grad);
USE_OP_ITSELF(square);
USE_OP_ITSELF(transpose2_grad);
USE_OP(concat_grad);
USE_OP_ITSELF(elementwise_mul_grad);
USE_OP_ITSELF(sigmoid_grad);
USE_OP_ITSELF(tanh_grad);
USE_OP(sum);
USE_OP(slice_grad);
USE_OP(lookup_table_grad);
USE_OP_ITSELF(slice_grad);
USE_OP_ITSELF(lookup_table_grad);
USE_OP(sqrt);
USE_OP(elementwise_max);
USE_OP_ITSELF(elementwise_div);
USE_OP(sgd);
USE_OP_ITSELF(sgd);
USE_OP(squared_l2_norm);
USE_OP(memcpy_h2d);
USE_OP(memcpy_d2h);
USE_OP_ITSELF(memcpy_h2d);
USE_OP_ITSELF(memcpy_d2h);
PD_DECLARE_KERNEL(full, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(uniform_random_raw, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(transpose, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(reshape, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(split, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(concat, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(matmul, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add_raw, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(mean, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(reshape_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(matmul_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(transpose_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sum, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sum_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(max_raw, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sgd, GPU, ALL_LAYOUT);
DECLARE_double(eager_delete_tensor_gb);
namespace paddle {
......
......@@ -144,6 +144,9 @@ class Scope : public ScopeBase {
void Rename(const std::string& origin_name,
const std::string& new_name) const;
// Return the number of variables in scope
size_t Size() { return vars_.size(); }
// Rename variable to a new name and return the new name
std::string Rename(const std::string& origin_name) const;
......
......@@ -1761,6 +1761,7 @@ All parameter, weight, gradient are variables in Paddle.
out (core.Variable|None): the found variable or None.
)DOC",
py::return_value_policy::reference)
.def("size", &Scope::Size)
.def("erase", &Scope::EraseVars, py::arg("names"),
R"DOC(
Find variable named :code:`name` in the current scope or
......@@ -2857,6 +2858,9 @@ All parameter, weight, gradient are variables in Paddle.
.def("run",
[](StandaloneExecutor &self, std::vector<std::string> feed_names,
std::vector<std::string> fetch_names) {
platform::RecordEvent record_event(
"StandaloneExecutor:run",
platform::TracerEventType::UserDefined, 1);
paddle::framework::FetchList ret;
{
pybind11::gil_scoped_release release;
......
......@@ -59,6 +59,13 @@ class TestScope(unittest.TestCase):
# It is not allowed to delete a nonexistent scope.
scope._remove_from_pool()
def test_size(self):
paddle_c = paddle.fluid.core
scope = paddle_c.Scope()
var_a = scope.var("var_a")
self.assertEqual(scope.size(), 1)
self.assertIsNotNone(scope.find_var('var_a'))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册