Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
8b87d5eb
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8b87d5eb
编写于
11月 24, 2021
作者:
A
Aurelius84
提交者:
GitHub
11月 24, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[NewExe] Support HandleComplexGradToRealGrad to cast complex into Real (#37450)
上级
1c969d20
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
145 addition
and
23 deletion
+145
-23
paddle/fluid/framework/new_executor/data_transfer.cc
paddle/fluid/framework/new_executor/data_transfer.cc
+112
-4
paddle/fluid/framework/new_executor/data_transfer.h
paddle/fluid/framework/new_executor/data_transfer.h
+15
-3
paddle/fluid/framework/new_executor/interpretercore.cc
paddle/fluid/framework/new_executor/interpretercore.cc
+2
-2
paddle/fluid/framework/new_executor/interpretercore_util.cc
paddle/fluid/framework/new_executor/interpretercore_util.cc
+11
-9
paddle/fluid/framework/new_executor/interpretercore_util.h
paddle/fluid/framework/new_executor/interpretercore_util.h
+0
-1
paddle/fluid/framework/new_executor/new_executor_defs.h
paddle/fluid/framework/new_executor/new_executor_defs.h
+1
-0
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+0
-4
paddle/fluid/framework/operator.h
paddle/fluid/framework/operator.h
+4
-0
未找到文件。
paddle/fluid/framework/new_executor/data_transfer.cc
浏览文件 @
8b87d5eb
...
...
@@ -62,6 +62,24 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var,
return
is_transferred
;
}
void
DataTranferHelper
::
RunAndConstructShareNode
(
const
std
::
string
&
src_var_name
,
const
std
::
string
&
dst_var_name
,
std
::
vector
<
OpFuncNode
>*
op_func_nodes
)
{
VariableNameMap
in_name_map
=
{{
"X"
,
{
src_var_name
}}};
VariableNameMap
out_name_map
=
{{
"Out"
,
{
dst_var_name
}}};
AttributeMap
attr_map
;
std
::
string
op_type
(
"share_data"
);
auto
&
op_info
=
OpInfoMap
::
Instance
().
Get
(
op_type
);
auto
op
=
std
::
shared_ptr
<
OperatorBase
>
(
op_info
.
Creator
()(
op_type
,
in_name_map
,
out_name_map
,
attr_map
));
VLOG
(
3
)
<<
string
::
Sprintf
(
"Insert %s with %s -> %s."
,
op_type
,
src_var_name
,
dst_var_name
);
RunAndConstructOpFuncNode
(
op
,
src_var_name
,
dst_var_name
,
op_func_nodes
);
}
void
DataTranferHelper
::
RunAndConstructOpFuncNode
(
const
std
::
shared_ptr
<
OperatorBase
>&
op
,
const
std
::
string
&
var_name
,
const
std
::
string
&
new_var_name
,
...
...
@@ -133,7 +151,7 @@ std::shared_ptr<OperatorBase> TransferLayout(const std::string& var_name,
VariableNameMap
out_name_map
=
{{
"Out"
,
{
*
new_var_name
}}};
AttributeMap
attr_map
=
{{
"dst_layout"
,
static_cast
<
int
>
(
out_layout
)}};
// 3. Create transfer_op
// 3. Create transfer_
layout_
op
std
::
string
op_type
(
"transfer_layout"
);
auto
&
op_info
=
OpInfoMap
::
Instance
().
Get
(
op_type
);
auto
op
=
std
::
shared_ptr
<
OperatorBase
>
(
...
...
@@ -154,9 +172,10 @@ std::shared_ptr<OperatorBase> TransferDtype(const std::string& var_name,
*
new_var_name
=
var_name
+
"_dtype_"
+
std
::
to_string
(
var_scope
->
VarSize
()
+
1
);
auto
*
ptr
=
local_scope
->
Var
(
new_var_name
);
var_scope
->
SetVarDesc
(
var_name
,
nullptr
);
auto
var_type
=
var_scope
->
Var
(
var_name
)
->
Type
();
InitializeVariable
(
ptr
,
static_cast
<
proto
::
VarType
::
Type
>
(
var_type
));
VLOG
(
3
)
<<
"Create Variable "
<<
*
new_var_name
<<
" locally, which pointer is "
<<
ptr
<<
"Variable Type "
<<
var_type
;
...
...
@@ -171,7 +190,7 @@ std::shared_ptr<OperatorBase> TransferDtype(const std::string& var_name,
// NOTE(Aurelius84): In whice case use_mkldnn = true?
attr_map
[
"use_mkldnn"
]
=
false
;
// 3. Create transfer_op
// 3. Create transfer_
dtype_
op
std
::
string
op_type
(
"transfer_dtype"
);
auto
&
op_info
=
OpInfoMap
::
Instance
().
Get
(
op_type
);
auto
op
=
std
::
shared_ptr
<
OperatorBase
>
(
...
...
@@ -209,7 +228,7 @@ std::shared_ptr<OperatorBase> TransferDevice(const std::string& var_name,
:
platform
::
is_gpu_place
(
dst_place
)
?
1
:
-
1
;
AttributeMap
attr_map
=
{{
"dst_place_type"
,
dst_place_type
}};
// 3. Create
transfer
_op
// 3. Create
memcpy_d2h_op or memcpy_h2d
_op
std
::
string
op_type
=
get_memcpy_type
(
src_place
,
dst_place
);
auto
&
op_info
=
OpInfoMap
::
Instance
().
Get
(
op_type
);
auto
op
=
std
::
shared_ptr
<
OperatorBase
>
(
...
...
@@ -303,6 +322,95 @@ std::string get_memcpy_type(const platform::Place& src_place,
}
}
void
HandleComplexGradToRealGrad
(
const
OpFuncNode
&
op_func_node
,
const
platform
::
Place
&
place
,
const
VariableNameMap
&
out_names
,
VariableValueMap
*
out_vars
,
VariableScope
*
var_scope
,
std
::
vector
<
OpFuncNode
>*
op_func_nodes
,
framework
::
Scope
*
local_scope
)
{
DataTranferHelper
data_transfer_helper
(
place
,
var_scope
);
for
(
auto
&
var_name_item
:
out_names
)
{
std
::
vector
<
Variable
*>&
vars
=
out_vars
->
at
(
var_name_item
.
first
);
for
(
size_t
i
=
0
;
i
<
var_name_item
.
second
.
size
();
++
i
)
{
// 1. find grad_var & check whether is complex tensor
auto
var_name
=
var_name_item
.
second
[
i
];
auto
orig_var_name
=
framework
::
GradOriginalVarName
(
var_name
);
// only focus on gradient var
if
(
var_name
==
orig_var_name
)
{
VLOG
(
3
)
<<
"skip "
<<
var_name
<<
" with same name as "
<<
orig_var_name
;
continue
;
}
auto
*
grad_var
=
vars
[
i
];
// skip nullptr var
if
(
grad_var
==
nullptr
)
{
VLOG
(
3
)
<<
"skip grad_var with nullptr"
;
continue
;
}
// don't process LoDTensorArray temporarily,
// add support if necessary for complex number calculations in the future
if
(
!
framework
::
VarIsTensor
(
*
grad_var
))
{
VLOG
(
3
)
<<
"skip grad_var with LoDTensorArray type"
;
continue
;
}
auto
*
grad_tensor
=
framework
::
GetMutableLoDTensorOrSelectedRowsValueFromVar
(
grad_var
);
// skip nullptr tensor
if
(
grad_tensor
==
nullptr
||
!
grad_tensor
->
IsInitialized
())
{
VLOG
(
3
)
<<
"skip with grad_tensor not IsInitialized"
;
continue
;
}
// only focus on complex dtype now
auto
src_type
=
grad_tensor
->
type
();
if
(
!
framework
::
IsComplexType
(
src_type
))
{
VLOG
(
3
)
<<
"skip grad_tensor with not complexType"
;
continue
;
}
// 2. find forward var & check whether need to cast
auto
*
var
=
var_scope
->
FindVar
(
orig_var_name
);
// if forward var not exists, do nothing
if
(
var
==
nullptr
)
{
VLOG
(
3
)
<<
"skip "
<<
orig_var_name
<<
" with not found in var_scope"
;
continue
;
}
if
(
!
framework
::
VarIsTensor
(
*
var
))
{
VLOG
(
3
)
<<
"skip "
<<
orig_var_name
<<
" with LoDTensorArray."
;
continue
;
}
const
auto
*
tensor
=
framework
::
GetLoDTensorOrSelectedRowsValueFromVar
(
*
var
);
PADDLE_ENFORCE_NOT_NULL
(
tensor
,
platform
::
errors
::
Unavailable
(
"Forward tensor is nullptr when handle complex data to real."
));
// only need record type, the allocation may have been released
auto
dst_type
=
tensor
->
saved_type
();
// only focus on real dtype and need casting
if
(
framework
::
IsComplexType
(
dst_type
))
{
continue
;
}
// 3. cast complex grad to real grad inplacely
VLOG
(
3
)
<<
"Transform "
<<
framework
::
DataTypeToString
(
src_type
)
<<
" var `"
<<
var_name
<<
"` to "
<<
framework
::
DataTypeToString
(
dst_type
)
<<
" real var in static graph."
;
// NOTE(Aurelius84): Consider to define a complex2real op to deal this
// case.
std
::
string
new_var_name
;
auto
op
=
TransferDtype
(
var_name
,
&
new_var_name
,
src_type
,
dst_type
,
var_scope
,
local_scope
);
data_transfer_helper
.
RunAndConstructOpFuncNode
(
op
,
var_name
,
new_var_name
,
op_func_nodes
);
data_transfer_helper
.
RunAndConstructShareNode
(
new_var_name
,
var_name
,
op_func_nodes
);
}
}
}
}
// namespace interpreter
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/new_executor/data_transfer.h
浏览文件 @
8b87d5eb
...
...
@@ -37,14 +37,18 @@ class DataTranferHelper {
const
std
::
string
&
var_name
,
std
::
string
*
new_var_name
,
std
::
vector
<
OpFuncNode
>*
new_op_func_nodes
,
bool
use_local_scope
);
private:
platform
::
Place
place_
;
VariableScope
*
var_scope_
;
void
RunAndConstructShareNode
(
const
std
::
string
&
src_var_name
,
const
std
::
string
&
dst_var_name
,
std
::
vector
<
OpFuncNode
>*
op_func_nodes
)
;
void
RunAndConstructOpFuncNode
(
const
std
::
shared_ptr
<
OperatorBase
>&
op
,
const
std
::
string
&
var_name
,
const
std
::
string
&
new_var_name
,
std
::
vector
<
OpFuncNode
>*
op_func_nodes
);
private:
platform
::
Place
place_
;
VariableScope
*
var_scope_
;
};
void
ApplyDataTransform
(
const
OpKernelType
&
expected_kernel_key
,
...
...
@@ -54,6 +58,14 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
std
::
vector
<
OpFuncNode
>*
op_func_nodes
,
bool
use_local_scope
=
true
);
void
HandleComplexGradToRealGrad
(
const
OpFuncNode
&
op_func_node
,
const
platform
::
Place
&
place
,
const
VariableNameMap
&
out_names
,
VariableValueMap
*
out_vars
,
VariableScope
*
var_scope
,
std
::
vector
<
OpFuncNode
>*
op_func_nodes
,
framework
::
Scope
*
local_scope
);
std
::
string
get_memcpy_type
(
const
platform
::
Place
&
src_place
,
const
platform
::
Place
&
dst_place
);
...
...
paddle/fluid/framework/new_executor/interpretercore.cc
浏览文件 @
8b87d5eb
...
...
@@ -90,7 +90,7 @@ paddle::framework::FetchList InterpreterCore::Run(
// return Fetch Tensors
auto
*
fetch_var
=
global_scope_
->
Var
(
interpreter
::
kFetchVarName
);
return
*
(
fetch_var
->
GetMutable
<
framework
::
FetchList
>
());
return
std
::
move
(
*
fetch_var
->
GetMutable
<
framework
::
FetchList
>
());
}
paddle
::
framework
::
FetchList
InterpreterCore
::
Run
(
...
...
@@ -124,7 +124,7 @@ paddle::framework::FetchList InterpreterCore::Run(
// return Fetch Tensors
auto
*
fetch_var
=
global_scope_
->
Var
(
interpreter
::
kFetchVarName
);
return
*
(
fetch_var
->
GetMutable
<
framework
::
FetchList
>
());
return
std
::
move
(
*
fetch_var
->
GetMutable
<
framework
::
FetchList
>
());
}
void
InterpreterCore
::
BuildOperatorDependences
()
{
...
...
paddle/fluid/framework/new_executor/interpretercore_util.cc
浏览文件 @
8b87d5eb
...
...
@@ -328,20 +328,14 @@ void build_op_func_list(const platform::Place& place,
->
GetExpectedKernelType
(
ExecutionContext
(
*
op
,
scope
,
*
dev_ctx
,
runtime_context
));
// consider device_guard()
apply_device_guard
(
op
,
place
,
&
expected_kernel_key
);
// change device by the device_guard()
// change device by the device_guard()
apply_device_guard
(
op
,
place
,
&
expected_kernel_key
);
VLOG
(
3
)
<<
"expected_kernel_key : "
<<
expected_kernel_key
;
// step 3. apply data transforms and insert data transfer ops
VariableValueMap
&
ins_map_temp
=
runtime_context
.
inputs
;
std
::
vector
<
OpFuncNode
>
new_op_func_nodes
;
ApplyDataTransform
(
expected_kernel_key
,
place
,
&
ins_map_temp
,
var_scope
,
&
op_func_node
,
&
new_op_func_nodes
,
use_local_scope
);
for
(
auto
&
item
:
new_op_func_nodes
)
{
vec_func_list
->
emplace_back
(
std
::
move
(
item
));
}
&
op_func_node
,
vec_func_list
,
use_local_scope
);
// step 4. Run op kernel
VLOG
(
3
)
<<
op
->
Type
()
<<
" : expected_kernel_key : "
<<
expected_kernel_key
;
...
...
@@ -370,6 +364,14 @@ void build_op_func_list(const platform::Place& place,
op_func_node
.
kernel_func_
=
OpKernelComputeFunc
(
kernel_iter
->
second
);
op_func_node
.
kernel_func_
(
exec_ctx
);
// post-process grad_op.outputs if need cast complex grad into real grad.
// NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it.
if
(
framework
::
IsComplexType
(
expected_kernel_key
.
data_type_
))
{
interpreter
::
HandleComplexGradToRealGrad
(
op_func_node
,
place
,
outputs_names
,
&
runtime_context
.
outputs
,
var_scope
,
vec_func_list
,
local_scope
);
}
}
vec_func_list
->
emplace_back
(
op_func_node
);
...
...
paddle/fluid/framework/new_executor/interpretercore_util.h
浏览文件 @
8b87d5eb
...
...
@@ -51,7 +51,6 @@ namespace framework {
namespace
interpreter
{
using
AtomicVectorSizeT
=
std
::
vector
<
std
::
unique_ptr
<
std
::
atomic
<
size_t
>>>
;
static
constexpr
char
kFetchVarName
[]
=
"fetch"
;
class
AsyncWorkQueue
{
public:
...
...
paddle/fluid/framework/new_executor/new_executor_defs.h
浏览文件 @
8b87d5eb
...
...
@@ -374,6 +374,7 @@ class Instruction {
namespace
interpreter
{
static
constexpr
char
kMemcpyH2D
[]
=
"memcpy_h2d"
;
static
constexpr
char
kMemcpyD2H
[]
=
"memcpy_d2h"
;
static
constexpr
char
kFetchVarName
[]
=
"fetch"
;
static
bool
IsMemcpyH2D
(
const
Instruction
&
instr
)
{
return
instr
.
OpBase
()
->
Type
()
==
kMemcpyH2D
;
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
8b87d5eb
...
...
@@ -479,10 +479,6 @@ void OperatorBase::GenerateTemporaryNames() {
}
}
static
bool
VarIsTensor
(
const
Variable
&
var
)
{
return
var
.
IsType
<
LoDTensor
>
()
||
var
.
IsType
<
SelectedRows
>
();
}
const
Tensor
*
GetLoDTensorOrSelectedRowsValueFromVar
(
const
Variable
&
var
)
{
if
(
var
.
IsType
<
LoDTensor
>
())
{
return
static_cast
<
const
Tensor
*>
(
&
(
var
.
Get
<
LoDTensor
>
()));
...
...
paddle/fluid/framework/operator.h
浏览文件 @
8b87d5eb
...
...
@@ -114,6 +114,10 @@ inline std::string GradOriginalVarName(const std::string& grad_var_name) {
}
}
inline
bool
VarIsTensor
(
const
Variable
&
var
)
{
return
var
.
IsType
<
LoDTensor
>
()
||
var
.
IsType
<
SelectedRows
>
();
}
const
Tensor
*
GetLoDTensorOrSelectedRowsValueFromVar
(
const
Variable
&
var
);
Tensor
*
GetMutableLoDTensorOrSelectedRowsValueFromVar
(
Variable
*
var
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录