Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
3a0d7bf0
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3a0d7bf0
编写于
4月 25, 2022
作者:
C
Chen Weihang
提交者:
GitHub
4月 24, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize dygraph GetExpectedKernelType perf (#42154)
* opt dygraph scheduling * revert part impl
上级
13190707
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
68 addition
and
19 deletion
+68
-19
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+42
-5
paddle/fluid/framework/operator.h
paddle/fluid/framework/operator.h
+7
-5
paddle/fluid/imperative/execution_context.h
paddle/fluid/imperative/execution_context.h
+13
-5
paddle/fluid/operators/transpose_op.cc
paddle/fluid/operators/transpose_op.cc
+1
-1
paddle/phi/core/kernel_context.h
paddle/phi/core/kernel_context.h
+5
-3
未找到文件。
paddle/fluid/framework/operator.cc
浏览文件 @
3a0d7bf0
...
...
@@ -940,7 +940,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
return
((
op_with_kernel
.
kernel_type
())
&&
(
op_with_kernel
.
kernel_type
()
->
data_layout_
==
framework
::
DataLayout
::
kMKLDNN
));
}
catch
(
std
::
bad_cast
exp
)
{
}
catch
(
const
std
::
bad_cast
&
exp
)
{
return
false
;
}
}
...
...
@@ -1965,6 +1965,36 @@ Scope* OperatorWithKernel::PrepareData(
}
void
OperatorWithKernel
::
ParseInputDataType
(
const
Variable
*
var
,
const
std
::
string
&
name
,
proto
::
VarType
::
Type
*
data_type
)
const
{
if
(
var
!=
nullptr
)
{
const
Tensor
*
t
=
nullptr
;
if
(
var
->
IsType
<
Tensor
>
())
{
t
=
&
var
->
Get
<
Tensor
>
();
}
else
if
(
var
->
IsType
<
LoDTensor
>
())
{
t
=
&
var
->
Get
<
LoDTensor
>
();
}
else
if
(
var
->
IsType
<
phi
::
SelectedRows
>
())
{
t
=
&
(
var
->
Get
<
phi
::
SelectedRows
>
().
value
());
}
else
if
(
var
->
IsType
<
LoDTensorArray
>
())
{
auto
t_arr
=
&
var
->
Get
<
LoDTensorArray
>
();
for
(
size_t
j
=
0
;
j
<
t_arr
->
size
();
j
++
)
{
if
(
t_arr
->
at
(
j
).
IsInitialized
())
{
t
=
&
(
t_arr
->
at
(
j
));
}
}
}
if
(
t
!=
nullptr
)
{
PADDLE_ENFORCE_EQ
(
t
->
IsInitialized
(),
true
,
platform
::
errors
::
InvalidArgument
(
"The %s Op's Input Variable `%s` "
"contains uninitialized Tensor."
,
Type
(),
name
));
*
data_type
=
paddle
::
framework
::
TransToProtoVarType
(
t
->
dtype
());
}
}
}
void
OperatorWithKernel
::
ParseMultiInputDataType
(
const
std
::
vector
<
Variable
*>&
vars
,
const
std
::
string
&
name
,
proto
::
VarType
::
Type
*
data_type
)
const
{
proto
::
VarType
::
Type
default_data_type
=
...
...
@@ -2015,9 +2045,12 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
proto
::
VarType
::
Type
dafault_data_type
=
static_cast
<
proto
::
VarType
::
Type
>
(
-
1
);
proto
::
VarType
::
Type
data_type
=
dafault_data_type
;
for
(
auto
&
input
:
ctx
.
InNameList
())
{
const
std
::
vector
<
Variable
*>
vars
=
ctx
.
MultiInputVar
(
input
);
ParseInputDataType
(
vars
,
input
,
&
data_type
);
for
(
auto
*
name
:
ctx
.
InNameList
())
{
if
(
ctx
.
InputSize
(
*
name
)
==
1UL
)
{
ParseInputDataType
(
ctx
.
InputVar
(
*
name
),
*
name
,
&
data_type
);
}
else
{
ParseMultiInputDataType
(
ctx
.
MultiInputVar
(
*
name
),
*
name
,
&
data_type
);
}
}
PADDLE_ENFORCE_NE
(
data_type
,
dafault_data_type
,
...
...
@@ -2031,7 +2064,11 @@ proto::VarType::Type OperatorWithKernel::IndicateVarDataType(
proto
::
VarType
::
Type
dafault_data_type
=
static_cast
<
proto
::
VarType
::
Type
>
(
-
1
);
proto
::
VarType
::
Type
data_type
=
dafault_data_type
;
ParseInputDataType
(
ctx
.
MultiInputVar
(
name
),
name
,
&
data_type
);
if
(
ctx
.
InputSize
(
name
)
==
1UL
)
{
ParseInputDataType
(
ctx
.
InputVar
(
name
),
name
,
&
data_type
);
}
else
{
ParseMultiInputDataType
(
ctx
.
MultiInputVar
(
name
),
name
,
&
data_type
);
}
PADDLE_ENFORCE_NE
(
data_type
,
dafault_data_type
,
platform
::
errors
::
InvalidArgument
(
...
...
paddle/fluid/framework/operator.h
浏览文件 @
3a0d7bf0
...
...
@@ -333,12 +333,12 @@ class ExecutionContext {
return
it
->
second
;
}
virtual
std
::
vector
<
std
::
string
>
InNameList
()
const
{
std
::
vector
<
std
::
string
>
vec_temp
;
virtual
paddle
::
SmallVector
<
const
std
::
string
*
>
InNameList
()
const
{
paddle
::
SmallVector
<
const
std
::
string
*
>
vec_temp
;
vec_temp
.
reserve
(
ctx_
.
inputs
.
size
());
for
(
auto
&
input
:
ctx_
.
inputs
)
{
vec_temp
.
push_back
(
input
.
first
);
vec_temp
.
push_back
(
&
input
.
first
);
}
return
vec_temp
;
...
...
@@ -680,9 +680,11 @@ class OperatorWithKernel : public OperatorBase {
// By default all input data must be same.
proto
::
VarType
::
Type
IndicateDataType
(
const
ExecutionContext
&
ctx
)
const
;
// used for IndicateDataType
void
ParseInputDataType
(
const
std
::
vector
<
Variable
*>&
vars
,
const
std
::
string
&
name
,
void
ParseInputDataType
(
const
Variable
*
vars
,
const
std
::
string
&
name
,
proto
::
VarType
::
Type
*
data_type
)
const
;
void
ParseMultiInputDataType
(
const
std
::
vector
<
Variable
*>&
vars
,
const
std
::
string
&
name
,
proto
::
VarType
::
Type
*
data_type
)
const
;
// used for IndicateOrPromoteVarDataTypes
Tensor
*
GetTensorFormInputSafely
(
const
ExecutionContext
&
ctx
,
const
std
::
string
&
name
)
const
;
...
...
paddle/fluid/imperative/execution_context.h
浏览文件 @
3a0d7bf0
...
...
@@ -117,12 +117,12 @@ class DygraphExecutionContext : public framework::ExecutionContext {
return
it
->
second
;
}
std
::
vector
<
std
::
string
>
InNameList
()
const
override
{
std
::
vector
<
std
::
string
>
vec_temp
;
paddle
::
SmallVector
<
const
std
::
string
*
>
InNameList
()
const
override
{
paddle
::
SmallVector
<
const
std
::
string
*
>
vec_temp
;
vec_temp
.
reserve
(
var_map_in_
.
size
());
for
(
auto
&
v
:
var_map_in_
)
{
vec_temp
.
push_back
(
v
.
first
);
vec_temp
.
push_back
(
&
v
.
first
);
}
return
vec_temp
;
...
...
@@ -144,11 +144,19 @@ class DygraphExecutionContext : public framework::ExecutionContext {
}
size_t
InputSize
(
const
std
::
string
&
name
)
const
override
{
return
InputNames
(
name
).
size
();
auto
it
=
var_map_in_
.
find
(
name
);
PADDLE_ENFORCE_NE
(
it
,
var_map_in_
.
end
(),
platform
::
errors
::
NotFound
(
"Can not find [%s] in Input"
,
name
));
return
it
->
second
.
size
();
}
size_t
OutputSize
(
const
std
::
string
&
name
)
const
override
{
return
OutputNames
(
name
).
size
();
auto
it
=
var_map_out_
.
find
(
name
);
PADDLE_ENFORCE_NE
(
it
,
var_map_out_
.
end
(),
platform
::
errors
::
NotFound
(
"Can not find [%s] in Output"
,
name
));
return
it
->
second
.
size
();
}
const
Variable
*
InputVar
(
const
std
::
string
&
name
)
const
override
{
...
...
paddle/fluid/operators/transpose_op.cc
浏览文件 @
3a0d7bf0
...
...
@@ -90,7 +90,7 @@ class TransposeOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
auto
&
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
framework
::
DataLayout
layout_
=
framework
::
StringToDataLayout
(
data_format
);
auto
data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
);
#ifdef PADDLE_WITH_MKLDNN
...
...
paddle/phi/core/kernel_context.h
浏览文件 @
3a0d7bf0
...
...
@@ -22,6 +22,7 @@
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/tensor_base.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/type_defs.h"
#include "paddle/utils/optional.h"
#include "paddle/utils/small_vector.h"
...
...
@@ -139,10 +140,11 @@ class KernelContext {
paddle
::
SmallVector
<
const
TensorBase
*>
inputs_
;
paddle
::
SmallVector
<
TensorBase
*>
outputs_
;
paddle
::
SmallVector
<
Attribute
>
attrs_
;
paddle
::
SmallVector
<
Attribute
,
kAttrSmallVectorSize
>
attrs_
;
paddle
::
SmallVector
<
std
::
pair
<
int
,
int
>>
input_range_
;
paddle
::
SmallVector
<
std
::
pair
<
int
,
int
>>
output_range_
;
paddle
::
SmallVector
<
std
::
pair
<
int
,
int
>
,
kInputSmallVectorSize
>
input_range_
;
paddle
::
SmallVector
<
std
::
pair
<
int
,
int
>
,
kOutputSmallVectorSize
>
output_range_
;
};
}
// namespace phi
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录