From 678a259a93ff606dd2260c4672a662a8cd93897d Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 26 Aug 2021 17:06:29 +0800 Subject: [PATCH] Support Multi-Stream, Single-Thread in New Executor (#35024) * Modify into QueueSync QueueAsync * fix complie on MacOS * fix pointer * fix conflict * polish unittest * fix windows fetch error * polish code according reviewer * fix device_guard on CPU place --- .../framework/new_executor/interpretercore.cc | 290 ++++++++++++++++-- .../framework/new_executor/interpretercore.h | 15 + .../new_executor/new_executor_defs.h | 36 ++- .../new_executor/standalone_executor.cc | 1 + paddle/fluid/operators/memcpy_d2h_op.cc | 148 +++++++++ paddle/fluid/operators/memcpy_d2h_op.h | 78 +++++ paddle/fluid/operators/memcpy_h2d_op.cc | 148 +++++++++ paddle/fluid/operators/memcpy_h2d_op.h | 76 +++++ .../interpreter/test_standalone_executor.py | 71 +++++ 9 files changed, 830 insertions(+), 33 deletions(-) create mode 100644 paddle/fluid/operators/memcpy_d2h_op.cc create mode 100644 paddle/fluid/operators/memcpy_d2h_op.h create mode 100644 paddle/fluid/operators/memcpy_h2d_op.cc create mode 100644 paddle/fluid/operators/memcpy_h2d_op.h diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index ec5a0362b6..8a9b478ca7 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -13,9 +13,127 @@ // limitations under the License. #include "paddle/fluid/framework/new_executor/interpretercore.h" +#include + namespace paddle { namespace framework { +static constexpr char kMemcpyH2D[] = "memcpy_h2d"; +static constexpr char kMemcpyD2H[] = "memcpy_d2h"; +namespace { +std::string GetMemcpyType(const platform::Place& src_place, + const platform::Place& dst_place) { + PADDLE_ENFORCE_EQ(platform::is_same_place(src_place, dst_place), false, + platform::errors::PreconditionNotMet( + "Required src_place shall be different with dst_place, " + "but received same place: %s", + src_place)); + if (platform::is_gpu_place(dst_place)) { + return kMemcpyH2D; + } else if (platform::is_gpu_place(src_place)) { + return kMemcpyD2H; + } else { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "Not support Memcpy typ : %s -> %s", src_place, dst_place)); + } +} + +/* + * Parse the var_ids that need to be associated with an event. + * The caller should guarantee front_op and back_op satisfy the + * following conditions: + * 1. kQueueAsync -> kQueueAsync + * 2. kQueueAsync -> kQueueSync + * + * For example: matmul(gpu) -> out_var -> memcpy_d2h + * out_var should be associated with an event. + */ +std::vector ParseEventVarIds(const Instruction& cur_instr, + const Instruction& next_instr) { + std::unordered_set unique_var_ids; + for (auto& item : cur_instr.output_index_) { + unique_var_ids.insert(item.second.begin(), item.second.end()); + } + + std::vector new_event_var_ids; + for (auto& item : next_instr.input_index_) { + for (auto var_id : item.second) { + if (unique_var_ids.count(var_id) > 0) { + new_event_var_ids.push_back(var_id); + } + } + } + return new_event_var_ids; +} + +void AssociateInputWithEvents( + const std::vector& new_event_var_id, Instruction* next_instr, + std::map>* var_id2event, + bool is_sync) { +#ifdef PADDLE_WITH_CUDA + for (auto var_id : new_event_var_id) { + if (var_id2event->count(var_id) == 0) { + auto cuda_event = std::make_shared( + platform::get_cuda_flags(false, false, false)); + var_id2event->emplace(var_id, std::move(cuda_event)); + } + // Add events for next_instr.inputs + next_instr->intput_events_.emplace_back(var_id, var_id2event->at(var_id), + is_sync); + } +#endif +} + +void ParseDirectAndEventRunOps( + const std::vector& op_func_nodes, + const std::vector& downstream_ops, size_t op_index, + std::map>* var_id2event, + std::vector* instructions) { + auto& op_func_type = op_func_nodes[op_index].type_; + auto& cur_instr = instructions->at(op_index); + auto& next_instruction = cur_instr.next_instruction_; + + if (op_func_type == OpFuncType::kQueueSync) { + // all downstream ops of kQueueSync can directly run, such as CPU -> Any + next_instruction.direct_run_ = downstream_ops; + } else { // kQueueAsync + std::vector event_var_ids; + for (auto next_op_id : downstream_ops) { + auto& next_instr = instructions->at(next_op_id); + // case 1: GPU -> GPU(same stream) + if (cur_instr.dev_ctx_ == next_instr.dev_ctx_) { + next_instruction.direct_run_.emplace_back(next_op_id); + continue; + } + // Always insert events between different stream + auto new_event_var_ids = ParseEventVarIds(cur_instr, next_instr); + event_var_ids.insert(event_var_ids.end(), new_event_var_ids.begin(), + new_event_var_ids.end()); + + bool is_sync = + (op_func_nodes[next_op_id].type_ == OpFuncType::kQueueSync); + AssociateInputWithEvents(new_event_var_ids, &next_instr, var_id2event, + is_sync); + + if (is_sync) { // GPU -> CPU + next_instruction.synchronize_run_.emplace_back(next_op_id); + } else { // GPU -> GPU(different stream) + next_instruction.event_wait_run_.emplace_back(next_op_id); + } + } +#ifdef PADDLE_WITH_CUDA + // Create events for these cross-stream vars + VLOG(3) << cur_instr.kernel_func_.operator_base_->Type() + << " event_var_ids.size: " << event_var_ids.size(); + for (auto var_id : event_var_ids) { + cur_instr.output_events_.emplace_back(var_id, var_id2event->at(var_id), + false /*not used*/); + } +#endif + } +} +} // namespace + InterpreterCore::InterpreterCore(const platform::Place& place, const ProgramDesc& main_prog, VariableScope* global_scope, @@ -24,6 +142,8 @@ InterpreterCore::InterpreterCore(const platform::Place& place, : place_(place), main_program_(main_prog), global_scope_(global_scope), + d2h_ctx_pool_({place}), + h2d_ctx_pool_({place}), fetch_context_pool_({place}) { is_build_ = false; feed_names_ = feed_names; @@ -88,8 +208,11 @@ void InterpreterCore::Convert() { vec_meta_info_.resize(global_scope_->var_list.size()); for (size_t i = 0; i < vec_func_list_.size(); ++i) { Instruction temp_inst; + auto* op_base = op_list_[i]; + temp_inst.dev_ctx_ = + ParseDeviceContextForInstruction(vec_func_list_[i], *op_base); temp_inst.kernel_func_.compute_func_ = vec_func_list_[i].kernel_func_; - temp_inst.kernel_func_.operator_base_ = op_list_[i]; + temp_inst.kernel_func_.operator_base_ = op_base; temp_inst.input_index_ = vec_func_list_[i].input_index; temp_inst.output_index_ = vec_func_list_[i].output_index; @@ -130,7 +253,9 @@ void InterpreterCore::Convert() { filter_next.push_back(item); } } - vec_instruction_[i].next_instruction_.direct_run_ = filter_next; + + ParseDirectAndEventRunOps(vec_func_list_, filter_next, i, &var_id2event_, + &vec_instruction_); // checkout ouput for (auto& item : vec_instruction_[i].output_index_) { @@ -146,6 +271,8 @@ void InterpreterCore::Convert() { for (auto inst_id : filter_next) { dependecy_count_[inst_id]++; } + vec_instruction_[i].next_instruction_.all_next_ops_ = + std::move(filter_next); } for (size_t i = 0; i < vec_instruction_.size(); ++i) { @@ -187,8 +314,7 @@ void InterpreterCore::BuildInstructionCtx(Instruction* instr_node, instr_node->infershape_ctx_.reset( new RuntimeInferShapeContext(*op_base, *instr_node->runtime_ctx_.get())); - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto* dev_ctx = pool.Get(place); + auto* dev_ctx = instr_node->dev_ctx_; if (instr_node->kernel_func_.operator_base_->Type() == "fetch_v2") { dev_ctx = fetch_context_pool_.Get(place); } @@ -230,10 +356,16 @@ void InterpreterCore::ExecuteInstructionList( auto instr_id = working_queue.front(); working_queue.pop(); auto& instr_node = vec_instr[instr_id]; + // step1 : stream_wait (non-block host) or sync (block host) + StreamWaitEventOrSync(instr_node); + // step2: run instruction RunInstruction(instr_node); - - auto& next_instr = instr_node.next_instruction_.direct_run_; ++run_op_number; + // step3: insert event for out_vars if needed + RecordEventInstruction(instr_node, vec_func_list_[instr_id]); + + // step4: update working_queue + auto& next_instr = instr_node.next_instruction_.all_next_ops_; for (auto next_i : next_instr) { --working_dependecy_count[next_i]; @@ -297,10 +429,10 @@ void InterpreterCore::BuildOpFuncList(const platform::Place& place, std::vector* vec_func_list, VariableScope* var_scope) { auto& global_block = pdesc.Block(0); + auto& all_op_kernels = OperatorWithKernel::AllOpKernels(); for (auto& op : global_block.AllOps()) { - VLOG(3) << op->Type(); - // << op->Type() << endl; + VLOG(3) << "Build OpFuncNode from : " << op->Type(); auto& info = OpInfoMap::Instance().Get(op->Type()); @@ -311,11 +443,10 @@ void InterpreterCore::BuildOpFuncList(const platform::Place& place, if (info.Checker() != nullptr) { info.Checker()->Check(&op_attr_map); } + // step 1. Prepare VariableValueMap of input/output auto op_base = info.Creator()(op->Type(), inputs_names, outputs_names, op_attr_map); - OpFuncNode op_func_node; - VariableValueMap ins_map; std::map> ins_name2id; for (auto& var_name_item : inputs_names) { @@ -348,15 +479,16 @@ void InterpreterCore::BuildOpFuncList(const platform::Place& place, outs_name2id[var_name_item.first] = vec_ids; } + OpFuncNode op_func_node; op_func_node.input_index = ins_name2id; op_func_node.output_index = outs_name2id; + // step 2: construct RuntimeContext and analysis KernelType RuntimeContext runtime_context({}, {}); runtime_context.inputs.swap(ins_map); runtime_context.outputs.swap(outs_map); RuntimeInferShapeContext infer_shape_ctx(*op_base, runtime_context); static_cast(op_base)->InferShape( &infer_shape_ctx); - auto& all_op_kernels = OperatorWithKernel::AllOpKernels(); auto kernels_iter = all_op_kernels.find(op->Type()); PADDLE_ENFORCE_NE( kernels_iter, all_op_kernels.end(), @@ -365,19 +497,37 @@ void InterpreterCore::BuildOpFuncList(const platform::Place& place, op->Type())); OpKernelMap& kernels = kernels_iter->second; - // auto place = platform::CPUPlace(); - // auto place = platform::CUDAPlace(0); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); Scope scope; - auto exec_ctx = - ExecutionContext(*op_base, scope, *dev_ctx, runtime_context); auto expected_kernel_key = dynamic_cast(op_base) - ->GetExpectedKernelType(exec_ctx); + ->GetExpectedKernelType( + ExecutionContext(*op_base, scope, *dev_ctx, runtime_context)); + + // consider device_guard context + bool need_change_place = + (op_base->HasAttr("op_device") && + (op_base->Attr("op_device").length() > 0)); + if (need_change_place) { + auto& op_device = op_base->Attr("op_device"); + if (op_device == "cpu" || platform::is_cpu_place(place)) { + VLOG(3) << "Switch into CPUPlace by device_guard."; + expected_kernel_key.place_ = platform::CPUPlace(); + } else if (op_device.find("gpu") != std::string::npos && + platform::is_gpu_place(place)) { + VLOG(3) << "Switch into " << place << " by device_guard."; + expected_kernel_key.place_ = place; + } else { + PADDLE_THROW( + platform::errors::Fatal("Unsupported current place %s", op_device)); + } + } + VLOG(3) << "expected_kernel_key : " << expected_kernel_key; + // step 3. Insert memcpy_op if needed VariableValueMap& ins_map_temp = runtime_context.inputs; - for (auto& var_name_item : ins_map_temp) { for (size_t i = 0; i < var_name_item.second.size(); ++i) { auto var = var_name_item.second[i]; @@ -408,7 +558,9 @@ void InterpreterCore::BuildOpFuncList(const platform::Place& place, copy_out_map["Out"] = {new_var_name}; AttributeMap attr_map; attr_map["dst_place_type"] = - is_cpu_place(place) ? 0 : is_gpu_place(place) ? 1 : -1; + is_cpu_place(expected_kernel_key.place_) + ? 0 + : is_gpu_place(expected_kernel_key.place_) ? 1 : -1; std::map> copy_ins_name2id; copy_ins_name2id["X"] = ins_name2id[var_name_item.first]; @@ -423,8 +575,15 @@ void InterpreterCore::BuildOpFuncList(const platform::Place& place, VariableValueMap copy_outs_value_map; copy_outs_value_map["Out"] = {v}; - auto& copy_info = OpInfoMap::Instance().Get("memcpy"); - auto copy_op = copy_info.Creator()("memcpy", copy_in_map, + // memcpy_d2h, memcpy_h2d + auto memcpy_op_type = GetMemcpyType(kernel_type_for_var.place_, + expected_kernel_key.place_); + VLOG(3) << string::Sprintf("Insert %s with %s(%s) -> %s(%s).", + memcpy_op_type, x_iter->second[i], + kernel_type_for_var.place_, new_var_name, + expected_kernel_key.place_); + auto& copy_info = OpInfoMap::Instance().Get(memcpy_op_type); + auto copy_op = copy_info.Creator()(memcpy_op_type, copy_in_map, copy_out_map, attr_map); OpFuncNode copy_op_func_node; copy_op_func_node.input_index = copy_ins_name2id; @@ -437,16 +596,14 @@ void InterpreterCore::BuildOpFuncList(const platform::Place& place, copy_runtime_context); static_cast(copy_op) ->InferShape(©_infer_shape_ctx); - auto& all_op_kernels = OperatorWithKernel::AllOpKernels(); - auto kernels_iter = all_op_kernels.find("memcpy"); + + auto kernels_iter = all_op_kernels.find(memcpy_op_type); PADDLE_ENFORCE_NE(kernels_iter, all_op_kernels.end(), platform::errors::Unavailable( "There are no kernels which are registered in " "the memcpy operator.")); OpKernelMap& kernels = kernels_iter->second; - platform::DeviceContextPool& pool = - platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); Scope scope; auto copy_exec_ctx = @@ -458,6 +615,9 @@ void InterpreterCore::BuildOpFuncList(const platform::Place& place, copy_op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second); copy_op_func_node.kernel_func_(copy_exec_ctx); + VLOG(3) << "Run " << memcpy_op_type << " done."; + copy_op_func_node.type_ = OpFuncType::kQueueAsync; + copy_op_func_node.dev_ctx_ = dev_ctx; op_list->push_back(copy_op); vec_func_list->push_back(copy_op_func_node); @@ -465,8 +625,27 @@ void InterpreterCore::BuildOpFuncList(const platform::Place& place, } } } - + // step 4. Run op kernel op_list->push_back(op_base); + VLOG(3) << op_base->Type() + << " : expected_kernel_key : " << expected_kernel_key; + + if (platform::is_gpu_place(expected_kernel_key.place_)) { + op_func_node.type_ = OpFuncType::kQueueAsync; + } else if (platform::is_cpu_place(expected_kernel_key.place_)) { + op_func_node.type_ = OpFuncType::kQueueSync; + } else { + PADDLE_THROW(platform::errors::Fatal("Unsupported current place %s", + expected_kernel_key.place_)); + } + + if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) { + dev_ctx = pool.Get(expected_kernel_key.place_); + } + op_func_node.dev_ctx_ = dev_ctx; + + auto exec_ctx = + ExecutionContext(*op_base, scope, *dev_ctx, runtime_context); auto kernel_iter = kernels.find(expected_kernel_key); PADDLE_ENFORCE_NE(kernel_iter, kernels.end(), @@ -477,8 +656,69 @@ void InterpreterCore::BuildOpFuncList(const platform::Place& place, op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second); op_func_node.kernel_func_(exec_ctx); vec_func_list->push_back(op_func_node); + VLOG(3) << "run " << op_base->Type() << " done."; } } +platform::DeviceContext* InterpreterCore::ParseDeviceContextForInstruction( + const OpFuncNode& op_func_node, const OperatorBase& op_base) { + auto& op_type = op_base.Type(); + auto* dev_ctx = op_func_node.dev_ctx_; + if (op_type == kMemcpyH2D) { + VLOG(3) << "Get dev_ctx from d2h_context_pool_"; + dev_ctx = d2h_ctx_pool_.Get(place_); + } else if (op_type == kMemcpyD2H) { + VLOG(3) << "Get dev_ctx from h2d_context_pool_"; + dev_ctx = h2d_ctx_pool_.Get(place_); + } + + return dev_ctx; +} + +void InterpreterCore::RecordEventInstruction(const Instruction& instruction, + const OpFuncNode& op_func_node) { + // If InterpreterCore in on CPUPlace, do nothing. + if (platform::is_cpu_place(place_)) return; + +#ifdef PADDLE_WITH_CUDA + const platform::CUDADeviceContext* dev_ctx = + reinterpret_cast( + instruction.dev_ctx_); + for (auto& event : instruction.output_events_) { + VLOG(3) << "Record event in out_var_id: " << event.var_id_; + event.event_->Record(*(dev_ctx->context()->Stream())); + } +#endif +} + +void InterpreterCore::WaitOrSync(const std::vector& events, + const platform::DeviceContext* dev_ctx) { +#ifdef PADDLE_WITH_CUDA + auto* cuda_dev_ctx = + reinterpret_cast(dev_ctx); + + for (auto& event : events) { + if (event.is_sync_) { + VLOG(3) << "host sync wait in_var_id " << event.var_id_; + event.event_->Synchronize(); + } else { + VLOG(3) << "stream async wait in_var_id " << event.var_id_; + cuda_dev_ctx->context()->Stream()->WaitEvent( + event.event_->GetRawCudaEvent()); + } + } +#endif +} + +void InterpreterCore::StreamWaitEventOrSync(const Instruction& instruction) { + // If InterpreterCore in on CPUPlace, do nothing. + if (platform::is_cpu_place(place_)) return; + + VLOG(3) << "Deal StreamWaitEventOrSync for " + << instruction.kernel_func_.operator_base_->Type(); + auto* dev_ctx = instruction.dev_ctx_; + + WaitOrSync(instruction.intput_events_, dev_ctx); +} } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index b4db692385..63a57e0c03 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -24,6 +24,7 @@ #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/platform/event.h" namespace paddle { namespace framework { @@ -63,9 +64,22 @@ class InterpreterCore { void BuildVariableScope(const framework::ProgramDesc& pdesc, VariableScope* var_scope); + platform::DeviceContext* ParseDeviceContextForInstruction( + const OpFuncNode& op_func_node, const OperatorBase& op_base); + + void RecordEventInstruction(const Instruction& instruction, + const OpFuncNode& op_func_node); + + void WaitOrSync(const std::vector& events, + const platform::DeviceContext* dev_ctx); + + void StreamWaitEventOrSync(const Instruction& instruction); + const platform::Place& place_; ProgramDesc main_program_; VariableScope* global_scope_; + platform::DeviceContextPool d2h_ctx_pool_; + platform::DeviceContextPool h2d_ctx_pool_; std::vector vec_meta_info_; std::vector vec_func_list_; @@ -80,6 +94,7 @@ class InterpreterCore { bool is_build_; std::vector feed_names_; + std::map> var_id2event_; platform::DeviceContextPool fetch_context_pool_; }; diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index 33ea943fdb..27526a9455 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -19,6 +19,7 @@ #include #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/event.h" namespace paddle { namespace framework { @@ -41,22 +42,30 @@ struct VariableScope { std::map name2id; }; +struct EventRun { + explicit EventRun(size_t op_id) : op_id_(op_id) {} + size_t op_id_; +}; struct NextInstruction { std::vector direct_run_; + std::vector event_wait_run_; + std::vector synchronize_run_; + std::vector all_next_ops_; }; -struct EventInter {}; +struct EventInter { + explicit EventInter(size_t var_id, std::shared_ptr event, + bool is_sync) + : var_id_(var_id), event_(event), is_sync_(is_sync) {} + size_t var_id_; + std::shared_ptr event_; + bool is_sync_; +}; struct InstructionInfo { std::vector dependecy_count_; }; -struct EventRun { - EventInter event_inter; - std::vector same_device_run_; - std::vector synchronized_run; -}; - struct Instruction { OpKernelFunc kernel_func_; std::shared_ptr runtime_ctx_; @@ -67,7 +76,16 @@ struct Instruction { std::vector gc_check_var_list; NextInstruction next_instruction_; - std::vector vec_event_list_; + + std::vector intput_events_; + std::vector output_events_; + + platform::DeviceContext* dev_ctx_; // not owned +}; + +enum class OpFuncType { + kQueueAsync, // GPU Kernel or d2h, h2d, send, recv, broadcast + kQueueSync, // CPU kernel, block host }; struct OpFuncNode { @@ -76,6 +94,8 @@ struct OpFuncNode { std::map> output_index; OpKernelComputeFunc kernel_func_; + platform::DeviceContext* dev_ctx_; // not owned + OpFuncType type_; }; } // namespace framework diff --git a/paddle/fluid/framework/new_executor/standalone_executor.cc b/paddle/fluid/framework/new_executor/standalone_executor.cc index 3c41910490..f43d1f9f9a 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor.cc @@ -91,6 +91,7 @@ std::shared_ptr StandaloneExecutor::GetInterpreterCore( auto iter = interpretercores_.find(oss.str()); if (iter == interpretercores_.end()) { + VLOG(3) << "create interpreter_core for " << oss.str(); auto core = std::make_shared( place_, main_prog_, &global_scope_, feed_names, fetch_names); interpretercores_.emplace(oss.str(), core); diff --git a/paddle/fluid/operators/memcpy_d2h_op.cc b/paddle/fluid/operators/memcpy_d2h_op.cc new file mode 100644 index 0000000000..41b8b36791 --- /dev/null +++ b/paddle/fluid/operators/memcpy_d2h_op.cc @@ -0,0 +1,148 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/memcpy_d2h_op.h" + +#include + +namespace paddle { +namespace framework { +class OpDesc; +class InferShapeContext; +template +class EmptyGradOpMaker; +} // namespace framework +namespace imperative { +class OpBase; +} // namespace imperative +namespace platform { +struct CPUPlace; +struct CUDAPlace; +struct float16; +} // namespace platform +} // namespace paddle + +namespace paddle { +namespace operators { + +class MemcpyD2HOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + auto type = ctx->GetInputsVarType("X")[0]; + if (type == framework::proto::VarType::SELECTED_ROWS || + type == framework::proto::VarType::LOD_TENSOR) { + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + if (type == framework::proto::VarType::LOD_TENSOR) { + ctx->ShareLoD("X", /*->*/ "Out"); + } + } + } + + protected: + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + return framework::OpKernelType(expected_kernel_type.data_type_, + expected_kernel_type.place_, + tensor.layout()); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); + } +}; + +class MemcpyD2HInferVarType : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + ctx->SyncTypeAndDataType("X", "Out"); + } +}; + +class MemcpyD2HKernel { + public: + void operator()(const framework::ExecutionContext &ctx) const { + auto *x = ctx.InputVar("X"); + if (x == nullptr) { + return; + } + PADDLE_ENFORCE_EQ(ctx.HasOutput("Out"), true, + platform::errors::NotFound( + "Output(Out) of memcpy_d2h_op is not found.")); + auto *out = ctx.OutputVar("Out"); + // Get dev_ctx from ExecutionContext, it's D2H stream + auto &dev_ctx = ctx.device_context(); + auto dst_place_type = ctx.Attr("dst_place_type"); + framework::VisitVarType(*x, MemcpyD2HFunctor(out, dev_ctx, dst_place_type)); + } +}; + +class MemcpyD2HOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(LoDTensor) The input variable "); + AddOutput("Out", + "(LoDTensor) The type of output " + "is the same as input X."); + AddAttr( + "dst_place_type", + "Determine the dst place of tensor copy. " + "By Now it ONLY support NPUPlace/CUDAPlace <-> CUDAPinnedPlace/CPU" + "Other place type is Unimplemented and will cause ERROR." + "0: dst is on CPUPlace. " + "1: dst is on CUDAPinnedPlace. "); + AddComment(R"DOC( + MemcpyD2H Operator. + By now, it ONLY supports the memcopy between NPUPlace/CUDAPlace <-> CUDAPinnedPlace/CPU. + You would have to update it if you want other more capacities. +Out = X, when type in [LoDTensor] +raise error if the type is not listed above. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OPERATOR( + memcpy_d2h, ops::MemcpyD2HOp, ops::MemcpyD2HOpProtoMaker, + ops::MemcpyD2HInferVarType, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL_FUNCTOR(memcpy_d2h, float, ops::MemcpyD2HKernel, double, + ops::MemcpyD2HKernel, int, ops::MemcpyD2HKernel, + int64_t, ops::MemcpyD2HKernel, bool, + ops::MemcpyD2HKernel, plat::float16, + ops::MemcpyD2HKernel); + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_ROCM) +REGISTER_OP_CUDA_KERNEL_FUNCTOR(memcpy_d2h, float, ops::MemcpyD2HKernel, double, + ops::MemcpyD2HKernel, int, ops::MemcpyD2HKernel, + int64_t, ops::MemcpyD2HKernel, bool, + ops::MemcpyD2HKernel, plat::float16, + ops::MemcpyD2HKernel); +#endif + +#ifdef PADDLE_WITH_ASCEND_CL +REGISTER_OP_NPU_KERNEL_FUNCTOR(memcpy_d2h, float, ops::MemcpyD2HKernel, double, + ops::MemcpyD2HKernel, int, ops::MemcpyD2HKernel, + int64_t, ops::MemcpyD2HKernel, bool, + ops::MemcpyD2HKernel, plat::float16, + ops::MemcpyD2HKernel); +#endif diff --git a/paddle/fluid/operators/memcpy_d2h_op.h b/paddle/fluid/operators/memcpy_d2h_op.h new file mode 100644 index 0000000000..6f9890d332 --- /dev/null +++ b/paddle/fluid/operators/memcpy_d2h_op.h @@ -0,0 +1,78 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/var_type.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace platform { +class DeviceContext; +} // namespace platform +} // namespace paddle + +namespace paddle { +namespace framework { +class LoDTensor; +class Variable; +class SelectedRows; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { +class MemcpyD2HFunctor { + public: + MemcpyD2HFunctor(framework::Variable *out, + const platform::DeviceContext &dev_ctx, + const int dst_place_type) + : out_(out), dev_ctx_(dev_ctx), dst_place_type_(dst_place_type) {} + + void operator()(const framework::LoDTensor &lod_tensor) const { + auto &out_tensor = *out_->GetMutable(); + + if (dst_place_type_ == 1) { + framework::TensorCopy(lod_tensor, platform::CUDAPinnedPlace(), dev_ctx_, + &out_tensor); + } else if (dst_place_type_ == 0) { + framework::TensorCopySync(lod_tensor, platform::CPUPlace(), &out_tensor); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "memcpy dst_place_type: %d is not supported yet.", dst_place_type_)); + } + out_tensor.set_lod(lod_tensor.lod()); + } + + void operator()(const framework::SelectedRows &rows) const { + // (JZ-LIANG) to support SelectedRows + PADDLE_THROW(platform::errors::Unimplemented( + "Memcpy for SelectedRows is NOT support yet.")); + } + + template + void operator()(const T &v) const { + PADDLE_ENFORCE_EQ( + true, false, + platform::errors::PermissionDenied( + "Not support type for Memcpy op with type %s", typeid(T).name())); + } + + private: + framework::Variable *out_; + const platform::DeviceContext &dev_ctx_; + const int dst_place_type_; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/memcpy_h2d_op.cc b/paddle/fluid/operators/memcpy_h2d_op.cc new file mode 100644 index 0000000000..e439be1620 --- /dev/null +++ b/paddle/fluid/operators/memcpy_h2d_op.cc @@ -0,0 +1,148 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/memcpy_h2d_op.h" + +#include + +namespace paddle { +namespace framework { +class OpDesc; +class InferShapeContext; +template +class EmptyGradOpMaker; +} // namespace framework +namespace imperative { +class OpBase; +} // namespace imperative +namespace platform { +struct CPUPlace; +struct CUDAPlace; +struct float16; +} // namespace platform +} // namespace paddle + +namespace paddle { +namespace operators { + +class MemcpyH2DOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + auto type = ctx->GetInputsVarType("X")[0]; + if (type == framework::proto::VarType::SELECTED_ROWS || + type == framework::proto::VarType::LOD_TENSOR) { + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + if (type == framework::proto::VarType::LOD_TENSOR) { + ctx->ShareLoD("X", /*->*/ "Out"); + } + } + } + + protected: + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + return framework::OpKernelType(expected_kernel_type.data_type_, + expected_kernel_type.place_, + tensor.layout()); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); + } +}; + +class MemcpyH2DInferVarType : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + ctx->SyncTypeAndDataType("X", "Out"); + } +}; + +class MemcpyH2DKernel { + public: + void operator()(const framework::ExecutionContext &ctx) const { + auto *x = ctx.InputVar("X"); + if (x == nullptr) { + return; + } + PADDLE_ENFORCE_EQ(ctx.HasOutput("Out"), true, + platform::errors::NotFound( + "Output(Out) of memcpy_d2h_op is not found.")); + auto *out = ctx.OutputVar("Out"); + // Get dev_ctx from ExecutionContext, it's H2D stream + auto &dev_ctx = ctx.device_context(); + auto dst_place_type = ctx.Attr("dst_place_type"); + framework::VisitVarType(*x, MemcpyH2DFunctor(out, dev_ctx, dst_place_type)); + } +}; + +class MemcpyH2DOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(LoDTensor) The input variable "); + AddOutput("Out", + "(LoDTensor) The type of output " + "is the same as input X."); + AddAttr( + "dst_place_type", + "Determine the dst place of tensor copy. " + "By Now it ONLY support CUDAPinnedPlace/CPU <-> NPUPlace/CUDAPlace " + "Other place type is Unimplemented and will cause ERROR." + "0: dst is on CUDAPlace. " + "1: dst is on NPUPlace. "); + AddComment(R"DOC( + MemcpyD2H Operator. + By now, it ONLY supports the memcopy between CUDAPinnedPlace/CPU <-> NPUPlace/CUDAPlace. + You would have to update it if you want other more capacities. +Out = X, when type in [LoDTensor] +raise error if the type is not listed above. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OPERATOR( + memcpy_h2d, ops::MemcpyH2DOp, ops::MemcpyH2DOpProtoMaker, + ops::MemcpyH2DInferVarType, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL_FUNCTOR(memcpy_h2d, float, ops::MemcpyH2DKernel, double, + ops::MemcpyH2DKernel, int, ops::MemcpyH2DKernel, + int64_t, ops::MemcpyH2DKernel, bool, + ops::MemcpyH2DKernel, plat::float16, + ops::MemcpyH2DKernel); + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_ROCM) +REGISTER_OP_CUDA_KERNEL_FUNCTOR(memcpy_h2d, float, ops::MemcpyH2DKernel, double, + ops::MemcpyH2DKernel, int, ops::MemcpyH2DKernel, + int64_t, ops::MemcpyH2DKernel, bool, + ops::MemcpyH2DKernel, plat::float16, + ops::MemcpyH2DKernel); +#endif + +#ifdef PADDLE_WITH_ASCEND_CL +REGISTER_OP_NPU_KERNEL_FUNCTOR(memcpy_h2d, float, ops::MemcpyH2DKernel, double, + ops::MemcpyH2DKernel, int, ops::MemcpyH2DKernel, + int64_t, ops::MemcpyH2DKernel, bool, + ops::MemcpyH2DKernel, plat::float16, + ops::MemcpyH2DKernel); +#endif diff --git a/paddle/fluid/operators/memcpy_h2d_op.h b/paddle/fluid/operators/memcpy_h2d_op.h new file mode 100644 index 0000000000..3998db6731 --- /dev/null +++ b/paddle/fluid/operators/memcpy_h2d_op.h @@ -0,0 +1,76 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/var_type.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace platform { +class DeviceContext; +} // namespace platform +} // namespace paddle + +namespace paddle { +namespace framework { +class LoDTensor; +class Variable; +class SelectedRows; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { +class MemcpyH2DFunctor { + public: + MemcpyH2DFunctor(framework::Variable *out, + const platform::DeviceContext &dev_ctx, + const int dst_place_type) + : out_(out), dev_ctx_(dev_ctx), dst_place_type_(dst_place_type) {} + + void operator()(const framework::LoDTensor &lod_tensor) const { + auto &out_tensor = *out_->GetMutable(); + + if (dst_place_type_ == 0 || dst_place_type_ == 1) { + framework::TensorCopy(lod_tensor, dev_ctx_.GetPlace(), dev_ctx_, + &out_tensor); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "memcpy dst_place_type: %d is not supported yet.", dst_place_type_)); + } + out_tensor.set_lod(lod_tensor.lod()); + } + + void operator()(const framework::SelectedRows &rows) const { + // (JZ-LIANG) to support SelectedRows + PADDLE_THROW(platform::errors::Unimplemented( + "Memcpy for SelectedRows is NOT support yet.")); + } + + template + void operator()(const T &v) const { + PADDLE_ENFORCE_EQ( + true, false, + platform::errors::PermissionDenied( + "Not support type for Memcpy op with type %s", typeid(T).name())); + } + + private: + framework::Variable *out_; + const platform::DeviceContext &dev_ctx_; + const int dst_place_type_; +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py b/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py index bfed9621c9..2286765680 100644 --- a/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py +++ b/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py @@ -57,5 +57,76 @@ class LinearTestCase(unittest.TestCase): }, [a.name, c.name]) +class MultiStreamModelTestCase(unittest.TestCase): + def setUp(self): + self.iter_n = 2 + self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else paddle.CPUPlace() + + def build_program(self): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + + with paddle.static.program_guard(main_program, startup_program): + with paddle.static.device_guard('cpu'): + data = paddle.ones([4, 64], dtype='float32', name='data') + + # data -> [memcpy_h2d] -> data' -> [matmul] -> out ->[add] -> add_out + with paddle.static.device_guard('gpu'): + weight = paddle.randn([64, 64], name='weight') # gpu + matmul_out = paddle.matmul( + data, weight, name='matmul_out') # gpus + bias = paddle.ones([4, 64], dtype='float32', name='bias') + add_out = paddle.add(matmul_out, bias, name='add_out') + + # add_out -> [memcpy_d2h] -> add_out' -> [sub] -> sub_out -> [tanh] -> tanh_out + with paddle.static.device_guard('cpu'): + sub_out = paddle.subtract(add_out, data, name='sub_out') + tanh_out = paddle.tanh(sub_out, name='tanh_out') + + with paddle.static.device_guard('gpu'): + bias_1 = paddle.add(bias, sub_out, name='bias_1') + out_before = paddle.tanh(bias_1, name='out_before') + out_last = paddle.subtract(tanh_out, data, name='out_last') + + out = paddle.add(out_before, out_last, name='out') + mean = paddle.mean(out, name='mean_out') + + return main_program, startup_program, [mean] + + def test_multi_stream(self): + ground_truths = self.run_raw_executor() + res = self.run_new_executor() + for gt, out in zip(ground_truths, res): + self.assertEqual(gt[0], out[0]) + + def run_raw_executor(self): + paddle.seed(2020) + main_program, startup_program, fetch_list = self.build_program() + + exe = paddle.static.Executor(self.place) + exe.run(startup_program) + + outs = [] + for i in range(self.iter_n): + outs.append(exe.run(main_program, fetch_list=fetch_list)) + return outs + + def run_new_executor(self): + paddle.seed(2020) + main_program, startup_program, fetch_list = self.build_program() + fetch_list = [x.name for x in fetch_list] + + p = core.Place() + p.set_place(self.place) + inter_core = StandaloneExecutor(p, startup_program.desc, + main_program.desc, core.Scope()) + outs = [] + for i in range(self.iter_n): + outs.append( + np.array(inter_core.run({}, fetch_list)._move_to_list()[0])) + return outs + + if __name__ == "__main__": unittest.main() -- GitLab