From 63d2d722409ec4e260828f9c862d15e755bb9653 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Wed, 28 Dec 2022 14:15:40 +0800 Subject: [PATCH] [new-exec] Ahead-Of-Time choosing kernel (#48789) * add skip run * alloc minimum memory * skip check_size in Alloc * skip check_size in Alloc * skip check_size in Alloc * fix cases when tensor is initialized or empty * alloc empty output for place info * add test * increase timeout * format code * skip cpu * add cudnn_deterministic * fit for hostAlloc * follow comments * change check_size to fake_alloc --- paddle/fluid/eager/eager_tensor.h | 3 +- .../new_executor/interpreter/data_transfer.cc | 41 +++--- .../new_executor/interpreter/data_transfer.h | 15 ++- .../interpreter/interpreter_util.cc | 119 +++++++++++++++++- .../interpreter/interpreter_util.h | 6 +- .../framework/new_executor/interpretercore.cc | 96 +++++++------- .../framework/new_executor/interpretercore.h | 1 + paddle/fluid/framework/operator.cc | 2 + paddle/fluid/operators/batch_norm_op.cc | 1 + paddle/phi/core/dense_tensor.cc | 36 +++--- paddle/phi/core/dense_tensor.h | 3 +- paddle/phi/core/device_context.cc | 40 +++--- paddle/phi/core/device_context.h | 6 +- paddle/phi/core/extended_tensor.cc | 3 +- paddle/phi/core/extended_tensor.h | 3 +- paddle/phi/core/kernel_context.h | 2 + paddle/phi/core/selected_rows.h | 5 +- paddle/phi/core/selected_rows_impl.cc | 5 +- paddle/phi/core/selected_rows_impl.h | 3 +- paddle/phi/core/sparse_coo_tensor.cc | 6 +- paddle/phi/core/sparse_coo_tensor.h | 3 +- paddle/phi/core/sparse_csr_tensor.cc | 6 +- paddle/phi/core/sparse_csr_tensor.h | 3 +- paddle/phi/core/string_tensor.cc | 35 +++--- paddle/phi/core/string_tensor.h | 3 +- paddle/phi/core/tensor_array.cc | 6 +- paddle/phi/core/tensor_array.h | 3 +- paddle/phi/core/tensor_base.h | 3 +- .../standalone_executor/CMakeLists.txt | 2 + ...t_standalone_executor_aot_choose_kernel.py | 86 +++++++++++++ 30 files changed, 403 insertions(+), 143 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/standalone_executor/test_standalone_executor_aot_choose_kernel.py diff --git a/paddle/fluid/eager/eager_tensor.h b/paddle/fluid/eager/eager_tensor.h index 3c5ca540062..22b6d705538 100644 --- a/paddle/fluid/eager/eager_tensor.h +++ b/paddle/fluid/eager/eager_tensor.h @@ -136,7 +136,8 @@ class VariableCompatTensor void* AllocateFrom(phi::Allocator* allocator, phi::DataType dtype, - size_t requested_size = 0) override { + size_t requested_size = 0, + bool fake_alloc = false) override { PADDLE_THROW(paddle::platform::errors::Unavailable( "VariableCompatTensor does not support `AllocateFrom` method.")); } diff --git a/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc b/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc index dd16c484dca..d41a1dca448 100644 --- a/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc +++ b/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc @@ -33,7 +33,8 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var, std::string* new_var_name, std::vector* op_func_nodes, bool use_local_scope, - bool is_fetch_v2) { + bool is_fetch_v2, + bool skip_run) { bool is_transferred = false; auto* src_var_name = &var_name; @@ -48,7 +49,7 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var, is_fetch_v2); if (op) { RunAndConstructOpFuncNode( - op, *src_var_name, *new_var_name, op_func_nodes); + op, *src_var_name, *new_var_name, op_func_nodes, skip_run); } // update src_var_name src_var_name = new_var_name; @@ -64,7 +65,7 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var, scope_); if (op) { RunAndConstructOpFuncNode( - op, *src_var_name, *new_var_name, op_func_nodes); + op, *src_var_name, *new_var_name, op_func_nodes, skip_run); } // update src_var_name src_var_name = new_var_name; @@ -79,7 +80,7 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var, *src_var_name, new_var_name, src_place, dst_place, var_scope_, scope_); if (op) { RunAndConstructOpFuncNode( - op, *src_var_name, *new_var_name, op_func_nodes); + op, *src_var_name, *new_var_name, op_func_nodes, skip_run); } is_transferred = true; } @@ -89,7 +90,8 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var, void DataTranferHelper::RunAndConstructShareNode( const std::string& src_var_name, const std::string& dst_var_name, - std::vector* op_func_nodes) { + std::vector* op_func_nodes, + bool skip_run) { VariableNameMap in_name_map = {{"X", {src_var_name}}}; VariableNameMap out_name_map = {{"Out", {dst_var_name}}}; AttributeMap attr_map; @@ -102,14 +104,16 @@ void DataTranferHelper::RunAndConstructShareNode( VLOG(3) << string::Sprintf( "Insert %s with %s -> %s.", op_type, src_var_name, dst_var_name); - RunAndConstructOpFuncNode(op, src_var_name, dst_var_name, op_func_nodes); + RunAndConstructOpFuncNode( + op, src_var_name, dst_var_name, op_func_nodes, skip_run); } void DataTranferHelper::RunAndConstructOpFuncNode( const std::shared_ptr& op, const std::string& var_name, const std::string& new_var_name, - std::vector* new_op_func_nodes) { + std::vector* new_op_func_nodes, + bool skip_run) { auto& op_type = op->Type(); // 1. Construct RuntimeContext @@ -172,7 +176,13 @@ void DataTranferHelper::RunAndConstructOpFuncNode( phi::KernelContext phi_kernel_context; op_with_kernel->BuildPhiKernelContext( runtime_context, dev_ctx, &phi_kernel_context); - (*new_op_func_node.phi_kernel_)(&phi_kernel_context); + if (!skip_run) { + (*new_op_func_node.phi_kernel_)(&phi_kernel_context); + } else { + FakeInitializeOutputs(new_op_func_node.phi_kernel_, + op_with_kernel->PhiKernelSignature(), + &phi_kernel_context); + } } const phi::Place& place = dev_ctx->GetPlace(); @@ -425,7 +435,8 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, VariableScope* var_scope, OpFuncNode* op_func_node, std::vector* new_op_func_nodes, - bool use_local_scope) { + bool use_local_scope, + bool skip_run) { Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope() : var_scope->GetMutableScope(); @@ -500,7 +511,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, op_base->Type() == "fetch_v2"); if (op) { data_transfer_helper.RunAndConstructOpFuncNode( - op, var_name, new_var_name, new_op_func_nodes); + op, var_name, new_var_name, new_op_func_nodes, skip_run); } is_transferred = true; } else { @@ -524,7 +535,8 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, &new_var_name, new_op_func_nodes, use_local_scope, - op_base->Type() == "fetch_v2"); + op_base->Type() == "fetch_v2", + skip_run); } if (is_transferred) { @@ -575,7 +587,8 @@ void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node, VariableValueMap* out_vars, VariableScope* var_scope, std::vector* op_func_nodes, - framework::Scope* local_scope) { + framework::Scope* local_scope, + bool skip_run) { DataTranferHelper data_transfer_helper(place, var_scope, local_scope); for (auto& var_name_item : out_names) { std::vector& vars = out_vars->at(var_name_item.first); @@ -651,9 +664,9 @@ void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node, auto op = TransferDtype( var_name, &new_var_name, src_type, dst_type, var_scope, local_scope); data_transfer_helper.RunAndConstructOpFuncNode( - op, var_name, new_var_name, op_func_nodes); + op, var_name, new_var_name, op_func_nodes, skip_run); data_transfer_helper.RunAndConstructShareNode( - new_var_name, var_name, op_func_nodes); + new_var_name, var_name, op_func_nodes, skip_run); } } } diff --git a/paddle/fluid/framework/new_executor/interpreter/data_transfer.h b/paddle/fluid/framework/new_executor/interpreter/data_transfer.h index 6503179fe7c..e74fe8066e6 100644 --- a/paddle/fluid/framework/new_executor/interpreter/data_transfer.h +++ b/paddle/fluid/framework/new_executor/interpreter/data_transfer.h @@ -40,16 +40,19 @@ class DataTranferHelper { std::string* new_var_name, std::vector* new_op_func_nodes, bool use_local_scope, - bool is_fetch_v2); + bool is_fetch_v2, + bool skip_run = false); void RunAndConstructShareNode(const std::string& src_var_name, const std::string& dst_var_name, - std::vector* op_func_nodes); + std::vector* op_func_nodes, + bool skip_run = false); void RunAndConstructOpFuncNode(const std::shared_ptr& op, const std::string& var_name, const std::string& new_var_name, - std::vector* op_func_nodes); + std::vector* op_func_nodes, + bool skip_run = false); private: platform::Place place_; @@ -64,7 +67,8 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, VariableScope* var_scope, OpFuncNode* op_func_node, std::vector* op_func_nodes, - bool use_local_scope = true); + bool use_local_scope = true, + bool skip_run = false); void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node, const platform::Place& place, @@ -72,7 +76,8 @@ void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node, VariableValueMap* out_vars, VariableScope* var_scope, std::vector* op_func_nodes, - framework::Scope* local_scope); + framework::Scope* local_scope, + bool skip_run = false); inline bool need_device_transform(const OpKernelType& kernel_type_for_var, const OpKernelType& expected_kernel_key) { diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index 651ebf4c437..637de3ee1d0 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -38,6 +38,11 @@ PADDLE_DEFINE_EXPORTED_bool( false, "Log memory stats after each op runs, just used for debug."); +PADDLE_DEFINE_EXPORTED_bool( + new_executor_static_build, + false, + "Build the interpreterCore statically without running."); + DECLARE_bool(use_mkldnn); DECLARE_bool(check_nan_inf); @@ -157,6 +162,33 @@ bool IsMemcpyOp(const Instruction& instr) { return IsMemcpyD2H(instr) || IsMemcpyH2D(instr); } +bool IsBlockContainsOnlyPhiKernel(const framework::BlockDesc& block) { + bool res = true; + for (auto& op : block.AllOps()) { + auto op_type = op->Type(); + if (op_type == "feed" || op_type == "fetch_v2") { + continue; + } + auto has_phi_kernel = + !phi::KernelFactory::Instance() + .SelectKernelMap(phi::TransToPhiKernelName(op_type)) + .empty(); + + if (!has_phi_kernel) { + auto kernel_iter = OperatorWithKernel::AllOpKernels().find(op_type); + if (kernel_iter != OperatorWithKernel::AllOpKernels().end()) { + VLOG(4) << op_type << " has no phi kernel, but has fluid kernel."; + res = false; + } else { + VLOG(4) << op_type << " has no phi kernel, and no fluid kernel."; + } + } else { + VLOG(4) << op_type << " has phi kernel"; + } + } + return res; +} + void AddFetch(const std::vector& fetch_names, framework::BlockDesc* block) { auto* fetch_holder = block->Var(kFetchVarName); @@ -476,7 +508,66 @@ void HandleOperatorBase(const platform::Place& place, op_func_node->dev_ctx_ = dev_ctx; } -void BuildOpFuncList(const platform::Place& place, +void FakeInitializeOutputs(phi::Kernel* phi_kernel, + phi::KernelSignature* kernel_sig, + phi::KernelContext* phi_kernel_context) { + auto output_defs = phi_kernel->args_def().output_defs(); + auto out_names = kernel_sig->output_names; + + for (size_t i = 0; i < out_names.size(); ++i) { + VLOG(4) << out_names[i]; + // calcute the start and end index of the output tensors + size_t start_idx = phi_kernel_context->OutputRangeAt(i).first; + size_t end_idx = phi_kernel_context->OutputRangeAt(i).second; + for (size_t j = start_idx; j < end_idx; ++j) { + auto* out_tensor = phi_kernel_context->MutableOutputAt(j); + if (out_tensor == nullptr) { + VLOG(4) << "Output" << out_names[i] << " is nullptr"; + continue; + } + auto backend = output_defs[j].backend; + auto* dev_ctx = + &(phi_kernel_context->GetDeviceContext()); + + if (phi::DenseTensor::classof(out_tensor)) { + if (!out_tensor->initialized()) { + VLOG(4) << "DenseTensor fake alloc 0 bytes of type " + << out_tensor->dtype() << " on backend " << backend << " " + << out_tensor; + if (backend == phi::TransToPhiBackend(dev_ctx->GetPlace())) { + dev_ctx->Alloc(out_tensor, + out_tensor->dtype(), + /*requested_size=*/0, + /*pinned=*/false, + /*fake_alloc=*/true); + } else { + if (backend == phi::Backend::CPU || + backend == phi::Backend::ONEDNN) { + dev_ctx->HostAlloc(out_tensor, + out_tensor->dtype(), + /*requested_size=*/0, + /*fake_alloc=*/true); + } + } + } + } else if (phi::SparseCooTensor::classof(out_tensor)) { + // todo + VLOG(4) << "SparseCooTensor"; + } else if (phi::SparseCsrTensor::classof(out_tensor)) { + // todo + VLOG(4) << "SparseCsrTensor"; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support " + "DenseTensor/SparseCooTensor/SparseCsrTensor " + "now")); + VLOG(4) << "SparseCooTensor"; + } + } + } +} + +bool BuildOpFuncList(const platform::Place& place, const framework::BlockDesc& block, const std::set& skip_gc_vars, std::vector* vec_func_list, @@ -490,6 +581,10 @@ void BuildOpFuncList(const platform::Place& place, // Step 1: create all ops for current block. CreateAllOps(block, &ops_unique); + auto skip_run = + FLAGS_new_executor_static_build && IsBlockContainsOnlyPhiKernel(block); + VLOG(4) << "Static build: " << skip_run; + if (!execution_config.used_for_jit) { // If gc is enabled and block size > 1 const ProgramDesc& main_program = *block.Program(); @@ -676,6 +771,7 @@ void BuildOpFuncList(const platform::Place& place, } } } + VLOG(4) << "if run phi kernel? : " << run_phi_kernel; if (!run_phi_kernel) { op_with_kernel->ChooseKernel(exec_ctx); @@ -704,12 +800,14 @@ void BuildOpFuncList(const platform::Place& place, var_scope, &op_func_node, vec_func_list, - use_local_scope); + use_local_scope, + skip_run); VLOG(4) << "apply data transform done. "; // step 4. infershape, see OperatorWithKernel::RunImpl in operator.cc // for why. if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) && op->Attr(kAllKernelsMustComputeRuntimeShape))) { + VLOG(4) << "infer shape"; InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context); // TODO(Aurelius84): In case of control flow ops, they are NOT @@ -722,11 +820,21 @@ void BuildOpFuncList(const platform::Place& place, phi::KernelContext phi_kernel_context; op_with_kernel->BuildPhiKernelContext( runtime_context, dev_ctx, &phi_kernel_context); - (*op_func_node.phi_kernel_)(&phi_kernel_context); + if (!skip_run) { + (*op_func_node.phi_kernel_)(&phi_kernel_context); + } else { + FakeInitializeOutputs(op_func_node.phi_kernel_, + op_with_kernel->PhiKernelSignature(), + &phi_kernel_context); + } } else { // the place of exec_ctx maybe has changed. - op_func_node.kernel_func_(ExecutionContext( - *op_with_kernel, *runtime_scope, *dev_ctx, runtime_context)); + if (!skip_run) { + op_func_node.kernel_func_(ExecutionContext( + *op_with_kernel, *runtime_scope, *dev_ctx, runtime_context)); + } else { + // TODO(zhiqiu): is it needed to support fluid kernel? + } } // post-process grad_op.outputs if need cast complex grad into real @@ -812,6 +920,7 @@ void BuildOpFuncList(const platform::Place& place, interpreter::LogDeviceMemoryStats(place); } + return skip_run; } void LogDeviceMemoryStats(const platform::Place& place) { diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h index d6652d26541..72653fe916e 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h @@ -82,7 +82,7 @@ bool IsSupportedHeterPlace(const phi::Place& place); void AddFetch(const std::vector& fetch_names, framework::BlockDesc* block); -void BuildOpFuncList(const platform::Place& place, +bool BuildOpFuncList(const platform::Place& place, const framework::BlockDesc& block, const std::set& skip_gc_vars, std::vector* vec_func_list, @@ -96,6 +96,10 @@ void BuildVariableScope(const framework::BlockDesc& block, void LogDeviceMemoryStats(const platform::Place& place); +void FakeInitializeOutputs(phi::Kernel* phi_kernel, + phi::KernelSignature* kernel_sig, + phi::KernelContext* phi_kernel_context); + } // namespace interpreter } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 5ad12071bd3..e7b80372614 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -188,6 +188,35 @@ interpreter::CostInfo InterpreterCore::DryRun( return cost_info; } +void InterpreterCore::RunImpl() { + // For the program that only run once, it is no need to + // create work_queue, so the async_work_queue_ is created + // until the second step run. + async_work_queue_ = GetWorkQueue(); + + // lazy initialization of gc, do not create gc is the program only run once + if (!gc_) { + gc_ = CreateInterpreterCoreGarbageCollector(place_, vec_instruction_); + } + + if (execution_config_.used_for_jit && (sync_op_num_ == 0)) { + VLOG(4) << "Tracing Instruction List"; + TraceInstructionList(vec_instruction_); + } else { + ExecuteInstructionList(vec_instruction_); + } +#ifdef PADDLE_WITH_ASCEND_CL + if (platform::is_npu_place(place_)) { + platform::DeviceContextPool::Instance().Get(place_)->Wait(); + } +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + if (platform::is_custom_place(place_)) { + platform::DeviceContextPool::Instance().Get(place_)->Wait(); + } +#endif +} + paddle::framework::FetchList InterpreterCore::Run( const std::vector& feed_names, const std::vector& feed_tensors) { @@ -201,33 +230,9 @@ paddle::framework::FetchList InterpreterCore::Run( Prepare(feed_names, feed_tensors, is_build); if (is_build) { - // For the program that only run once, it is no need to - // create work_queue, so the async_work_queue_ is created - // until the second step run. - async_work_queue_ = GetWorkQueue(); - - // lazy initialization of gc, do not create gc is the program only run once - if (!gc_) { - gc_ = CreateInterpreterCoreGarbageCollector(place_, vec_instruction_); - } - - if (execution_config_.used_for_jit && (sync_op_num_ == 0)) { - VLOG(4) << "Tracing Instruction List"; - TraceInstructionList(vec_instruction_); - } else { - ExecuteInstructionList(vec_instruction_); - } -#ifdef PADDLE_WITH_ASCEND_CL - if (platform::is_npu_place(place_)) { - platform::DeviceContextPool::Instance().Get(place_)->Wait(); - } -#endif -#ifdef PADDLE_WITH_CUSTOM_DEVICE - if (platform::is_custom_place(place_)) { - platform::DeviceContextPool::Instance().Get(place_)->Wait(); - } -#endif + RunImpl(); } + if (HasLocalScope()) { ClearLoDTensorArrayInLocalScope(); } @@ -255,7 +260,7 @@ paddle::framework::FetchList InterpreterCore::Run( block_, &var_scope_, HasLocalScope()); std::vector op_func_nodes; - paddle::framework::interpreter::BuildOpFuncList( + auto skip_run = paddle::framework::interpreter::BuildOpFuncList( place_, block_, execution_config_.skip_gc_vars, @@ -268,33 +273,12 @@ paddle::framework::FetchList InterpreterCore::Run( Convert(&op_func_nodes); is_build_ = true; UpdateSyncOpNum(); - } else { - // For the program that only run once, it is no need to - // create work_queue, so the async_work_queue_ is created - // until the second step run. - async_work_queue_ = GetWorkQueue(); - - // lazy initialization of gc, do not create gc is the program only run once - if (!gc_) { - gc_ = CreateInterpreterCoreGarbageCollector(place_, vec_instruction_); - } - - if (execution_config_.used_for_jit && (sync_op_num_ == 0)) { - VLOG(4) << "Tracing Instruction List"; - TraceInstructionList(vec_instruction_); - } else { - ExecuteInstructionList(vec_instruction_); - } -#ifdef PADDLE_WITH_ASCEND_CL - if (platform::is_npu_place(place_)) { - platform::DeviceContextPool::Instance().Get(place_)->Wait(); + if (skip_run) { + VLOG(4) << "RUN impl"; + RunImpl(); } -#endif -#ifdef PADDLE_WITH_CUSTOM_DEVICE - if (platform::is_custom_place(place_)) { - platform::DeviceContextPool::Instance().Get(place_)->Wait(); - } -#endif + } else { + RunImpl(); } if (HasLocalScope()) { @@ -1197,7 +1181,7 @@ void InterpreterCore::Prepare(const std::vector& feed_names, block_, &var_scope_, HasLocalScope()); FeedInput(); std::vector op_func_nodes; - paddle::framework::interpreter::BuildOpFuncList( + auto skip_run = paddle::framework::interpreter::BuildOpFuncList( place_, block_, execution_config_.skip_gc_vars, @@ -1210,6 +1194,10 @@ void InterpreterCore::Prepare(const std::vector& feed_names, Convert(&op_func_nodes); UpdateSyncOpNum(); is_build_ = true; + if (skip_run) { + VLOG(4) << "RUN impl"; + RunImpl(); + } } // NOTE: Because feed_tensor will be GC after // paddle::framework::BuildOpFuncList, so we should diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index a09942387a9..74ff5c56365 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -98,6 +98,7 @@ class InterpreterCore { void SetFeedVarsInplaceSkip(const std::vector& feed_names); // execution + void RunImpl(); void ExecuteInstructionList(const std::vector& vec_instr); void RunInstructionAsync(size_t instr_id); void RunInstruction(const Instruction& instr_node); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 9773d90c5cd..ae216b1e499 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -3028,6 +3028,7 @@ void OperatorWithKernel::BuildPhiKernelContext( (i == 0 ? 0 : phi_kernel_context->OutputRangeAt(i - 1).second); if (it == ctx.outputs.end() || it->second.empty()) { + VLOG(4) << "Output " << output_names[i] << " not found"; // Deal with the case that some outputs are not found or be NULL when run // the kernel. // For example : the outputs of matmul_grad are dx and dy, @@ -3073,6 +3074,7 @@ void OperatorWithKernel::BuildPhiKernelContext( framework::ToTypeName(var->Type()))); } } else { + VLOG(4) << "Output " << output_names[i] << " is nullptr"; phi_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out); } } diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index b4a24c84bcc..32cb10ec890 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -553,6 +553,7 @@ REGISTER_OPERATOR(batch_norm, ops::BatchNormOpInferVarType, ops::BatchNormGradMaker, ops::BatchNormGradMaker); + REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp, ops::BatchNormDoubleGradMaker, diff --git a/paddle/phi/core/dense_tensor.cc b/paddle/phi/core/dense_tensor.cc index 3d717969afa..73685f598f8 100644 --- a/paddle/phi/core/dense_tensor.cc +++ b/paddle/phi/core/dense_tensor.cc @@ -98,7 +98,8 @@ bool DenseTensor::IsSharedWith(const DenseTensor& b) const { void* DenseTensor::AllocateFrom(Allocator* allocator, DataType dtype, - size_t requested_size) { + size_t requested_size, + bool fake_alloc) { PADDLE_ENFORCE_NOT_NULL( allocator, phi::errors::InvalidArgument( @@ -107,21 +108,28 @@ void* DenseTensor::AllocateFrom(Allocator* allocator, VLOG(10) << "change data type in mutbale_data, target dtype - " << dtype; meta_.dtype = dtype; } - PADDLE_ENFORCE( - valid(), - phi::errors::PreconditionNotMet( - "The meta data must be valid when call the mutable data function.")); + size_t bytes = numel() * SizeOf(this->dtype()); - if (requested_size) { - PADDLE_ENFORCE_GE(requested_size, - bytes, - phi::errors::InvalidArgument( - "The reserved size %d should be enough to meet the " - "volume required by metadata %d.", - requested_size, - bytes)); - bytes = requested_size; + + if (fake_alloc) { + bytes = 0; + } else { + PADDLE_ENFORCE( + valid(), + phi::errors::PreconditionNotMet("The meta data must be valid when " + "call the mutable data function.")); + if (requested_size) { + PADDLE_ENFORCE_GE(requested_size, + bytes, + phi::errors::InvalidArgument( + "The reserved size %d should be enough to meet the " + "volume required by metadata %d.", + requested_size, + bytes)); + bytes = requested_size; + } } + // NOTE(paddle-dev): In case of the allocator of storage_ is different with // the incoming allocator, we will re-alloc data using the incoming // allocator. See DeviceContext.Alloc in core/device_context.cc. diff --git a/paddle/phi/core/dense_tensor.h b/paddle/phi/core/dense_tensor.h index c5f38b76216..a6c5c358262 100644 --- a/paddle/phi/core/dense_tensor.h +++ b/paddle/phi/core/dense_tensor.h @@ -125,7 +125,8 @@ class DenseTensor : public TensorBase, /// \return The mutable data pointer value of type T. void* AllocateFrom(Allocator* allocator, DataType dtype, - size_t requested_size = 0) override; + size_t requested_size = 0, + bool fake_alloc = false) override; /// \brief Check if allocation is shared with other objects. /// \return Whether the allocation is shared with other objects. diff --git a/paddle/phi/core/device_context.cc b/paddle/phi/core/device_context.cc index 60747e36185..721b92aa584 100644 --- a/paddle/phi/core/device_context.cc +++ b/paddle/phi/core/device_context.cc @@ -134,7 +134,8 @@ struct DeviceContext::Impl { const Place& place, DataType dtype = DataType::UNDEFINED, size_t requested_size = 0, - bool pinned = false) const { + bool pinned = false, + bool fake_alloc = false) const { PADDLE_ENFORCE_NOT_NULL( tensor, phi::errors::InvalidArgument( @@ -148,9 +149,10 @@ struct DeviceContext::Impl { if (tensor->initialized() && tensor->place() != place) { ClearHolder(tensor); } - auto* allocator = tensor->numel() == 0 && requested_size == 0 - ? zero_allocator_ - : (pinned ? pinned_allocator_ : device_allocator_); + auto* allocator = + (tensor->numel() == 0 || fake_alloc) && requested_size == 0 + ? zero_allocator_ + : (pinned ? pinned_allocator_ : device_allocator_); #ifdef PADDLE_WITH_CUDA bool must_cuda_graph_allocator = (tensor->numel() != 0) && !pinned; if (must_cuda_graph_allocator && @@ -164,7 +166,7 @@ struct DeviceContext::Impl { } #endif return tensor->AllocateFrom( - const_cast(allocator), dtype, requested_size); + const_cast(allocator), dtype, requested_size, fake_alloc); } template @@ -178,7 +180,8 @@ struct DeviceContext::Impl { void* HostAlloc(TensorBase* tensor, DataType dtype = DataType::UNDEFINED, - size_t requested_size = 0) const { + size_t requested_size = 0, + bool fake_alloc = false) const { PADDLE_ENFORCE_NOT_NULL( tensor, phi::errors::InvalidArgument( @@ -190,9 +193,11 @@ struct DeviceContext::Impl { ClearHolder(tensor); } auto* allocator = - tensor->numel() == 0 ? host_zero_allocator_ : host_allocator_; + (tensor->numel() == 0 || fake_alloc) && requested_size == 0 + ? host_zero_allocator_ + : host_allocator_; return tensor->AllocateFrom( - const_cast(allocator), dtype, requested_size); + const_cast(allocator), dtype, requested_size, fake_alloc); } template @@ -342,12 +347,18 @@ const Allocator& DeviceContext::GetPinnedAllocator() const { void* DeviceContext::Alloc(TensorBase* tensor, DataType dtype, size_t requested_size, - bool pinned) const { + bool pinned, + bool fake_alloc) const { if (pinned) { - return impl_->Alloc( - tensor, GetPinnedPlace(GetPlace()), dtype, requested_size, pinned); + return impl_->Alloc(tensor, + GetPinnedPlace(GetPlace()), + dtype, + requested_size, + pinned, + fake_alloc); } - return impl_->Alloc(tensor, GetPlace(), dtype, requested_size, pinned); + return impl_->Alloc( + tensor, GetPlace(), dtype, requested_size, pinned, fake_alloc); } template @@ -363,8 +374,9 @@ T* DeviceContext::Alloc(TensorBase* tensor, void* DeviceContext::HostAlloc(TensorBase* tensor, DataType dtype, - size_t requested_size) const { - return impl_->HostAlloc(tensor, dtype, requested_size); + size_t requested_size, + bool fake_alloc) const { + return impl_->HostAlloc(tensor, dtype, requested_size, fake_alloc); } template diff --git a/paddle/phi/core/device_context.h b/paddle/phi/core/device_context.h index 9114490d1a7..f233a5f185c 100644 --- a/paddle/phi/core/device_context.h +++ b/paddle/phi/core/device_context.h @@ -149,7 +149,8 @@ class PADDLE_API DeviceContext { void* Alloc(TensorBase*, DataType dtype, size_t requested_size = 0, - bool pinned = false) const; + bool pinned = false, + bool fake_alloc = false) const; template T* Alloc(TensorBase* tensor, @@ -161,7 +162,8 @@ class PADDLE_API DeviceContext { */ void* HostAlloc(TensorBase* tensor, DataType dtype, - size_t requested_size = 0) const; + size_t requested_size = 0, + bool fake_alloc = false) const; template T* HostAlloc(TensorBase* tensor, size_t requested_size = 0) const; diff --git a/paddle/phi/core/extended_tensor.cc b/paddle/phi/core/extended_tensor.cc index 6ffbcf40122..e5b5c3773f8 100644 --- a/paddle/phi/core/extended_tensor.cc +++ b/paddle/phi/core/extended_tensor.cc @@ -53,7 +53,8 @@ bool ExtendedTensor::initialized() const { void* ExtendedTensor::AllocateFrom(Allocator* allocator, DataType dtype, - size_t requested_size) { + size_t requested_size, + bool fake_alloc) { PADDLE_THROW(phi::errors::Unavailable( "ExtendedTensor does not support `AllocateFrom` method.")); } diff --git a/paddle/phi/core/extended_tensor.h b/paddle/phi/core/extended_tensor.h index 404e1014bb3..66c4987fb4c 100644 --- a/paddle/phi/core/extended_tensor.h +++ b/paddle/phi/core/extended_tensor.h @@ -49,7 +49,8 @@ class ExtendedTensor : public TensorBase { void* AllocateFrom(Allocator* allocator, DataType dtype, - size_t requested_size = 0) override; + size_t requested_size = 0, + bool fake_alloc = false) override; }; } // namespace phi diff --git a/paddle/phi/core/kernel_context.h b/paddle/phi/core/kernel_context.h index 58afc8c5fd2..020312fdbcb 100644 --- a/paddle/phi/core/kernel_context.h +++ b/paddle/phi/core/kernel_context.h @@ -119,6 +119,8 @@ class KernelContext { return static_cast(outputs_.at(idx)); } + TensorBase* MutableOutputAt(size_t idx) { return outputs_.at(idx); } + template std::vector MutableOutputBetween(size_t start, size_t end) { std::vector v; diff --git a/paddle/phi/core/selected_rows.h b/paddle/phi/core/selected_rows.h index c011605809e..08d02bee40d 100644 --- a/paddle/phi/core/selected_rows.h +++ b/paddle/phi/core/selected_rows.h @@ -90,8 +90,9 @@ class SelectedRows : public TensorBase, void* AllocateFrom(Allocator* allocator, DataType dtype, - size_t requested_size = 0) override { - return impl_->AllocateFrom(allocator, dtype, requested_size); + size_t requested_size = 0, + bool fake_alloc = false) override { + return impl_->AllocateFrom(allocator, dtype, requested_size, fake_alloc); } /* diff --git a/paddle/phi/core/selected_rows_impl.cc b/paddle/phi/core/selected_rows_impl.cc index f099ea711f1..45a724fe6f5 100644 --- a/paddle/phi/core/selected_rows_impl.cc +++ b/paddle/phi/core/selected_rows_impl.cc @@ -94,8 +94,9 @@ struct TensorFillVisitor { void* SelectedRowsImpl::AllocateFrom(Allocator* allocator, DataType dtype, - size_t requested_size) { - return value_->AllocateFrom(allocator, dtype, requested_size); + size_t requested_size, + bool fake_alloc) { + return value_->AllocateFrom(allocator, dtype, requested_size, fake_alloc); } bool SelectedRowsImpl::HasKey(int64_t key) const { diff --git a/paddle/phi/core/selected_rows_impl.h b/paddle/phi/core/selected_rows_impl.h index 3c54b59a159..d4a42a9653b 100644 --- a/paddle/phi/core/selected_rows_impl.h +++ b/paddle/phi/core/selected_rows_impl.h @@ -109,7 +109,8 @@ class SelectedRowsImpl { void* AllocateFrom(Allocator* allocator, DataType dtype, - size_t requested_size = 0); + size_t requested_size = 0, + bool fake_alloc = false); /* * @brief Get the index of the key from id_to_index_ map. If the key not diff --git a/paddle/phi/core/sparse_coo_tensor.cc b/paddle/phi/core/sparse_coo_tensor.cc index 8df031421fe..6d3296e2852 100644 --- a/paddle/phi/core/sparse_coo_tensor.cc +++ b/paddle/phi/core/sparse_coo_tensor.cc @@ -67,8 +67,10 @@ SparseCooTensor SparseCooTensor::operator=(const SparseCooTensor& other) { void* SparseCooTensor::AllocateFrom(Allocator* allocator, DataType dtype, - size_t requested_size) { - return non_zero_elements_.AllocateFrom(allocator, dtype, requested_size); + size_t requested_size, + bool fake_alloc) { + return non_zero_elements_.AllocateFrom( + allocator, dtype, requested_size, fake_alloc); } int64_t SparseCooTensor::nnz() const { diff --git a/paddle/phi/core/sparse_coo_tensor.h b/paddle/phi/core/sparse_coo_tensor.h index a28229996c8..13fc7d444b4 100644 --- a/paddle/phi/core/sparse_coo_tensor.h +++ b/paddle/phi/core/sparse_coo_tensor.h @@ -170,7 +170,8 @@ class SparseCooTensor : public TensorBase, /// \brief This function is not recommended void* AllocateFrom(Allocator* allocator, DataType dtype, - size_t requested_size = 0) override; + size_t requested_size = 0, + bool fake_alloc = false) override; /// \brief get the sparse dim int32_t sparse_dim() const; diff --git a/paddle/phi/core/sparse_csr_tensor.cc b/paddle/phi/core/sparse_csr_tensor.cc index 5c793048ea3..0b4662760c0 100644 --- a/paddle/phi/core/sparse_csr_tensor.cc +++ b/paddle/phi/core/sparse_csr_tensor.cc @@ -82,8 +82,10 @@ SparseCsrTensor& SparseCsrTensor::operator=(const SparseCsrTensor& other) { void* SparseCsrTensor::AllocateFrom(Allocator* allocator, DataType dtype, - size_t requested_size) { - return non_zero_elements_.AllocateFrom(allocator, dtype, requested_size); + size_t requested_size, + bool fake_alloc) { + return non_zero_elements_.AllocateFrom( + allocator, dtype, requested_size, fake_alloc); } void SparseCsrTensor::Resize(const DDim& dense_dims, diff --git a/paddle/phi/core/sparse_csr_tensor.h b/paddle/phi/core/sparse_csr_tensor.h index 2acb35915a9..4d607188d2e 100644 --- a/paddle/phi/core/sparse_csr_tensor.h +++ b/paddle/phi/core/sparse_csr_tensor.h @@ -62,7 +62,8 @@ class SparseCsrTensor : public TensorBase, /// \brief This function is not recommended void* AllocateFrom(Allocator* allocator, DataType dtype, - size_t requested_size = 0) override; + size_t requested_size = 0, + bool fake_alloc = false) override; public: /// \brief Returns the name of the class for type traits. diff --git a/paddle/phi/core/string_tensor.cc b/paddle/phi/core/string_tensor.cc index 89272e1de59..bab2a8772c6 100644 --- a/paddle/phi/core/string_tensor.cc +++ b/paddle/phi/core/string_tensor.cc @@ -130,25 +130,32 @@ void StringTensor::init_holder() { void* StringTensor::AllocateFrom(Allocator* allocator, DataType dtype, - size_t requested_size) { + size_t requested_size, + bool fake_alloc) { PADDLE_ENFORCE_NOT_NULL( allocator, errors::InvalidArgument( "Required allocator shall not be nullptr, but received nullptr.")); - PADDLE_ENFORCE( - valid(), - errors::PreconditionNotMet( - "The meta data must be valid when call the mutable data function.")); + size_t bytes = numel() * SizeOf(this->dtype()); - if (requested_size) { - PADDLE_ENFORCE_GE(requested_size, - bytes, - errors::InvalidArgument( - "The reserved size %d should be enough to meet the " - "volume required by metadata %d.", - requested_size, - bytes)); - bytes = requested_size; + if (fake_alloc) { + bytes = 0; + } else { + PADDLE_ENFORCE( + valid(), + errors::PreconditionNotMet("The meta data must be valid when call the " + "mutable data function.")); + if (requested_size) { + PADDLE_ENFORCE_GE(requested_size, + bytes, + errors::InvalidArgument( + "The reserved size %d should be enough to meet the " + "volume required by metadata %d.", + requested_size, + bytes)); + + bytes = requested_size; + } } if (!holder_ || holder_->size() < bytes + meta_.offset) { diff --git a/paddle/phi/core/string_tensor.h b/paddle/phi/core/string_tensor.h index 80d6b69aa6c..ccf89d88e42 100644 --- a/paddle/phi/core/string_tensor.h +++ b/paddle/phi/core/string_tensor.h @@ -123,7 +123,8 @@ class StringTensor : public TensorBase, } void* AllocateFrom(Allocator* allocator, DataType dtype, - size_t requested_size = 0) override; + size_t requested_size = 0, + bool fake_alloc = false) override; dtype::pstring* mutable_data(const phi::Place& place, size_t requested_size = 0); diff --git a/paddle/phi/core/tensor_array.cc b/paddle/phi/core/tensor_array.cc index 2007e71d5a0..43089d95254 100644 --- a/paddle/phi/core/tensor_array.cc +++ b/paddle/phi/core/tensor_array.cc @@ -65,9 +65,11 @@ bool TensorArray::valid() const { /// \return Void pointer void* TensorArray::AllocateFrom(Allocator* allocator, DataType dtype, - size_t requested_size) { + size_t requested_size, + bool fake_allc) { for (size_t i = 0; i < tensors_.size(); i++) { - tensors_[i].AllocateFrom(allocator, tensors_[i].dtype(), requested_size); + tensors_[i].AllocateFrom( + allocator, tensors_[i].dtype(), requested_size, fake_allc); } return nullptr; } diff --git a/paddle/phi/core/tensor_array.h b/paddle/phi/core/tensor_array.h index 6d834a9375a..14679429ea7 100644 --- a/paddle/phi/core/tensor_array.h +++ b/paddle/phi/core/tensor_array.h @@ -83,7 +83,8 @@ class TensorArray : public TensorBase, /// \return Void pointer void* AllocateFrom(Allocator* allocator, DataType dtype, - size_t requested_size = 0) override; + size_t requested_size = 0, + bool fake_alloc = false) override; bool empty() const { return tensors_.empty(); } diff --git a/paddle/phi/core/tensor_base.h b/paddle/phi/core/tensor_base.h index 3dc0e455a63..351eec8d2a5 100644 --- a/paddle/phi/core/tensor_base.h +++ b/paddle/phi/core/tensor_base.h @@ -66,7 +66,8 @@ class TensorBase { /// \return The mutable data pointer value of type T. virtual void* AllocateFrom(Allocator* allocator, DataType dtype, - size_t requested_size = 0) = 0; + size_t requested_size = 0, + bool fake_alloc = false) = 0; /// \brief Return the type information of the derived class to support /// safely downcast in non-rtti environment. diff --git a/python/paddle/fluid/tests/unittests/standalone_executor/CMakeLists.txt b/python/paddle/fluid/tests/unittests/standalone_executor/CMakeLists.txt index a9832154200..d6a1fa1c9be 100644 --- a/python/paddle/fluid/tests/unittests/standalone_executor/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/standalone_executor/CMakeLists.txt @@ -25,3 +25,5 @@ py_test_modules( FLAGS_host_trace_level=10 FLAGS_static_executor_perfstat_filepath=./perfstat) set_tests_properties(test_standalone_cross_step_overlap PROPERTIES TIMEOUT 30) +set_tests_properties(test_standalone_executor_aot_choose_kernel + PROPERTIES TIMEOUT 60) diff --git a/python/paddle/fluid/tests/unittests/standalone_executor/test_standalone_executor_aot_choose_kernel.py b/python/paddle/fluid/tests/unittests/standalone_executor/test_standalone_executor_aot_choose_kernel.py new file mode 100644 index 00000000000..4281d8b76f4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/standalone_executor/test_standalone_executor_aot_choose_kernel.py @@ -0,0 +1,86 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.framework import set_flags + +paddle.enable_static() + + +def build_resnet50(): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + + with paddle.static.program_guard(main_program, startup_program): + image = paddle.static.data( + name='image', shape=[32, 3, 224, 224], dtype='float32' + ) + label = paddle.static.data(name='label', shape=[32], dtype='int64') + model = paddle.vision.models.resnet50() + prediction = model(image) + loss = paddle.nn.functional.cross_entropy(input=prediction, label=label) + loss = paddle.mean(loss) + adam = paddle.optimizer.Adam(learning_rate=0.001) + adam.minimize(loss) + + return main_program, startup_program, loss + + +class TestAOTChooseKernel(unittest.TestCase): + def test_aot_choose_kernel(self): + if not paddle.fluid.core.is_compiled_with_cuda(): + return + + def run(aot_choose_kernel=None): + paddle.seed(2022) + np.random.seed(2022) + + main_program, startup_program, loss = build_resnet50() + + scope = paddle.static.Scope() + exe = paddle.static.Executor() + + set_flags({'FLAGS_cudnn_deterministic': 1}) + if aot_choose_kernel: + set_flags({'FLAGS_new_executor_static_build': 1}) + else: + set_flags({'FLAGS_new_executor_static_build': 0}) + + with paddle.static.scope_guard(scope): + exe.run(startup_program) + + for i in range(10): + feed = { + 'image': np.random.randint( + 0, 256, size=[32, 3, 224, 224] + ).astype('float32'), + 'label': np.random.randint(0, 1000, size=[32]).astype( + 'int64' + ), + } + loss_ = exe.run(main_program, feed=feed, fetch_list=[loss]) + return loss_ + + loss1 = run(aot_choose_kernel=True) + loss2 = run(aot_choose_kernel=False) + + self.assertEqual(loss1, loss2) + + +if __name__ == "__main__": + unittest.main() -- GitLab