Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
91ae7848
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
91ae7848
编写于
4月 27, 2020
作者:
L
liuwei1031
提交者:
GitHub
4月 27, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
improve efficiency of runtime InferVarType (#22778) (#24181)
* cherry pick #22778
上级
57b062e1
变更
58
隐藏空白更改
内联
并排
Showing
58 changed file
with
967 addition
and
423 deletion
+967
-423
paddle/fluid/framework/ir/graph_test.cc
paddle/fluid/framework/ir/graph_test.cc
+2
-8
paddle/fluid/framework/var_type_inference.h
paddle/fluid/framework/var_type_inference.h
+286
-40
paddle/fluid/framework/var_type_inference_test.cc
paddle/fluid/framework/var_type_inference_test.cc
+186
-15
paddle/fluid/imperative/infer_var_type_context.h
paddle/fluid/imperative/infer_var_type_context.h
+131
-80
paddle/fluid/imperative/tests/test_layer.cc
paddle/fluid/imperative/tests/test_layer.cc
+149
-28
paddle/fluid/operators/activation_op.cc
paddle/fluid/operators/activation_op.cc
+3
-2
paddle/fluid/operators/allclose_op.cc
paddle/fluid/operators/allclose_op.cc
+1
-2
paddle/fluid/operators/assign_op.cc
paddle/fluid/operators/assign_op.cc
+1
-5
paddle/fluid/operators/batch_norm_op.h
paddle/fluid/operators/batch_norm_op.h
+3
-2
paddle/fluid/operators/beam_search_decode_op.cc
paddle/fluid/operators/beam_search_decode_op.cc
+4
-6
paddle/fluid/operators/beam_search_op.cc
paddle/fluid/operators/beam_search_op.cc
+4
-6
paddle/fluid/operators/controlflow/get_places_op.cc
paddle/fluid/operators/controlflow/get_places_op.cc
+2
-3
paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc
...fluid/operators/controlflow/tensor_array_read_write_op.cc
+6
-6
paddle/fluid/operators/controlflow/while_op.cc
paddle/fluid/operators/controlflow/while_op.cc
+8
-7
paddle/fluid/operators/conv_op.h
paddle/fluid/operators/conv_op.h
+3
-2
paddle/fluid/operators/cross_entropy_op.cc
paddle/fluid/operators/cross_entropy_op.cc
+3
-2
paddle/fluid/operators/distributed_ops/merge_ids_op.cc
paddle/fluid/operators/distributed_ops/merge_ids_op.cc
+2
-4
paddle/fluid/operators/distributed_ops/split_ids_op.cc
paddle/fluid/operators/distributed_ops/split_ids_op.cc
+2
-4
paddle/fluid/operators/elementwise/elementwise_op.h
paddle/fluid/operators/elementwise/elementwise_op.h
+3
-2
paddle/fluid/operators/eye_op.cc
paddle/fluid/operators/eye_op.cc
+1
-2
paddle/fluid/operators/fill_any_like_op.cc
paddle/fluid/operators/fill_any_like_op.cc
+2
-4
paddle/fluid/operators/fill_constant_op.cc
paddle/fluid/operators/fill_constant_op.cc
+1
-2
paddle/fluid/operators/fill_op.cc
paddle/fluid/operators/fill_op.cc
+1
-2
paddle/fluid/operators/flip_op.cc
paddle/fluid/operators/flip_op.cc
+3
-2
paddle/fluid/operators/fused/fused_bn_activation_op.h
paddle/fluid/operators/fused/fused_bn_activation_op.h
+3
-2
paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc
paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc
+5
-4
paddle/fluid/operators/get_tensor_from_selected_rows_op.cc
paddle/fluid/operators/get_tensor_from_selected_rows_op.cc
+2
-5
paddle/fluid/operators/group_norm_op.cc
paddle/fluid/operators/group_norm_op.cc
+3
-2
paddle/fluid/operators/hierarchical_sigmoid_op.cc
paddle/fluid/operators/hierarchical_sigmoid_op.cc
+14
-15
paddle/fluid/operators/instance_norm_op.h
paddle/fluid/operators/instance_norm_op.h
+3
-2
paddle/fluid/operators/lod_rank_table_op.cc
paddle/fluid/operators/lod_rank_table_op.cc
+2
-3
paddle/fluid/operators/lod_reset_op.cc
paddle/fluid/operators/lod_reset_op.cc
+12
-11
paddle/fluid/operators/lod_tensor_to_array_op.cc
paddle/fluid/operators/lod_tensor_to_array_op.cc
+2
-3
paddle/fluid/operators/lookup_table_op.cc
paddle/fluid/operators/lookup_table_op.cc
+5
-4
paddle/fluid/operators/lookup_table_v2_op.cc
paddle/fluid/operators/lookup_table_v2_op.cc
+5
-4
paddle/fluid/operators/mean_op.cc
paddle/fluid/operators/mean_op.cc
+3
-2
paddle/fluid/operators/merge_selected_rows_op.cc
paddle/fluid/operators/merge_selected_rows_op.cc
+4
-2
paddle/fluid/operators/mul_op.cc
paddle/fluid/operators/mul_op.cc
+3
-2
paddle/fluid/operators/nccl/nccl_op.cc
paddle/fluid/operators/nccl/nccl_op.cc
+1
-2
paddle/fluid/operators/nce_op.cc
paddle/fluid/operators/nce_op.cc
+4
-4
paddle/fluid/operators/optimizers/momentum_op.cc
paddle/fluid/operators/optimizers/momentum_op.cc
+9
-12
paddle/fluid/operators/optimizers/sgd_op.cc
paddle/fluid/operators/optimizers/sgd_op.cc
+9
-13
paddle/fluid/operators/pool_op.cc
paddle/fluid/operators/pool_op.cc
+3
-2
paddle/fluid/operators/print_op.cc
paddle/fluid/operators/print_op.cc
+1
-3
paddle/fluid/operators/py_func_op.cc
paddle/fluid/operators/py_func_op.cc
+12
-15
paddle/fluid/operators/randperm_op.cc
paddle/fluid/operators/randperm_op.cc
+1
-2
paddle/fluid/operators/reader/read_op.cc
paddle/fluid/operators/reader/read_op.cc
+6
-6
paddle/fluid/operators/reader/reader_op_registry.cc
paddle/fluid/operators/reader/reader_op_registry.cc
+3
-6
paddle/fluid/operators/reduce_ops/reduce_sum_op.cc
paddle/fluid/operators/reduce_ops/reduce_sum_op.cc
+1
-2
paddle/fluid/operators/save_combine_op.cc
paddle/fluid/operators/save_combine_op.cc
+2
-3
paddle/fluid/operators/save_op.cc
paddle/fluid/operators/save_op.cc
+1
-1
paddle/fluid/operators/scale_op.cc
paddle/fluid/operators/scale_op.cc
+1
-7
paddle/fluid/operators/selu_op.cc
paddle/fluid/operators/selu_op.cc
+3
-2
paddle/fluid/operators/softmax_op.cc
paddle/fluid/operators/softmax_op.cc
+3
-2
paddle/fluid/operators/split_selected_rows_op.cc
paddle/fluid/operators/split_selected_rows_op.cc
+2
-3
paddle/fluid/operators/sum_op.cc
paddle/fluid/operators/sum_op.cc
+26
-33
paddle/fluid/operators/tensor_array_to_tensor_op.cc
paddle/fluid/operators/tensor_array_to_tensor_op.cc
+3
-3
paddle/fluid/operators/uniform_random_op.cc
paddle/fluid/operators/uniform_random_op.cc
+3
-5
未找到文件。
paddle/fluid/framework/ir/graph_test.cc
浏览文件 @
91ae7848
...
...
@@ -45,19 +45,13 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
class
SumOpVarTypeInference
:
public
VarTypeInference
{
public:
void
operator
()(
InferVarTypeContext
*
ctx
)
const
override
{
auto
&
inputs
=
ctx
->
Input
(
"X"
);
auto
default_var_type
=
proto
::
VarType
::
SELECTED_ROWS
;
bool
any_input_is_lod_tensor
=
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
ctx
](
const
std
::
string
&
name
)
{
return
ctx
->
GetType
(
name
)
==
proto
::
VarType
::
LOD_TENSOR
;
});
if
(
any_input_is_lod_tensor
)
{
if
(
ctx
->
InputTypeAnyOf
(
"X"
,
proto
::
VarType
::
LOD_TENSOR
))
{
default_var_type
=
proto
::
VarType
::
LOD_TENSOR
;
}
auto
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
ctx
->
SetType
(
out_var_name
,
default_var_type
);
ctx
->
SetOutputType
(
"Out"
,
default_var_type
);
}
};
...
...
paddle/fluid/framework/var_type_inference.h
浏览文件 @
91ae7848
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <string>
#include <unordered_map>
#include <vector>
...
...
@@ -25,8 +26,14 @@ namespace framework {
class
OpDesc
;
class
BlockDesc
;
class
StaticGraphVarTypeInference
;
// default infer var type context
static
const
int
ALL_ELEMENTS
=
-
1
;
class
InferVarTypeContext
{
friend
class
StaticGraphVarTypeInference
;
public:
InferVarTypeContext
(
const
OpDesc
*
op
,
BlockDesc
*
block
)
:
op_
(
op
),
block_
(
block
)
{}
...
...
@@ -34,91 +41,267 @@ class InferVarTypeContext {
virtual
~
InferVarTypeContext
()
{}
virtual
Attribute
GetAttr
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
);
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
return
op_
->
GetAttr
(
name
);
}
virtual
bool
HasVar
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
block_
);
return
block_
->
FindVarRecursive
(
name
)
!=
nullptr
;
}
virtual
bool
HasInput
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
);
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
auto
&
inputs
=
op_
->
Inputs
();
auto
input
=
inputs
.
find
(
name
);
return
input
!=
inputs
.
end
()
&&
!
input
->
second
.
empty
();
}
virtual
bool
HasOutput
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
);
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
auto
&
outputs
=
op_
->
Outputs
();
auto
output
=
outputs
.
find
(
name
);
return
output
!=
outputs
.
end
()
&&
!
output
->
second
.
empty
();
}
virtual
const
std
::
vector
<
std
::
string
>&
Input
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
);
virtual
size_t
InputSize
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
return
op_
->
Inputs
().
at
(
name
).
size
();
}
virtual
const
std
::
string
&
InputVarName
(
const
std
::
string
&
name
,
const
int
index
=
0
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
return
op_
->
Inputs
().
at
(
name
)[
index
];
}
virtual
bool
InputTypeAnyOf
(
const
std
::
string
&
name
,
proto
::
VarType
::
Type
type
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
auto
&
inputs
=
op_
->
Input
(
name
);
return
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
this
,
&
type
](
const
std
::
string
&
name
)
{
return
this
->
GetVarType
(
name
)
==
type
;
});
}
virtual
bool
InputTypeAllOf
(
const
std
::
string
&
name
,
proto
::
VarType
::
Type
type
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
auto
&
inputs
=
op_
->
Input
(
name
);
return
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
this
,
&
type
](
const
std
::
string
&
name
)
{
return
this
->
GetVarType
(
name
)
==
type
;
});
}
virtual
void
SyncTypeAndDataType
(
const
std
::
string
&
input_name
,
const
std
::
string
&
output_name
,
int
index
=
0
)
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
auto
&
x_name
=
op_
->
Input
(
input_name
).
at
(
index
);
auto
&
out_name
=
op_
->
Output
(
output_name
).
at
(
index
);
if
(
x_name
!=
out_name
)
{
this
->
SetVarType
(
out_name
,
this
->
GetVarType
(
x_name
));
this
->
SetVarDataType
(
out_name
,
this
->
GetVarDataType
(
x_name
));
}
}
virtual
void
SetOutputType
(
const
std
::
string
&
name
,
proto
::
VarType
::
Type
type
,
int
index
=
0
)
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
if
(
ALL_ELEMENTS
==
index
)
{
for
(
const
auto
&
var_name
:
op_
->
Output
(
name
))
{
this
->
SetVarType
(
var_name
,
type
);
}
}
else
{
auto
&
var_name
=
op_
->
Output
(
name
).
at
(
index
);
this
->
SetVarType
(
var_name
,
type
);
}
}
virtual
proto
::
VarType
::
Type
GetInputType
(
const
std
::
string
&
name
,
const
int
&
index
=
0
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
return
this
->
GetVarType
(
op_
->
Input
(
name
).
at
(
index
));
}
virtual
proto
::
VarType
::
Type
GetOutputType
(
const
std
::
string
&
name
,
const
int
&
index
=
0
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
return
this
->
GetVarType
(
op_
->
Output
(
name
).
at
(
index
));
}
virtual
proto
::
VarType
::
Type
GetInputDataType
(
const
std
::
string
&
name
,
const
int
&
index
=
0
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
return
this
->
GetVarDataType
(
op_
->
Input
(
name
).
at
(
index
));
}
virtual
void
SetOutputDataType
(
const
std
::
string
&
name
,
proto
::
VarType
::
Type
type
,
int
index
=
0
)
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
if
(
ALL_ELEMENTS
==
index
)
{
for
(
const
auto
&
var_name
:
op_
->
Output
(
name
))
{
this
->
SetVarDataType
(
var_name
,
type
);
}
}
else
{
auto
&
var_name
=
op_
->
Output
(
name
).
at
(
index
);
this
->
SetVarDataType
(
var_name
,
type
);
}
}
virtual
std
::
vector
<
proto
::
VarType
::
Type
>
GetInputDataTypes
(
const
std
::
string
&
name
,
const
int
&
index
=
0
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
return
this
->
GetVarDataTypes
(
op_
->
Input
(
name
).
at
(
index
));
}
virtual
void
SetOutputDataTypes
(
const
std
::
string
&
name
,
const
std
::
vector
<
proto
::
VarType
::
Type
>&
multiple_data_type
,
const
int
&
index
=
0
)
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
auto
&
var_name
=
op_
->
Output
(
name
).
at
(
index
);
this
->
SetVarDataTypes
(
var_name
,
multiple_data_type
);
}
virtual
std
::
vector
<
int64_t
>
GetInputShape
(
const
std
::
string
&
name
,
const
int
&
index
=
0
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
auto
&
var_name
=
op_
->
Input
(
name
).
at
(
index
);
return
this
->
GetVarShape
(
var_name
);
}
virtual
void
SetOutputShape
(
const
std
::
string
&
name
,
const
std
::
vector
<
int64_t
>&
dims
,
const
int
&
index
=
0
)
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
auto
&
var_name
=
op_
->
Output
(
name
).
at
(
index
);
this
->
SetVarShape
(
var_name
,
dims
);
}
virtual
int32_t
GetInputLoDLevel
(
const
std
::
string
&
name
,
const
int
&
index
=
0
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
auto
&
var_name
=
op_
->
Input
(
name
).
at
(
index
);
return
this
->
GetVarLoDLevel
(
var_name
);
}
virtual
void
SetOutputLoDLevel
(
const
std
::
string
&
name
,
int32_t
lod_level
,
const
int
&
index
=
0
)
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
auto
&
var_name
=
op_
->
Output
(
name
).
at
(
index
);
this
->
SetVarLoDLevel
(
var_name
,
lod_level
);
}
// add a speical API for save_op
// avoid use this API for common logic
virtual
void
InsertVar
(
const
std
::
string
&
var_name
,
proto
::
VarType
::
Type
var_type
)
{
if
(
!
IsDygraph
())
this
->
SetVarType
(
var_name
,
var_type
);
}
virtual
bool
IsDygraph
()
const
{
return
false
;
}
protected:
virtual
bool
HasVar
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
block_
,
platform
::
errors
::
PreconditionNotMet
(
"block_ should not be null"
));
return
block_
->
FindVarRecursive
(
name
)
!=
nullptr
;
}
virtual
const
std
::
vector
<
std
::
string
>&
InputVars
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
return
op_
->
Input
(
name
);
}
virtual
const
std
::
vector
<
std
::
string
>&
Output
(
virtual
const
std
::
vector
<
std
::
string
>&
Output
Vars
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
);
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
return
op_
->
Output
(
name
);
}
virtual
proto
::
VarType
::
Type
GetType
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
block_
);
virtual
proto
::
VarType
::
Type
GetVarType
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
block_
,
platform
::
errors
::
PreconditionNotMet
(
"block_ should not be null"
));
return
block_
->
FindRecursiveOrCreateVar
(
name
).
GetType
();
}
virtual
void
SetType
(
const
std
::
string
&
name
,
proto
::
VarType
::
Type
type
)
{
PADDLE_ENFORCE_NOT_NULL
(
block_
);
virtual
void
SetVarType
(
const
std
::
string
&
name
,
proto
::
VarType
::
Type
type
)
{
PADDLE_ENFORCE_NOT_NULL
(
block_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
block_
->
FindRecursiveOrCreateVar
(
name
).
SetType
(
type
);
}
virtual
proto
::
VarType
::
Type
GetDataType
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
block_
);
virtual
proto
::
VarType
::
Type
GetVarDataType
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
block_
,
platform
::
errors
::
PreconditionNotMet
(
"block_ should not be null"
));
return
block_
->
FindRecursiveOrCreateVar
(
name
).
GetDataType
();
}
virtual
void
SetDataType
(
const
std
::
string
&
name
,
proto
::
VarType
::
Type
type
)
{
PADDLE_ENFORCE_NOT_NULL
(
block_
);
virtual
void
SetVarDataType
(
const
std
::
string
&
name
,
proto
::
VarType
::
Type
type
)
{
PADDLE_ENFORCE_NOT_NULL
(
block_
,
platform
::
errors
::
PreconditionNotMet
(
"block_ should not be null"
));
block_
->
FindRecursiveOrCreateVar
(
name
).
SetDataType
(
type
);
}
virtual
std
::
vector
<
proto
::
VarType
::
Type
>
GetDataTypes
(
virtual
std
::
vector
<
proto
::
VarType
::
Type
>
Get
Var
DataTypes
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
block_
);
PADDLE_ENFORCE_NOT_NULL
(
block_
,
platform
::
errors
::
PreconditionNotMet
(
"block_ should not be null"
));
return
block_
->
FindRecursiveOrCreateVar
(
name
).
GetDataTypes
();
}
virtual
void
SetDataTypes
(
virtual
void
Set
Var
DataTypes
(
const
std
::
string
&
name
,
const
std
::
vector
<
proto
::
VarType
::
Type
>&
multiple_data_type
)
{
PADDLE_ENFORCE_NOT_NULL
(
block_
);
PADDLE_ENFORCE_NOT_NULL
(
block_
,
platform
::
errors
::
PreconditionNotMet
(
"block_ should not be null"
));
block_
->
FindRecursiveOrCreateVar
(
name
).
SetDataTypes
(
multiple_data_type
);
}
virtual
std
::
vector
<
int64_t
>
GetShape
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
block_
);
virtual
std
::
vector
<
int64_t
>
GetVarShape
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
block_
,
platform
::
errors
::
PreconditionNotMet
(
"block_ should not be null"
));
return
block_
->
FindRecursiveOrCreateVar
(
name
).
GetShape
();
}
virtual
void
SetShape
(
const
std
::
string
&
name
,
const
std
::
vector
<
int64_t
>&
dims
)
{
PADDLE_ENFORCE_NOT_NULL
(
block_
);
virtual
void
SetVarShape
(
const
std
::
string
&
name
,
const
std
::
vector
<
int64_t
>&
dims
)
{
PADDLE_ENFORCE_NOT_NULL
(
block_
,
platform
::
errors
::
PreconditionNotMet
(
"block_ should not be null"
));
block_
->
FindRecursiveOrCreateVar
(
name
).
SetShape
(
dims
);
}
virtual
int32_t
GetLoDLevel
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
block_
);
virtual
int32_t
GetVarLoDLevel
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
block_
,
platform
::
errors
::
PreconditionNotMet
(
"block_ should not be null"
));
return
block_
->
FindRecursiveOrCreateVar
(
name
).
GetLoDLevel
();
}
virtual
void
SetLoDLevel
(
const
std
::
string
&
name
,
int32_t
lod_level
)
{
PADDLE_ENFORCE_NOT_NULL
(
block_
);
virtual
void
SetVarLoDLevel
(
const
std
::
string
&
name
,
int32_t
lod_level
)
{
PADDLE_ENFORCE_NOT_NULL
(
block_
,
platform
::
errors
::
PreconditionNotMet
(
"block_ should not be null"
));
block_
->
FindRecursiveOrCreateVar
(
name
).
SetLoDLevel
(
lod_level
);
}
...
...
@@ -133,22 +316,85 @@ class VarTypeInference {
virtual
void
operator
()(
InferVarTypeContext
*
context
)
const
=
0
;
// NOLINT
};
class
StaticGraphVarTypeInference
:
public
VarTypeInference
{
protected:
bool
HasVar
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
ctx
->
HasVar
(
name
);
}
const
std
::
vector
<
std
::
string
>&
Input
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
ctx
->
InputVars
(
name
);
}
const
std
::
vector
<
std
::
string
>&
Output
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
ctx
->
OutputVars
(
name
);
}
proto
::
VarType
::
Type
GetType
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
ctx
->
GetVarType
(
name
);
}
void
SetType
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
,
proto
::
VarType
::
Type
type
)
const
{
ctx
->
SetVarType
(
name
,
type
);
}
proto
::
VarType
::
Type
GetDataType
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
ctx
->
GetVarDataType
(
name
);
}
void
SetDataType
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
,
proto
::
VarType
::
Type
type
)
const
{
ctx
->
SetVarDataType
(
name
,
type
);
}
std
::
vector
<
proto
::
VarType
::
Type
>
GetDataTypes
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
ctx
->
GetVarDataTypes
(
name
);
}
void
SetDataTypes
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
,
const
std
::
vector
<
proto
::
VarType
::
Type
>&
multiple_data_type
)
{
return
ctx
->
SetVarDataTypes
(
name
,
multiple_data_type
);
}
std
::
vector
<
int64_t
>
GetShape
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
ctx
->
GetVarShape
(
name
);
}
void
SetShape
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
,
const
std
::
vector
<
int64_t
>&
dims
)
const
{
ctx
->
SetVarShape
(
name
,
dims
);
}
int32_t
GetLoDLevel
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
ctx
->
GetVarLoDLevel
(
name
);
}
void
SetLoDLevel
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
,
int32_t
lod_level
)
const
{
ctx
->
SetVarLoDLevel
(
name
,
lod_level
);
}
};
class
PassInDtypeAndVarTypeToOutput
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
final
{
// NOLINT
auto
in_out_var_names
=
this
->
GetInputOutputWithSameType
();
auto
&
in_out_var_names
=
this
->
GetInputOutputWithSameType
();
for
(
auto
&
i_o_n
:
in_out_var_names
)
{
auto
&
x_name
=
ctx
->
Input
(
i_o_n
.
first
).
at
(
0
);
auto
&
out_name
=
ctx
->
Output
(
i_o_n
.
second
).
at
(
0
);
ctx
->
SetType
(
out_name
,
ctx
->
GetType
(
x_name
));
ctx
->
SetDataType
(
out_name
,
ctx
->
GetDataType
(
x_name
));
ctx
->
SyncTypeAndDataType
(
i_o_n
.
first
,
i_o_n
.
second
);
}
}
protected:
virtual
std
::
unordered_map
<
std
::
string
,
std
::
string
>
virtual
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
=
0
;
};
...
...
paddle/fluid/framework/var_type_inference_test.cc
浏览文件 @
91ae7848
...
...
@@ -24,13 +24,13 @@ namespace framework {
class
NOP
:
public
OperatorBase
{
public:
NOP
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
NOP
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
private:
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
};
class
SumOpMaker
:
public
OpProtoAndCheckerMaker
{
...
...
@@ -44,20 +44,14 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
class
SumOpVarTypeInference
:
public
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
&
inputs
=
ctx
->
Input
(
"X"
);
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
default_var_type
=
proto
::
VarType
::
SELECTED_ROWS
;
bool
any_input_is_lod_tensor
=
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
ctx
](
const
std
::
string
&
name
)
{
return
ctx
->
GetType
(
name
)
==
proto
::
VarType
::
LOD_TENSOR
;
});
if
(
any_input_is_lod_tensor
)
{
if
(
ctx
->
InputTypeAnyOf
(
"X"
,
proto
::
VarType
::
LOD_TENSOR
))
{
default_var_type
=
proto
::
VarType
::
LOD_TENSOR
;
}
auto
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
ctx
->
SetType
(
out_var_name
,
default_var_type
);
ctx
->
SetOutputType
(
"Out"
,
default_var_type
);
}
};
}
// namespace framework
...
...
@@ -71,9 +65,79 @@ REGISTER_OPERATOR(sum_without_infer_var_type, paddle::framework::NOP,
namespace
paddle
{
namespace
framework
{
class
TestStaticGraphVarTypeInference
:
public
StaticGraphVarTypeInference
{
public:
void
operator
()(
InferVarTypeContext
*
context
)
const
override
{}
bool
HasVar
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
StaticGraphVarTypeInference
::
HasVar
(
ctx
,
name
);
}
const
std
::
vector
<
std
::
string
>&
Input
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
StaticGraphVarTypeInference
::
Input
(
ctx
,
name
);
}
const
std
::
vector
<
std
::
string
>&
Output
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
StaticGraphVarTypeInference
::
Output
(
ctx
,
name
);
}
proto
::
VarType
::
Type
GetType
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
StaticGraphVarTypeInference
::
GetType
(
ctx
,
name
);
}
void
SetType
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
,
proto
::
VarType
::
Type
type
)
const
{
StaticGraphVarTypeInference
::
SetType
(
ctx
,
name
,
type
);
}
proto
::
VarType
::
Type
GetDataType
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
StaticGraphVarTypeInference
::
GetDataType
(
ctx
,
name
);
}
void
SetDataType
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
,
proto
::
VarType
::
Type
type
)
const
{
StaticGraphVarTypeInference
::
SetDataType
(
ctx
,
name
,
type
);
}
std
::
vector
<
proto
::
VarType
::
Type
>
GetDataTypes
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
StaticGraphVarTypeInference
::
GetDataTypes
(
ctx
,
name
);
}
void
SetDataTypes
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
,
const
std
::
vector
<
proto
::
VarType
::
Type
>&
multiple_data_type
)
{
return
StaticGraphVarTypeInference
::
SetDataTypes
(
ctx
,
name
,
multiple_data_type
);
}
std
::
vector
<
int64_t
>
GetShape
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
StaticGraphVarTypeInference
::
GetShape
(
ctx
,
name
);
}
void
SetShape
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
,
const
std
::
vector
<
int64_t
>&
dims
)
const
{
StaticGraphVarTypeInference
::
SetShape
(
ctx
,
name
,
dims
);
}
int32_t
GetLoDLevel
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
StaticGraphVarTypeInference
::
GetLoDLevel
(
ctx
,
name
);
}
void
SetLoDLevel
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
,
int32_t
lod_level
)
const
{
StaticGraphVarTypeInference
::
SetLoDLevel
(
ctx
,
name
,
lod_level
);
}
};
TEST
(
InferVarType
,
sum_op
)
{
ProgramDesc
prog
;
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"sum"
);
op
->
SetInput
(
"X"
,
{
"test_a"
,
"test_b"
,
"test_c"
});
op
->
SetOutput
(
"Out"
,
{
"test_out"
});
...
...
@@ -96,7 +160,7 @@ TEST(InferVarType, sum_op) {
TEST
(
InferVarType
,
sum_op_without_infer_var_type
)
{
ProgramDesc
prog
;
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"sum_without_infer_var_type"
);
op
->
SetInput
(
"X"
,
{
"test2_a"
,
"test2_b"
,
"test2_c"
});
op
->
SetOutput
(
"Out"
,
{
"test2_out"
});
...
...
@@ -112,5 +176,112 @@ TEST(InferVarType, sum_op_without_infer_var_type) {
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_out"
)
->
GetType
());
}
TEST
(
InferVarType
,
multiple_api
)
{
ProgramDesc
prog
;
auto
*
block
=
prog
.
MutableBlock
(
0
);
auto
*
op
=
block
->
AppendOp
();
op
->
SetType
(
"sum_without_infer_var_type"
);
op
->
SetInput
(
"X"
,
{
"test2_a"
,
"test2_b"
});
op
->
SetOutput
(
"Out"
,
{
"test2_a_out"
,
"test2_b_out"
});
block
->
Var
(
"test2_a"
)
->
SetType
(
proto
::
VarType
::
SELECTED_ROWS
);
block
->
Var
(
"test2_b"
)
->
SetType
(
proto
::
VarType
::
SELECTED_ROWS
);
block
->
Var
(
"test2_a_out"
);
block
->
Var
(
"test2_b_out"
);
InferVarTypeContext
ctx
(
op
,
block
);
ASSERT_TRUE
(
ctx
.
HasInput
(
"X"
));
ASSERT_TRUE
(
ctx
.
HasOutput
(
"Out"
));
ASSERT_EQ
(
2u
,
ctx
.
InputSize
(
"X"
));
ASSERT_EQ
(
"test2_a"
,
ctx
.
InputVarName
(
"X"
,
0
));
ASSERT_EQ
(
proto
::
VarType
::
SELECTED_ROWS
,
ctx
.
GetInputType
(
"X"
));
ASSERT_TRUE
(
ctx
.
InputTypeAllOf
(
"X"
,
proto
::
VarType
::
SELECTED_ROWS
));
ASSERT_FALSE
(
ctx
.
InputTypeAnyOf
(
"X"
,
proto
::
VarType
::
LOD_TENSOR
));
ctx
.
SyncTypeAndDataType
(
"X"
,
"Out"
);
ASSERT_EQ
(
proto
::
VarType
::
SELECTED_ROWS
,
ctx
.
GetOutputType
(
"Out"
));
ASSERT_EQ
(
proto
::
VarType
::
LOD_TENSOR
,
ctx
.
GetOutputType
(
"Out"
,
1
));
ctx
.
SetOutputType
(
"Out"
,
proto
::
VarType
::
SELECTED_ROWS
,
ALL_ELEMENTS
);
ctx
.
SetOutputType
(
"Out"
,
proto
::
VarType
::
LOD_TENSOR
,
1
);
ASSERT_EQ
(
proto
::
VarType
::
SELECTED_ROWS
,
ctx
.
GetOutputType
(
"Out"
));
ASSERT_EQ
(
proto
::
VarType
::
LOD_TENSOR
,
ctx
.
GetOutputType
(
"Out"
,
1
));
ASSERT_EQ
(
0
,
ctx
.
GetInputDataType
(
"X"
));
ctx
.
SetOutputDataType
(
"Out"
,
proto
::
VarType
::
FP32
,
ALL_ELEMENTS
);
ctx
.
SetOutputDataType
(
"Out"
,
proto
::
VarType
::
INT8
,
1
);
ASSERT_EQ
(
proto
::
VarType
::
FP32
,
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_a_out"
)
->
GetDataType
());
ASSERT_EQ
(
proto
::
VarType
::
INT8
,
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_b_out"
)
->
GetDataType
());
ASSERT_FALSE
(
ctx
.
IsDygraph
());
// test StaticGraphVarTypeInference
TestStaticGraphVarTypeInference
infer
;
ASSERT_TRUE
(
infer
.
HasVar
(
&
ctx
,
"test2_a"
));
ASSERT_EQ
(
infer
.
Input
(
&
ctx
,
"X"
).
size
(),
infer
.
Output
(
&
ctx
,
"Out"
).
size
());
ASSERT_EQ
(
proto
::
VarType
::
FP32
,
infer
.
GetDataType
(
&
ctx
,
"test2_a_out"
));
infer
.
SetDataType
(
&
ctx
,
"test2_a_out"
,
proto
::
VarType
::
FP64
);
ASSERT_EQ
(
proto
::
VarType
::
FP64
,
infer
.
GetDataType
(
&
ctx
,
"test2_a_out"
));
ASSERT_EQ
(
proto
::
VarType
::
SELECTED_ROWS
,
infer
.
GetType
(
&
ctx
,
"test2_a_out"
));
infer
.
SetType
(
&
ctx
,
"test2_a_out"
,
proto
::
VarType
::
LOD_TENSOR
);
ASSERT_EQ
(
proto
::
VarType
::
LOD_TENSOR
,
infer
.
GetType
(
&
ctx
,
"test2_a_out"
));
ASSERT_ANY_THROW
(
infer
.
GetDataTypes
(
&
ctx
,
"test2_a_out"
));
ASSERT_ANY_THROW
(
infer
.
SetDataTypes
(
&
ctx
,
"test2_a_out"
,
{}));
ASSERT_EQ
(
0u
,
infer
.
GetShape
(
&
ctx
,
"test2_a_out"
).
size
());
infer
.
SetShape
(
&
ctx
,
"test2_a_out"
,
{
1
,
3
,
3
,
});
ASSERT_EQ
(
3u
,
infer
.
GetShape
(
&
ctx
,
"test2_a_out"
).
size
());
ASSERT_EQ
(
0
,
infer
.
GetLoDLevel
(
&
ctx
,
"test2_a_out"
));
infer
.
SetLoDLevel
(
&
ctx
,
"test2_a_out"
,
2
);
ASSERT_EQ
(
2
,
infer
.
GetLoDLevel
(
&
ctx
,
"test2_a_out"
));
}
TEST
(
InferVarType
,
test_enforce_check
)
{
InferVarTypeContext
ctx
(
nullptr
,
nullptr
);
ASSERT_ANY_THROW
(
ctx
.
HasInput
(
"X"
));
ASSERT_ANY_THROW
(
ctx
.
HasOutput
(
"Out"
));
ASSERT_ANY_THROW
(
ctx
.
InputSize
(
"X"
));
ASSERT_ANY_THROW
(
ctx
.
InputVarName
(
"X"
));
ASSERT_ANY_THROW
(
ctx
.
InputTypeAnyOf
(
"X"
,
proto
::
VarType
::
LOD_TENSOR
));
ASSERT_ANY_THROW
(
ctx
.
InputTypeAllOf
(
"X"
,
proto
::
VarType
::
LOD_TENSOR
));
ASSERT_ANY_THROW
(
ctx
.
SyncTypeAndDataType
(
"X"
,
"Out"
));
ASSERT_ANY_THROW
(
ctx
.
SetOutputType
(
"Out"
,
proto
::
VarType
::
LOD_TENSOR
));
ASSERT_ANY_THROW
(
ctx
.
GetInputType
(
"X"
));
ASSERT_ANY_THROW
(
ctx
.
GetOutputType
(
"Out"
));
ASSERT_ANY_THROW
(
ctx
.
GetInputDataType
(
"X"
));
ASSERT_ANY_THROW
(
ctx
.
SetOutputDataType
(
"Out"
,
proto
::
VarType
::
LOD_TENSOR
));
ASSERT_ANY_THROW
(
ctx
.
GetInputDataTypes
(
"X"
));
ASSERT_ANY_THROW
(
ctx
.
SetOutputDataTypes
(
"Out"
,
{}));
ASSERT_ANY_THROW
(
ctx
.
GetInputShape
(
"X"
));
ASSERT_ANY_THROW
(
ctx
.
SetOutputShape
(
"Out"
,
{}));
ASSERT_ANY_THROW
(
ctx
.
GetInputLoDLevel
(
"X"
));
ASSERT_ANY_THROW
(
ctx
.
SetOutputLoDLevel
(
"Out"
,
1
));
ASSERT_ANY_THROW
(
ctx
.
InsertVar
(
"var"
,
proto
::
VarType
::
LOD_TENSOR
));
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/imperative/infer_var_type_context.h
浏览文件 @
91ae7848
...
...
@@ -14,6 +14,7 @@
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
...
...
@@ -35,30 +36,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
:
InferVarTypeContext
(
nullptr
,
nullptr
),
inputs_
(
inputs
),
outputs_
(
outputs
),
attrs_
(
attrs_map
),
input_names_
(),
output_names_
(),
var_set_
()
{
input_names_
.
reserve
(
inputs_
.
size
());
for
(
auto
&
it
:
inputs_
)
{
for
(
auto
&
var
:
it
.
second
)
{
if
(
var
)
{
input_names_
[
it
.
first
].
emplace_back
(
var
->
Name
());
var_set_
[
var
->
Name
()]
=
var
.
get
();
}
}
}
output_names_
.
reserve
(
outputs_
.
size
());
for
(
auto
&
it
:
outputs_
)
{
for
(
auto
&
var
:
it
.
second
)
{
if
(
var
)
{
output_names_
[
it
.
first
].
emplace_back
(
var
->
Name
());
var_set_
[
var
->
Name
()]
=
var
.
get
();
}
}
}
}
attrs_
(
attrs_map
)
{}
virtual
~
RuntimeInferVarTypeContext
()
{}
...
...
@@ -70,10 +48,6 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
return
iter
->
second
;
}
bool
HasVar
(
const
std
::
string
&
name
)
const
override
{
return
var_set_
.
count
(
name
)
>
0
;
}
bool
HasInput
(
const
std
::
string
&
name
)
const
override
{
auto
it
=
inputs_
.
find
(
name
);
return
(
it
!=
inputs_
.
end
()
&&
it
->
second
.
size
()
>
0
);
...
...
@@ -84,93 +58,173 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
return
(
it
!=
outputs_
.
end
()
&&
it
->
second
.
size
()
>
0
);
}
const
std
::
vector
<
std
::
string
>&
Input
(
const
std
::
string
&
name
)
const
override
{
auto
iter
=
input_names_
.
find
(
name
);
PADDLE_ENFORCE_EQ
(
iter
!=
input_names_
.
end
(),
true
,
platform
::
errors
::
NotFound
(
"Cannot find input var %s"
,
name
));
return
iter
->
second
;
size_t
InputSize
(
const
std
::
string
&
name
)
const
{
return
inputs_
.
at
(
name
).
size
();
}
const
std
::
vector
<
std
::
string
>&
Output
(
const
std
::
string
&
name
)
const
override
{
auto
iter
=
output_names_
.
find
(
name
);
const
std
::
string
&
InputVarName
(
const
std
::
string
&
name
,
const
int
index
=
0
)
const
{
return
inputs_
.
at
(
name
)[
index
]
->
Name
();
}
PADDLE_ENFORCE_EQ
(
iter
!=
output_names_
.
end
(),
true
,
platform
::
errors
::
NotFound
(
"Cannot find output var %s"
,
name
));
return
iter
->
second
;
bool
InputTypeAnyOf
(
const
std
::
string
&
name
,
framework
::
proto
::
VarType
::
Type
type
)
const
override
{
auto
&
inputs
=
inputs_
.
at
(
name
);
return
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
type
](
const
std
::
shared_ptr
<
VarType
>&
var
)
{
return
var
->
Type
()
==
type
;
});
}
framework
::
proto
::
VarType
::
Type
GetType
(
const
std
::
string
&
name
)
const
override
{
auto
iter
=
var_set_
.
find
(
name
);
bool
InputTypeAllOf
(
const
std
::
string
&
name
,
framework
::
proto
::
VarType
::
Type
type
)
const
override
{
auto
&
inputs
=
inputs_
.
at
(
name
);
return
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
type
](
const
std
::
shared_ptr
<
VarType
>&
var
)
{
return
var
->
Type
()
==
type
;
});
}
PADDLE_ENFORCE_EQ
(
iter
!=
var_set_
.
end
(),
true
,
platform
::
errors
::
NotFound
(
"Cannot find var %s in GetType"
,
name
));
return
iter
->
second
->
Type
();
void
SyncTypeAndDataType
(
const
std
::
string
&
input_name
,
const
std
::
string
&
output_name
,
int
index
=
0
)
override
{
auto
in_var
=
inputs_
.
at
(
input_name
)[
index
];
auto
out_var
=
outputs_
.
at
(
output_name
)[
index
];
if
(
in_var
!=
out_var
)
{
this
->
SetVarBaseType
(
out_var
,
in_var
->
Type
());
this
->
SetVarBaseDataType
(
out_var
,
in_var
->
DataType
());
}
}
void
SetType
(
const
std
::
string
&
name
,
framework
::
proto
::
VarType
::
Type
type
)
override
{
if
(
name
==
"kLookupTablePath"
)
{
VLOG
(
2
)
<<
"SUPER UGLY FIX, remove this when move imperative mode in C++"
;
void
SetOutputType
(
const
std
::
string
&
name
,
framework
::
proto
::
VarType
::
Type
type
,
int
index
=
0
)
override
{
if
(
index
==
framework
::
ALL_ELEMENTS
)
{
for
(
auto
&
item
:
outputs_
.
at
(
name
))
{
this
->
SetVarBaseType
(
item
,
type
);
}
}
else
{
var_set_
[
name
]
->
SetType
(
type
);
if
((
var_set_
[
name
]
->
MutableVar
()
->
IsInitialized
()
==
true
)
&&
(
var_set_
[
name
]
->
MutableVar
()
->
Type
()
!=
type
))
{
var_set_
[
name
]
->
MutableVar
()
->
Clear
();
auto
&
var
=
outputs_
.
at
(
name
)[
index
];
this
->
SetVarBaseType
(
var
,
type
);
}
}
void
SetVarBaseType
(
std
::
shared_ptr
<
VarType
>
out
,
framework
::
proto
::
VarType
::
Type
type
)
{
out
->
SetType
(
type
);
if
((
out
->
MutableVar
()
->
IsInitialized
()
==
true
)
&&
(
out
->
MutableVar
()
->
Type
()
!=
type
))
{
out
->
MutableVar
()
->
Clear
();
}
}
void
SetVarBaseDataType
(
std
::
shared_ptr
<
VarType
>
out
,
framework
::
proto
::
VarType
::
Type
type
)
{
out
->
SetDataType
(
type
);
}
framework
::
proto
::
VarType
::
Type
GetInputType
(
const
std
::
string
&
name
,
const
int
&
index
=
0
)
const
override
{
return
inputs_
.
at
(
name
)[
index
]
->
Type
();
}
framework
::
proto
::
VarType
::
Type
GetOutputType
(
const
std
::
string
&
name
,
const
int
&
index
=
0
)
const
override
{
return
outputs_
.
at
(
name
)[
index
]
->
Type
();
}
framework
::
proto
::
VarType
::
Type
GetInputDataType
(
const
std
::
string
&
name
,
const
int
&
index
=
0
)
const
override
{
return
inputs_
.
at
(
name
)[
index
]
->
DataType
();
}
void
SetOutputDataType
(
const
std
::
string
&
name
,
framework
::
proto
::
VarType
::
Type
type
,
int
index
=
0
)
override
{
if
(
framework
::
ALL_ELEMENTS
==
index
)
{
for
(
auto
&
item
:
outputs_
.
at
(
name
))
{
this
->
SetVarBaseDataType
(
item
,
type
);
}
}
else
{
auto
&
var
=
outputs_
.
at
(
name
)[
index
];
this
->
SetVarBaseDataType
(
var
,
type
);
}
}
framework
::
proto
::
VarType
::
Type
GetDataType
(
bool
IsDygraph
()
const
override
{
return
true
;
}
protected:
bool
HasVar
(
const
std
::
string
&
name
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"HasVar is not supported in runtime InferVarType"
));
}
const
std
::
vector
<
std
::
string
>&
InputVars
(
const
std
::
string
&
name
)
const
override
{
auto
iter
=
var_set_
.
find
(
name
);
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"InputVars is not supported in runtime InferVarType"
));
}
PADDLE_ENFORCE_EQ
(
iter
!=
var_set_
.
end
(),
true
,
platform
::
errors
::
NotFound
(
"Cannot find var %s in GetDataType"
,
name
));
return
iter
->
second
->
DataType
();
const
std
::
vector
<
std
::
string
>&
OutputVars
(
const
std
::
string
&
name
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"OutputVars is not supported in runtime InferVarType"
));
}
framework
::
proto
::
VarType
::
Type
GetVarType
(
const
std
::
string
&
name
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Do not manipulate var in runtime InferVarType"
));
}
void
SetVarType
(
const
std
::
string
&
name
,
framework
::
proto
::
VarType
::
Type
type
)
override
{
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Do not manipulate var in runtime InferVarType"
));
}
void
SetDataType
(
const
std
::
string
&
name
,
framework
::
proto
::
VarType
::
Type
type
)
override
{
var_set_
[
name
]
->
SetDataType
(
type
);
framework
::
proto
::
VarType
::
Type
GetVarDataType
(
const
std
::
string
&
name
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Do not manipulate var in runtime InferVarType"
));
}
void
SetVarDataType
(
const
std
::
string
&
name
,
framework
::
proto
::
VarType
::
Type
type
)
override
{
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Do not manipulate var in runtime InferVarType"
));
}
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
GetDataTypes
(
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
Get
Var
DataTypes
(
const
std
::
string
&
name
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"GetDataTypes is not supported in runtime InferVarType"
));
"Get
Var
DataTypes is not supported in runtime InferVarType"
));
}
void
SetDataTypes
(
const
std
::
string
&
name
,
const
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>&
multiple_data_type
)
override
{
void
Set
Var
DataTypes
(
const
std
::
string
&
name
,
const
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>&
multiple_data_type
)
override
{
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"SetDataTypes is not supported in runtime InferVarType"
));
"Set
Var
DataTypes is not supported in runtime InferVarType"
));
}
std
::
vector
<
int64_t
>
GetShape
(
const
std
::
string
&
name
)
const
override
{
std
::
vector
<
int64_t
>
Get
Var
Shape
(
const
std
::
string
&
name
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Do not handle Shape in runtime InferVarType"
));
}
void
SetShape
(
const
std
::
string
&
name
,
const
std
::
vector
<
int64_t
>&
dims
)
override
{
void
Set
Var
Shape
(
const
std
::
string
&
name
,
const
std
::
vector
<
int64_t
>&
dims
)
override
{
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Do not handle Shape in runtime InferVarType"
));
}
int32_t
GetLoDLevel
(
const
std
::
string
&
name
)
const
override
{
int32_t
Get
Var
LoDLevel
(
const
std
::
string
&
name
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Do not handle LoDLevel in runtime InferVarType"
));
}
void
SetLoDLevel
(
const
std
::
string
&
name
,
int32_t
lod_level
)
override
{
void
Set
Var
LoDLevel
(
const
std
::
string
&
name
,
int32_t
lod_level
)
override
{
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Do not handle LoDLevel in runtime InferVarType"
));
}
...
...
@@ -179,9 +233,6 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
const
NameVarMap
<
VarType
>&
inputs_
;
const
NameVarMap
<
VarType
>&
outputs_
;
const
framework
::
AttributeMap
&
attrs_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
input_names_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
output_names_
;
std
::
unordered_map
<
std
::
string
,
VarType
*>
var_set_
;
};
}
// namespace imperative
...
...
paddle/fluid/imperative/tests/test_layer.cc
浏览文件 @
91ae7848
...
...
@@ -37,33 +37,154 @@ using vb_vector = std::vector<std::shared_ptr<imperative::VarBase>>;
using
var_pair
=
std
::
pair
<
std
::
string
,
vb_vector
>
;
template
<
typename
VarType
>
class
TestRuntimeInferVarTypeContext
:
public
RuntimeInferVarTypeContext
<
VarType
>
{
public:
TestRuntimeInferVarTypeContext
(
const
NameVarMap
<
VarType
>&
inputs
,
const
NameVarMap
<
VarType
>&
outputs
,
const
framework
::
AttributeMap
&
attrs_map
)
:
RuntimeInferVarTypeContext
<
VarType
>
(
inputs
,
outputs
,
attrs_map
)
{}
bool
HasVar
(
const
std
::
string
&
name
)
const
{
return
RuntimeInferVarTypeContext
<
VarType
>::
HasVar
(
name
);
}
const
std
::
vector
<
std
::
string
>&
InputVars
(
const
std
::
string
&
name
)
const
{
return
RuntimeInferVarTypeContext
<
VarType
>::
InputVars
(
name
);
}
const
std
::
vector
<
std
::
string
>&
OutputVars
(
const
std
::
string
&
name
)
const
{
return
RuntimeInferVarTypeContext
<
VarType
>::
OutputVars
(
name
);
}
framework
::
proto
::
VarType
::
Type
GetVarType
(
const
std
::
string
&
name
)
const
{
return
RuntimeInferVarTypeContext
<
VarType
>::
GetVarType
(
name
);
}
void
SetVarType
(
const
std
::
string
&
name
,
framework
::
proto
::
VarType
::
Type
type
)
{
RuntimeInferVarTypeContext
<
VarType
>::
SetVarType
(
name
,
type
);
}
framework
::
proto
::
VarType
::
Type
GetVarDataType
(
const
std
::
string
&
name
)
const
{
return
RuntimeInferVarTypeContext
<
VarType
>::
GetVarDataType
(
name
);
}
void
SetVarDataType
(
const
std
::
string
&
name
,
framework
::
proto
::
VarType
::
Type
type
)
{
RuntimeInferVarTypeContext
<
VarType
>::
SetVarDataType
(
name
,
type
);
}
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
GetVarDataTypes
(
const
std
::
string
&
name
)
const
{
return
RuntimeInferVarTypeContext
<
VarType
>::
GetVarDataTypes
(
name
);
}
void
SetVarDataTypes
(
const
std
::
string
&
name
,
const
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>&
multiple_data_type
)
{
RuntimeInferVarTypeContext
<
VarType
>::
SetVarDataTypes
(
name
,
multiple_data_type
);
}
std
::
vector
<
int64_t
>
GetVarShape
(
const
std
::
string
&
name
)
const
{
return
RuntimeInferVarTypeContext
<
VarType
>::
GetVarShape
(
name
);
}
void
SetVarShape
(
const
std
::
string
&
name
,
const
std
::
vector
<
int64_t
>&
dims
)
{
RuntimeInferVarTypeContext
<
VarType
>::
SetVarShape
(
name
,
dims
);
}
int32_t
GetVarLoDLevel
(
const
std
::
string
&
name
)
const
{
return
RuntimeInferVarTypeContext
<
VarType
>::
GetVarLoDLevel
(
name
);
}
void
SetVarLoDLevel
(
const
std
::
string
&
name
,
int32_t
lod_level
)
{
RuntimeInferVarTypeContext
<
VarType
>::
SetVarLoDLevel
(
name
,
lod_level
);
}
};
TEST
(
test_layer
,
test_runtime_context
)
{
std
::
shared_ptr
<
imperative
::
VarBase
>
vin
(
new
imperative
::
VarBase
(
false
,
"vin"
));
std
::
shared_ptr
<
imperative
::
VarBase
>
vin_b
(
new
imperative
::
VarBase
(
false
,
"vin_b"
));
std
::
shared_ptr
<
imperative
::
VarBase
>
vout
(
new
imperative
::
VarBase
(
false
,
"vout"
));
var_pair
in_pair
=
var_pair
(
"X"
,
vb_vector
(
1
,
vin
));
var_pair
out_pair
=
var_pair
(
"Out"
,
vb_vector
(
1
,
vout
));
std
::
shared_ptr
<
imperative
::
VarBase
>
vout_b
(
new
imperative
::
VarBase
(
false
,
"vout_b"
));
var_pair
in_pair
=
var_pair
(
"X"
,
{
vin
,
vin_b
});
var_pair
out_pair
=
var_pair
(
"Out"
,
{
vout
,
vout_b
});
imperative
::
NameVarBaseMap
ins
=
{
in_pair
};
imperative
::
NameVarBaseMap
outs
=
{
out_pair
};
framework
::
AttributeMap
attrs
;
auto
*
ctx
=
new
imperative
::
RuntimeInferVarTypeContext
<
imperative
::
VarBase
>
(
ins
,
outs
,
attrs
);
ASSERT_TRUE
(
ctx
->
HasVar
(
"vin"
));
auto
*
ctx
=
new
imperative
::
TestRuntimeInferVarTypeContext
<
imperative
::
VarBase
>
(
ins
,
outs
,
attrs
);
ASSERT_TRUE
(
ctx
->
HasInput
(
"X"
));
ASSERT_TRUE
(
ctx
->
HasOutput
(
"Out"
));
ASSERT_ANY_THROW
(
ctx
->
GetDataTypes
(
"vin"
));
ASSERT_EQ
(
2u
,
ctx
->
InputSize
(
"X"
));
ASSERT_EQ
(
"vin"
,
ctx
->
InputVarName
(
"X"
,
0
));
ASSERT_TRUE
(
ctx
->
InputTypeAnyOf
(
"X"
,
framework
::
proto
::
VarType
::
LOD_TENSOR
));
ASSERT_TRUE
(
ctx
->
InputTypeAllOf
(
"X"
,
framework
::
proto
::
VarType
::
LOD_TENSOR
));
ASSERT_EQ
(
framework
::
proto
::
VarType
::
LOD_TENSOR
,
ctx
->
GetInputType
(
"X"
));
ASSERT_EQ
(
framework
::
proto
::
VarType
::
FP32
,
ctx
->
GetInputDataType
(
"X"
));
ctx
->
SyncTypeAndDataType
(
"X"
,
"Out"
);
ASSERT_EQ
(
framework
::
proto
::
VarType
::
FP32
,
vout
->
DataType
());
ASSERT_EQ
(
framework
::
proto
::
VarType
::
LOD_TENSOR
,
ctx
->
GetOutputType
(
"Out"
));
ctx
->
SetOutputType
(
"Out"
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
,
framework
::
ALL_ELEMENTS
);
ctx
->
SetOutputType
(
"Out"
,
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
);
ASSERT_EQ
(
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
,
vout
->
Type
());
ASSERT_EQ
(
framework
::
proto
::
VarType
::
SELECTED_ROWS
,
vout_b
->
Type
());
ctx
->
SetOutputDataType
(
"Out"
,
framework
::
proto
::
VarType
::
FP64
,
framework
::
ALL_ELEMENTS
);
ctx
->
SetOutputDataType
(
"Out"
,
framework
::
proto
::
VarType
::
INT8
);
ASSERT_EQ
(
framework
::
proto
::
VarType
::
INT8
,
vout
->
DataType
());
ASSERT_EQ
(
framework
::
proto
::
VarType
::
FP64
,
vout_b
->
DataType
());
// no throw, but do nothing
ASSERT_NO_THROW
(
ctx
->
InsertVar
(
"vout"
,
framework
::
proto
::
VarType
::
LOD_TENSOR
));
ASSERT_EQ
(
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
,
vout
->
Type
());
ASSERT_ANY_THROW
(
ctx
->
HasVar
(
"vin"
));
ASSERT_ANY_THROW
(
ctx
->
InputVars
(
"X"
));
ASSERT_ANY_THROW
(
ctx
->
OutputVars
(
"Out"
));
ASSERT_ANY_THROW
(
ctx
->
GetVarType
(
"vin"
));
ASSERT_ANY_THROW
(
ctx
->
SetVarType
(
"vin"
,
framework
::
proto
::
VarType
::
LOD_TENSOR
));
ASSERT_ANY_THROW
(
ctx
->
GetVarDataType
(
"vin"
));
ASSERT_ANY_THROW
(
ctx
->
SetVarDataType
(
"vout"
,
framework
::
proto
::
VarType
::
FP32
));
ASSERT_ANY_THROW
(
ctx
->
GetVarDataTypes
(
"vin"
));
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
NullType
;
ASSERT_ANY_THROW
(
ctx
->
SetDataTypes
(
"vin"
,
NullType
));
ASSERT_ANY_THROW
(
ctx
->
GetShape
(
"vin"
));
ASSERT_ANY_THROW
(
ctx
->
GetLoDLevel
(
"vin"
));
ASSERT_ANY_THROW
(
ctx
->
SetLoDLevel
(
"vin"
,
2
));
ASSERT_ANY_THROW
(
ctx
->
SetVarDataTypes
(
"vin"
,
NullType
));
ASSERT_ANY_THROW
(
ctx
->
GetVarShape
(
"vin"
));
ASSERT_ANY_THROW
(
ctx
->
SetVarShape
(
"vin"
,
{}));
ASSERT_ANY_THROW
(
ctx
->
GetVarLoDLevel
(
"vin"
));
ASSERT_ANY_THROW
(
ctx
->
SetVarLoDLevel
(
"vin"
,
2
));
ASSERT_TRUE
(
ctx
->
IsDygraph
());
}
std
::
string
LayerDebugString
(
const
std
::
string
&
op_type
,
const
NameVarBaseMap
&
ins
,
const
NameVarBaseMap
&
outs
);
std
::
string
LayerDebugString
(
const
std
::
string
&
op_type
,
const
NameVarBaseMap
&
ins
,
const
NameVarBaseMap
&
outs
);
TEST
(
test_layer
,
test_debug_string
)
{
platform
::
CPUPlace
place
;
...
...
@@ -71,7 +192,7 @@ TEST(test_layer, test_debug_string) {
new
imperative
::
VarBase
(
false
,
"vin"
));
var_pair
in_pair
=
var_pair
(
"X"
,
vb_vector
(
1
,
vin
));
auto
test_func
=
[
&
](
std
::
shared_ptr
<
imperative
::
VarBase
>
&
vout
)
{
auto
test_func
=
[
&
](
std
::
shared_ptr
<
imperative
::
VarBase
>
&
vout
)
{
var_pair
out_pair
=
var_pair
(
"Out"
,
vb_vector
(
1
,
vout
));
imperative
::
NameVarBaseMap
ins
=
{
in_pair
};
imperative
::
NameVarBaseMap
outs
=
{
out_pair
};
...
...
@@ -124,26 +245,26 @@ TEST(test_layer, test_debug_string) {
}
static
std
::
shared_ptr
<
imperative
::
GradOpNode
>
CreateGradNode
(
size_t
id
,
const
std
::
string
&
type
,
const
imperative
::
NameVarBaseMap
&
ins
,
const
imperative
::
NameVarBaseMap
&
outs
,
const
framework
::
AttributeMap
&
attrs
,
const
platform
::
Place
&
place
)
{
size_t
id
,
const
std
::
string
&
type
,
const
imperative
::
NameVarBaseMap
&
ins
,
const
imperative
::
NameVarBaseMap
&
outs
,
const
framework
::
AttributeMap
&
attrs
,
const
platform
::
Place
&
place
)
{
auto
node
=
std
::
make_shared
<
imperative
::
GradOpNode
>
();
auto
*
op
=
&
(
node
->
emplace_back
());
auto
*
op
=
&
(
node
->
emplace_back
());
op
->
SetId
(
id
);
op
->
SetPlace
(
place
);
op
->
SetType
(
type
);
op
->
SetAttrMap
(
attrs
);
for
(
auto
&
pair
:
ins
)
{
for
(
auto
&
pair
:
ins
)
{
std
::
vector
<
std
::
shared_ptr
<
VariableWrapper
>>
vars
;
for
(
auto
&
var
:
pair
.
second
)
{
for
(
auto
&
var
:
pair
.
second
)
{
vars
.
emplace_back
(
var
->
SharedVar
());
}
op
->
SetInput
(
pair
.
first
,
vars
,
false
);
}
for
(
auto
&
pair
:
outs
)
{
for
(
auto
&
pair
:
outs
)
{
std
::
vector
<
std
::
shared_ptr
<
VariableWrapper
>>
vars
;
for
(
auto
&
var
:
pair
.
second
)
{
for
(
auto
&
var
:
pair
.
second
)
{
vars
.
emplace_back
(
var
->
SharedVar
());
}
op
->
SetOutput
(
pair
.
first
,
vars
,
false
);
...
...
@@ -173,7 +294,7 @@ TEST(test_layer, test_clear_backward_info) {
node
->
InsertGradPendingNode
(
pending_node
);
ASSERT_EQ
(
node
->
size
(),
1UL
);
auto
*
op
=
&
(
node
->
back
());
auto
*
op
=
&
(
node
->
back
());
ASSERT_GT
(
op
->
GetInsMap
().
size
(),
0UL
);
ASSERT_GT
(
op
->
GetOutsMap
().
size
(),
0UL
);
...
...
@@ -196,10 +317,10 @@ TEST(test_layer, test_varbase_basic) {
std
::
shared_ptr
<
imperative
::
VarBase
>
vin_with_grad
(
new
imperative
::
VarBase
(
true
,
"vin"
));
ASSERT_ANY_THROW
(
vin
->
MutableGradVar
());
ASSERT_NO_THROW
(
ASSERT_TRUE
(
dynamic_cast
<
framework
::
Variable
*>
(
ASSERT_NO_THROW
(
ASSERT_TRUE
(
dynamic_cast
<
framework
::
Variable
*>
(
vin_with_grad
->
MutableGradVar
())
!=
0
));
ASSERT_TRUE
(
dynamic_cast
<
framework
::
Variable
*>
(
vin_with_grad
->
MutableGradVar
())
!=
0
);
ASSERT_TRUE
(
dynamic_cast
<
framework
::
Variable
*>
(
vin_with_grad
->
MutableGradVar
())
!=
0
);
vin_with_grad
->
SetOverridedStopGradient
(
false
);
ASSERT_FALSE
(
vin_with_grad
->
OverridedStopGradient
());
ASSERT_NO_FATAL_FAILURE
(
vin_with_grad
->
SetPersistable
(
true
));
...
...
@@ -228,9 +349,9 @@ TEST(test_layer, test_dygraph_execution_context) {
auto
op
=
framework
::
OpRegistry
::
CreateOp
(
"mul"
,
{},
{},
{},
false
);
paddle
::
platform
::
CPUPlace
cpu_place
;
paddle
::
platform
::
DeviceContextPool
&
pool
=
paddle
::
platform
::
DeviceContextPool
&
pool
=
paddle
::
platform
::
DeviceContextPool
::
Instance
();
auto
*
dev_ctx
=
pool
.
Get
(
cpu_place
);
auto
*
dev_ctx
=
pool
.
Get
(
cpu_place
);
paddle
::
framework
::
RuntimeContext
ctx
({},
{});
framework
::
Scope
scope
;
...
...
paddle/fluid/operators/activation_op.cc
浏览文件 @
91ae7848
...
...
@@ -129,9 +129,10 @@ class ActivationOp : public framework::OperatorWithKernel {
class
ActivationOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
/*->*/
"Out"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Out"
}};
return
m
;
}
};
...
...
paddle/fluid/operators/allclose_op.cc
浏览文件 @
91ae7848
...
...
@@ -103,8 +103,7 @@ class AllcloseOp : public framework::OperatorWithKernel {
class
AllcloseOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
ctx
->
SetDataType
(
out_var_name
,
framework
::
proto
::
VarType
::
BOOL
);
ctx
->
SetOutputDataType
(
"Out"
,
framework
::
proto
::
VarType
::
BOOL
);
}
};
...
...
paddle/fluid/operators/assign_op.cc
浏览文件 @
91ae7848
...
...
@@ -60,11 +60,7 @@ class AssignOp : public framework::OperatorWithKernel {
class
AssignInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
out_var_name
=
ctx
->
Output
(
"Out"
)[
0
];
auto
input_type
=
ctx
->
GetType
(
ctx
->
Input
(
"X"
)[
0
]);
auto
input_data_type
=
ctx
->
GetDataType
(
ctx
->
Input
(
"X"
)[
0
]);
ctx
->
SetType
(
out_var_name
,
input_type
);
ctx
->
SetDataType
(
out_var_name
,
input_data_type
);
ctx
->
SyncTypeAndDataType
(
"X"
,
"Out"
);
}
};
...
...
paddle/fluid/operators/batch_norm_op.h
浏览文件 @
91ae7848
...
...
@@ -171,9 +171,10 @@ class BatchNormGradMaker : public framework::SingleGradOpMaker<T> {
class
BatchNormOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
/*->*/
"Y"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Y"
}};
return
m
;
}
};
...
...
paddle/fluid/operators/beam_search_decode_op.cc
浏览文件 @
91ae7848
...
...
@@ -204,12 +204,10 @@ class BeamSearchDecodeInferShape : public framework::InferShapeBase {
class
BeamSearchDecodeInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
for
(
auto
&
o
:
ctx
->
Output
(
"SentenceIds"
))
{
ctx
->
SetType
(
o
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
for
(
auto
&
o
:
ctx
->
Output
(
"SentenceScores"
))
{
ctx
->
SetType
(
o
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
ctx
->
SetOutputType
(
"SentenceIds"
,
framework
::
proto
::
VarType
::
LOD_TENSOR
,
framework
::
ALL_ELEMENTS
);
ctx
->
SetOutputType
(
"SentenceScores"
,
framework
::
proto
::
VarType
::
LOD_TENSOR
,
framework
::
ALL_ELEMENTS
);
}
};
...
...
paddle/fluid/operators/beam_search_op.cc
浏览文件 @
91ae7848
...
...
@@ -122,12 +122,10 @@ class BeamSearchOp : public framework::OperatorWithKernel {
class
BeamSearchInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
for
(
auto
&
o
:
ctx
->
Output
(
"selected_ids"
))
{
ctx
->
SetType
(
o
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
for
(
auto
&
o
:
ctx
->
Output
(
"selected_scores"
))
{
ctx
->
SetType
(
o
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
ctx
->
SetOutputType
(
"selected_ids"
,
framework
::
proto
::
VarType
::
LOD_TENSOR
,
framework
::
ALL_ELEMENTS
);
ctx
->
SetOutputType
(
"selected_scores"
,
framework
::
proto
::
VarType
::
LOD_TENSOR
,
framework
::
ALL_ELEMENTS
);
}
};
...
...
paddle/fluid/operators/controlflow/get_places_op.cc
浏览文件 @
91ae7848
...
...
@@ -92,9 +92,8 @@ execution.
class
GetPlacesInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
for
(
auto
&
o_name
:
ctx
->
Output
(
"Out"
))
{
ctx
->
SetType
(
o_name
,
framework
::
proto
::
VarType
::
PLACE_LIST
);
}
ctx
->
SetOutputType
(
"Out"
,
framework
::
proto
::
VarType
::
PLACE_LIST
,
framework
::
ALL_ELEMENTS
);
}
};
...
...
paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc
浏览文件 @
91ae7848
...
...
@@ -111,15 +111,15 @@ class WriteToArrayInferShape : public framework::InferShapeBase {
}
};
class
WriteToArrayInferVarType
:
public
framework
::
VarTypeInference
{
class
WriteToArrayInferVarType
:
public
framework
::
StaticGraph
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
x_name
=
ctx
->
Input
(
"X"
)[
0
];
auto
out_name
=
ctx
->
Output
(
"Out"
)[
0
];
auto
x_name
=
Input
(
ctx
,
"X"
)[
0
];
auto
out_name
=
Output
(
ctx
,
"Out"
)[
0
];
VLOG
(
10
)
<<
"Set Variable "
<<
out_name
<<
" as LOD_TENSOR_ARRAY"
;
ctx
->
SetType
(
out_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
);
if
(
ctx
->
HasVar
(
x_name
))
{
ctx
->
SetDataType
(
out_name
,
ctx
->
GetDataType
(
x_name
));
SetType
(
ctx
,
out_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
);
if
(
HasVar
(
ctx
,
x_name
))
{
SetDataType
(
ctx
,
out_name
,
GetDataType
(
ctx
,
x_name
));
}
}
};
...
...
paddle/fluid/operators/controlflow/while_op.cc
浏览文件 @
91ae7848
...
...
@@ -398,18 +398,19 @@ class WhileGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
class
WhileGradOpVarTypeInference
:
public
framework
::
VarTypeInference
{
class
WhileGradOpVarTypeInference
:
public
framework
::
StaticGraphVarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
p_names
=
ctx
->
Input
(
kX
);
auto
pg_ig_names
=
ctx
->
Output
(
framework
::
GradVarName
(
kX
));
auto
p_names
=
Input
(
ctx
,
kX
);
auto
pg_ig_names
=
Output
(
ctx
,
framework
::
GradVarName
(
kX
));
for
(
size_t
i
=
0
;
i
<
p_names
.
size
();
++
i
)
{
if
(
ctx
->
HasVar
(
pg_ig_names
[
i
]))
{
if
(
HasVar
(
ctx
,
pg_ig_names
[
i
]))
{
VLOG
(
5
)
<<
"Setting "
<<
pg_ig_names
[
i
]
<<
" following "
<<
p_names
[
i
]
<<
" type: "
<<
ctx
->
GetType
(
p_names
[
i
]);
ctx
->
SetType
(
pg_ig_names
[
i
],
ctx
->
GetType
(
p_names
[
i
]));
ctx
->
SetDataType
(
pg_ig_names
[
i
],
ctx
->
GetDataType
(
p_names
[
i
]));
<<
" type: "
<<
GetType
(
ctx
,
p_names
[
i
]);
SetType
(
ctx
,
pg_ig_names
[
i
],
GetType
(
ctx
,
p_names
[
i
]));
SetDataType
(
ctx
,
pg_ig_names
[
i
],
GetDataType
(
ctx
,
p_names
[
i
]));
}
}
}
...
...
paddle/fluid/operators/conv_op.h
浏览文件 @
91ae7848
...
...
@@ -254,10 +254,11 @@ class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker {
class
ConvOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{
{
"Input"
,
/*->*/
"Output"
}};
return
m
;
}
};
...
...
paddle/fluid/operators/cross_entropy_op.cc
浏览文件 @
91ae7848
...
...
@@ -177,9 +177,10 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel {
class
CrossEntropyOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
/*->*/
"Y"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Y"
}};
return
m
;
}
};
...
...
paddle/fluid/operators/distributed_ops/merge_ids_op.cc
浏览文件 @
91ae7848
...
...
@@ -115,10 +115,8 @@ class MergeIdsOp : public framework::OperatorWithKernel {
class
MergeIdsOpInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
input_type
=
ctx
->
GetType
(
ctx
->
Input
(
"Ids"
)[
0
]);
for
(
auto
&
out_var
:
ctx
->
Output
(
"Out"
))
{
ctx
->
SetType
(
out_var
,
input_type
);
}
auto
input_type
=
ctx
->
GetInputType
(
"Ids"
);
ctx
->
SetOutputType
(
"Out"
,
input_type
,
framework
::
ALL_ELEMENTS
);
}
};
...
...
paddle/fluid/operators/distributed_ops/split_ids_op.cc
浏览文件 @
91ae7848
...
...
@@ -73,10 +73,8 @@ class SplitIdsOp : public framework::OperatorWithKernel {
class
SplitIdsOpInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
input_type
=
ctx
->
GetType
(
ctx
->
Input
(
"Ids"
)[
0
]);
for
(
auto
&
out_var
:
ctx
->
Output
(
"Out"
))
{
ctx
->
SetType
(
out_var
,
input_type
);
}
auto
input_type
=
ctx
->
GetInputType
(
"Ids"
);
ctx
->
SetOutputType
(
"Out"
,
input_type
,
framework
::
ALL_ELEMENTS
);
}
};
...
...
paddle/fluid/operators/elementwise/elementwise_op.h
浏览文件 @
91ae7848
...
...
@@ -119,9 +119,10 @@ class ElementwiseOp : public framework::OperatorWithKernel {
class
ElementwiseOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
/*->*/
"Out"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Out"
}};
return
m
;
}
};
...
...
paddle/fluid/operators/eye_op.cc
浏览文件 @
91ae7848
...
...
@@ -49,8 +49,7 @@ class EyeOpVarTypeInference : public framework::VarTypeInference {
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
data_type
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
boost
::
get
<
int
>
(
ctx
->
GetAttr
(
"dtype"
)));
auto
&
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
ctx
->
SetDataType
(
out_var_name
,
data_type
);
ctx
->
SetOutputDataType
(
"Out"
,
data_type
);
}
};
...
...
paddle/fluid/operators/fill_any_like_op.cc
浏览文件 @
91ae7848
...
...
@@ -72,14 +72,12 @@ The output will have the same shape and dtype as the input.
class
FillAnyLikeVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
auto
var_data_type
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
boost
::
get
<
int
>
(
ctx
->
GetAttr
(
"dtype"
)));
if
(
var_data_type
<
0
)
{
const
auto
&
input_var_name
=
ctx
->
Input
(
"X"
).
front
();
ctx
->
SetDataType
(
out_var_name
,
ctx
->
GetDataType
(
input_var_name
));
ctx
->
SetOutputDataType
(
"Out"
,
ctx
->
GetInputDataType
(
"X"
));
}
else
{
ctx
->
Set
DataType
(
out_var_name
,
var_data_type
);
ctx
->
Set
OutputDataType
(
"Out"
,
var_data_type
);
}
}
};
...
...
paddle/fluid/operators/fill_constant_op.cc
浏览文件 @
91ae7848
...
...
@@ -64,8 +64,7 @@ class FillConstantOpVarTypeInference : public framework::VarTypeInference {
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
data_type
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
boost
::
get
<
int
>
(
ctx
->
GetAttr
(
"dtype"
)));
auto
&
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
ctx
->
SetDataType
(
out_var_name
,
data_type
);
ctx
->
SetOutputDataType
(
"Out"
,
data_type
);
}
};
...
...
paddle/fluid/operators/fill_op.cc
浏览文件 @
91ae7848
...
...
@@ -63,8 +63,7 @@ class FillOpVarTypeInference : public framework::VarTypeInference {
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
data_type
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
boost
::
get
<
int
>
(
ctx
->
GetAttr
(
"dtype"
)));
auto
&
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
ctx
->
SetDataType
(
out_var_name
,
data_type
);
ctx
->
SetOutputDataType
(
"Out"
,
data_type
);
}
};
...
...
paddle/fluid/operators/flip_op.cc
浏览文件 @
91ae7848
...
...
@@ -114,9 +114,10 @@ class FlipOpMaker : public framework::OpProtoAndCheckerMaker {
class
FlipOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
/*->*/
"Out"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Out"
}};
return
m
;
}
};
...
...
paddle/fluid/operators/fused/fused_bn_activation_op.h
浏览文件 @
91ae7848
...
...
@@ -85,9 +85,10 @@ class FusedBatchNormActGradOpMaker : public framework::SingleGradOpMaker<T> {
class
FusedBatchNormActOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
/*->*/
"Y"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Y"
}};
return
m
;
}
};
...
...
paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc
浏览文件 @
91ae7848
...
...
@@ -146,19 +146,20 @@ class FusedEmbeddingSeqPoolOpGradVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
out_var_name
=
ctx
->
Output
(
framework
::
GradVarName
(
"W"
)).
front
(
);
auto
out_var_name
=
framework
::
GradVarName
(
"W"
);
auto
attr
=
ctx
->
GetAttr
(
"is_sparse"
);
bool
is_sparse
=
boost
::
get
<
bool
>
(
attr
);
if
(
is_sparse
)
{
VLOG
(
3
)
<<
"fused_embedding_seq_pool_grad op "
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to SelectedRows"
;
ctx
->
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
ctx
->
SetOutputType
(
out_var_name
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
else
{
VLOG
(
3
)
<<
"fused_embedding_seq_pool_grad op "
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to LoDTensor"
;
ctx
->
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
->
Set
Output
Type
(
out_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
ctx
->
Set
DataType
(
out_var_name
,
ctx
->
GetDataType
(
ctx
->
Input
(
"W"
)[
0
]
));
ctx
->
Set
OutputDataType
(
out_var_name
,
ctx
->
GetInputDataType
(
"W"
));
}
};
...
...
paddle/fluid/operators/get_tensor_from_selected_rows_op.cc
浏览文件 @
91ae7848
...
...
@@ -83,11 +83,8 @@ class GetTensorFromSelectedRowsOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
{
// NOLINT
auto
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
auto
in_var_name
=
ctx
->
Input
(
"X"
).
front
();
ctx
->
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
->
SetDataType
(
out_var_name
,
ctx
->
GetDataType
(
in_var_name
));
ctx
->
SetOutputType
(
"Out"
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
->
SetOutputDataType
(
"Out"
,
ctx
->
GetInputDataType
(
"X"
));
}
};
...
...
paddle/fluid/operators/group_norm_op.cc
浏览文件 @
91ae7848
...
...
@@ -216,9 +216,10 @@ DECLARE_INPLACE_OP_INFERER(GroupNormGradInplaceInToOut,
class
GroupNormOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
return
{{
"X"
,
/*->*/
"Y"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Y"
}};
return
m
;
}
};
...
...
paddle/fluid/operators/hierarchical_sigmoid_op.cc
浏览文件 @
91ae7848
...
...
@@ -229,31 +229,30 @@ class HierarchicalSigmoidGradOpGradVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
w_grad_var_name
=
ctx
->
Output
(
framework
::
GradVarName
(
"W"
)).
front
(
);
auto
has_bias_grad_var
=
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Bias"
)
);
std
::
string
bias_grad_var_name
;
bool
hasBias
=
false
;
if
(
has_bias_grad_var
)
{
hasBias
=
true
;
bias_grad_var_name
=
ctx
->
Output
(
framework
::
GradVarName
(
"Bias"
)).
front
(
);
auto
w_grad_var_name
=
framework
::
GradVarName
(
"W"
);
auto
bias_grad_var_name
=
framework
::
GradVarName
(
"Bias"
);
if
(
ctx
->
HasOutput
(
bias_grad_var_name
))
{
VLOG
(
3
)
<<
"hierarchical_sigmoid_grad op "
<<
framework
::
GradVarName
(
"Bias"
)
<<
" is set to LoDTensor"
;
ctx
->
SetOutputType
(
bias_grad_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
auto
attr
=
ctx
->
GetAttr
(
"is_sparse"
);
bool
is_sparse
=
boost
::
get
<
bool
>
(
attr
);
if
(
is_sparse
)
{
VLOG
(
3
)
<<
"hierarchical_sigmoid_grad op "
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to SelectedRows"
;
ctx
->
SetType
(
w_grad_var_name
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
ctx
->
SetOutputType
(
w_grad_var_name
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
else
{
VLOG
(
3
)
<<
"hierarchical_sigmoid_grad op "
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to LoDTensor"
;
ctx
->
SetType
(
w_grad_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
->
SetOutputType
(
w_grad_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
if
(
hasBias
)
{
VLOG
(
3
)
<<
"hierarchical_sigmoid_grad op "
<<
framework
::
GradVarName
(
"Bias"
)
<<
" is set to LoDTensor"
;
ctx
->
SetType
(
bias_grad_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
ctx
->
SetDataType
(
w_grad_var_name
,
ctx
->
GetDataType
(
ctx
->
Input
(
"W"
)[
0
]));
ctx
->
SetOutputDataType
(
w_grad_var_name
,
ctx
->
GetInputDataType
(
"W"
));
}
};
...
...
paddle/fluid/operators/instance_norm_op.h
浏览文件 @
91ae7848
...
...
@@ -123,9 +123,10 @@ class InstanceNormDoubleGradMaker : public framework::SingleGradOpMaker<T> {
class
InstanceNormOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
"Y"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
"Y"
}};
return
m
;
}
};
...
...
paddle/fluid/operators/lod_rank_table_op.cc
浏览文件 @
91ae7848
...
...
@@ -65,9 +65,8 @@ class LoDRankTableInferShape : public framework::InferShapeBase {
class
LoDRankTableInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
for
(
auto
&
o
:
ctx
->
Output
(
"Out"
))
{
ctx
->
SetType
(
o
,
framework
::
proto
::
VarType
::
LOD_RANK_TABLE
);
}
ctx
->
SetOutputType
(
"Out"
,
framework
::
proto
::
VarType
::
LOD_RANK_TABLE
,
framework
::
ALL_ELEMENTS
);
}
};
...
...
paddle/fluid/operators/lod_reset_op.cc
浏览文件 @
91ae7848
...
...
@@ -76,24 +76,25 @@ class LoDResetOp : public framework::OperatorWithKernel {
}
};
class
LoDResetOpVarTypeInference
:
public
framework
::
VarTypeInference
{
class
LoDResetOpVarTypeInference
:
public
framework
::
StaticGraphVarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
x_var_name
=
ctx
->
Input
(
"X"
).
front
();
auto
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
auto
x_var_name
=
Input
(
ctx
,
"X"
).
front
();
auto
out_var_name
=
Output
(
ctx
,
"Out"
).
front
();
bool
append
=
boost
::
get
<
bool
>
(
ctx
->
GetAttr
(
"append"
));
if
(
ctx
->
HasInput
(
"Y"
))
{
auto
y_var_name
=
ctx
->
Input
(
"Y"
).
front
();
auto
y_lod_level
=
std
::
max
(
ctx
->
GetLoDLevel
(
y_var_name
),
1
);
ctx
->
SetLoDLevel
(
out_var_name
,
y_lod_level
);
auto
y_var_name
=
Input
(
ctx
,
"Y"
).
front
();
auto
y_lod_level
=
std
::
max
(
GetLoDLevel
(
ctx
,
y_var_name
),
1
);
SetLoDLevel
(
ctx
,
out_var_name
,
y_lod_level
);
}
else
if
(
append
)
{
auto
x_lod_level
=
std
::
max
(
ctx
->
GetLoDLevel
(
x_var_name
),
1
);
ctx
->
SetLoDLevel
(
out_var_name
,
x_lod_level
);
auto
x_lod_level
=
std
::
max
(
GetLoDLevel
(
ctx
,
x_var_name
),
1
);
SetLoDLevel
(
ctx
,
out_var_name
,
x_lod_level
);
}
else
{
ctx
->
SetLoDLevel
(
out_var_name
,
1
);
SetLoDLevel
(
ctx
,
out_var_name
,
1
);
}
ctx
->
SetDataType
(
out_var_name
,
ctx
->
GetDataType
(
x_var_name
));
ctx
->
SetType
(
out_var_name
,
paddle
::
framework
::
proto
::
VarType
::
LOD_TENSOR
);
SetDataType
(
ctx
,
out_var_name
,
GetDataType
(
ctx
,
x_var_name
));
SetType
(
ctx
,
out_var_name
,
paddle
::
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
};
...
...
paddle/fluid/operators/lod_tensor_to_array_op.cc
浏览文件 @
91ae7848
...
...
@@ -221,9 +221,8 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase {
class
LoDTensorToArrayInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
for
(
auto
&
out_var
:
ctx
->
Output
(
"Out"
))
{
ctx
->
SetType
(
out_var
,
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
);
}
ctx
->
SetOutputType
(
"Out"
,
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
,
framework
::
ALL_ELEMENTS
);
}
};
...
...
paddle/fluid/operators/lookup_table_op.cc
浏览文件 @
91ae7848
...
...
@@ -173,19 +173,20 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
class
LookupTableOpGradVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
out_var_name
=
ctx
->
Output
(
framework
::
GradVarName
(
"W"
)).
front
(
);
auto
out_var_name
=
framework
::
GradVarName
(
"W"
);
auto
attr
=
ctx
->
GetAttr
(
"is_sparse"
);
bool
is_sparse
=
boost
::
get
<
bool
>
(
attr
);
if
(
is_sparse
)
{
VLOG
(
3
)
<<
"lookup_table_grad op "
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to SelectedRows"
;
ctx
->
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
ctx
->
SetOutputType
(
out_var_name
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
else
{
VLOG
(
3
)
<<
"lookup_table_grad op "
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to LoDTensor"
;
ctx
->
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
->
Set
Output
Type
(
out_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
ctx
->
Set
DataType
(
out_var_name
,
ctx
->
GetDataType
(
ctx
->
Input
(
"W"
)[
0
]
));
ctx
->
Set
OutputDataType
(
out_var_name
,
ctx
->
GetInputDataType
(
"W"
));
}
};
...
...
paddle/fluid/operators/lookup_table_v2_op.cc
浏览文件 @
91ae7848
...
...
@@ -160,19 +160,20 @@ class LookupTableV2OpGrad : public framework::OperatorWithKernel {
class
LookupTableV2OpGradVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
out_var_name
=
ctx
->
Output
(
framework
::
GradVarName
(
"W"
)).
front
(
);
auto
out_var_name
=
framework
::
GradVarName
(
"W"
);
auto
attr
=
ctx
->
GetAttr
(
"is_sparse"
);
bool
is_sparse
=
boost
::
get
<
bool
>
(
attr
);
if
(
is_sparse
)
{
VLOG
(
3
)
<<
"lookup_table_v2_grad op "
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to SelectedRows"
;
ctx
->
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
ctx
->
SetOutputType
(
out_var_name
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
else
{
VLOG
(
3
)
<<
"lookup_table_v2_grad op "
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to LoDTensor"
;
ctx
->
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
->
Set
Output
Type
(
out_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
ctx
->
Set
DataType
(
out_var_name
,
ctx
->
GetDataType
(
ctx
->
Input
(
"W"
)[
0
]
));
ctx
->
Set
OutputDataType
(
out_var_name
,
ctx
->
GetInputDataType
(
"W"
));
}
};
...
...
paddle/fluid/operators/mean_op.cc
浏览文件 @
91ae7848
...
...
@@ -45,9 +45,10 @@ Mean Operator calculates the mean of all elements in X.
class
MeanOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
/*->*/
"Out"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Out"
}};
return
m
;
}
};
...
...
paddle/fluid/operators/merge_selected_rows_op.cc
浏览文件 @
91ae7848
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/merge_selected_rows_op.h"
#include <unordered_map>
namespace
paddle
{
namespace
operators
{
...
...
@@ -79,9 +80,10 @@ Example:
class
MergeSelectedRowsOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
/*->*/
"Out"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Out"
}};
return
m
;
}
};
...
...
paddle/fluid/operators/mul_op.cc
浏览文件 @
91ae7848
...
...
@@ -200,9 +200,10 @@ or not. But the output only shares the LoD information with input $X$.
class
MulOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
/*->*/
"Out"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Out"
}};
return
m
;
}
};
...
...
paddle/fluid/operators/nccl/nccl_op.cc
浏览文件 @
91ae7848
...
...
@@ -61,8 +61,7 @@ class NCCLInitOp : public framework::OperatorBase {
class
NCCLInitOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
out_var_name
=
ctx
->
Output
(
"Communicator"
).
front
();
ctx
->
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
RAW
);
ctx
->
SetOutputType
(
"Communicator"
,
framework
::
proto
::
VarType
::
RAW
);
}
};
...
...
paddle/fluid/operators/nce_op.cc
浏览文件 @
91ae7848
...
...
@@ -280,20 +280,20 @@ class NCEOpGrad : public framework::OperatorWithKernel {
class
NCEOpGradVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
weight_grad
=
ctx
->
Output
(
framework
::
GradVarName
(
"Weight"
)).
front
(
);
auto
weight_grad
=
framework
::
GradVarName
(
"Weight"
);
auto
attr
=
ctx
->
GetAttr
(
"is_sparse"
);
bool
is_sparse
=
boost
::
get
<
bool
>
(
attr
);
if
(
is_sparse
)
{
VLOG
(
3
)
<<
"nce_op_grad op "
<<
weight_grad
<<
" and "
<<
" is set to SelectedRows"
;
ctx
->
SetType
(
weight_grad
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
ctx
->
Set
Output
Type
(
weight_grad
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
else
{
VLOG
(
3
)
<<
"nce_op_grad op "
<<
weight_grad
<<
" and "
<<
" is set to LoDTensor"
;
ctx
->
SetType
(
weight_grad
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
->
Set
Output
Type
(
weight_grad
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
ctx
->
Set
DataType
(
weight_grad
,
ctx
->
GetDataType
(
ctx
->
Input
(
"Input"
)[
0
]
));
ctx
->
Set
OutputDataType
(
weight_grad
,
ctx
->
GetInputDataType
(
"Input"
));
}
};
...
...
paddle/fluid/operators/optimizers/momentum_op.cc
浏览文件 @
91ae7848
...
...
@@ -22,18 +22,15 @@ using Tensor = framework::Tensor;
class
MomentumOpInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
&
input_var
=
ctx
->
Input
(
"Param"
)[
0
];
for
(
auto
&
out_var
:
ctx
->
Output
(
"ParamOut"
))
{
if
(
ctx
->
GetType
(
input_var
)
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
ctx
->
SetType
(
out_var
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
else
if
(
ctx
->
GetType
(
input_var
)
==
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
ctx
->
SetType
(
out_var
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
else
{
PADDLE_THROW
(
"Only support LodTensor and SelectedRows, Unexpected Input Type."
);
}
}
auto
in_var_type
=
ctx
->
GetInputType
(
"Param"
);
PADDLE_ENFORCE_EQ
(
in_var_type
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
||
in_var_type
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
true
,
platform
::
errors
::
InvalidArgument
(
"Only support LodTensor and SelectedRows, Unexpected Input Type."
));
ctx
->
SetOutputType
(
"ParamOut"
,
in_var_type
,
framework
::
ALL_ELEMENTS
);
}
};
...
...
paddle/fluid/operators/optimizers/sgd_op.cc
浏览文件 @
91ae7848
...
...
@@ -75,19 +75,15 @@ class SGDOp : public framework::OperatorWithKernel {
class
SGDOpInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
&
input_var_n
=
ctx
->
Input
(
"Param"
)[
0
];
auto
in_var_type
=
ctx
->
GetType
(
input_var_n
);
PADDLE_ENFORCE
(
in_var_type
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
||
in_var_type
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
"The input Var's type should be LoDtensor or SelectedRows,"
" but the received var(%s)'s type is %s"
,
input_var_n
,
in_var_type
);
for
(
auto
&
out_var_n
:
ctx
->
Output
(
"ParamOut"
))
{
if
(
ctx
->
GetType
(
out_var_n
)
!=
in_var_type
)
{
ctx
->
SetType
(
out_var_n
,
in_var_type
);
}
}
auto
in_var_type
=
ctx
->
GetInputType
(
"Param"
);
PADDLE_ENFORCE_EQ
(
in_var_type
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
||
in_var_type
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
true
,
platform
::
errors
::
InvalidArgument
(
"The input Var's type should be LoDtensor or "
"SelectedRows, but the received type is %s"
,
in_var_type
));
ctx
->
SetOutputType
(
"ParamOut"
,
in_var_type
,
framework
::
ALL_ELEMENTS
);
}
};
...
...
paddle/fluid/operators/pool_op.cc
浏览文件 @
91ae7848
...
...
@@ -422,9 +422,10 @@ Example:
class
PoolOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
/*->*/
"Out"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Out"
}};
return
m
;
}
};
...
...
paddle/fluid/operators/print_op.cc
浏览文件 @
91ae7848
...
...
@@ -260,9 +260,7 @@ class PrintOpInferShape : public framework::InferShapeBase {
class
PrintOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
input_type
=
ctx
->
GetType
(
ctx
->
Input
(
"In"
)[
0
]);
auto
out_name
=
ctx
->
Output
(
"Out"
).
front
();
ctx
->
SetType
(
out_name
,
input_type
);
ctx
->
SetOutputType
(
"Out"
,
ctx
->
GetInputType
(
"In"
));
}
};
...
...
paddle/fluid/operators/py_func_op.cc
浏览文件 @
91ae7848
...
...
@@ -116,12 +116,11 @@ static void CallPythonFunc(py::object *callable,
}
}
class
PyFuncOpVarTypeInference
:
public
framework
::
VarTypeInference
{
class
PyFuncOpVarTypeInference
:
public
framework
::
StaticGraph
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
bool
has_out
=
(
ctx
->
HasOutput
(
"Out"
)
&&
!
ctx
->
Output
(
"Out"
).
empty
());
bool
has_in
=
(
ctx
->
HasInput
(
"X"
)
&&
!
ctx
->
Input
(
"X"
).
empty
());
bool
has_out
=
ctx
->
HasOutput
(
"Out"
);
bool
has_in
=
ctx
->
HasInput
(
"X"
);
/**
* X or Out can be empty, so that py_func can be more flexible
...
...
@@ -147,7 +146,7 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference {
* the corresponding forward variable
*/
const
std
::
string
kGradVarSuffix
=
framework
::
kGradVarSuffix
;
auto
&
out_var_names
=
ctx
->
Output
(
"Out"
);
auto
&
out_var_names
=
Output
(
ctx
,
"Out"
);
for
(
auto
&
out_var_name
:
out_var_names
)
{
if
(
out_var_name
==
framework
::
kEmptyVarName
||
out_var_name
.
size
()
<
kGradVarSuffix
.
size
())
{
...
...
@@ -157,19 +156,17 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference {
size_t
len
=
out_var_name
.
size
()
-
kGradVarSuffix
.
size
();
if
(
out_var_name
.
substr
(
len
)
==
kGradVarSuffix
)
{
auto
fwd_var_name
=
out_var_name
.
substr
(
0
,
len
);
PADDLE_ENFORCE_EQ
(
ctx
->
HasVar
(
out_var_name
),
true
,
platform
::
errors
::
InvalidArgument
(
"Backward variable %s not found"
,
out_var_name
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasVar
(
fwd_var_name
),
true
,
platform
::
errors
::
InvalidArgument
(
"Backward variable %s not found"
,
fwd_var_name
));
OP_INOUT_CHECK
(
HasVar
(
ctx
,
out_var_name
),
"Var"
,
out_var_name
,
"py_func"
);
OP_INOUT_CHECK
(
HasVar
(
ctx
,
fwd_var_name
),
"Var"
,
fwd_var_name
,
"py_func"
);
VLOG
(
10
)
<<
"Infer var_desc of Output("
<<
out_var_name
<<
") as Input("
<<
fwd_var_name
<<
")"
;
ctx
->
SetShape
(
out_var_name
,
ctx
->
GetShape
(
fwd_var_name
));
ctx
->
SetDataType
(
out_var_name
,
ctx
->
GetDataType
(
fwd_var_name
));
ctx
->
SetLoDLevel
(
out_var_name
,
ctx
->
GetLoDLevel
(
fwd_var_name
));
ctx
->
SetType
(
out_var_name
,
ctx
->
GetType
(
fwd_var_name
));
SetShape
(
ctx
,
out_var_name
,
GetShape
(
ctx
,
fwd_var_name
));
SetDataType
(
ctx
,
out_var_name
,
GetDataType
(
ctx
,
fwd_var_name
));
SetLoDLevel
(
ctx
,
out_var_name
,
GetLoDLevel
(
ctx
,
fwd_var_name
));
SetType
(
ctx
,
out_var_name
,
GetType
(
ctx
,
fwd_var_name
));
}
}
}
...
...
paddle/fluid/operators/randperm_op.cc
浏览文件 @
91ae7848
...
...
@@ -75,8 +75,7 @@ class RandpermOpVarTypeInference : public framework::VarTypeInference {
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
var_data_type
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
boost
::
get
<
int
>
(
ctx
->
GetAttr
(
"dtype"
)));
auto
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
ctx
->
SetDataType
(
out_var_name
,
var_data_type
);
ctx
->
SetOutputDataType
(
"Out"
,
var_data_type
);
}
};
...
...
paddle/fluid/operators/reader/read_op.cc
浏览文件 @
91ae7848
...
...
@@ -70,18 +70,18 @@ class ReadInferShape : public framework::InferShapeBase {
}
};
class
ReadInferVarType
:
public
framework
::
VarTypeInference
{
class
ReadInferVarType
:
public
framework
::
StaticGraph
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
bool
infer_out
=
boost
::
get
<
bool
>
(
ctx
->
GetAttr
(
"infer_out"
));
if
(
infer_out
)
{
std
::
string
reader_name
=
ctx
->
Input
(
"Reader"
)[
0
];
std
::
vector
<
std
::
string
>
out_names
=
ctx
->
Output
(
"Out"
);
auto
dtypes
=
ctx
->
GetDataTypes
(
reader_name
);
std
::
string
reader_name
=
Input
(
ctx
,
"Reader"
)[
0
];
auto
&
out_names
=
Output
(
ctx
,
"Out"
);
auto
dtypes
=
GetDataTypes
(
ctx
,
reader_name
);
PADDLE_ENFORCE_EQ
(
dtypes
.
size
(),
out_names
.
size
());
for
(
size_t
i
=
0
;
i
<
dtypes
.
size
();
++
i
)
{
ctx
->
SetType
(
out_names
[
i
],
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
->
SetDataType
(
out_names
[
i
],
dtypes
[
i
]);
SetType
(
ctx
,
out_names
[
i
],
framework
::
proto
::
VarType
::
LOD_TENSOR
);
SetDataType
(
ctx
,
out_names
[
i
],
dtypes
[
i
]);
}
}
}
...
...
paddle/fluid/operators/reader/reader_op_registry.cc
浏览文件 @
91ae7848
...
...
@@ -100,8 +100,7 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
void
FileReaderInferVarType
::
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
{
std
::
string
reader_name
=
ctx
->
Output
(
"Out"
)[
0
];
ctx
->
SetType
(
reader_name
,
framework
::
proto
::
VarType
::
READER
);
ctx
->
SetOutputType
(
"Out"
,
framework
::
proto
::
VarType
::
READER
);
}
void
DecoratedReaderInferShape
::
operator
()(
...
...
@@ -125,10 +124,8 @@ void DecoratedReaderInferShape::operator()(
void
DecoratedReaderInferVarType
::
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
{
const
std
::
string
&
in_reader_name
=
ctx
->
Input
(
"UnderlyingReader"
)[
0
];
const
std
::
string
&
out_reader_name
=
ctx
->
Output
(
"Out"
)[
0
];
ctx
->
SetType
(
out_reader_name
,
framework
::
proto
::
VarType
::
READER
);
ctx
->
SetDataTypes
(
out_reader_name
,
ctx
->
GetDataTypes
(
in_reader_name
));
ctx
->
SetOutputType
(
"Out"
,
framework
::
proto
::
VarType
::
READER
);
ctx
->
SetOutputDataTypes
(
"Out"
,
ctx
->
GetInputDataTypes
(
"UnderlyingReader"
));
}
void
DecoratedReaderMakerBase
::
Make
()
{
...
...
paddle/fluid/operators/reduce_ops/reduce_sum_op.cc
浏览文件 @
91ae7848
...
...
@@ -58,8 +58,7 @@ class ReduceSumVarTypeInference : public paddle::framework::VarTypeInference {
auto
data_type
=
static_cast
<
paddle
::
framework
::
proto
::
VarType
::
Type
>
(
boost
::
get
<
int
>
(
ctx
->
GetAttr
(
"out_dtype"
)));
if
(
data_type
>=
0
)
{
auto
&
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
ctx
->
SetDataType
(
out_var_name
,
data_type
);
ctx
->
SetOutputDataType
(
"Out"
,
data_type
);
}
}
};
...
...
paddle/fluid/operators/save_combine_op.cc
浏览文件 @
91ae7848
...
...
@@ -85,9 +85,8 @@ to a file on disk.
class
SaveCombineOpInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
for
(
auto
&
o
:
ctx
->
Output
(
"Y"
))
{
ctx
->
SetType
(
o
,
framework
::
proto
::
VarType
::
RAW
);
}
ctx
->
SetOutputType
(
"Y"
,
framework
::
proto
::
VarType
::
RAW
,
framework
::
ALL_ELEMENTS
);
}
};
...
...
paddle/fluid/operators/save_op.cc
浏览文件 @
91ae7848
...
...
@@ -73,7 +73,7 @@ class SaveOpVarTypeInference : public framework::VarTypeInference {
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
var_type
=
framework
::
proto
::
VarType
::
RAW
;
ctx
->
SetType
(
LOOKUP_TABLE_PATH
,
var_type
);
ctx
->
InsertVar
(
LOOKUP_TABLE_PATH
,
var_type
);
}
};
...
...
paddle/fluid/operators/scale_op.cc
浏览文件 @
91ae7848
...
...
@@ -82,13 +82,7 @@ $$Out = scale*(X + bias)$$
class
ScaleOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
&
in_var_name
=
ctx
->
Input
(
"X"
).
front
();
auto
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
if
(
in_var_name
!=
out_var_name
)
{
ctx
->
SetType
(
out_var_name
,
ctx
->
GetType
(
in_var_name
));
ctx
->
SetDataType
(
out_var_name
,
ctx
->
GetDataType
(
in_var_name
));
}
ctx
->
SyncTypeAndDataType
(
"X"
,
"Out"
);
}
};
...
...
paddle/fluid/operators/selu_op.cc
浏览文件 @
91ae7848
...
...
@@ -45,9 +45,10 @@ class SeluOp : public framework::OperatorWithKernel {
class
SeluOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
/*->*/
"Out"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Out"
}};
return
m
;
}
};
...
...
paddle/fluid/operators/softmax_op.cc
浏览文件 @
91ae7848
...
...
@@ -145,9 +145,10 @@ For each row $i$ and each column $j$ in the matrix, we have:
class
SoftmaxOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
/*->*/
"Out"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Out"
}};
return
m
;
}
};
...
...
paddle/fluid/operators/split_selected_rows_op.cc
浏览文件 @
91ae7848
...
...
@@ -64,9 +64,8 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel {
class
SplitSelectedRowsOpInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
for
(
auto
&
out_var
:
ctx
->
Output
(
"Out"
))
{
ctx
->
SetType
(
out_var
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
ctx
->
SetOutputType
(
"Out"
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
,
framework
::
ALL_ELEMENTS
);
}
};
...
...
paddle/fluid/operators/sum_op.cc
浏览文件 @
91ae7848
...
...
@@ -210,43 +210,36 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
class
SumOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
&
inputs
=
ctx
->
Input
(
"X"
);
auto
var_type
=
framework
::
proto
::
VarType
::
SELECTED_ROWS
;
for
(
auto
&
name
:
ctx
->
Input
(
"X"
))
{
VLOG
(
10
)
<<
name
<<
" "
<<
ctx
->
GetType
(
name
);
}
bool
any_input_is_lod_tensor
=
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
ctx
](
const
std
::
string
&
name
)
{
return
ctx
->
GetType
(
name
)
==
framework
::
proto
::
VarType
::
LOD_TENSOR
;
});
auto
is_tensor_array
=
[
ctx
](
const
std
::
string
&
name
)
{
return
ctx
->
GetType
(
name
)
==
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
;
};
bool
any_input_is_tensor_array
=
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
is_tensor_array
);
bool
all_inputs_are_tensor_array
=
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
is_tensor_array
);
if
(
!
ctx
->
IsDygraph
())
{
auto
var_type
=
framework
::
proto
::
VarType
::
SELECTED_ROWS
;
if
(
VLOG_IS_ON
(
10
))
{
for
(
size_t
ind
=
0
;
ind
<
ctx
->
InputSize
(
"X"
);
++
ind
)
{
VLOG
(
10
)
<<
ctx
->
InputVarName
(
"X"
,
ind
)
<<
" "
<<
ctx
->
GetInputType
(
"X"
,
ind
);
}
}
if
(
any_input_is_tensor_array
)
{
if
(
!
all_inputs_are_tensor_array
)
{
std
::
ostringstream
os
;
for
(
auto
&
each
:
inputs
)
{
os
<<
" "
<<
each
<<
" type is "
<<
ctx
->
GetType
(
each
)
<<
"
\n
"
;
if
(
ctx
->
InputTypeAnyOf
(
"X"
,
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
))
{
if
(
!
ctx
->
InputTypeAllOf
(
"X"
,
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
))
{
std
::
ostringstream
os
;
for
(
size_t
ind
=
0
;
ind
<
ctx
->
InputSize
(
"X"
);
++
ind
)
{
os
<<
" "
<<
ctx
->
InputVarName
(
"X"
,
ind
)
<<
" type is "
<<
ctx
->
GetInputType
(
"X"
,
ind
)
<<
"
\n
"
;
}
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Not all inputs are tensor array:
\n
%s"
,
os
.
str
()));
}
PADDLE_ENFORCE_EQ
(
all_inputs_are_tensor_array
,
true
,
"Not all inputs are tensor array:
\n
%s"
,
os
.
str
());
var_type
=
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
;
}
else
if
(
ctx
->
InputTypeAnyOf
(
"X"
,
framework
::
proto
::
VarType
::
LOD_TENSOR
))
{
var_type
=
framework
::
proto
::
VarType
::
LOD_TENSOR
;
}
var_type
=
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
;
}
else
if
(
any_input_is_lod_tensor
)
{
var_type
=
framework
::
proto
::
VarType
::
LOD_TENSOR
;
}
auto
out_var_name
=
ctx
->
Output
(
"Out"
).
front
(
);
ctx
->
SetType
(
out_var_name
,
var_type
);
ctx
->
SetDataType
(
out_var_name
,
ctx
->
GetDataType
(
inputs
.
front
()));
ctx
->
SetOutputType
(
"Out"
,
var_type
);
ctx
->
SetOutputDataType
(
"Out"
,
ctx
->
GetInputDataType
(
"X"
)
);
}
}
};
...
...
paddle/fluid/operators/tensor_array_to_tensor_op.cc
浏览文件 @
91ae7848
...
...
@@ -213,9 +213,9 @@ class LoDTensorArray2TensorGradInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
for
(
auto
&
out_var
:
ctx
->
Output
(
framework
::
GradVarName
(
"X"
)))
{
ctx
->
SetType
(
out_var
,
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
);
}
ctx
->
SetOutputType
(
framework
::
GradVarName
(
"X"
),
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
,
framework
::
ALL_ELEMENTS
);
}
};
...
...
paddle/fluid/operators/uniform_random_op.cc
浏览文件 @
91ae7848
...
...
@@ -232,15 +232,13 @@ uniform distribution. The random result is in set [min, max).
class
UniformRandomOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
auto
var_data_type
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
boost
::
get
<
int
>
(
ctx
->
GetAttr
(
"dtype"
)));
if
(
ctx
->
GetType
(
out_var_name
)
!=
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
ctx
->
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
if
(
ctx
->
GetOutputType
(
"Out"
)
!=
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
ctx
->
SetOutputType
(
"Out"
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
ctx
->
Set
DataType
(
out_var_name
,
var_data_type
);
ctx
->
Set
OutputDataType
(
"Out"
,
var_data_type
);
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录