From 8b87d5eb5af0d7ae8d77498374a94747c081abd2 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 24 Nov 2021 11:33:42 +0800 Subject: [PATCH] [NewExe] Support HandleComplexGradToRealGrad to cast complex into Real (#37450) --- .../framework/new_executor/data_transfer.cc | 116 +++++++++++++++++- .../framework/new_executor/data_transfer.h | 18 ++- .../framework/new_executor/interpretercore.cc | 4 +- .../new_executor/interpretercore_util.cc | 20 +-- .../new_executor/interpretercore_util.h | 1 - .../new_executor/new_executor_defs.h | 1 + paddle/fluid/framework/operator.cc | 4 - paddle/fluid/framework/operator.h | 4 + 8 files changed, 145 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/framework/new_executor/data_transfer.cc b/paddle/fluid/framework/new_executor/data_transfer.cc index 5d0db4b4028..ca3647e7d3d 100644 --- a/paddle/fluid/framework/new_executor/data_transfer.cc +++ b/paddle/fluid/framework/new_executor/data_transfer.cc @@ -62,6 +62,24 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var, return is_transferred; } +void DataTranferHelper::RunAndConstructShareNode( + const std::string& src_var_name, const std::string& dst_var_name, + std::vector* op_func_nodes) { + VariableNameMap in_name_map = {{"X", {src_var_name}}}; + VariableNameMap out_name_map = {{"Out", {dst_var_name}}}; + AttributeMap attr_map; + + std::string op_type("share_data"); + auto& op_info = OpInfoMap::Instance().Get(op_type); + auto op = std::shared_ptr( + op_info.Creator()(op_type, in_name_map, out_name_map, attr_map)); + + VLOG(3) << string::Sprintf("Insert %s with %s -> %s.", op_type, src_var_name, + dst_var_name); + + RunAndConstructOpFuncNode(op, src_var_name, dst_var_name, op_func_nodes); +} + void DataTranferHelper::RunAndConstructOpFuncNode( const std::shared_ptr& op, const std::string& var_name, const std::string& new_var_name, @@ -133,7 +151,7 @@ std::shared_ptr TransferLayout(const std::string& var_name, VariableNameMap out_name_map = {{"Out", {*new_var_name}}}; AttributeMap attr_map = {{"dst_layout", static_cast(out_layout)}}; - // 3. Create transfer_op + // 3. Create transfer_layout_op std::string op_type("transfer_layout"); auto& op_info = OpInfoMap::Instance().Get(op_type); auto op = std::shared_ptr( @@ -154,9 +172,10 @@ std::shared_ptr TransferDtype(const std::string& var_name, *new_var_name = var_name + "_dtype_" + std::to_string(var_scope->VarSize() + 1); auto* ptr = local_scope->Var(new_var_name); - + var_scope->SetVarDesc(var_name, nullptr); auto var_type = var_scope->Var(var_name)->Type(); InitializeVariable(ptr, static_cast(var_type)); + VLOG(3) << "Create Variable " << *new_var_name << " locally, which pointer is " << ptr << "Variable Type " << var_type; @@ -171,7 +190,7 @@ std::shared_ptr TransferDtype(const std::string& var_name, // NOTE(Aurelius84): In whice case use_mkldnn = true? attr_map["use_mkldnn"] = false; - // 3. Create transfer_op + // 3. Create transfer_dtype_op std::string op_type("transfer_dtype"); auto& op_info = OpInfoMap::Instance().Get(op_type); auto op = std::shared_ptr( @@ -209,7 +228,7 @@ std::shared_ptr TransferDevice(const std::string& var_name, : platform::is_gpu_place(dst_place) ? 1 : -1; AttributeMap attr_map = {{"dst_place_type", dst_place_type}}; - // 3. Create transfer_op + // 3. Create memcpy_d2h_op or memcpy_h2d_op std::string op_type = get_memcpy_type(src_place, dst_place); auto& op_info = OpInfoMap::Instance().Get(op_type); auto op = std::shared_ptr( @@ -303,6 +322,95 @@ std::string get_memcpy_type(const platform::Place& src_place, } } +void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node, + const platform::Place& place, + const VariableNameMap& out_names, + VariableValueMap* out_vars, + VariableScope* var_scope, + std::vector* op_func_nodes, + framework::Scope* local_scope) { + DataTranferHelper data_transfer_helper(place, var_scope); + for (auto& var_name_item : out_names) { + std::vector& vars = out_vars->at(var_name_item.first); + for (size_t i = 0; i < var_name_item.second.size(); ++i) { + // 1. find grad_var & check whether is complex tensor + auto var_name = var_name_item.second[i]; + auto orig_var_name = framework::GradOriginalVarName(var_name); + // only focus on gradient var + if (var_name == orig_var_name) { + VLOG(3) << "skip " << var_name << " with same name as " + << orig_var_name; + continue; + } + auto* grad_var = vars[i]; + // skip nullptr var + if (grad_var == nullptr) { + VLOG(3) << "skip grad_var with nullptr"; + continue; + } + // don't process LoDTensorArray temporarily, + // add support if necessary for complex number calculations in the future + if (!framework::VarIsTensor(*grad_var)) { + VLOG(3) << "skip grad_var with LoDTensorArray type"; + continue; + } + auto* grad_tensor = + framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(grad_var); + // skip nullptr tensor + if (grad_tensor == nullptr || !grad_tensor->IsInitialized()) { + VLOG(3) << "skip with grad_tensor not IsInitialized"; + continue; + } + // only focus on complex dtype now + auto src_type = grad_tensor->type(); + if (!framework::IsComplexType(src_type)) { + VLOG(3) << "skip grad_tensor with not complexType"; + continue; + } + + // 2. find forward var & check whether need to cast + auto* var = var_scope->FindVar(orig_var_name); + // if forward var not exists, do nothing + if (var == nullptr) { + VLOG(3) << "skip " << orig_var_name << " with not found in var_scope"; + continue; + } + if (!framework::VarIsTensor(*var)) { + VLOG(3) << "skip " << orig_var_name << " with LoDTensorArray."; + continue; + } + const auto* tensor = + framework::GetLoDTensorOrSelectedRowsValueFromVar(*var); + PADDLE_ENFORCE_NOT_NULL( + tensor, + platform::errors::Unavailable( + "Forward tensor is nullptr when handle complex data to real.")); + // only need record type, the allocation may have been released + auto dst_type = tensor->saved_type(); + // only focus on real dtype and need casting + if (framework::IsComplexType(dst_type)) { + continue; + } + + // 3. cast complex grad to real grad inplacely + VLOG(3) << "Transform " << framework::DataTypeToString(src_type) + << " var `" << var_name << "` to " + << framework::DataTypeToString(dst_type) + << " real var in static graph."; + + // NOTE(Aurelius84): Consider to define a complex2real op to deal this + // case. + std::string new_var_name; + auto op = TransferDtype(var_name, &new_var_name, src_type, dst_type, + var_scope, local_scope); + data_transfer_helper.RunAndConstructOpFuncNode(op, var_name, new_var_name, + op_func_nodes); + data_transfer_helper.RunAndConstructShareNode(new_var_name, var_name, + op_func_nodes); + } + } +} + } // namespace interpreter } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/data_transfer.h b/paddle/fluid/framework/new_executor/data_transfer.h index d66188709f7..7744e955c85 100644 --- a/paddle/fluid/framework/new_executor/data_transfer.h +++ b/paddle/fluid/framework/new_executor/data_transfer.h @@ -37,14 +37,18 @@ class DataTranferHelper { const std::string& var_name, std::string* new_var_name, std::vector* new_op_func_nodes, bool use_local_scope); - private: - platform::Place place_; - VariableScope* var_scope_; + void RunAndConstructShareNode(const std::string& src_var_name, + const std::string& dst_var_name, + std::vector* op_func_nodes); void RunAndConstructOpFuncNode(const std::shared_ptr& op, const std::string& var_name, const std::string& new_var_name, std::vector* op_func_nodes); + + private: + platform::Place place_; + VariableScope* var_scope_; }; void ApplyDataTransform(const OpKernelType& expected_kernel_key, @@ -54,6 +58,14 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, std::vector* op_func_nodes, bool use_local_scope = true); +void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node, + const platform::Place& place, + const VariableNameMap& out_names, + VariableValueMap* out_vars, + VariableScope* var_scope, + std::vector* op_func_nodes, + framework::Scope* local_scope); + std::string get_memcpy_type(const platform::Place& src_place, const platform::Place& dst_place); diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index c7b922ad9e4..71cd49bd7ef 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -90,7 +90,7 @@ paddle::framework::FetchList InterpreterCore::Run( // return Fetch Tensors auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName); - return *(fetch_var->GetMutable()); + return std::move(*fetch_var->GetMutable()); } paddle::framework::FetchList InterpreterCore::Run( @@ -124,7 +124,7 @@ paddle::framework::FetchList InterpreterCore::Run( // return Fetch Tensors auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName); - return *(fetch_var->GetMutable()); + return std::move(*fetch_var->GetMutable()); } void InterpreterCore::BuildOperatorDependences() { diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index 061b2e7806b..fe4d1546ea1 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -328,20 +328,14 @@ void build_op_func_list(const platform::Place& place, ->GetExpectedKernelType( ExecutionContext(*op, scope, *dev_ctx, runtime_context)); - // consider device_guard() - apply_device_guard( - op, place, - &expected_kernel_key); // change device by the device_guard() + // change device by the device_guard() + apply_device_guard(op, place, &expected_kernel_key); VLOG(3) << "expected_kernel_key : " << expected_kernel_key; // step 3. apply data transforms and insert data transfer ops VariableValueMap& ins_map_temp = runtime_context.inputs; - std::vector new_op_func_nodes; ApplyDataTransform(expected_kernel_key, place, &ins_map_temp, var_scope, - &op_func_node, &new_op_func_nodes, use_local_scope); - for (auto& item : new_op_func_nodes) { - vec_func_list->emplace_back(std::move(item)); - } + &op_func_node, vec_func_list, use_local_scope); // step 4. Run op kernel VLOG(3) << op->Type() << " : expected_kernel_key : " << expected_kernel_key; @@ -370,6 +364,14 @@ void build_op_func_list(const platform::Place& place, op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second); op_func_node.kernel_func_(exec_ctx); + + // post-process grad_op.outputs if need cast complex grad into real grad. + // NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it. + if (framework::IsComplexType(expected_kernel_key.data_type_)) { + interpreter::HandleComplexGradToRealGrad( + op_func_node, place, outputs_names, &runtime_context.outputs, + var_scope, vec_func_list, local_scope); + } } vec_func_list->emplace_back(op_func_node); diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.h b/paddle/fluid/framework/new_executor/interpretercore_util.h index 9fc93afb5de..c92cea6c97c 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.h +++ b/paddle/fluid/framework/new_executor/interpretercore_util.h @@ -51,7 +51,6 @@ namespace framework { namespace interpreter { using AtomicVectorSizeT = std::vector>>; -static constexpr char kFetchVarName[] = "fetch"; class AsyncWorkQueue { public: diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index 94631b7ae64..76071f21819 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -374,6 +374,7 @@ class Instruction { namespace interpreter { static constexpr char kMemcpyH2D[] = "memcpy_h2d"; static constexpr char kMemcpyD2H[] = "memcpy_d2h"; +static constexpr char kFetchVarName[] = "fetch"; static bool IsMemcpyH2D(const Instruction& instr) { return instr.OpBase()->Type() == kMemcpyH2D; diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index cde5d1353d0..1a60acf49a4 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -479,10 +479,6 @@ void OperatorBase::GenerateTemporaryNames() { } } -static bool VarIsTensor(const Variable& var) { - return var.IsType() || var.IsType(); -} - const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var) { if (var.IsType()) { return static_cast(&(var.Get())); diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 6a5bac393ed..725657dd817 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -114,6 +114,10 @@ inline std::string GradOriginalVarName(const std::string& grad_var_name) { } } +inline bool VarIsTensor(const Variable& var) { + return var.IsType() || var.IsType(); +} + const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var); Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var); -- GitLab