未验证 提交 8b87d5eb 编写于 作者: A Aurelius84 提交者: GitHub

[NewExe] Support HandleComplexGradToRealGrad to cast complex into Real (#37450)

上级 1c969d20
...@@ -62,6 +62,24 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var, ...@@ -62,6 +62,24 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var,
return is_transferred; return is_transferred;
} }
void DataTranferHelper::RunAndConstructShareNode(
const std::string& src_var_name, const std::string& dst_var_name,
std::vector<OpFuncNode>* op_func_nodes) {
VariableNameMap in_name_map = {{"X", {src_var_name}}};
VariableNameMap out_name_map = {{"Out", {dst_var_name}}};
AttributeMap attr_map;
std::string op_type("share_data");
auto& op_info = OpInfoMap::Instance().Get(op_type);
auto op = std::shared_ptr<OperatorBase>(
op_info.Creator()(op_type, in_name_map, out_name_map, attr_map));
VLOG(3) << string::Sprintf("Insert %s with %s -> %s.", op_type, src_var_name,
dst_var_name);
RunAndConstructOpFuncNode(op, src_var_name, dst_var_name, op_func_nodes);
}
void DataTranferHelper::RunAndConstructOpFuncNode( void DataTranferHelper::RunAndConstructOpFuncNode(
const std::shared_ptr<OperatorBase>& op, const std::string& var_name, const std::shared_ptr<OperatorBase>& op, const std::string& var_name,
const std::string& new_var_name, const std::string& new_var_name,
...@@ -133,7 +151,7 @@ std::shared_ptr<OperatorBase> TransferLayout(const std::string& var_name, ...@@ -133,7 +151,7 @@ std::shared_ptr<OperatorBase> TransferLayout(const std::string& var_name,
VariableNameMap out_name_map = {{"Out", {*new_var_name}}}; VariableNameMap out_name_map = {{"Out", {*new_var_name}}};
AttributeMap attr_map = {{"dst_layout", static_cast<int>(out_layout)}}; AttributeMap attr_map = {{"dst_layout", static_cast<int>(out_layout)}};
// 3. Create transfer_op // 3. Create transfer_layout_op
std::string op_type("transfer_layout"); std::string op_type("transfer_layout");
auto& op_info = OpInfoMap::Instance().Get(op_type); auto& op_info = OpInfoMap::Instance().Get(op_type);
auto op = std::shared_ptr<OperatorBase>( auto op = std::shared_ptr<OperatorBase>(
...@@ -154,9 +172,10 @@ std::shared_ptr<OperatorBase> TransferDtype(const std::string& var_name, ...@@ -154,9 +172,10 @@ std::shared_ptr<OperatorBase> TransferDtype(const std::string& var_name,
*new_var_name = *new_var_name =
var_name + "_dtype_" + std::to_string(var_scope->VarSize() + 1); var_name + "_dtype_" + std::to_string(var_scope->VarSize() + 1);
auto* ptr = local_scope->Var(new_var_name); auto* ptr = local_scope->Var(new_var_name);
var_scope->SetVarDesc(var_name, nullptr);
auto var_type = var_scope->Var(var_name)->Type(); auto var_type = var_scope->Var(var_name)->Type();
InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type)); InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type));
VLOG(3) << "Create Variable " << *new_var_name VLOG(3) << "Create Variable " << *new_var_name
<< " locally, which pointer is " << ptr << "Variable Type " << " locally, which pointer is " << ptr << "Variable Type "
<< var_type; << var_type;
...@@ -171,7 +190,7 @@ std::shared_ptr<OperatorBase> TransferDtype(const std::string& var_name, ...@@ -171,7 +190,7 @@ std::shared_ptr<OperatorBase> TransferDtype(const std::string& var_name,
// NOTE(Aurelius84): In whice case use_mkldnn = true? // NOTE(Aurelius84): In whice case use_mkldnn = true?
attr_map["use_mkldnn"] = false; attr_map["use_mkldnn"] = false;
// 3. Create transfer_op // 3. Create transfer_dtype_op
std::string op_type("transfer_dtype"); std::string op_type("transfer_dtype");
auto& op_info = OpInfoMap::Instance().Get(op_type); auto& op_info = OpInfoMap::Instance().Get(op_type);
auto op = std::shared_ptr<OperatorBase>( auto op = std::shared_ptr<OperatorBase>(
...@@ -209,7 +228,7 @@ std::shared_ptr<OperatorBase> TransferDevice(const std::string& var_name, ...@@ -209,7 +228,7 @@ std::shared_ptr<OperatorBase> TransferDevice(const std::string& var_name,
: platform::is_gpu_place(dst_place) ? 1 : -1; : platform::is_gpu_place(dst_place) ? 1 : -1;
AttributeMap attr_map = {{"dst_place_type", dst_place_type}}; AttributeMap attr_map = {{"dst_place_type", dst_place_type}};
// 3. Create transfer_op // 3. Create memcpy_d2h_op or memcpy_h2d_op
std::string op_type = get_memcpy_type(src_place, dst_place); std::string op_type = get_memcpy_type(src_place, dst_place);
auto& op_info = OpInfoMap::Instance().Get(op_type); auto& op_info = OpInfoMap::Instance().Get(op_type);
auto op = std::shared_ptr<OperatorBase>( auto op = std::shared_ptr<OperatorBase>(
...@@ -303,6 +322,95 @@ std::string get_memcpy_type(const platform::Place& src_place, ...@@ -303,6 +322,95 @@ std::string get_memcpy_type(const platform::Place& src_place,
} }
} }
void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node,
const platform::Place& place,
const VariableNameMap& out_names,
VariableValueMap* out_vars,
VariableScope* var_scope,
std::vector<OpFuncNode>* op_func_nodes,
framework::Scope* local_scope) {
DataTranferHelper data_transfer_helper(place, var_scope);
for (auto& var_name_item : out_names) {
std::vector<Variable*>& vars = out_vars->at(var_name_item.first);
for (size_t i = 0; i < var_name_item.second.size(); ++i) {
// 1. find grad_var & check whether is complex tensor
auto var_name = var_name_item.second[i];
auto orig_var_name = framework::GradOriginalVarName(var_name);
// only focus on gradient var
if (var_name == orig_var_name) {
VLOG(3) << "skip " << var_name << " with same name as "
<< orig_var_name;
continue;
}
auto* grad_var = vars[i];
// skip nullptr var
if (grad_var == nullptr) {
VLOG(3) << "skip grad_var with nullptr";
continue;
}
// don't process LoDTensorArray temporarily,
// add support if necessary for complex number calculations in the future
if (!framework::VarIsTensor(*grad_var)) {
VLOG(3) << "skip grad_var with LoDTensorArray type";
continue;
}
auto* grad_tensor =
framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(grad_var);
// skip nullptr tensor
if (grad_tensor == nullptr || !grad_tensor->IsInitialized()) {
VLOG(3) << "skip with grad_tensor not IsInitialized";
continue;
}
// only focus on complex dtype now
auto src_type = grad_tensor->type();
if (!framework::IsComplexType(src_type)) {
VLOG(3) << "skip grad_tensor with not complexType";
continue;
}
// 2. find forward var & check whether need to cast
auto* var = var_scope->FindVar(orig_var_name);
// if forward var not exists, do nothing
if (var == nullptr) {
VLOG(3) << "skip " << orig_var_name << " with not found in var_scope";
continue;
}
if (!framework::VarIsTensor(*var)) {
VLOG(3) << "skip " << orig_var_name << " with LoDTensorArray.";
continue;
}
const auto* tensor =
framework::GetLoDTensorOrSelectedRowsValueFromVar(*var);
PADDLE_ENFORCE_NOT_NULL(
tensor,
platform::errors::Unavailable(
"Forward tensor is nullptr when handle complex data to real."));
// only need record type, the allocation may have been released
auto dst_type = tensor->saved_type();
// only focus on real dtype and need casting
if (framework::IsComplexType(dst_type)) {
continue;
}
// 3. cast complex grad to real grad inplacely
VLOG(3) << "Transform " << framework::DataTypeToString(src_type)
<< " var `" << var_name << "` to "
<< framework::DataTypeToString(dst_type)
<< " real var in static graph.";
// NOTE(Aurelius84): Consider to define a complex2real op to deal this
// case.
std::string new_var_name;
auto op = TransferDtype(var_name, &new_var_name, src_type, dst_type,
var_scope, local_scope);
data_transfer_helper.RunAndConstructOpFuncNode(op, var_name, new_var_name,
op_func_nodes);
data_transfer_helper.RunAndConstructShareNode(new_var_name, var_name,
op_func_nodes);
}
}
}
} // namespace interpreter } // namespace interpreter
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -37,14 +37,18 @@ class DataTranferHelper { ...@@ -37,14 +37,18 @@ class DataTranferHelper {
const std::string& var_name, std::string* new_var_name, const std::string& var_name, std::string* new_var_name,
std::vector<OpFuncNode>* new_op_func_nodes, bool use_local_scope); std::vector<OpFuncNode>* new_op_func_nodes, bool use_local_scope);
private: void RunAndConstructShareNode(const std::string& src_var_name,
platform::Place place_; const std::string& dst_var_name,
VariableScope* var_scope_; std::vector<OpFuncNode>* op_func_nodes);
void RunAndConstructOpFuncNode(const std::shared_ptr<OperatorBase>& op, void RunAndConstructOpFuncNode(const std::shared_ptr<OperatorBase>& op,
const std::string& var_name, const std::string& var_name,
const std::string& new_var_name, const std::string& new_var_name,
std::vector<OpFuncNode>* op_func_nodes); std::vector<OpFuncNode>* op_func_nodes);
private:
platform::Place place_;
VariableScope* var_scope_;
}; };
void ApplyDataTransform(const OpKernelType& expected_kernel_key, void ApplyDataTransform(const OpKernelType& expected_kernel_key,
...@@ -54,6 +58,14 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, ...@@ -54,6 +58,14 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
std::vector<OpFuncNode>* op_func_nodes, std::vector<OpFuncNode>* op_func_nodes,
bool use_local_scope = true); bool use_local_scope = true);
void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node,
const platform::Place& place,
const VariableNameMap& out_names,
VariableValueMap* out_vars,
VariableScope* var_scope,
std::vector<OpFuncNode>* op_func_nodes,
framework::Scope* local_scope);
std::string get_memcpy_type(const platform::Place& src_place, std::string get_memcpy_type(const platform::Place& src_place,
const platform::Place& dst_place); const platform::Place& dst_place);
......
...@@ -90,7 +90,7 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -90,7 +90,7 @@ paddle::framework::FetchList InterpreterCore::Run(
// return Fetch Tensors // return Fetch Tensors
auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName); auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName);
return *(fetch_var->GetMutable<framework::FetchList>()); return std::move(*fetch_var->GetMutable<framework::FetchList>());
} }
paddle::framework::FetchList InterpreterCore::Run( paddle::framework::FetchList InterpreterCore::Run(
...@@ -124,7 +124,7 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -124,7 +124,7 @@ paddle::framework::FetchList InterpreterCore::Run(
// return Fetch Tensors // return Fetch Tensors
auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName); auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName);
return *(fetch_var->GetMutable<framework::FetchList>()); return std::move(*fetch_var->GetMutable<framework::FetchList>());
} }
void InterpreterCore::BuildOperatorDependences() { void InterpreterCore::BuildOperatorDependences() {
......
...@@ -328,20 +328,14 @@ void build_op_func_list(const platform::Place& place, ...@@ -328,20 +328,14 @@ void build_op_func_list(const platform::Place& place,
->GetExpectedKernelType( ->GetExpectedKernelType(
ExecutionContext(*op, scope, *dev_ctx, runtime_context)); ExecutionContext(*op, scope, *dev_ctx, runtime_context));
// consider device_guard() // change device by the device_guard()
apply_device_guard( apply_device_guard(op, place, &expected_kernel_key);
op, place,
&expected_kernel_key); // change device by the device_guard()
VLOG(3) << "expected_kernel_key : " << expected_kernel_key; VLOG(3) << "expected_kernel_key : " << expected_kernel_key;
// step 3. apply data transforms and insert data transfer ops // step 3. apply data transforms and insert data transfer ops
VariableValueMap& ins_map_temp = runtime_context.inputs; VariableValueMap& ins_map_temp = runtime_context.inputs;
std::vector<OpFuncNode> new_op_func_nodes;
ApplyDataTransform(expected_kernel_key, place, &ins_map_temp, var_scope, ApplyDataTransform(expected_kernel_key, place, &ins_map_temp, var_scope,
&op_func_node, &new_op_func_nodes, use_local_scope); &op_func_node, vec_func_list, use_local_scope);
for (auto& item : new_op_func_nodes) {
vec_func_list->emplace_back(std::move(item));
}
// step 4. Run op kernel // step 4. Run op kernel
VLOG(3) << op->Type() VLOG(3) << op->Type()
<< " : expected_kernel_key : " << expected_kernel_key; << " : expected_kernel_key : " << expected_kernel_key;
...@@ -370,6 +364,14 @@ void build_op_func_list(const platform::Place& place, ...@@ -370,6 +364,14 @@ void build_op_func_list(const platform::Place& place,
op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second); op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second);
op_func_node.kernel_func_(exec_ctx); op_func_node.kernel_func_(exec_ctx);
// post-process grad_op.outputs if need cast complex grad into real grad.
// NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it.
if (framework::IsComplexType(expected_kernel_key.data_type_)) {
interpreter::HandleComplexGradToRealGrad(
op_func_node, place, outputs_names, &runtime_context.outputs,
var_scope, vec_func_list, local_scope);
}
} }
vec_func_list->emplace_back(op_func_node); vec_func_list->emplace_back(op_func_node);
......
...@@ -51,7 +51,6 @@ namespace framework { ...@@ -51,7 +51,6 @@ namespace framework {
namespace interpreter { namespace interpreter {
using AtomicVectorSizeT = std::vector<std::unique_ptr<std::atomic<size_t>>>; using AtomicVectorSizeT = std::vector<std::unique_ptr<std::atomic<size_t>>>;
static constexpr char kFetchVarName[] = "fetch";
class AsyncWorkQueue { class AsyncWorkQueue {
public: public:
......
...@@ -374,6 +374,7 @@ class Instruction { ...@@ -374,6 +374,7 @@ class Instruction {
namespace interpreter { namespace interpreter {
static constexpr char kMemcpyH2D[] = "memcpy_h2d"; static constexpr char kMemcpyH2D[] = "memcpy_h2d";
static constexpr char kMemcpyD2H[] = "memcpy_d2h"; static constexpr char kMemcpyD2H[] = "memcpy_d2h";
static constexpr char kFetchVarName[] = "fetch";
static bool IsMemcpyH2D(const Instruction& instr) { static bool IsMemcpyH2D(const Instruction& instr) {
return instr.OpBase()->Type() == kMemcpyH2D; return instr.OpBase()->Type() == kMemcpyH2D;
......
...@@ -479,10 +479,6 @@ void OperatorBase::GenerateTemporaryNames() { ...@@ -479,10 +479,6 @@ void OperatorBase::GenerateTemporaryNames() {
} }
} }
static bool VarIsTensor(const Variable& var) {
return var.IsType<LoDTensor>() || var.IsType<SelectedRows>();
}
const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var) { const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var) {
if (var.IsType<LoDTensor>()) { if (var.IsType<LoDTensor>()) {
return static_cast<const Tensor*>(&(var.Get<LoDTensor>())); return static_cast<const Tensor*>(&(var.Get<LoDTensor>()));
......
...@@ -114,6 +114,10 @@ inline std::string GradOriginalVarName(const std::string& grad_var_name) { ...@@ -114,6 +114,10 @@ inline std::string GradOriginalVarName(const std::string& grad_var_name) {
} }
} }
inline bool VarIsTensor(const Variable& var) {
return var.IsType<LoDTensor>() || var.IsType<SelectedRows>();
}
const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var); const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var);
Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var); Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册