Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
91ae7848
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
91ae7848
编写于
4月 27, 2020
作者:
L
liuwei1031
提交者:
GitHub
4月 27, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
improve efficiency of runtime InferVarType (#22778) (#24181)
* cherry pick #22778
上级
57b062e1
变更
58
显示空白变更内容
内联
并排
Showing
58 changed file
with
967 addition
and
423 deletion
+967
-423
paddle/fluid/framework/ir/graph_test.cc
paddle/fluid/framework/ir/graph_test.cc
+2
-8
paddle/fluid/framework/var_type_inference.h
paddle/fluid/framework/var_type_inference.h
+286
-40
paddle/fluid/framework/var_type_inference_test.cc
paddle/fluid/framework/var_type_inference_test.cc
+186
-15
paddle/fluid/imperative/infer_var_type_context.h
paddle/fluid/imperative/infer_var_type_context.h
+131
-80
paddle/fluid/imperative/tests/test_layer.cc
paddle/fluid/imperative/tests/test_layer.cc
+149
-28
paddle/fluid/operators/activation_op.cc
paddle/fluid/operators/activation_op.cc
+3
-2
paddle/fluid/operators/allclose_op.cc
paddle/fluid/operators/allclose_op.cc
+1
-2
paddle/fluid/operators/assign_op.cc
paddle/fluid/operators/assign_op.cc
+1
-5
paddle/fluid/operators/batch_norm_op.h
paddle/fluid/operators/batch_norm_op.h
+3
-2
paddle/fluid/operators/beam_search_decode_op.cc
paddle/fluid/operators/beam_search_decode_op.cc
+4
-6
paddle/fluid/operators/beam_search_op.cc
paddle/fluid/operators/beam_search_op.cc
+4
-6
paddle/fluid/operators/controlflow/get_places_op.cc
paddle/fluid/operators/controlflow/get_places_op.cc
+2
-3
paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc
...fluid/operators/controlflow/tensor_array_read_write_op.cc
+6
-6
paddle/fluid/operators/controlflow/while_op.cc
paddle/fluid/operators/controlflow/while_op.cc
+8
-7
paddle/fluid/operators/conv_op.h
paddle/fluid/operators/conv_op.h
+3
-2
paddle/fluid/operators/cross_entropy_op.cc
paddle/fluid/operators/cross_entropy_op.cc
+3
-2
paddle/fluid/operators/distributed_ops/merge_ids_op.cc
paddle/fluid/operators/distributed_ops/merge_ids_op.cc
+2
-4
paddle/fluid/operators/distributed_ops/split_ids_op.cc
paddle/fluid/operators/distributed_ops/split_ids_op.cc
+2
-4
paddle/fluid/operators/elementwise/elementwise_op.h
paddle/fluid/operators/elementwise/elementwise_op.h
+3
-2
paddle/fluid/operators/eye_op.cc
paddle/fluid/operators/eye_op.cc
+1
-2
paddle/fluid/operators/fill_any_like_op.cc
paddle/fluid/operators/fill_any_like_op.cc
+2
-4
paddle/fluid/operators/fill_constant_op.cc
paddle/fluid/operators/fill_constant_op.cc
+1
-2
paddle/fluid/operators/fill_op.cc
paddle/fluid/operators/fill_op.cc
+1
-2
paddle/fluid/operators/flip_op.cc
paddle/fluid/operators/flip_op.cc
+3
-2
paddle/fluid/operators/fused/fused_bn_activation_op.h
paddle/fluid/operators/fused/fused_bn_activation_op.h
+3
-2
paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc
paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc
+5
-4
paddle/fluid/operators/get_tensor_from_selected_rows_op.cc
paddle/fluid/operators/get_tensor_from_selected_rows_op.cc
+2
-5
paddle/fluid/operators/group_norm_op.cc
paddle/fluid/operators/group_norm_op.cc
+3
-2
paddle/fluid/operators/hierarchical_sigmoid_op.cc
paddle/fluid/operators/hierarchical_sigmoid_op.cc
+14
-15
paddle/fluid/operators/instance_norm_op.h
paddle/fluid/operators/instance_norm_op.h
+3
-2
paddle/fluid/operators/lod_rank_table_op.cc
paddle/fluid/operators/lod_rank_table_op.cc
+2
-3
paddle/fluid/operators/lod_reset_op.cc
paddle/fluid/operators/lod_reset_op.cc
+12
-11
paddle/fluid/operators/lod_tensor_to_array_op.cc
paddle/fluid/operators/lod_tensor_to_array_op.cc
+2
-3
paddle/fluid/operators/lookup_table_op.cc
paddle/fluid/operators/lookup_table_op.cc
+5
-4
paddle/fluid/operators/lookup_table_v2_op.cc
paddle/fluid/operators/lookup_table_v2_op.cc
+5
-4
paddle/fluid/operators/mean_op.cc
paddle/fluid/operators/mean_op.cc
+3
-2
paddle/fluid/operators/merge_selected_rows_op.cc
paddle/fluid/operators/merge_selected_rows_op.cc
+4
-2
paddle/fluid/operators/mul_op.cc
paddle/fluid/operators/mul_op.cc
+3
-2
paddle/fluid/operators/nccl/nccl_op.cc
paddle/fluid/operators/nccl/nccl_op.cc
+1
-2
paddle/fluid/operators/nce_op.cc
paddle/fluid/operators/nce_op.cc
+4
-4
paddle/fluid/operators/optimizers/momentum_op.cc
paddle/fluid/operators/optimizers/momentum_op.cc
+9
-12
paddle/fluid/operators/optimizers/sgd_op.cc
paddle/fluid/operators/optimizers/sgd_op.cc
+9
-13
paddle/fluid/operators/pool_op.cc
paddle/fluid/operators/pool_op.cc
+3
-2
paddle/fluid/operators/print_op.cc
paddle/fluid/operators/print_op.cc
+1
-3
paddle/fluid/operators/py_func_op.cc
paddle/fluid/operators/py_func_op.cc
+12
-15
paddle/fluid/operators/randperm_op.cc
paddle/fluid/operators/randperm_op.cc
+1
-2
paddle/fluid/operators/reader/read_op.cc
paddle/fluid/operators/reader/read_op.cc
+6
-6
paddle/fluid/operators/reader/reader_op_registry.cc
paddle/fluid/operators/reader/reader_op_registry.cc
+3
-6
paddle/fluid/operators/reduce_ops/reduce_sum_op.cc
paddle/fluid/operators/reduce_ops/reduce_sum_op.cc
+1
-2
paddle/fluid/operators/save_combine_op.cc
paddle/fluid/operators/save_combine_op.cc
+2
-3
paddle/fluid/operators/save_op.cc
paddle/fluid/operators/save_op.cc
+1
-1
paddle/fluid/operators/scale_op.cc
paddle/fluid/operators/scale_op.cc
+1
-7
paddle/fluid/operators/selu_op.cc
paddle/fluid/operators/selu_op.cc
+3
-2
paddle/fluid/operators/softmax_op.cc
paddle/fluid/operators/softmax_op.cc
+3
-2
paddle/fluid/operators/split_selected_rows_op.cc
paddle/fluid/operators/split_selected_rows_op.cc
+2
-3
paddle/fluid/operators/sum_op.cc
paddle/fluid/operators/sum_op.cc
+26
-33
paddle/fluid/operators/tensor_array_to_tensor_op.cc
paddle/fluid/operators/tensor_array_to_tensor_op.cc
+3
-3
paddle/fluid/operators/uniform_random_op.cc
paddle/fluid/operators/uniform_random_op.cc
+3
-5
未找到文件。
paddle/fluid/framework/ir/graph_test.cc
浏览文件 @
91ae7848
...
@@ -45,19 +45,13 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
...
@@ -45,19 +45,13 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
class
SumOpVarTypeInference
:
public
VarTypeInference
{
class
SumOpVarTypeInference
:
public
VarTypeInference
{
public:
public:
void
operator
()(
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
InferVarTypeContext
*
ctx
)
const
override
{
auto
&
inputs
=
ctx
->
Input
(
"X"
);
auto
default_var_type
=
proto
::
VarType
::
SELECTED_ROWS
;
auto
default_var_type
=
proto
::
VarType
::
SELECTED_ROWS
;
bool
any_input_is_lod_tensor
=
std
::
any_of
(
if
(
ctx
->
InputTypeAnyOf
(
"X"
,
proto
::
VarType
::
LOD_TENSOR
))
{
inputs
.
begin
(),
inputs
.
end
(),
[
&
ctx
](
const
std
::
string
&
name
)
{
return
ctx
->
GetType
(
name
)
==
proto
::
VarType
::
LOD_TENSOR
;
});
if
(
any_input_is_lod_tensor
)
{
default_var_type
=
proto
::
VarType
::
LOD_TENSOR
;
default_var_type
=
proto
::
VarType
::
LOD_TENSOR
;
}
}
auto
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
ctx
->
SetOutputType
(
"Out"
,
default_var_type
);
ctx
->
SetType
(
out_var_name
,
default_var_type
);
}
}
};
};
...
...
paddle/fluid/framework/var_type_inference.h
浏览文件 @
91ae7848
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include <algorithm>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
...
@@ -25,8 +26,14 @@ namespace framework {
...
@@ -25,8 +26,14 @@ namespace framework {
class
OpDesc
;
class
OpDesc
;
class
BlockDesc
;
class
BlockDesc
;
class
StaticGraphVarTypeInference
;
// default infer var type context
// default infer var type context
static
const
int
ALL_ELEMENTS
=
-
1
;
class
InferVarTypeContext
{
class
InferVarTypeContext
{
friend
class
StaticGraphVarTypeInference
;
public:
public:
InferVarTypeContext
(
const
OpDesc
*
op
,
BlockDesc
*
block
)
InferVarTypeContext
(
const
OpDesc
*
op
,
BlockDesc
*
block
)
:
op_
(
op
),
block_
(
block
)
{}
:
op_
(
op
),
block_
(
block
)
{}
...
@@ -34,91 +41,267 @@ class InferVarTypeContext {
...
@@ -34,91 +41,267 @@ class InferVarTypeContext {
virtual
~
InferVarTypeContext
()
{}
virtual
~
InferVarTypeContext
()
{}
virtual
Attribute
GetAttr
(
const
std
::
string
&
name
)
const
{
virtual
Attribute
GetAttr
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
);
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
return
op_
->
GetAttr
(
name
);
return
op_
->
GetAttr
(
name
);
}
}
virtual
bool
HasVar
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
block_
);
return
block_
->
FindVarRecursive
(
name
)
!=
nullptr
;
}
virtual
bool
HasInput
(
const
std
::
string
&
name
)
const
{
virtual
bool
HasInput
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
);
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
auto
&
inputs
=
op_
->
Inputs
();
auto
&
inputs
=
op_
->
Inputs
();
auto
input
=
inputs
.
find
(
name
);
auto
input
=
inputs
.
find
(
name
);
return
input
!=
inputs
.
end
()
&&
!
input
->
second
.
empty
();
return
input
!=
inputs
.
end
()
&&
!
input
->
second
.
empty
();
}
}
virtual
bool
HasOutput
(
const
std
::
string
&
name
)
const
{
virtual
bool
HasOutput
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
);
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
auto
&
outputs
=
op_
->
Outputs
();
auto
&
outputs
=
op_
->
Outputs
();
auto
output
=
outputs
.
find
(
name
);
auto
output
=
outputs
.
find
(
name
);
return
output
!=
outputs
.
end
()
&&
!
output
->
second
.
empty
();
return
output
!=
outputs
.
end
()
&&
!
output
->
second
.
empty
();
}
}
virtual
const
std
::
vector
<
std
::
string
>&
Input
(
const
std
::
string
&
name
)
const
{
virtual
size_t
InputSize
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
);
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
return
op_
->
Inputs
().
at
(
name
).
size
();
}
virtual
const
std
::
string
&
InputVarName
(
const
std
::
string
&
name
,
const
int
index
=
0
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
return
op_
->
Inputs
().
at
(
name
)[
index
];
}
virtual
bool
InputTypeAnyOf
(
const
std
::
string
&
name
,
proto
::
VarType
::
Type
type
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
auto
&
inputs
=
op_
->
Input
(
name
);
return
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
this
,
&
type
](
const
std
::
string
&
name
)
{
return
this
->
GetVarType
(
name
)
==
type
;
});
}
virtual
bool
InputTypeAllOf
(
const
std
::
string
&
name
,
proto
::
VarType
::
Type
type
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
auto
&
inputs
=
op_
->
Input
(
name
);
return
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
this
,
&
type
](
const
std
::
string
&
name
)
{
return
this
->
GetVarType
(
name
)
==
type
;
});
}
virtual
void
SyncTypeAndDataType
(
const
std
::
string
&
input_name
,
const
std
::
string
&
output_name
,
int
index
=
0
)
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
auto
&
x_name
=
op_
->
Input
(
input_name
).
at
(
index
);
auto
&
out_name
=
op_
->
Output
(
output_name
).
at
(
index
);
if
(
x_name
!=
out_name
)
{
this
->
SetVarType
(
out_name
,
this
->
GetVarType
(
x_name
));
this
->
SetVarDataType
(
out_name
,
this
->
GetVarDataType
(
x_name
));
}
}
virtual
void
SetOutputType
(
const
std
::
string
&
name
,
proto
::
VarType
::
Type
type
,
int
index
=
0
)
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
if
(
ALL_ELEMENTS
==
index
)
{
for
(
const
auto
&
var_name
:
op_
->
Output
(
name
))
{
this
->
SetVarType
(
var_name
,
type
);
}
}
else
{
auto
&
var_name
=
op_
->
Output
(
name
).
at
(
index
);
this
->
SetVarType
(
var_name
,
type
);
}
}
virtual
proto
::
VarType
::
Type
GetInputType
(
const
std
::
string
&
name
,
const
int
&
index
=
0
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
return
this
->
GetVarType
(
op_
->
Input
(
name
).
at
(
index
));
}
virtual
proto
::
VarType
::
Type
GetOutputType
(
const
std
::
string
&
name
,
const
int
&
index
=
0
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
return
this
->
GetVarType
(
op_
->
Output
(
name
).
at
(
index
));
}
virtual
proto
::
VarType
::
Type
GetInputDataType
(
const
std
::
string
&
name
,
const
int
&
index
=
0
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
return
this
->
GetVarDataType
(
op_
->
Input
(
name
).
at
(
index
));
}
virtual
void
SetOutputDataType
(
const
std
::
string
&
name
,
proto
::
VarType
::
Type
type
,
int
index
=
0
)
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
if
(
ALL_ELEMENTS
==
index
)
{
for
(
const
auto
&
var_name
:
op_
->
Output
(
name
))
{
this
->
SetVarDataType
(
var_name
,
type
);
}
}
else
{
auto
&
var_name
=
op_
->
Output
(
name
).
at
(
index
);
this
->
SetVarDataType
(
var_name
,
type
);
}
}
virtual
std
::
vector
<
proto
::
VarType
::
Type
>
GetInputDataTypes
(
const
std
::
string
&
name
,
const
int
&
index
=
0
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
return
this
->
GetVarDataTypes
(
op_
->
Input
(
name
).
at
(
index
));
}
virtual
void
SetOutputDataTypes
(
const
std
::
string
&
name
,
const
std
::
vector
<
proto
::
VarType
::
Type
>&
multiple_data_type
,
const
int
&
index
=
0
)
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
auto
&
var_name
=
op_
->
Output
(
name
).
at
(
index
);
this
->
SetVarDataTypes
(
var_name
,
multiple_data_type
);
}
virtual
std
::
vector
<
int64_t
>
GetInputShape
(
const
std
::
string
&
name
,
const
int
&
index
=
0
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
auto
&
var_name
=
op_
->
Input
(
name
).
at
(
index
);
return
this
->
GetVarShape
(
var_name
);
}
virtual
void
SetOutputShape
(
const
std
::
string
&
name
,
const
std
::
vector
<
int64_t
>&
dims
,
const
int
&
index
=
0
)
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
auto
&
var_name
=
op_
->
Output
(
name
).
at
(
index
);
this
->
SetVarShape
(
var_name
,
dims
);
}
virtual
int32_t
GetInputLoDLevel
(
const
std
::
string
&
name
,
const
int
&
index
=
0
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
auto
&
var_name
=
op_
->
Input
(
name
).
at
(
index
);
return
this
->
GetVarLoDLevel
(
var_name
);
}
virtual
void
SetOutputLoDLevel
(
const
std
::
string
&
name
,
int32_t
lod_level
,
const
int
&
index
=
0
)
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
auto
&
var_name
=
op_
->
Output
(
name
).
at
(
index
);
this
->
SetVarLoDLevel
(
var_name
,
lod_level
);
}
// add a speical API for save_op
// avoid use this API for common logic
virtual
void
InsertVar
(
const
std
::
string
&
var_name
,
proto
::
VarType
::
Type
var_type
)
{
if
(
!
IsDygraph
())
this
->
SetVarType
(
var_name
,
var_type
);
}
virtual
bool
IsDygraph
()
const
{
return
false
;
}
protected:
virtual
bool
HasVar
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
block_
,
platform
::
errors
::
PreconditionNotMet
(
"block_ should not be null"
));
return
block_
->
FindVarRecursive
(
name
)
!=
nullptr
;
}
virtual
const
std
::
vector
<
std
::
string
>&
InputVars
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
return
op_
->
Input
(
name
);
return
op_
->
Input
(
name
);
}
}
virtual
const
std
::
vector
<
std
::
string
>&
Output
(
virtual
const
std
::
vector
<
std
::
string
>&
Output
Vars
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
op_
);
PADDLE_ENFORCE_NOT_NULL
(
op_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
return
op_
->
Output
(
name
);
return
op_
->
Output
(
name
);
}
}
virtual
proto
::
VarType
::
Type
GetType
(
const
std
::
string
&
name
)
const
{
virtual
proto
::
VarType
::
Type
GetVarType
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
block_
);
PADDLE_ENFORCE_NOT_NULL
(
block_
,
platform
::
errors
::
PreconditionNotMet
(
"block_ should not be null"
));
return
block_
->
FindRecursiveOrCreateVar
(
name
).
GetType
();
return
block_
->
FindRecursiveOrCreateVar
(
name
).
GetType
();
}
}
virtual
void
SetType
(
const
std
::
string
&
name
,
proto
::
VarType
::
Type
type
)
{
virtual
void
SetVarType
(
const
std
::
string
&
name
,
proto
::
VarType
::
Type
type
)
{
PADDLE_ENFORCE_NOT_NULL
(
block_
);
PADDLE_ENFORCE_NOT_NULL
(
block_
,
platform
::
errors
::
PreconditionNotMet
(
"op_ should not be null"
));
block_
->
FindRecursiveOrCreateVar
(
name
).
SetType
(
type
);
block_
->
FindRecursiveOrCreateVar
(
name
).
SetType
(
type
);
}
}
virtual
proto
::
VarType
::
Type
GetDataType
(
const
std
::
string
&
name
)
const
{
virtual
proto
::
VarType
::
Type
GetVarDataType
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
block_
);
PADDLE_ENFORCE_NOT_NULL
(
block_
,
platform
::
errors
::
PreconditionNotMet
(
"block_ should not be null"
));
return
block_
->
FindRecursiveOrCreateVar
(
name
).
GetDataType
();
return
block_
->
FindRecursiveOrCreateVar
(
name
).
GetDataType
();
}
}
virtual
void
SetDataType
(
const
std
::
string
&
name
,
proto
::
VarType
::
Type
type
)
{
virtual
void
SetVarDataType
(
const
std
::
string
&
name
,
PADDLE_ENFORCE_NOT_NULL
(
block_
);
proto
::
VarType
::
Type
type
)
{
PADDLE_ENFORCE_NOT_NULL
(
block_
,
platform
::
errors
::
PreconditionNotMet
(
"block_ should not be null"
));
block_
->
FindRecursiveOrCreateVar
(
name
).
SetDataType
(
type
);
block_
->
FindRecursiveOrCreateVar
(
name
).
SetDataType
(
type
);
}
}
virtual
std
::
vector
<
proto
::
VarType
::
Type
>
GetDataTypes
(
virtual
std
::
vector
<
proto
::
VarType
::
Type
>
Get
Var
DataTypes
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
block_
);
PADDLE_ENFORCE_NOT_NULL
(
block_
,
platform
::
errors
::
PreconditionNotMet
(
"block_ should not be null"
));
return
block_
->
FindRecursiveOrCreateVar
(
name
).
GetDataTypes
();
return
block_
->
FindRecursiveOrCreateVar
(
name
).
GetDataTypes
();
}
}
virtual
void
SetDataTypes
(
virtual
void
Set
Var
DataTypes
(
const
std
::
string
&
name
,
const
std
::
string
&
name
,
const
std
::
vector
<
proto
::
VarType
::
Type
>&
multiple_data_type
)
{
const
std
::
vector
<
proto
::
VarType
::
Type
>&
multiple_data_type
)
{
PADDLE_ENFORCE_NOT_NULL
(
block_
);
PADDLE_ENFORCE_NOT_NULL
(
block_
,
platform
::
errors
::
PreconditionNotMet
(
"block_ should not be null"
));
block_
->
FindRecursiveOrCreateVar
(
name
).
SetDataTypes
(
multiple_data_type
);
block_
->
FindRecursiveOrCreateVar
(
name
).
SetDataTypes
(
multiple_data_type
);
}
}
virtual
std
::
vector
<
int64_t
>
GetShape
(
const
std
::
string
&
name
)
const
{
virtual
std
::
vector
<
int64_t
>
GetVarShape
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
block_
);
PADDLE_ENFORCE_NOT_NULL
(
block_
,
platform
::
errors
::
PreconditionNotMet
(
"block_ should not be null"
));
return
block_
->
FindRecursiveOrCreateVar
(
name
).
GetShape
();
return
block_
->
FindRecursiveOrCreateVar
(
name
).
GetShape
();
}
}
virtual
void
SetShape
(
const
std
::
string
&
name
,
virtual
void
Set
Var
Shape
(
const
std
::
string
&
name
,
const
std
::
vector
<
int64_t
>&
dims
)
{
const
std
::
vector
<
int64_t
>&
dims
)
{
PADDLE_ENFORCE_NOT_NULL
(
block_
);
PADDLE_ENFORCE_NOT_NULL
(
block_
,
platform
::
errors
::
PreconditionNotMet
(
"block_ should not be null"
));
block_
->
FindRecursiveOrCreateVar
(
name
).
SetShape
(
dims
);
block_
->
FindRecursiveOrCreateVar
(
name
).
SetShape
(
dims
);
}
}
virtual
int32_t
GetLoDLevel
(
const
std
::
string
&
name
)
const
{
virtual
int32_t
GetVarLoDLevel
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
block_
);
PADDLE_ENFORCE_NOT_NULL
(
block_
,
platform
::
errors
::
PreconditionNotMet
(
"block_ should not be null"
));
return
block_
->
FindRecursiveOrCreateVar
(
name
).
GetLoDLevel
();
return
block_
->
FindRecursiveOrCreateVar
(
name
).
GetLoDLevel
();
}
}
virtual
void
SetLoDLevel
(
const
std
::
string
&
name
,
int32_t
lod_level
)
{
virtual
void
SetVarLoDLevel
(
const
std
::
string
&
name
,
int32_t
lod_level
)
{
PADDLE_ENFORCE_NOT_NULL
(
block_
);
PADDLE_ENFORCE_NOT_NULL
(
block_
,
platform
::
errors
::
PreconditionNotMet
(
"block_ should not be null"
));
block_
->
FindRecursiveOrCreateVar
(
name
).
SetLoDLevel
(
lod_level
);
block_
->
FindRecursiveOrCreateVar
(
name
).
SetLoDLevel
(
lod_level
);
}
}
...
@@ -133,22 +316,85 @@ class VarTypeInference {
...
@@ -133,22 +316,85 @@ class VarTypeInference {
virtual
void
operator
()(
InferVarTypeContext
*
context
)
const
=
0
;
// NOLINT
virtual
void
operator
()(
InferVarTypeContext
*
context
)
const
=
0
;
// NOLINT
};
};
class
StaticGraphVarTypeInference
:
public
VarTypeInference
{
protected:
bool
HasVar
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
ctx
->
HasVar
(
name
);
}
const
std
::
vector
<
std
::
string
>&
Input
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
ctx
->
InputVars
(
name
);
}
const
std
::
vector
<
std
::
string
>&
Output
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
ctx
->
OutputVars
(
name
);
}
proto
::
VarType
::
Type
GetType
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
ctx
->
GetVarType
(
name
);
}
void
SetType
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
,
proto
::
VarType
::
Type
type
)
const
{
ctx
->
SetVarType
(
name
,
type
);
}
proto
::
VarType
::
Type
GetDataType
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
ctx
->
GetVarDataType
(
name
);
}
void
SetDataType
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
,
proto
::
VarType
::
Type
type
)
const
{
ctx
->
SetVarDataType
(
name
,
type
);
}
std
::
vector
<
proto
::
VarType
::
Type
>
GetDataTypes
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
ctx
->
GetVarDataTypes
(
name
);
}
void
SetDataTypes
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
,
const
std
::
vector
<
proto
::
VarType
::
Type
>&
multiple_data_type
)
{
return
ctx
->
SetVarDataTypes
(
name
,
multiple_data_type
);
}
std
::
vector
<
int64_t
>
GetShape
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
ctx
->
GetVarShape
(
name
);
}
void
SetShape
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
,
const
std
::
vector
<
int64_t
>&
dims
)
const
{
ctx
->
SetVarShape
(
name
,
dims
);
}
int32_t
GetLoDLevel
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
ctx
->
GetVarLoDLevel
(
name
);
}
void
SetLoDLevel
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
,
int32_t
lod_level
)
const
{
ctx
->
SetVarLoDLevel
(
name
,
lod_level
);
}
};
class
PassInDtypeAndVarTypeToOutput
:
public
framework
::
VarTypeInference
{
class
PassInDtypeAndVarTypeToOutput
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
final
{
// NOLINT
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
final
{
// NOLINT
auto
in_out_var_names
=
this
->
GetInputOutputWithSameType
();
auto
&
in_out_var_names
=
this
->
GetInputOutputWithSameType
();
for
(
auto
&
i_o_n
:
in_out_var_names
)
{
for
(
auto
&
i_o_n
:
in_out_var_names
)
{
auto
&
x_name
=
ctx
->
Input
(
i_o_n
.
first
).
at
(
0
);
ctx
->
SyncTypeAndDataType
(
i_o_n
.
first
,
i_o_n
.
second
);
auto
&
out_name
=
ctx
->
Output
(
i_o_n
.
second
).
at
(
0
);
ctx
->
SetType
(
out_name
,
ctx
->
GetType
(
x_name
));
ctx
->
SetDataType
(
out_name
,
ctx
->
GetDataType
(
x_name
));
}
}
}
}
protected:
protected:
virtual
std
::
unordered_map
<
std
::
string
,
std
::
string
>
virtual
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
=
0
;
GetInputOutputWithSameType
()
const
=
0
;
};
};
...
...
paddle/fluid/framework/var_type_inference_test.cc
浏览文件 @
91ae7848
...
@@ -24,13 +24,13 @@ namespace framework {
...
@@ -24,13 +24,13 @@ namespace framework {
class
NOP
:
public
OperatorBase
{
class
NOP
:
public
OperatorBase
{
public:
public:
NOP
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
NOP
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
private:
private:
void
RunImpl
(
const
Scope
&
scope
,
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
const
platform
::
Place
&
place
)
const
override
{}
};
};
class
SumOpMaker
:
public
OpProtoAndCheckerMaker
{
class
SumOpMaker
:
public
OpProtoAndCheckerMaker
{
...
@@ -44,20 +44,14 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
...
@@ -44,20 +44,14 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
class
SumOpVarTypeInference
:
public
VarTypeInference
{
class
SumOpVarTypeInference
:
public
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
&
inputs
=
ctx
->
Input
(
"X"
);
auto
default_var_type
=
proto
::
VarType
::
SELECTED_ROWS
;
auto
default_var_type
=
proto
::
VarType
::
SELECTED_ROWS
;
bool
any_input_is_lod_tensor
=
std
::
any_of
(
if
(
ctx
->
InputTypeAnyOf
(
"X"
,
proto
::
VarType
::
LOD_TENSOR
))
{
inputs
.
begin
(),
inputs
.
end
(),
[
&
ctx
](
const
std
::
string
&
name
)
{
return
ctx
->
GetType
(
name
)
==
proto
::
VarType
::
LOD_TENSOR
;
});
if
(
any_input_is_lod_tensor
)
{
default_var_type
=
proto
::
VarType
::
LOD_TENSOR
;
default_var_type
=
proto
::
VarType
::
LOD_TENSOR
;
}
}
auto
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
ctx
->
SetOutputType
(
"Out"
,
default_var_type
);
ctx
->
SetType
(
out_var_name
,
default_var_type
);
}
}
};
};
}
// namespace framework
}
// namespace framework
...
@@ -71,9 +65,79 @@ REGISTER_OPERATOR(sum_without_infer_var_type, paddle::framework::NOP,
...
@@ -71,9 +65,79 @@ REGISTER_OPERATOR(sum_without_infer_var_type, paddle::framework::NOP,
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
TestStaticGraphVarTypeInference
:
public
StaticGraphVarTypeInference
{
public:
void
operator
()(
InferVarTypeContext
*
context
)
const
override
{}
bool
HasVar
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
StaticGraphVarTypeInference
::
HasVar
(
ctx
,
name
);
}
const
std
::
vector
<
std
::
string
>&
Input
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
StaticGraphVarTypeInference
::
Input
(
ctx
,
name
);
}
const
std
::
vector
<
std
::
string
>&
Output
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
StaticGraphVarTypeInference
::
Output
(
ctx
,
name
);
}
proto
::
VarType
::
Type
GetType
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
StaticGraphVarTypeInference
::
GetType
(
ctx
,
name
);
}
void
SetType
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
,
proto
::
VarType
::
Type
type
)
const
{
StaticGraphVarTypeInference
::
SetType
(
ctx
,
name
,
type
);
}
proto
::
VarType
::
Type
GetDataType
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
StaticGraphVarTypeInference
::
GetDataType
(
ctx
,
name
);
}
void
SetDataType
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
,
proto
::
VarType
::
Type
type
)
const
{
StaticGraphVarTypeInference
::
SetDataType
(
ctx
,
name
,
type
);
}
std
::
vector
<
proto
::
VarType
::
Type
>
GetDataTypes
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
StaticGraphVarTypeInference
::
GetDataTypes
(
ctx
,
name
);
}
void
SetDataTypes
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
,
const
std
::
vector
<
proto
::
VarType
::
Type
>&
multiple_data_type
)
{
return
StaticGraphVarTypeInference
::
SetDataTypes
(
ctx
,
name
,
multiple_data_type
);
}
std
::
vector
<
int64_t
>
GetShape
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
StaticGraphVarTypeInference
::
GetShape
(
ctx
,
name
);
}
void
SetShape
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
,
const
std
::
vector
<
int64_t
>&
dims
)
const
{
StaticGraphVarTypeInference
::
SetShape
(
ctx
,
name
,
dims
);
}
int32_t
GetLoDLevel
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
)
const
{
return
StaticGraphVarTypeInference
::
GetLoDLevel
(
ctx
,
name
);
}
void
SetLoDLevel
(
InferVarTypeContext
*
ctx
,
const
std
::
string
&
name
,
int32_t
lod_level
)
const
{
StaticGraphVarTypeInference
::
SetLoDLevel
(
ctx
,
name
,
lod_level
);
}
};
TEST
(
InferVarType
,
sum_op
)
{
TEST
(
InferVarType
,
sum_op
)
{
ProgramDesc
prog
;
ProgramDesc
prog
;
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"sum"
);
op
->
SetType
(
"sum"
);
op
->
SetInput
(
"X"
,
{
"test_a"
,
"test_b"
,
"test_c"
});
op
->
SetInput
(
"X"
,
{
"test_a"
,
"test_b"
,
"test_c"
});
op
->
SetOutput
(
"Out"
,
{
"test_out"
});
op
->
SetOutput
(
"Out"
,
{
"test_out"
});
...
@@ -96,7 +160,7 @@ TEST(InferVarType, sum_op) {
...
@@ -96,7 +160,7 @@ TEST(InferVarType, sum_op) {
TEST
(
InferVarType
,
sum_op_without_infer_var_type
)
{
TEST
(
InferVarType
,
sum_op_without_infer_var_type
)
{
ProgramDesc
prog
;
ProgramDesc
prog
;
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"sum_without_infer_var_type"
);
op
->
SetType
(
"sum_without_infer_var_type"
);
op
->
SetInput
(
"X"
,
{
"test2_a"
,
"test2_b"
,
"test2_c"
});
op
->
SetInput
(
"X"
,
{
"test2_a"
,
"test2_b"
,
"test2_c"
});
op
->
SetOutput
(
"Out"
,
{
"test2_out"
});
op
->
SetOutput
(
"Out"
,
{
"test2_out"
});
...
@@ -112,5 +176,112 @@ TEST(InferVarType, sum_op_without_infer_var_type) {
...
@@ -112,5 +176,112 @@ TEST(InferVarType, sum_op_without_infer_var_type) {
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_out"
)
->
GetType
());
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_out"
)
->
GetType
());
}
}
TEST
(
InferVarType
,
multiple_api
)
{
ProgramDesc
prog
;
auto
*
block
=
prog
.
MutableBlock
(
0
);
auto
*
op
=
block
->
AppendOp
();
op
->
SetType
(
"sum_without_infer_var_type"
);
op
->
SetInput
(
"X"
,
{
"test2_a"
,
"test2_b"
});
op
->
SetOutput
(
"Out"
,
{
"test2_a_out"
,
"test2_b_out"
});
block
->
Var
(
"test2_a"
)
->
SetType
(
proto
::
VarType
::
SELECTED_ROWS
);
block
->
Var
(
"test2_b"
)
->
SetType
(
proto
::
VarType
::
SELECTED_ROWS
);
block
->
Var
(
"test2_a_out"
);
block
->
Var
(
"test2_b_out"
);
InferVarTypeContext
ctx
(
op
,
block
);
ASSERT_TRUE
(
ctx
.
HasInput
(
"X"
));
ASSERT_TRUE
(
ctx
.
HasOutput
(
"Out"
));
ASSERT_EQ
(
2u
,
ctx
.
InputSize
(
"X"
));
ASSERT_EQ
(
"test2_a"
,
ctx
.
InputVarName
(
"X"
,
0
));
ASSERT_EQ
(
proto
::
VarType
::
SELECTED_ROWS
,
ctx
.
GetInputType
(
"X"
));
ASSERT_TRUE
(
ctx
.
InputTypeAllOf
(
"X"
,
proto
::
VarType
::
SELECTED_ROWS
));
ASSERT_FALSE
(
ctx
.
InputTypeAnyOf
(
"X"
,
proto
::
VarType
::
LOD_TENSOR
));
ctx
.
SyncTypeAndDataType
(
"X"
,
"Out"
);
ASSERT_EQ
(
proto
::
VarType
::
SELECTED_ROWS
,
ctx
.
GetOutputType
(
"Out"
));
ASSERT_EQ
(
proto
::
VarType
::
LOD_TENSOR
,
ctx
.
GetOutputType
(
"Out"
,
1
));
ctx
.
SetOutputType
(
"Out"
,
proto
::
VarType
::
SELECTED_ROWS
,
ALL_ELEMENTS
);
ctx
.
SetOutputType
(
"Out"
,
proto
::
VarType
::
LOD_TENSOR
,
1
);
ASSERT_EQ
(
proto
::
VarType
::
SELECTED_ROWS
,
ctx
.
GetOutputType
(
"Out"
));
ASSERT_EQ
(
proto
::
VarType
::
LOD_TENSOR
,
ctx
.
GetOutputType
(
"Out"
,
1
));
ASSERT_EQ
(
0
,
ctx
.
GetInputDataType
(
"X"
));
ctx
.
SetOutputDataType
(
"Out"
,
proto
::
VarType
::
FP32
,
ALL_ELEMENTS
);
ctx
.
SetOutputDataType
(
"Out"
,
proto
::
VarType
::
INT8
,
1
);
ASSERT_EQ
(
proto
::
VarType
::
FP32
,
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_a_out"
)
->
GetDataType
());
ASSERT_EQ
(
proto
::
VarType
::
INT8
,
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_b_out"
)
->
GetDataType
());
ASSERT_FALSE
(
ctx
.
IsDygraph
());
// test StaticGraphVarTypeInference
TestStaticGraphVarTypeInference
infer
;
ASSERT_TRUE
(
infer
.
HasVar
(
&
ctx
,
"test2_a"
));
ASSERT_EQ
(
infer
.
Input
(
&
ctx
,
"X"
).
size
(),
infer
.
Output
(
&
ctx
,
"Out"
).
size
());
ASSERT_EQ
(
proto
::
VarType
::
FP32
,
infer
.
GetDataType
(
&
ctx
,
"test2_a_out"
));
infer
.
SetDataType
(
&
ctx
,
"test2_a_out"
,
proto
::
VarType
::
FP64
);
ASSERT_EQ
(
proto
::
VarType
::
FP64
,
infer
.
GetDataType
(
&
ctx
,
"test2_a_out"
));
ASSERT_EQ
(
proto
::
VarType
::
SELECTED_ROWS
,
infer
.
GetType
(
&
ctx
,
"test2_a_out"
));
infer
.
SetType
(
&
ctx
,
"test2_a_out"
,
proto
::
VarType
::
LOD_TENSOR
);
ASSERT_EQ
(
proto
::
VarType
::
LOD_TENSOR
,
infer
.
GetType
(
&
ctx
,
"test2_a_out"
));
ASSERT_ANY_THROW
(
infer
.
GetDataTypes
(
&
ctx
,
"test2_a_out"
));
ASSERT_ANY_THROW
(
infer
.
SetDataTypes
(
&
ctx
,
"test2_a_out"
,
{}));
ASSERT_EQ
(
0u
,
infer
.
GetShape
(
&
ctx
,
"test2_a_out"
).
size
());
infer
.
SetShape
(
&
ctx
,
"test2_a_out"
,
{
1
,
3
,
3
,
});
ASSERT_EQ
(
3u
,
infer
.
GetShape
(
&
ctx
,
"test2_a_out"
).
size
());
ASSERT_EQ
(
0
,
infer
.
GetLoDLevel
(
&
ctx
,
"test2_a_out"
));
infer
.
SetLoDLevel
(
&
ctx
,
"test2_a_out"
,
2
);
ASSERT_EQ
(
2
,
infer
.
GetLoDLevel
(
&
ctx
,
"test2_a_out"
));
}
TEST
(
InferVarType
,
test_enforce_check
)
{
InferVarTypeContext
ctx
(
nullptr
,
nullptr
);
ASSERT_ANY_THROW
(
ctx
.
HasInput
(
"X"
));
ASSERT_ANY_THROW
(
ctx
.
HasOutput
(
"Out"
));
ASSERT_ANY_THROW
(
ctx
.
InputSize
(
"X"
));
ASSERT_ANY_THROW
(
ctx
.
InputVarName
(
"X"
));
ASSERT_ANY_THROW
(
ctx
.
InputTypeAnyOf
(
"X"
,
proto
::
VarType
::
LOD_TENSOR
));
ASSERT_ANY_THROW
(
ctx
.
InputTypeAllOf
(
"X"
,
proto
::
VarType
::
LOD_TENSOR
));
ASSERT_ANY_THROW
(
ctx
.
SyncTypeAndDataType
(
"X"
,
"Out"
));
ASSERT_ANY_THROW
(
ctx
.
SetOutputType
(
"Out"
,
proto
::
VarType
::
LOD_TENSOR
));
ASSERT_ANY_THROW
(
ctx
.
GetInputType
(
"X"
));
ASSERT_ANY_THROW
(
ctx
.
GetOutputType
(
"Out"
));
ASSERT_ANY_THROW
(
ctx
.
GetInputDataType
(
"X"
));
ASSERT_ANY_THROW
(
ctx
.
SetOutputDataType
(
"Out"
,
proto
::
VarType
::
LOD_TENSOR
));
ASSERT_ANY_THROW
(
ctx
.
GetInputDataTypes
(
"X"
));
ASSERT_ANY_THROW
(
ctx
.
SetOutputDataTypes
(
"Out"
,
{}));
ASSERT_ANY_THROW
(
ctx
.
GetInputShape
(
"X"
));
ASSERT_ANY_THROW
(
ctx
.
SetOutputShape
(
"Out"
,
{}));
ASSERT_ANY_THROW
(
ctx
.
GetInputLoDLevel
(
"X"
));
ASSERT_ANY_THROW
(
ctx
.
SetOutputLoDLevel
(
"Out"
,
1
));
ASSERT_ANY_THROW
(
ctx
.
InsertVar
(
"var"
,
proto
::
VarType
::
LOD_TENSOR
));
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/imperative/infer_var_type_context.h
浏览文件 @
91ae7848
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#pragma once
#pragma once
#include <memory>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
...
@@ -35,30 +36,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
...
@@ -35,30 +36,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
:
InferVarTypeContext
(
nullptr
,
nullptr
),
:
InferVarTypeContext
(
nullptr
,
nullptr
),
inputs_
(
inputs
),
inputs_
(
inputs
),
outputs_
(
outputs
),
outputs_
(
outputs
),
attrs_
(
attrs_map
),
attrs_
(
attrs_map
)
{}
input_names_
(),
output_names_
(),
var_set_
()
{
input_names_
.
reserve
(
inputs_
.
size
());
for
(
auto
&
it
:
inputs_
)
{
for
(
auto
&
var
:
it
.
second
)
{
if
(
var
)
{
input_names_
[
it
.
first
].
emplace_back
(
var
->
Name
());
var_set_
[
var
->
Name
()]
=
var
.
get
();
}
}
}
output_names_
.
reserve
(
outputs_
.
size
());
for
(
auto
&
it
:
outputs_
)
{
for
(
auto
&
var
:
it
.
second
)
{
if
(
var
)
{
output_names_
[
it
.
first
].
emplace_back
(
var
->
Name
());
var_set_
[
var
->
Name
()]
=
var
.
get
();
}
}
}
}
virtual
~
RuntimeInferVarTypeContext
()
{}
virtual
~
RuntimeInferVarTypeContext
()
{}
...
@@ -70,10 +48,6 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
...
@@ -70,10 +48,6 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
return
iter
->
second
;
return
iter
->
second
;
}
}
bool
HasVar
(
const
std
::
string
&
name
)
const
override
{
return
var_set_
.
count
(
name
)
>
0
;
}
bool
HasInput
(
const
std
::
string
&
name
)
const
override
{
bool
HasInput
(
const
std
::
string
&
name
)
const
override
{
auto
it
=
inputs_
.
find
(
name
);
auto
it
=
inputs_
.
find
(
name
);
return
(
it
!=
inputs_
.
end
()
&&
it
->
second
.
size
()
>
0
);
return
(
it
!=
inputs_
.
end
()
&&
it
->
second
.
size
()
>
0
);
...
@@ -84,93 +58,173 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
...
@@ -84,93 +58,173 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
return
(
it
!=
outputs_
.
end
()
&&
it
->
second
.
size
()
>
0
);
return
(
it
!=
outputs_
.
end
()
&&
it
->
second
.
size
()
>
0
);
}
}
const
std
::
vector
<
std
::
string
>&
Input
(
size_t
InputSize
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
name
)
const
override
{
return
inputs_
.
at
(
name
).
size
();
auto
iter
=
input_names_
.
find
(
name
);
PADDLE_ENFORCE_EQ
(
iter
!=
input_names_
.
end
(),
true
,
platform
::
errors
::
NotFound
(
"Cannot find input var %s"
,
name
));
return
iter
->
second
;
}
}
const
std
::
vector
<
std
::
string
>&
Output
(
const
std
::
string
&
InputVarName
(
const
std
::
string
&
name
,
const
std
::
string
&
name
)
const
override
{
const
int
index
=
0
)
const
{
auto
iter
=
output_names_
.
find
(
name
);
return
inputs_
.
at
(
name
)[
index
]
->
Name
();
}
PADDLE_ENFORCE_EQ
(
bool
InputTypeAnyOf
(
const
std
::
string
&
name
,
iter
!=
output_names_
.
end
(),
true
,
framework
::
proto
::
VarType
::
Type
type
)
const
override
{
platform
::
errors
::
NotFound
(
"Cannot find output var %s"
,
name
));
auto
&
inputs
=
inputs_
.
at
(
name
);
return
iter
->
second
;
return
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
type
](
const
std
::
shared_ptr
<
VarType
>&
var
)
{
return
var
->
Type
()
==
type
;
});
}
}
framework
::
proto
::
VarType
::
Type
GetType
(
bool
InputTypeAllOf
(
const
std
::
string
&
name
,
const
std
::
string
&
name
)
const
override
{
framework
::
proto
::
VarType
::
Type
type
)
const
override
{
auto
iter
=
var_set_
.
find
(
name
);
auto
&
inputs
=
inputs_
.
at
(
name
);
return
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
type
](
const
std
::
shared_ptr
<
VarType
>&
var
)
{
return
var
->
Type
()
==
type
;
});
}
PADDLE_ENFORCE_EQ
(
void
SyncTypeAndDataType
(
const
std
::
string
&
input_name
,
iter
!=
var_set_
.
end
(),
true
,
const
std
::
string
&
output_name
,
platform
::
errors
::
NotFound
(
"Cannot find var %s in GetType"
,
name
));
int
index
=
0
)
override
{
return
iter
->
second
->
Type
();
auto
in_var
=
inputs_
.
at
(
input_name
)[
index
];
auto
out_var
=
outputs_
.
at
(
output_name
)[
index
];
if
(
in_var
!=
out_var
)
{
this
->
SetVarBaseType
(
out_var
,
in_var
->
Type
());
this
->
SetVarBaseDataType
(
out_var
,
in_var
->
DataType
());
}
}
}
void
SetType
(
const
std
::
string
&
name
,
void
SetOutputType
(
const
std
::
string
&
name
,
framework
::
proto
::
VarType
::
Type
type
)
override
{
framework
::
proto
::
VarType
::
Type
type
,
if
(
name
==
"kLookupTablePath"
)
{
int
index
=
0
)
override
{
VLOG
(
2
)
<<
"SUPER UGLY FIX, remove this when move imperative mode in C++"
;
if
(
index
==
framework
::
ALL_ELEMENTS
)
{
for
(
auto
&
item
:
outputs_
.
at
(
name
))
{
this
->
SetVarBaseType
(
item
,
type
);
}
}
else
{
auto
&
var
=
outputs_
.
at
(
name
)[
index
];
this
->
SetVarBaseType
(
var
,
type
);
}
}
void
SetVarBaseType
(
std
::
shared_ptr
<
VarType
>
out
,
framework
::
proto
::
VarType
::
Type
type
)
{
out
->
SetType
(
type
);
if
((
out
->
MutableVar
()
->
IsInitialized
()
==
true
)
&&
(
out
->
MutableVar
()
->
Type
()
!=
type
))
{
out
->
MutableVar
()
->
Clear
();
}
}
void
SetVarBaseDataType
(
std
::
shared_ptr
<
VarType
>
out
,
framework
::
proto
::
VarType
::
Type
type
)
{
out
->
SetDataType
(
type
);
}
framework
::
proto
::
VarType
::
Type
GetInputType
(
const
std
::
string
&
name
,
const
int
&
index
=
0
)
const
override
{
return
inputs_
.
at
(
name
)[
index
]
->
Type
();
}
framework
::
proto
::
VarType
::
Type
GetOutputType
(
const
std
::
string
&
name
,
const
int
&
index
=
0
)
const
override
{
return
outputs_
.
at
(
name
)[
index
]
->
Type
();
}
framework
::
proto
::
VarType
::
Type
GetInputDataType
(
const
std
::
string
&
name
,
const
int
&
index
=
0
)
const
override
{
return
inputs_
.
at
(
name
)[
index
]
->
DataType
();
}
void
SetOutputDataType
(
const
std
::
string
&
name
,
framework
::
proto
::
VarType
::
Type
type
,
int
index
=
0
)
override
{
if
(
framework
::
ALL_ELEMENTS
==
index
)
{
for
(
auto
&
item
:
outputs_
.
at
(
name
))
{
this
->
SetVarBaseDataType
(
item
,
type
);
}
}
else
{
}
else
{
var_set_
[
name
]
->
SetType
(
type
);
auto
&
var
=
outputs_
.
at
(
name
)[
index
];
if
((
var_set_
[
name
]
->
MutableVar
()
->
IsInitialized
()
==
true
)
&&
this
->
SetVarBaseDataType
(
var
,
type
);
(
var_set_
[
name
]
->
MutableVar
()
->
Type
()
!=
type
))
{
var_set_
[
name
]
->
MutableVar
()
->
Clear
();
}
}
}
}
bool
IsDygraph
()
const
override
{
return
true
;
}
protected:
bool
HasVar
(
const
std
::
string
&
name
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"HasVar is not supported in runtime InferVarType"
));
}
}
framework
::
proto
::
VarType
::
Type
GetDataType
(
const
std
::
vector
<
std
::
string
>&
InputVars
(
const
std
::
string
&
name
)
const
override
{
const
std
::
string
&
name
)
const
override
{
auto
iter
=
var_set_
.
find
(
name
);
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"InputVars is not supported in runtime InferVarType"
));
}
PADDLE_ENFORCE_EQ
(
const
std
::
vector
<
std
::
string
>&
OutputVars
(
iter
!=
var_set_
.
end
(),
true
,
const
std
::
string
&
name
)
const
override
{
platform
::
errors
::
NotFound
(
"Cannot find var %s in GetDataType"
,
name
));
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
return
iter
->
second
->
DataType
();
"OutputVars is not supported in runtime InferVarType"
));
}
framework
::
proto
::
VarType
::
Type
GetVarType
(
const
std
::
string
&
name
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Do not manipulate var in runtime InferVarType"
));
}
}
void
Set
Data
Type
(
const
std
::
string
&
name
,
void
Set
Var
Type
(
const
std
::
string
&
name
,
framework
::
proto
::
VarType
::
Type
type
)
override
{
framework
::
proto
::
VarType
::
Type
type
)
override
{
var_set_
[
name
]
->
SetDataType
(
type
);
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Do not manipulate var in runtime InferVarType"
));
}
framework
::
proto
::
VarType
::
Type
GetVarDataType
(
const
std
::
string
&
name
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Do not manipulate var in runtime InferVarType"
));
}
void
SetVarDataType
(
const
std
::
string
&
name
,
framework
::
proto
::
VarType
::
Type
type
)
override
{
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Do not manipulate var in runtime InferVarType"
));
}
}
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
GetDataTypes
(
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
Get
Var
DataTypes
(
const
std
::
string
&
name
)
const
override
{
const
std
::
string
&
name
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"GetDataTypes is not supported in runtime InferVarType"
));
"Get
Var
DataTypes is not supported in runtime InferVarType"
));
}
}
void
SetDataTypes
(
const
std
::
string
&
name
,
void
Set
Var
DataTypes
(
const
std
::
string
&
name
,
const
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>&
const
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>&
multiple_data_type
)
override
{
multiple_data_type
)
override
{
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"SetDataTypes is not supported in runtime InferVarType"
));
"Set
Var
DataTypes is not supported in runtime InferVarType"
));
}
}
std
::
vector
<
int64_t
>
GetShape
(
const
std
::
string
&
name
)
const
override
{
std
::
vector
<
int64_t
>
Get
Var
Shape
(
const
std
::
string
&
name
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Do not handle Shape in runtime InferVarType"
));
"Do not handle Shape in runtime InferVarType"
));
}
}
void
SetShape
(
const
std
::
string
&
name
,
void
Set
Var
Shape
(
const
std
::
string
&
name
,
const
std
::
vector
<
int64_t
>&
dims
)
override
{
const
std
::
vector
<
int64_t
>&
dims
)
override
{
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Do not handle Shape in runtime InferVarType"
));
"Do not handle Shape in runtime InferVarType"
));
}
}
int32_t
GetLoDLevel
(
const
std
::
string
&
name
)
const
override
{
int32_t
Get
Var
LoDLevel
(
const
std
::
string
&
name
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Do not handle LoDLevel in runtime InferVarType"
));
"Do not handle LoDLevel in runtime InferVarType"
));
}
}
void
SetLoDLevel
(
const
std
::
string
&
name
,
int32_t
lod_level
)
override
{
void
Set
Var
LoDLevel
(
const
std
::
string
&
name
,
int32_t
lod_level
)
override
{
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Do not handle LoDLevel in runtime InferVarType"
));
"Do not handle LoDLevel in runtime InferVarType"
));
}
}
...
@@ -179,9 +233,6 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
...
@@ -179,9 +233,6 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
const
NameVarMap
<
VarType
>&
inputs_
;
const
NameVarMap
<
VarType
>&
inputs_
;
const
NameVarMap
<
VarType
>&
outputs_
;
const
NameVarMap
<
VarType
>&
outputs_
;
const
framework
::
AttributeMap
&
attrs_
;
const
framework
::
AttributeMap
&
attrs_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
input_names_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
output_names_
;
std
::
unordered_map
<
std
::
string
,
VarType
*>
var_set_
;
};
};
}
// namespace imperative
}
// namespace imperative
...
...
paddle/fluid/imperative/tests/test_layer.cc
浏览文件 @
91ae7848
...
@@ -37,33 +37,154 @@ using vb_vector = std::vector<std::shared_ptr<imperative::VarBase>>;
...
@@ -37,33 +37,154 @@ using vb_vector = std::vector<std::shared_ptr<imperative::VarBase>>;
using
var_pair
=
std
::
pair
<
std
::
string
,
vb_vector
>
;
using
var_pair
=
std
::
pair
<
std
::
string
,
vb_vector
>
;
template
<
typename
VarType
>
class
TestRuntimeInferVarTypeContext
:
public
RuntimeInferVarTypeContext
<
VarType
>
{
public:
TestRuntimeInferVarTypeContext
(
const
NameVarMap
<
VarType
>&
inputs
,
const
NameVarMap
<
VarType
>&
outputs
,
const
framework
::
AttributeMap
&
attrs_map
)
:
RuntimeInferVarTypeContext
<
VarType
>
(
inputs
,
outputs
,
attrs_map
)
{}
bool
HasVar
(
const
std
::
string
&
name
)
const
{
return
RuntimeInferVarTypeContext
<
VarType
>::
HasVar
(
name
);
}
const
std
::
vector
<
std
::
string
>&
InputVars
(
const
std
::
string
&
name
)
const
{
return
RuntimeInferVarTypeContext
<
VarType
>::
InputVars
(
name
);
}
const
std
::
vector
<
std
::
string
>&
OutputVars
(
const
std
::
string
&
name
)
const
{
return
RuntimeInferVarTypeContext
<
VarType
>::
OutputVars
(
name
);
}
framework
::
proto
::
VarType
::
Type
GetVarType
(
const
std
::
string
&
name
)
const
{
return
RuntimeInferVarTypeContext
<
VarType
>::
GetVarType
(
name
);
}
void
SetVarType
(
const
std
::
string
&
name
,
framework
::
proto
::
VarType
::
Type
type
)
{
RuntimeInferVarTypeContext
<
VarType
>::
SetVarType
(
name
,
type
);
}
framework
::
proto
::
VarType
::
Type
GetVarDataType
(
const
std
::
string
&
name
)
const
{
return
RuntimeInferVarTypeContext
<
VarType
>::
GetVarDataType
(
name
);
}
void
SetVarDataType
(
const
std
::
string
&
name
,
framework
::
proto
::
VarType
::
Type
type
)
{
RuntimeInferVarTypeContext
<
VarType
>::
SetVarDataType
(
name
,
type
);
}
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
GetVarDataTypes
(
const
std
::
string
&
name
)
const
{
return
RuntimeInferVarTypeContext
<
VarType
>::
GetVarDataTypes
(
name
);
}
void
SetVarDataTypes
(
const
std
::
string
&
name
,
const
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>&
multiple_data_type
)
{
RuntimeInferVarTypeContext
<
VarType
>::
SetVarDataTypes
(
name
,
multiple_data_type
);
}
std
::
vector
<
int64_t
>
GetVarShape
(
const
std
::
string
&
name
)
const
{
return
RuntimeInferVarTypeContext
<
VarType
>::
GetVarShape
(
name
);
}
void
SetVarShape
(
const
std
::
string
&
name
,
const
std
::
vector
<
int64_t
>&
dims
)
{
RuntimeInferVarTypeContext
<
VarType
>::
SetVarShape
(
name
,
dims
);
}
int32_t
GetVarLoDLevel
(
const
std
::
string
&
name
)
const
{
return
RuntimeInferVarTypeContext
<
VarType
>::
GetVarLoDLevel
(
name
);
}
void
SetVarLoDLevel
(
const
std
::
string
&
name
,
int32_t
lod_level
)
{
RuntimeInferVarTypeContext
<
VarType
>::
SetVarLoDLevel
(
name
,
lod_level
);
}
};
TEST
(
test_layer
,
test_runtime_context
)
{
TEST
(
test_layer
,
test_runtime_context
)
{
std
::
shared_ptr
<
imperative
::
VarBase
>
vin
(
std
::
shared_ptr
<
imperative
::
VarBase
>
vin
(
new
imperative
::
VarBase
(
false
,
"vin"
));
new
imperative
::
VarBase
(
false
,
"vin"
));
std
::
shared_ptr
<
imperative
::
VarBase
>
vin_b
(
new
imperative
::
VarBase
(
false
,
"vin_b"
));
std
::
shared_ptr
<
imperative
::
VarBase
>
vout
(
std
::
shared_ptr
<
imperative
::
VarBase
>
vout
(
new
imperative
::
VarBase
(
false
,
"vout"
));
new
imperative
::
VarBase
(
false
,
"vout"
));
var_pair
in_pair
=
var_pair
(
"X"
,
vb_vector
(
1
,
vin
));
std
::
shared_ptr
<
imperative
::
VarBase
>
vout_b
(
var_pair
out_pair
=
var_pair
(
"Out"
,
vb_vector
(
1
,
vout
));
new
imperative
::
VarBase
(
false
,
"vout_b"
));
var_pair
in_pair
=
var_pair
(
"X"
,
{
vin
,
vin_b
});
var_pair
out_pair
=
var_pair
(
"Out"
,
{
vout
,
vout_b
});
imperative
::
NameVarBaseMap
ins
=
{
in_pair
};
imperative
::
NameVarBaseMap
ins
=
{
in_pair
};
imperative
::
NameVarBaseMap
outs
=
{
out_pair
};
imperative
::
NameVarBaseMap
outs
=
{
out_pair
};
framework
::
AttributeMap
attrs
;
framework
::
AttributeMap
attrs
;
auto
*
ctx
=
new
imperative
::
RuntimeInferVarTypeContext
<
imperative
::
VarBase
>
(
auto
*
ctx
=
new
imperative
::
TestRuntimeInferVarTypeContext
<
imperative
::
VarBase
>
(
ins
,
outs
,
attrs
);
ins
,
outs
,
attrs
);
ASSERT_TRUE
(
ctx
->
HasVar
(
"vin"
));
ASSERT_TRUE
(
ctx
->
HasInput
(
"X"
));
ASSERT_TRUE
(
ctx
->
HasInput
(
"X"
));
ASSERT_TRUE
(
ctx
->
HasOutput
(
"Out"
));
ASSERT_TRUE
(
ctx
->
HasOutput
(
"Out"
));
ASSERT_ANY_THROW
(
ctx
->
GetDataTypes
(
"vin"
));
ASSERT_EQ
(
2u
,
ctx
->
InputSize
(
"X"
));
ASSERT_EQ
(
"vin"
,
ctx
->
InputVarName
(
"X"
,
0
));
ASSERT_TRUE
(
ctx
->
InputTypeAnyOf
(
"X"
,
framework
::
proto
::
VarType
::
LOD_TENSOR
));
ASSERT_TRUE
(
ctx
->
InputTypeAllOf
(
"X"
,
framework
::
proto
::
VarType
::
LOD_TENSOR
));
ASSERT_EQ
(
framework
::
proto
::
VarType
::
LOD_TENSOR
,
ctx
->
GetInputType
(
"X"
));
ASSERT_EQ
(
framework
::
proto
::
VarType
::
FP32
,
ctx
->
GetInputDataType
(
"X"
));
ctx
->
SyncTypeAndDataType
(
"X"
,
"Out"
);
ASSERT_EQ
(
framework
::
proto
::
VarType
::
FP32
,
vout
->
DataType
());
ASSERT_EQ
(
framework
::
proto
::
VarType
::
LOD_TENSOR
,
ctx
->
GetOutputType
(
"Out"
));
ctx
->
SetOutputType
(
"Out"
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
,
framework
::
ALL_ELEMENTS
);
ctx
->
SetOutputType
(
"Out"
,
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
);
ASSERT_EQ
(
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
,
vout
->
Type
());
ASSERT_EQ
(
framework
::
proto
::
VarType
::
SELECTED_ROWS
,
vout_b
->
Type
());
ctx
->
SetOutputDataType
(
"Out"
,
framework
::
proto
::
VarType
::
FP64
,
framework
::
ALL_ELEMENTS
);
ctx
->
SetOutputDataType
(
"Out"
,
framework
::
proto
::
VarType
::
INT8
);
ASSERT_EQ
(
framework
::
proto
::
VarType
::
INT8
,
vout
->
DataType
());
ASSERT_EQ
(
framework
::
proto
::
VarType
::
FP64
,
vout_b
->
DataType
());
// no throw, but do nothing
ASSERT_NO_THROW
(
ctx
->
InsertVar
(
"vout"
,
framework
::
proto
::
VarType
::
LOD_TENSOR
));
ASSERT_EQ
(
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
,
vout
->
Type
());
ASSERT_ANY_THROW
(
ctx
->
HasVar
(
"vin"
));
ASSERT_ANY_THROW
(
ctx
->
InputVars
(
"X"
));
ASSERT_ANY_THROW
(
ctx
->
OutputVars
(
"Out"
));
ASSERT_ANY_THROW
(
ctx
->
GetVarType
(
"vin"
));
ASSERT_ANY_THROW
(
ctx
->
SetVarType
(
"vin"
,
framework
::
proto
::
VarType
::
LOD_TENSOR
));
ASSERT_ANY_THROW
(
ctx
->
GetVarDataType
(
"vin"
));
ASSERT_ANY_THROW
(
ctx
->
SetVarDataType
(
"vout"
,
framework
::
proto
::
VarType
::
FP32
));
ASSERT_ANY_THROW
(
ctx
->
GetVarDataTypes
(
"vin"
));
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
NullType
;
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
NullType
;
ASSERT_ANY_THROW
(
ctx
->
SetDataTypes
(
"vin"
,
NullType
));
ASSERT_ANY_THROW
(
ctx
->
SetVarDataTypes
(
"vin"
,
NullType
));
ASSERT_ANY_THROW
(
ctx
->
GetShape
(
"vin"
));
ASSERT_ANY_THROW
(
ctx
->
GetVarShape
(
"vin"
));
ASSERT_ANY_THROW
(
ctx
->
GetLoDLevel
(
"vin"
));
ASSERT_ANY_THROW
(
ctx
->
SetVarShape
(
"vin"
,
{}));
ASSERT_ANY_THROW
(
ctx
->
SetLoDLevel
(
"vin"
,
2
));
ASSERT_ANY_THROW
(
ctx
->
GetVarLoDLevel
(
"vin"
));
ASSERT_ANY_THROW
(
ctx
->
SetVarLoDLevel
(
"vin"
,
2
));
ASSERT_TRUE
(
ctx
->
IsDygraph
());
}
}
std
::
string
LayerDebugString
(
const
std
::
string
&
op_type
,
std
::
string
LayerDebugString
(
const
std
::
string
&
op_type
,
const
NameVarBaseMap
&
ins
,
const
NameVarBaseMap
&
ins
,
const
NameVarBaseMap
&
outs
);
const
NameVarBaseMap
&
outs
);
TEST
(
test_layer
,
test_debug_string
)
{
TEST
(
test_layer
,
test_debug_string
)
{
platform
::
CPUPlace
place
;
platform
::
CPUPlace
place
;
...
@@ -71,7 +192,7 @@ TEST(test_layer, test_debug_string) {
...
@@ -71,7 +192,7 @@ TEST(test_layer, test_debug_string) {
new
imperative
::
VarBase
(
false
,
"vin"
));
new
imperative
::
VarBase
(
false
,
"vin"
));
var_pair
in_pair
=
var_pair
(
"X"
,
vb_vector
(
1
,
vin
));
var_pair
in_pair
=
var_pair
(
"X"
,
vb_vector
(
1
,
vin
));
auto
test_func
=
[
&
](
std
::
shared_ptr
<
imperative
::
VarBase
>
&
vout
)
{
auto
test_func
=
[
&
](
std
::
shared_ptr
<
imperative
::
VarBase
>
&
vout
)
{
var_pair
out_pair
=
var_pair
(
"Out"
,
vb_vector
(
1
,
vout
));
var_pair
out_pair
=
var_pair
(
"Out"
,
vb_vector
(
1
,
vout
));
imperative
::
NameVarBaseMap
ins
=
{
in_pair
};
imperative
::
NameVarBaseMap
ins
=
{
in_pair
};
imperative
::
NameVarBaseMap
outs
=
{
out_pair
};
imperative
::
NameVarBaseMap
outs
=
{
out_pair
};
...
@@ -124,26 +245,26 @@ TEST(test_layer, test_debug_string) {
...
@@ -124,26 +245,26 @@ TEST(test_layer, test_debug_string) {
}
}
static
std
::
shared_ptr
<
imperative
::
GradOpNode
>
CreateGradNode
(
static
std
::
shared_ptr
<
imperative
::
GradOpNode
>
CreateGradNode
(
size_t
id
,
const
std
::
string
&
type
,
const
imperative
::
NameVarBaseMap
&
ins
,
size_t
id
,
const
std
::
string
&
type
,
const
imperative
::
NameVarBaseMap
&
ins
,
const
imperative
::
NameVarBaseMap
&
outs
,
const
imperative
::
NameVarBaseMap
&
outs
,
const
framework
::
AttributeMap
&
attrs
,
const
platform
::
Place
&
place
)
{
const
framework
::
AttributeMap
&
attrs
,
const
platform
::
Place
&
place
)
{
auto
node
=
std
::
make_shared
<
imperative
::
GradOpNode
>
();
auto
node
=
std
::
make_shared
<
imperative
::
GradOpNode
>
();
auto
*
op
=
&
(
node
->
emplace_back
());
auto
*
op
=
&
(
node
->
emplace_back
());
op
->
SetId
(
id
);
op
->
SetId
(
id
);
op
->
SetPlace
(
place
);
op
->
SetPlace
(
place
);
op
->
SetType
(
type
);
op
->
SetType
(
type
);
op
->
SetAttrMap
(
attrs
);
op
->
SetAttrMap
(
attrs
);
for
(
auto
&
pair
:
ins
)
{
for
(
auto
&
pair
:
ins
)
{
std
::
vector
<
std
::
shared_ptr
<
VariableWrapper
>>
vars
;
std
::
vector
<
std
::
shared_ptr
<
VariableWrapper
>>
vars
;
for
(
auto
&
var
:
pair
.
second
)
{
for
(
auto
&
var
:
pair
.
second
)
{
vars
.
emplace_back
(
var
->
SharedVar
());
vars
.
emplace_back
(
var
->
SharedVar
());
}
}
op
->
SetInput
(
pair
.
first
,
vars
,
false
);
op
->
SetInput
(
pair
.
first
,
vars
,
false
);
}
}
for
(
auto
&
pair
:
outs
)
{
for
(
auto
&
pair
:
outs
)
{
std
::
vector
<
std
::
shared_ptr
<
VariableWrapper
>>
vars
;
std
::
vector
<
std
::
shared_ptr
<
VariableWrapper
>>
vars
;
for
(
auto
&
var
:
pair
.
second
)
{
for
(
auto
&
var
:
pair
.
second
)
{
vars
.
emplace_back
(
var
->
SharedVar
());
vars
.
emplace_back
(
var
->
SharedVar
());
}
}
op
->
SetOutput
(
pair
.
first
,
vars
,
false
);
op
->
SetOutput
(
pair
.
first
,
vars
,
false
);
...
@@ -173,7 +294,7 @@ TEST(test_layer, test_clear_backward_info) {
...
@@ -173,7 +294,7 @@ TEST(test_layer, test_clear_backward_info) {
node
->
InsertGradPendingNode
(
pending_node
);
node
->
InsertGradPendingNode
(
pending_node
);
ASSERT_EQ
(
node
->
size
(),
1UL
);
ASSERT_EQ
(
node
->
size
(),
1UL
);
auto
*
op
=
&
(
node
->
back
());
auto
*
op
=
&
(
node
->
back
());
ASSERT_GT
(
op
->
GetInsMap
().
size
(),
0UL
);
ASSERT_GT
(
op
->
GetInsMap
().
size
(),
0UL
);
ASSERT_GT
(
op
->
GetOutsMap
().
size
(),
0UL
);
ASSERT_GT
(
op
->
GetOutsMap
().
size
(),
0UL
);
...
@@ -196,10 +317,10 @@ TEST(test_layer, test_varbase_basic) {
...
@@ -196,10 +317,10 @@ TEST(test_layer, test_varbase_basic) {
std
::
shared_ptr
<
imperative
::
VarBase
>
vin_with_grad
(
std
::
shared_ptr
<
imperative
::
VarBase
>
vin_with_grad
(
new
imperative
::
VarBase
(
true
,
"vin"
));
new
imperative
::
VarBase
(
true
,
"vin"
));
ASSERT_ANY_THROW
(
vin
->
MutableGradVar
());
ASSERT_ANY_THROW
(
vin
->
MutableGradVar
());
ASSERT_NO_THROW
(
ASSERT_TRUE
(
dynamic_cast
<
framework
::
Variable
*>
(
ASSERT_NO_THROW
(
ASSERT_TRUE
(
dynamic_cast
<
framework
::
Variable
*>
(
vin_with_grad
->
MutableGradVar
())
!=
0
));
vin_with_grad
->
MutableGradVar
())
!=
0
));
ASSERT_TRUE
(
dynamic_cast
<
framework
::
Variable
*>
(
ASSERT_TRUE
(
vin_with_grad
->
MutableGradVar
())
!=
0
);
dynamic_cast
<
framework
::
Variable
*>
(
vin_with_grad
->
MutableGradVar
())
!=
0
);
vin_with_grad
->
SetOverridedStopGradient
(
false
);
vin_with_grad
->
SetOverridedStopGradient
(
false
);
ASSERT_FALSE
(
vin_with_grad
->
OverridedStopGradient
());
ASSERT_FALSE
(
vin_with_grad
->
OverridedStopGradient
());
ASSERT_NO_FATAL_FAILURE
(
vin_with_grad
->
SetPersistable
(
true
));
ASSERT_NO_FATAL_FAILURE
(
vin_with_grad
->
SetPersistable
(
true
));
...
@@ -228,9 +349,9 @@ TEST(test_layer, test_dygraph_execution_context) {
...
@@ -228,9 +349,9 @@ TEST(test_layer, test_dygraph_execution_context) {
auto
op
=
framework
::
OpRegistry
::
CreateOp
(
"mul"
,
{},
{},
{},
false
);
auto
op
=
framework
::
OpRegistry
::
CreateOp
(
"mul"
,
{},
{},
{},
false
);
paddle
::
platform
::
CPUPlace
cpu_place
;
paddle
::
platform
::
CPUPlace
cpu_place
;
paddle
::
platform
::
DeviceContextPool
&
pool
=
paddle
::
platform
::
DeviceContextPool
&
pool
=
paddle
::
platform
::
DeviceContextPool
::
Instance
();
paddle
::
platform
::
DeviceContextPool
::
Instance
();
auto
*
dev_ctx
=
pool
.
Get
(
cpu_place
);
auto
*
dev_ctx
=
pool
.
Get
(
cpu_place
);
paddle
::
framework
::
RuntimeContext
ctx
({},
{});
paddle
::
framework
::
RuntimeContext
ctx
({},
{});
framework
::
Scope
scope
;
framework
::
Scope
scope
;
...
...
paddle/fluid/operators/activation_op.cc
浏览文件 @
91ae7848
...
@@ -129,9 +129,10 @@ class ActivationOp : public framework::OperatorWithKernel {
...
@@ -129,9 +129,10 @@ class ActivationOp : public framework::OperatorWithKernel {
class
ActivationOpInferVarType
class
ActivationOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
/*->*/
"Out"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Out"
}};
return
m
;
}
}
};
};
...
...
paddle/fluid/operators/allclose_op.cc
浏览文件 @
91ae7848
...
@@ -103,8 +103,7 @@ class AllcloseOp : public framework::OperatorWithKernel {
...
@@ -103,8 +103,7 @@ class AllcloseOp : public framework::OperatorWithKernel {
class
AllcloseOpVarTypeInference
:
public
framework
::
VarTypeInference
{
class
AllcloseOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
ctx
->
SetOutputDataType
(
"Out"
,
framework
::
proto
::
VarType
::
BOOL
);
ctx
->
SetDataType
(
out_var_name
,
framework
::
proto
::
VarType
::
BOOL
);
}
}
};
};
...
...
paddle/fluid/operators/assign_op.cc
浏览文件 @
91ae7848
...
@@ -60,11 +60,7 @@ class AssignOp : public framework::OperatorWithKernel {
...
@@ -60,11 +60,7 @@ class AssignOp : public framework::OperatorWithKernel {
class
AssignInferVarType
:
public
framework
::
VarTypeInference
{
class
AssignInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
out_var_name
=
ctx
->
Output
(
"Out"
)[
0
];
ctx
->
SyncTypeAndDataType
(
"X"
,
"Out"
);
auto
input_type
=
ctx
->
GetType
(
ctx
->
Input
(
"X"
)[
0
]);
auto
input_data_type
=
ctx
->
GetDataType
(
ctx
->
Input
(
"X"
)[
0
]);
ctx
->
SetType
(
out_var_name
,
input_type
);
ctx
->
SetDataType
(
out_var_name
,
input_data_type
);
}
}
};
};
...
...
paddle/fluid/operators/batch_norm_op.h
浏览文件 @
91ae7848
...
@@ -171,9 +171,10 @@ class BatchNormGradMaker : public framework::SingleGradOpMaker<T> {
...
@@ -171,9 +171,10 @@ class BatchNormGradMaker : public framework::SingleGradOpMaker<T> {
class
BatchNormOpInferVarType
class
BatchNormOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
/*->*/
"Y"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Y"
}};
return
m
;
}
}
};
};
...
...
paddle/fluid/operators/beam_search_decode_op.cc
浏览文件 @
91ae7848
...
@@ -204,12 +204,10 @@ class BeamSearchDecodeInferShape : public framework::InferShapeBase {
...
@@ -204,12 +204,10 @@ class BeamSearchDecodeInferShape : public framework::InferShapeBase {
class
BeamSearchDecodeInferVarType
:
public
framework
::
VarTypeInference
{
class
BeamSearchDecodeInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
for
(
auto
&
o
:
ctx
->
Output
(
"SentenceIds"
))
{
ctx
->
SetOutputType
(
"SentenceIds"
,
framework
::
proto
::
VarType
::
LOD_TENSOR
,
ctx
->
SetType
(
o
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
framework
::
ALL_ELEMENTS
);
}
ctx
->
SetOutputType
(
"SentenceScores"
,
framework
::
proto
::
VarType
::
LOD_TENSOR
,
for
(
auto
&
o
:
ctx
->
Output
(
"SentenceScores"
))
{
framework
::
ALL_ELEMENTS
);
ctx
->
SetType
(
o
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
}
}
};
};
...
...
paddle/fluid/operators/beam_search_op.cc
浏览文件 @
91ae7848
...
@@ -122,12 +122,10 @@ class BeamSearchOp : public framework::OperatorWithKernel {
...
@@ -122,12 +122,10 @@ class BeamSearchOp : public framework::OperatorWithKernel {
class
BeamSearchInferVarType
:
public
framework
::
VarTypeInference
{
class
BeamSearchInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
for
(
auto
&
o
:
ctx
->
Output
(
"selected_ids"
))
{
ctx
->
SetOutputType
(
"selected_ids"
,
framework
::
proto
::
VarType
::
LOD_TENSOR
,
ctx
->
SetType
(
o
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
framework
::
ALL_ELEMENTS
);
}
ctx
->
SetOutputType
(
"selected_scores"
,
framework
::
proto
::
VarType
::
LOD_TENSOR
,
for
(
auto
&
o
:
ctx
->
Output
(
"selected_scores"
))
{
framework
::
ALL_ELEMENTS
);
ctx
->
SetType
(
o
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
}
}
};
};
...
...
paddle/fluid/operators/controlflow/get_places_op.cc
浏览文件 @
91ae7848
...
@@ -92,9 +92,8 @@ execution.
...
@@ -92,9 +92,8 @@ execution.
class
GetPlacesInferVarType
:
public
framework
::
VarTypeInference
{
class
GetPlacesInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
for
(
auto
&
o_name
:
ctx
->
Output
(
"Out"
))
{
ctx
->
SetOutputType
(
"Out"
,
framework
::
proto
::
VarType
::
PLACE_LIST
,
ctx
->
SetType
(
o_name
,
framework
::
proto
::
VarType
::
PLACE_LIST
);
framework
::
ALL_ELEMENTS
);
}
}
}
};
};
...
...
paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc
浏览文件 @
91ae7848
...
@@ -111,15 +111,15 @@ class WriteToArrayInferShape : public framework::InferShapeBase {
...
@@ -111,15 +111,15 @@ class WriteToArrayInferShape : public framework::InferShapeBase {
}
}
};
};
class
WriteToArrayInferVarType
:
public
framework
::
VarTypeInference
{
class
WriteToArrayInferVarType
:
public
framework
::
StaticGraph
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
x_name
=
ctx
->
Input
(
"X"
)[
0
];
auto
x_name
=
Input
(
ctx
,
"X"
)[
0
];
auto
out_name
=
ctx
->
Output
(
"Out"
)[
0
];
auto
out_name
=
Output
(
ctx
,
"Out"
)[
0
];
VLOG
(
10
)
<<
"Set Variable "
<<
out_name
<<
" as LOD_TENSOR_ARRAY"
;
VLOG
(
10
)
<<
"Set Variable "
<<
out_name
<<
" as LOD_TENSOR_ARRAY"
;
ctx
->
SetType
(
out_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
);
SetType
(
ctx
,
out_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
);
if
(
ctx
->
HasVar
(
x_name
))
{
if
(
HasVar
(
ctx
,
x_name
))
{
ctx
->
SetDataType
(
out_name
,
ctx
->
GetDataType
(
x_name
));
SetDataType
(
ctx
,
out_name
,
GetDataType
(
ctx
,
x_name
));
}
}
}
}
};
};
...
...
paddle/fluid/operators/controlflow/while_op.cc
浏览文件 @
91ae7848
...
@@ -398,18 +398,19 @@ class WhileGradOpMaker : public framework::SingleGradOpMaker<T> {
...
@@ -398,18 +398,19 @@ class WhileGradOpMaker : public framework::SingleGradOpMaker<T> {
}
}
};
};
class
WhileGradOpVarTypeInference
:
public
framework
::
VarTypeInference
{
class
WhileGradOpVarTypeInference
:
public
framework
::
StaticGraphVarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
p_names
=
ctx
->
Input
(
kX
);
auto
p_names
=
Input
(
ctx
,
kX
);
auto
pg_ig_names
=
ctx
->
Output
(
framework
::
GradVarName
(
kX
));
auto
pg_ig_names
=
Output
(
ctx
,
framework
::
GradVarName
(
kX
));
for
(
size_t
i
=
0
;
i
<
p_names
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
p_names
.
size
();
++
i
)
{
if
(
ctx
->
HasVar
(
pg_ig_names
[
i
]))
{
if
(
HasVar
(
ctx
,
pg_ig_names
[
i
]))
{
VLOG
(
5
)
<<
"Setting "
<<
pg_ig_names
[
i
]
<<
" following "
<<
p_names
[
i
]
VLOG
(
5
)
<<
"Setting "
<<
pg_ig_names
[
i
]
<<
" following "
<<
p_names
[
i
]
<<
" type: "
<<
ctx
->
GetType
(
p_names
[
i
]);
<<
" type: "
<<
GetType
(
ctx
,
p_names
[
i
]);
ctx
->
SetType
(
pg_ig_names
[
i
],
ctx
->
GetType
(
p_names
[
i
]));
SetType
(
ctx
,
pg_ig_names
[
i
],
GetType
(
ctx
,
p_names
[
i
]));
ctx
->
SetDataType
(
pg_ig_names
[
i
],
ctx
->
GetDataType
(
p_names
[
i
]));
SetDataType
(
ctx
,
pg_ig_names
[
i
],
GetDataType
(
ctx
,
p_names
[
i
]));
}
}
}
}
}
}
...
...
paddle/fluid/operators/conv_op.h
浏览文件 @
91ae7848
...
@@ -254,10 +254,11 @@ class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -254,10 +254,11 @@ class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker {
class
ConvOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
class
ConvOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{
{
"Input"
,
/*->*/
"Output"
}};
{
"Input"
,
/*->*/
"Output"
}};
return
m
;
}
}
};
};
...
...
paddle/fluid/operators/cross_entropy_op.cc
浏览文件 @
91ae7848
...
@@ -177,9 +177,10 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel {
...
@@ -177,9 +177,10 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel {
class
CrossEntropyOpInferVarType
class
CrossEntropyOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
/*->*/
"Y"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Y"
}};
return
m
;
}
}
};
};
...
...
paddle/fluid/operators/distributed_ops/merge_ids_op.cc
浏览文件 @
91ae7848
...
@@ -115,10 +115,8 @@ class MergeIdsOp : public framework::OperatorWithKernel {
...
@@ -115,10 +115,8 @@ class MergeIdsOp : public framework::OperatorWithKernel {
class
MergeIdsOpInferVarType
:
public
framework
::
VarTypeInference
{
class
MergeIdsOpInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
input_type
=
ctx
->
GetType
(
ctx
->
Input
(
"Ids"
)[
0
]);
auto
input_type
=
ctx
->
GetInputType
(
"Ids"
);
for
(
auto
&
out_var
:
ctx
->
Output
(
"Out"
))
{
ctx
->
SetOutputType
(
"Out"
,
input_type
,
framework
::
ALL_ELEMENTS
);
ctx
->
SetType
(
out_var
,
input_type
);
}
}
}
};
};
...
...
paddle/fluid/operators/distributed_ops/split_ids_op.cc
浏览文件 @
91ae7848
...
@@ -73,10 +73,8 @@ class SplitIdsOp : public framework::OperatorWithKernel {
...
@@ -73,10 +73,8 @@ class SplitIdsOp : public framework::OperatorWithKernel {
class
SplitIdsOpInferVarType
:
public
framework
::
VarTypeInference
{
class
SplitIdsOpInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
input_type
=
ctx
->
GetType
(
ctx
->
Input
(
"Ids"
)[
0
]);
auto
input_type
=
ctx
->
GetInputType
(
"Ids"
);
for
(
auto
&
out_var
:
ctx
->
Output
(
"Out"
))
{
ctx
->
SetOutputType
(
"Out"
,
input_type
,
framework
::
ALL_ELEMENTS
);
ctx
->
SetType
(
out_var
,
input_type
);
}
}
}
};
};
...
...
paddle/fluid/operators/elementwise/elementwise_op.h
浏览文件 @
91ae7848
...
@@ -119,9 +119,10 @@ class ElementwiseOp : public framework::OperatorWithKernel {
...
@@ -119,9 +119,10 @@ class ElementwiseOp : public framework::OperatorWithKernel {
class
ElementwiseOpInferVarType
class
ElementwiseOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
/*->*/
"Out"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Out"
}};
return
m
;
}
}
};
};
...
...
paddle/fluid/operators/eye_op.cc
浏览文件 @
91ae7848
...
@@ -49,8 +49,7 @@ class EyeOpVarTypeInference : public framework::VarTypeInference {
...
@@ -49,8 +49,7 @@ class EyeOpVarTypeInference : public framework::VarTypeInference {
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
data_type
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
auto
data_type
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
boost
::
get
<
int
>
(
ctx
->
GetAttr
(
"dtype"
)));
boost
::
get
<
int
>
(
ctx
->
GetAttr
(
"dtype"
)));
auto
&
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
ctx
->
SetOutputDataType
(
"Out"
,
data_type
);
ctx
->
SetDataType
(
out_var_name
,
data_type
);
}
}
};
};
...
...
paddle/fluid/operators/fill_any_like_op.cc
浏览文件 @
91ae7848
...
@@ -72,14 +72,12 @@ The output will have the same shape and dtype as the input.
...
@@ -72,14 +72,12 @@ The output will have the same shape and dtype as the input.
class
FillAnyLikeVarTypeInference
:
public
framework
::
VarTypeInference
{
class
FillAnyLikeVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
auto
var_data_type
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
auto
var_data_type
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
boost
::
get
<
int
>
(
ctx
->
GetAttr
(
"dtype"
)));
boost
::
get
<
int
>
(
ctx
->
GetAttr
(
"dtype"
)));
if
(
var_data_type
<
0
)
{
if
(
var_data_type
<
0
)
{
const
auto
&
input_var_name
=
ctx
->
Input
(
"X"
).
front
();
ctx
->
SetOutputDataType
(
"Out"
,
ctx
->
GetInputDataType
(
"X"
));
ctx
->
SetDataType
(
out_var_name
,
ctx
->
GetDataType
(
input_var_name
));
}
else
{
}
else
{
ctx
->
Set
DataType
(
out_var_name
,
var_data_type
);
ctx
->
Set
OutputDataType
(
"Out"
,
var_data_type
);
}
}
}
}
};
};
...
...
paddle/fluid/operators/fill_constant_op.cc
浏览文件 @
91ae7848
...
@@ -64,8 +64,7 @@ class FillConstantOpVarTypeInference : public framework::VarTypeInference {
...
@@ -64,8 +64,7 @@ class FillConstantOpVarTypeInference : public framework::VarTypeInference {
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
data_type
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
auto
data_type
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
boost
::
get
<
int
>
(
ctx
->
GetAttr
(
"dtype"
)));
boost
::
get
<
int
>
(
ctx
->
GetAttr
(
"dtype"
)));
auto
&
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
ctx
->
SetOutputDataType
(
"Out"
,
data_type
);
ctx
->
SetDataType
(
out_var_name
,
data_type
);
}
}
};
};
...
...
paddle/fluid/operators/fill_op.cc
浏览文件 @
91ae7848
...
@@ -63,8 +63,7 @@ class FillOpVarTypeInference : public framework::VarTypeInference {
...
@@ -63,8 +63,7 @@ class FillOpVarTypeInference : public framework::VarTypeInference {
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
data_type
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
auto
data_type
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
boost
::
get
<
int
>
(
ctx
->
GetAttr
(
"dtype"
)));
boost
::
get
<
int
>
(
ctx
->
GetAttr
(
"dtype"
)));
auto
&
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
ctx
->
SetOutputDataType
(
"Out"
,
data_type
);
ctx
->
SetDataType
(
out_var_name
,
data_type
);
}
}
};
};
...
...
paddle/fluid/operators/flip_op.cc
浏览文件 @
91ae7848
...
@@ -114,9 +114,10 @@ class FlipOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -114,9 +114,10 @@ class FlipOpMaker : public framework::OpProtoAndCheckerMaker {
class
FlipOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
class
FlipOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
/*->*/
"Out"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Out"
}};
return
m
;
}
}
};
};
...
...
paddle/fluid/operators/fused/fused_bn_activation_op.h
浏览文件 @
91ae7848
...
@@ -85,9 +85,10 @@ class FusedBatchNormActGradOpMaker : public framework::SingleGradOpMaker<T> {
...
@@ -85,9 +85,10 @@ class FusedBatchNormActGradOpMaker : public framework::SingleGradOpMaker<T> {
class
FusedBatchNormActOpInferVarType
class
FusedBatchNormActOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
/*->*/
"Y"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Y"
}};
return
m
;
}
}
};
};
...
...
paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc
浏览文件 @
91ae7848
...
@@ -146,19 +146,20 @@ class FusedEmbeddingSeqPoolOpGradVarTypeInference
...
@@ -146,19 +146,20 @@ class FusedEmbeddingSeqPoolOpGradVarTypeInference
:
public
framework
::
VarTypeInference
{
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
out_var_name
=
ctx
->
Output
(
framework
::
GradVarName
(
"W"
)).
front
(
);
auto
out_var_name
=
framework
::
GradVarName
(
"W"
);
auto
attr
=
ctx
->
GetAttr
(
"is_sparse"
);
auto
attr
=
ctx
->
GetAttr
(
"is_sparse"
);
bool
is_sparse
=
boost
::
get
<
bool
>
(
attr
);
bool
is_sparse
=
boost
::
get
<
bool
>
(
attr
);
if
(
is_sparse
)
{
if
(
is_sparse
)
{
VLOG
(
3
)
<<
"fused_embedding_seq_pool_grad op "
VLOG
(
3
)
<<
"fused_embedding_seq_pool_grad op "
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to SelectedRows"
;
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to SelectedRows"
;
ctx
->
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
ctx
->
SetOutputType
(
out_var_name
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
else
{
}
else
{
VLOG
(
3
)
<<
"fused_embedding_seq_pool_grad op "
VLOG
(
3
)
<<
"fused_embedding_seq_pool_grad op "
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to LoDTensor"
;
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to LoDTensor"
;
ctx
->
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
->
Set
Output
Type
(
out_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
}
ctx
->
Set
DataType
(
out_var_name
,
ctx
->
GetDataType
(
ctx
->
Input
(
"W"
)[
0
]
));
ctx
->
Set
OutputDataType
(
out_var_name
,
ctx
->
GetInputDataType
(
"W"
));
}
}
};
};
...
...
paddle/fluid/operators/get_tensor_from_selected_rows_op.cc
浏览文件 @
91ae7848
...
@@ -83,11 +83,8 @@ class GetTensorFromSelectedRowsOpVarTypeInference
...
@@ -83,11 +83,8 @@ class GetTensorFromSelectedRowsOpVarTypeInference
:
public
framework
::
VarTypeInference
{
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
{
// NOLINT
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
{
// NOLINT
auto
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
ctx
->
SetOutputType
(
"Out"
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
auto
in_var_name
=
ctx
->
Input
(
"X"
).
front
();
ctx
->
SetOutputDataType
(
"Out"
,
ctx
->
GetInputDataType
(
"X"
));
ctx
->
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
->
SetDataType
(
out_var_name
,
ctx
->
GetDataType
(
in_var_name
));
}
}
};
};
...
...
paddle/fluid/operators/group_norm_op.cc
浏览文件 @
91ae7848
...
@@ -216,9 +216,10 @@ DECLARE_INPLACE_OP_INFERER(GroupNormGradInplaceInToOut,
...
@@ -216,9 +216,10 @@ DECLARE_INPLACE_OP_INFERER(GroupNormGradInplaceInToOut,
class
GroupNormOpInferVarType
class
GroupNormOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
const
override
{
return
{{
"X"
,
/*->*/
"Y"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Y"
}};
return
m
;
}
}
};
};
...
...
paddle/fluid/operators/hierarchical_sigmoid_op.cc
浏览文件 @
91ae7848
...
@@ -229,31 +229,30 @@ class HierarchicalSigmoidGradOpGradVarTypeInference
...
@@ -229,31 +229,30 @@ class HierarchicalSigmoidGradOpGradVarTypeInference
:
public
framework
::
VarTypeInference
{
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
w_grad_var_name
=
ctx
->
Output
(
framework
::
GradVarName
(
"W"
)).
front
(
);
auto
w_grad_var_name
=
framework
::
GradVarName
(
"W"
);
auto
has_bias_grad_var
=
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Bias"
)
);
auto
bias_grad_var_name
=
framework
::
GradVarName
(
"Bias"
);
std
::
string
bias_grad_var_name
;
if
(
ctx
->
HasOutput
(
bias_grad_var_name
))
{
bool
hasBias
=
false
;
VLOG
(
3
)
<<
"hierarchical_sigmoid_grad op "
if
(
has_bias_grad_var
)
{
<<
framework
::
GradVarName
(
"Bias"
)
<<
" is set to LoDTensor"
;
hasBias
=
true
;
ctx
->
SetOutputType
(
bias_grad_var_name
,
bias_grad_var_name
=
ctx
->
Output
(
framework
::
GradVarName
(
"Bias"
)).
front
(
);
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
}
auto
attr
=
ctx
->
GetAttr
(
"is_sparse"
);
auto
attr
=
ctx
->
GetAttr
(
"is_sparse"
);
bool
is_sparse
=
boost
::
get
<
bool
>
(
attr
);
bool
is_sparse
=
boost
::
get
<
bool
>
(
attr
);
if
(
is_sparse
)
{
if
(
is_sparse
)
{
VLOG
(
3
)
<<
"hierarchical_sigmoid_grad op "
<<
framework
::
GradVarName
(
"W"
)
VLOG
(
3
)
<<
"hierarchical_sigmoid_grad op "
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to SelectedRows"
;
<<
" is set to SelectedRows"
;
ctx
->
SetType
(
w_grad_var_name
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
ctx
->
SetOutputType
(
w_grad_var_name
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
else
{
}
else
{
VLOG
(
3
)
<<
"hierarchical_sigmoid_grad op "
<<
framework
::
GradVarName
(
"W"
)
VLOG
(
3
)
<<
"hierarchical_sigmoid_grad op "
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to LoDTensor"
;
<<
" is set to LoDTensor"
;
ctx
->
SetType
(
w_grad_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
->
SetOutputType
(
w_grad_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
}
if
(
hasBias
)
{
VLOG
(
3
)
<<
"hierarchical_sigmoid_grad op "
ctx
->
SetOutputDataType
(
w_grad_var_name
,
ctx
->
GetInputDataType
(
"W"
));
<<
framework
::
GradVarName
(
"Bias"
)
<<
" is set to LoDTensor"
;
ctx
->
SetType
(
bias_grad_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
ctx
->
SetDataType
(
w_grad_var_name
,
ctx
->
GetDataType
(
ctx
->
Input
(
"W"
)[
0
]));
}
}
};
};
...
...
paddle/fluid/operators/instance_norm_op.h
浏览文件 @
91ae7848
...
@@ -123,9 +123,10 @@ class InstanceNormDoubleGradMaker : public framework::SingleGradOpMaker<T> {
...
@@ -123,9 +123,10 @@ class InstanceNormDoubleGradMaker : public framework::SingleGradOpMaker<T> {
class
InstanceNormOpInferVarType
class
InstanceNormOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
"Y"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
"Y"
}};
return
m
;
}
}
};
};
...
...
paddle/fluid/operators/lod_rank_table_op.cc
浏览文件 @
91ae7848
...
@@ -65,9 +65,8 @@ class LoDRankTableInferShape : public framework::InferShapeBase {
...
@@ -65,9 +65,8 @@ class LoDRankTableInferShape : public framework::InferShapeBase {
class
LoDRankTableInferVarType
:
public
framework
::
VarTypeInference
{
class
LoDRankTableInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
for
(
auto
&
o
:
ctx
->
Output
(
"Out"
))
{
ctx
->
SetOutputType
(
"Out"
,
framework
::
proto
::
VarType
::
LOD_RANK_TABLE
,
ctx
->
SetType
(
o
,
framework
::
proto
::
VarType
::
LOD_RANK_TABLE
);
framework
::
ALL_ELEMENTS
);
}
}
}
};
};
...
...
paddle/fluid/operators/lod_reset_op.cc
浏览文件 @
91ae7848
...
@@ -76,24 +76,25 @@ class LoDResetOp : public framework::OperatorWithKernel {
...
@@ -76,24 +76,25 @@ class LoDResetOp : public framework::OperatorWithKernel {
}
}
};
};
class
LoDResetOpVarTypeInference
:
public
framework
::
VarTypeInference
{
class
LoDResetOpVarTypeInference
:
public
framework
::
StaticGraphVarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
x_var_name
=
ctx
->
Input
(
"X"
).
front
();
auto
x_var_name
=
Input
(
ctx
,
"X"
).
front
();
auto
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
auto
out_var_name
=
Output
(
ctx
,
"Out"
).
front
();
bool
append
=
boost
::
get
<
bool
>
(
ctx
->
GetAttr
(
"append"
));
bool
append
=
boost
::
get
<
bool
>
(
ctx
->
GetAttr
(
"append"
));
if
(
ctx
->
HasInput
(
"Y"
))
{
if
(
ctx
->
HasInput
(
"Y"
))
{
auto
y_var_name
=
ctx
->
Input
(
"Y"
).
front
();
auto
y_var_name
=
Input
(
ctx
,
"Y"
).
front
();
auto
y_lod_level
=
std
::
max
(
ctx
->
GetLoDLevel
(
y_var_name
),
1
);
auto
y_lod_level
=
std
::
max
(
GetLoDLevel
(
ctx
,
y_var_name
),
1
);
ctx
->
SetLoDLevel
(
out_var_name
,
y_lod_level
);
SetLoDLevel
(
ctx
,
out_var_name
,
y_lod_level
);
}
else
if
(
append
)
{
}
else
if
(
append
)
{
auto
x_lod_level
=
std
::
max
(
ctx
->
GetLoDLevel
(
x_var_name
),
1
);
auto
x_lod_level
=
std
::
max
(
GetLoDLevel
(
ctx
,
x_var_name
),
1
);
ctx
->
SetLoDLevel
(
out_var_name
,
x_lod_level
);
SetLoDLevel
(
ctx
,
out_var_name
,
x_lod_level
);
}
else
{
}
else
{
ctx
->
SetLoDLevel
(
out_var_name
,
1
);
SetLoDLevel
(
ctx
,
out_var_name
,
1
);
}
}
ctx
->
SetDataType
(
out_var_name
,
ctx
->
GetDataType
(
x_var_name
));
SetDataType
(
ctx
,
out_var_name
,
GetDataType
(
ctx
,
x_var_name
));
ctx
->
SetType
(
out_var_name
,
paddle
::
framework
::
proto
::
VarType
::
LOD_TENSOR
);
SetType
(
ctx
,
out_var_name
,
paddle
::
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
}
};
};
...
...
paddle/fluid/operators/lod_tensor_to_array_op.cc
浏览文件 @
91ae7848
...
@@ -221,9 +221,8 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase {
...
@@ -221,9 +221,8 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase {
class
LoDTensorToArrayInferVarType
:
public
framework
::
VarTypeInference
{
class
LoDTensorToArrayInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
for
(
auto
&
out_var
:
ctx
->
Output
(
"Out"
))
{
ctx
->
SetOutputType
(
"Out"
,
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
,
ctx
->
SetType
(
out_var
,
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
);
framework
::
ALL_ELEMENTS
);
}
}
}
};
};
...
...
paddle/fluid/operators/lookup_table_op.cc
浏览文件 @
91ae7848
...
@@ -173,19 +173,20 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
...
@@ -173,19 +173,20 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
class
LookupTableOpGradVarTypeInference
:
public
framework
::
VarTypeInference
{
class
LookupTableOpGradVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
out_var_name
=
ctx
->
Output
(
framework
::
GradVarName
(
"W"
)).
front
(
);
auto
out_var_name
=
framework
::
GradVarName
(
"W"
);
auto
attr
=
ctx
->
GetAttr
(
"is_sparse"
);
auto
attr
=
ctx
->
GetAttr
(
"is_sparse"
);
bool
is_sparse
=
boost
::
get
<
bool
>
(
attr
);
bool
is_sparse
=
boost
::
get
<
bool
>
(
attr
);
if
(
is_sparse
)
{
if
(
is_sparse
)
{
VLOG
(
3
)
<<
"lookup_table_grad op "
<<
framework
::
GradVarName
(
"W"
)
VLOG
(
3
)
<<
"lookup_table_grad op "
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to SelectedRows"
;
<<
" is set to SelectedRows"
;
ctx
->
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
ctx
->
SetOutputType
(
out_var_name
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
else
{
}
else
{
VLOG
(
3
)
<<
"lookup_table_grad op "
<<
framework
::
GradVarName
(
"W"
)
VLOG
(
3
)
<<
"lookup_table_grad op "
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to LoDTensor"
;
<<
" is set to LoDTensor"
;
ctx
->
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
->
Set
Output
Type
(
out_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
}
ctx
->
Set
DataType
(
out_var_name
,
ctx
->
GetDataType
(
ctx
->
Input
(
"W"
)[
0
]
));
ctx
->
Set
OutputDataType
(
out_var_name
,
ctx
->
GetInputDataType
(
"W"
));
}
}
};
};
...
...
paddle/fluid/operators/lookup_table_v2_op.cc
浏览文件 @
91ae7848
...
@@ -160,19 +160,20 @@ class LookupTableV2OpGrad : public framework::OperatorWithKernel {
...
@@ -160,19 +160,20 @@ class LookupTableV2OpGrad : public framework::OperatorWithKernel {
class
LookupTableV2OpGradVarTypeInference
:
public
framework
::
VarTypeInference
{
class
LookupTableV2OpGradVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
out_var_name
=
ctx
->
Output
(
framework
::
GradVarName
(
"W"
)).
front
(
);
auto
out_var_name
=
framework
::
GradVarName
(
"W"
);
auto
attr
=
ctx
->
GetAttr
(
"is_sparse"
);
auto
attr
=
ctx
->
GetAttr
(
"is_sparse"
);
bool
is_sparse
=
boost
::
get
<
bool
>
(
attr
);
bool
is_sparse
=
boost
::
get
<
bool
>
(
attr
);
if
(
is_sparse
)
{
if
(
is_sparse
)
{
VLOG
(
3
)
<<
"lookup_table_v2_grad op "
<<
framework
::
GradVarName
(
"W"
)
VLOG
(
3
)
<<
"lookup_table_v2_grad op "
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to SelectedRows"
;
<<
" is set to SelectedRows"
;
ctx
->
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
ctx
->
SetOutputType
(
out_var_name
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
else
{
}
else
{
VLOG
(
3
)
<<
"lookup_table_v2_grad op "
<<
framework
::
GradVarName
(
"W"
)
VLOG
(
3
)
<<
"lookup_table_v2_grad op "
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to LoDTensor"
;
<<
" is set to LoDTensor"
;
ctx
->
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
->
Set
Output
Type
(
out_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
}
ctx
->
Set
DataType
(
out_var_name
,
ctx
->
GetDataType
(
ctx
->
Input
(
"W"
)[
0
]
));
ctx
->
Set
OutputDataType
(
out_var_name
,
ctx
->
GetInputDataType
(
"W"
));
}
}
};
};
...
...
paddle/fluid/operators/mean_op.cc
浏览文件 @
91ae7848
...
@@ -45,9 +45,10 @@ Mean Operator calculates the mean of all elements in X.
...
@@ -45,9 +45,10 @@ Mean Operator calculates the mean of all elements in X.
class
MeanOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
class
MeanOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
/*->*/
"Out"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Out"
}};
return
m
;
}
}
};
};
...
...
paddle/fluid/operators/merge_selected_rows_op.cc
浏览文件 @
91ae7848
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/merge_selected_rows_op.h"
#include "paddle/fluid/operators/merge_selected_rows_op.h"
#include <unordered_map>
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -79,9 +80,10 @@ Example:
...
@@ -79,9 +80,10 @@ Example:
class
MergeSelectedRowsOpInferVarType
class
MergeSelectedRowsOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
/*->*/
"Out"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Out"
}};
return
m
;
}
}
};
};
...
...
paddle/fluid/operators/mul_op.cc
浏览文件 @
91ae7848
...
@@ -200,9 +200,10 @@ or not. But the output only shares the LoD information with input $X$.
...
@@ -200,9 +200,10 @@ or not. But the output only shares the LoD information with input $X$.
class
MulOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
class
MulOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
/*->*/
"Out"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Out"
}};
return
m
;
}
}
};
};
...
...
paddle/fluid/operators/nccl/nccl_op.cc
浏览文件 @
91ae7848
...
@@ -61,8 +61,7 @@ class NCCLInitOp : public framework::OperatorBase {
...
@@ -61,8 +61,7 @@ class NCCLInitOp : public framework::OperatorBase {
class
NCCLInitOpVarTypeInference
:
public
framework
::
VarTypeInference
{
class
NCCLInitOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
out_var_name
=
ctx
->
Output
(
"Communicator"
).
front
();
ctx
->
SetOutputType
(
"Communicator"
,
framework
::
proto
::
VarType
::
RAW
);
ctx
->
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
RAW
);
}
}
};
};
...
...
paddle/fluid/operators/nce_op.cc
浏览文件 @
91ae7848
...
@@ -280,20 +280,20 @@ class NCEOpGrad : public framework::OperatorWithKernel {
...
@@ -280,20 +280,20 @@ class NCEOpGrad : public framework::OperatorWithKernel {
class
NCEOpGradVarTypeInference
:
public
framework
::
VarTypeInference
{
class
NCEOpGradVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
weight_grad
=
ctx
->
Output
(
framework
::
GradVarName
(
"Weight"
)).
front
(
);
auto
weight_grad
=
framework
::
GradVarName
(
"Weight"
);
auto
attr
=
ctx
->
GetAttr
(
"is_sparse"
);
auto
attr
=
ctx
->
GetAttr
(
"is_sparse"
);
bool
is_sparse
=
boost
::
get
<
bool
>
(
attr
);
bool
is_sparse
=
boost
::
get
<
bool
>
(
attr
);
if
(
is_sparse
)
{
if
(
is_sparse
)
{
VLOG
(
3
)
<<
"nce_op_grad op "
<<
weight_grad
<<
" and "
VLOG
(
3
)
<<
"nce_op_grad op "
<<
weight_grad
<<
" and "
<<
" is set to SelectedRows"
;
<<
" is set to SelectedRows"
;
ctx
->
SetType
(
weight_grad
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
ctx
->
Set
Output
Type
(
weight_grad
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
else
{
}
else
{
VLOG
(
3
)
<<
"nce_op_grad op "
<<
weight_grad
<<
" and "
VLOG
(
3
)
<<
"nce_op_grad op "
<<
weight_grad
<<
" and "
<<
" is set to LoDTensor"
;
<<
" is set to LoDTensor"
;
ctx
->
SetType
(
weight_grad
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
->
Set
Output
Type
(
weight_grad
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
}
ctx
->
Set
DataType
(
weight_grad
,
ctx
->
GetDataType
(
ctx
->
Input
(
"Input"
)[
0
]
));
ctx
->
Set
OutputDataType
(
weight_grad
,
ctx
->
GetInputDataType
(
"Input"
));
}
}
};
};
...
...
paddle/fluid/operators/optimizers/momentum_op.cc
浏览文件 @
91ae7848
...
@@ -22,18 +22,15 @@ using Tensor = framework::Tensor;
...
@@ -22,18 +22,15 @@ using Tensor = framework::Tensor;
class
MomentumOpInferVarType
:
public
framework
::
VarTypeInference
{
class
MomentumOpInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
&
input_var
=
ctx
->
Input
(
"Param"
)[
0
];
auto
in_var_type
=
ctx
->
GetInputType
(
"Param"
);
for
(
auto
&
out_var
:
ctx
->
Output
(
"ParamOut"
))
{
PADDLE_ENFORCE_EQ
(
if
(
ctx
->
GetType
(
input_var
)
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
in_var_type
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
||
ctx
->
SetType
(
out_var
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
in_var_type
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
}
else
if
(
ctx
->
GetType
(
input_var
)
==
true
,
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
platform
::
errors
::
InvalidArgument
(
ctx
->
SetType
(
out_var
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
"Only support LodTensor and SelectedRows, Unexpected Input Type."
));
}
else
{
PADDLE_THROW
(
ctx
->
SetOutputType
(
"ParamOut"
,
in_var_type
,
framework
::
ALL_ELEMENTS
);
"Only support LodTensor and SelectedRows, Unexpected Input Type."
);
}
}
}
}
};
};
...
...
paddle/fluid/operators/optimizers/sgd_op.cc
浏览文件 @
91ae7848
...
@@ -75,19 +75,15 @@ class SGDOp : public framework::OperatorWithKernel {
...
@@ -75,19 +75,15 @@ class SGDOp : public framework::OperatorWithKernel {
class
SGDOpInferVarType
:
public
framework
::
VarTypeInference
{
class
SGDOpInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
&
input_var_n
=
ctx
->
Input
(
"Param"
)[
0
];
auto
in_var_type
=
ctx
->
GetInputType
(
"Param"
);
auto
in_var_type
=
ctx
->
GetType
(
input_var_n
);
PADDLE_ENFORCE_EQ
(
in_var_type
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
||
PADDLE_ENFORCE
(
in_var_type
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
||
in_var_type
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
in_var_type
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
"The input Var's type should be LoDtensor or SelectedRows,"
true
,
platform
::
errors
::
InvalidArgument
(
" but the received var(%s)'s type is %s"
,
"The input Var's type should be LoDtensor or "
input_var_n
,
in_var_type
);
"SelectedRows, but the received type is %s"
,
in_var_type
));
for
(
auto
&
out_var_n
:
ctx
->
Output
(
"ParamOut"
))
{
ctx
->
SetOutputType
(
"ParamOut"
,
in_var_type
,
framework
::
ALL_ELEMENTS
);
if
(
ctx
->
GetType
(
out_var_n
)
!=
in_var_type
)
{
ctx
->
SetType
(
out_var_n
,
in_var_type
);
}
}
}
}
};
};
...
...
paddle/fluid/operators/pool_op.cc
浏览文件 @
91ae7848
...
@@ -422,9 +422,10 @@ Example:
...
@@ -422,9 +422,10 @@ Example:
class
PoolOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
class
PoolOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
/*->*/
"Out"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Out"
}};
return
m
;
}
}
};
};
...
...
paddle/fluid/operators/print_op.cc
浏览文件 @
91ae7848
...
@@ -260,9 +260,7 @@ class PrintOpInferShape : public framework::InferShapeBase {
...
@@ -260,9 +260,7 @@ class PrintOpInferShape : public framework::InferShapeBase {
class
PrintOpVarTypeInference
:
public
framework
::
VarTypeInference
{
class
PrintOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
input_type
=
ctx
->
GetType
(
ctx
->
Input
(
"In"
)[
0
]);
ctx
->
SetOutputType
(
"Out"
,
ctx
->
GetInputType
(
"In"
));
auto
out_name
=
ctx
->
Output
(
"Out"
).
front
();
ctx
->
SetType
(
out_name
,
input_type
);
}
}
};
};
...
...
paddle/fluid/operators/py_func_op.cc
浏览文件 @
91ae7848
...
@@ -116,12 +116,11 @@ static void CallPythonFunc(py::object *callable,
...
@@ -116,12 +116,11 @@ static void CallPythonFunc(py::object *callable,
}
}
}
}
class
PyFuncOpVarTypeInference
:
public
framework
::
VarTypeInference
{
class
PyFuncOpVarTypeInference
:
public
framework
::
StaticGraph
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
bool
has_out
=
(
ctx
->
HasOutput
(
"Out"
)
&&
!
ctx
->
Output
(
"Out"
).
empty
());
bool
has_out
=
ctx
->
HasOutput
(
"Out"
);
bool
has_in
=
ctx
->
HasInput
(
"X"
);
bool
has_in
=
(
ctx
->
HasInput
(
"X"
)
&&
!
ctx
->
Input
(
"X"
).
empty
());
/**
/**
* X or Out can be empty, so that py_func can be more flexible
* X or Out can be empty, so that py_func can be more flexible
...
@@ -147,7 +146,7 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference {
...
@@ -147,7 +146,7 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference {
* the corresponding forward variable
* the corresponding forward variable
*/
*/
const
std
::
string
kGradVarSuffix
=
framework
::
kGradVarSuffix
;
const
std
::
string
kGradVarSuffix
=
framework
::
kGradVarSuffix
;
auto
&
out_var_names
=
ctx
->
Output
(
"Out"
);
auto
&
out_var_names
=
Output
(
ctx
,
"Out"
);
for
(
auto
&
out_var_name
:
out_var_names
)
{
for
(
auto
&
out_var_name
:
out_var_names
)
{
if
(
out_var_name
==
framework
::
kEmptyVarName
||
if
(
out_var_name
==
framework
::
kEmptyVarName
||
out_var_name
.
size
()
<
kGradVarSuffix
.
size
())
{
out_var_name
.
size
()
<
kGradVarSuffix
.
size
())
{
...
@@ -157,19 +156,17 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference {
...
@@ -157,19 +156,17 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference {
size_t
len
=
out_var_name
.
size
()
-
kGradVarSuffix
.
size
();
size_t
len
=
out_var_name
.
size
()
-
kGradVarSuffix
.
size
();
if
(
out_var_name
.
substr
(
len
)
==
kGradVarSuffix
)
{
if
(
out_var_name
.
substr
(
len
)
==
kGradVarSuffix
)
{
auto
fwd_var_name
=
out_var_name
.
substr
(
0
,
len
);
auto
fwd_var_name
=
out_var_name
.
substr
(
0
,
len
);
PADDLE_ENFORCE_EQ
(
ctx
->
HasVar
(
out_var_name
),
true
,
OP_INOUT_CHECK
(
HasVar
(
ctx
,
out_var_name
),
"Var"
,
out_var_name
,
platform
::
errors
::
InvalidArgument
(
"py_func"
);
"Backward variable %s not found"
,
out_var_name
));
OP_INOUT_CHECK
(
HasVar
(
ctx
,
fwd_var_name
),
"Var"
,
fwd_var_name
,
PADDLE_ENFORCE_EQ
(
ctx
->
HasVar
(
fwd_var_name
),
true
,
"py_func"
);
platform
::
errors
::
InvalidArgument
(
"Backward variable %s not found"
,
fwd_var_name
));
VLOG
(
10
)
<<
"Infer var_desc of Output("
<<
out_var_name
<<
") as Input("
VLOG
(
10
)
<<
"Infer var_desc of Output("
<<
out_var_name
<<
") as Input("
<<
fwd_var_name
<<
")"
;
<<
fwd_var_name
<<
")"
;
ctx
->
SetShape
(
out_var_name
,
ctx
->
GetShape
(
fwd_var_name
));
SetShape
(
ctx
,
out_var_name
,
GetShape
(
ctx
,
fwd_var_name
));
ctx
->
SetDataType
(
out_var_name
,
ctx
->
GetDataType
(
fwd_var_name
));
SetDataType
(
ctx
,
out_var_name
,
GetDataType
(
ctx
,
fwd_var_name
));
ctx
->
SetLoDLevel
(
out_var_name
,
ctx
->
GetLoDLevel
(
fwd_var_name
));
SetLoDLevel
(
ctx
,
out_var_name
,
GetLoDLevel
(
ctx
,
fwd_var_name
));
ctx
->
SetType
(
out_var_name
,
ctx
->
GetType
(
fwd_var_name
));
SetType
(
ctx
,
out_var_name
,
GetType
(
ctx
,
fwd_var_name
));
}
}
}
}
}
}
...
...
paddle/fluid/operators/randperm_op.cc
浏览文件 @
91ae7848
...
@@ -75,8 +75,7 @@ class RandpermOpVarTypeInference : public framework::VarTypeInference {
...
@@ -75,8 +75,7 @@ class RandpermOpVarTypeInference : public framework::VarTypeInference {
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
var_data_type
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
auto
var_data_type
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
boost
::
get
<
int
>
(
ctx
->
GetAttr
(
"dtype"
)));
boost
::
get
<
int
>
(
ctx
->
GetAttr
(
"dtype"
)));
auto
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
ctx
->
SetOutputDataType
(
"Out"
,
var_data_type
);
ctx
->
SetDataType
(
out_var_name
,
var_data_type
);
}
}
};
};
...
...
paddle/fluid/operators/reader/read_op.cc
浏览文件 @
91ae7848
...
@@ -70,18 +70,18 @@ class ReadInferShape : public framework::InferShapeBase {
...
@@ -70,18 +70,18 @@ class ReadInferShape : public framework::InferShapeBase {
}
}
};
};
class
ReadInferVarType
:
public
framework
::
VarTypeInference
{
class
ReadInferVarType
:
public
framework
::
StaticGraph
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
bool
infer_out
=
boost
::
get
<
bool
>
(
ctx
->
GetAttr
(
"infer_out"
));
bool
infer_out
=
boost
::
get
<
bool
>
(
ctx
->
GetAttr
(
"infer_out"
));
if
(
infer_out
)
{
if
(
infer_out
)
{
std
::
string
reader_name
=
ctx
->
Input
(
"Reader"
)[
0
];
std
::
string
reader_name
=
Input
(
ctx
,
"Reader"
)[
0
];
std
::
vector
<
std
::
string
>
out_names
=
ctx
->
Output
(
"Out"
);
auto
&
out_names
=
Output
(
ctx
,
"Out"
);
auto
dtypes
=
ctx
->
GetDataTypes
(
reader_name
);
auto
dtypes
=
GetDataTypes
(
ctx
,
reader_name
);
PADDLE_ENFORCE_EQ
(
dtypes
.
size
(),
out_names
.
size
());
PADDLE_ENFORCE_EQ
(
dtypes
.
size
(),
out_names
.
size
());
for
(
size_t
i
=
0
;
i
<
dtypes
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
dtypes
.
size
();
++
i
)
{
ctx
->
SetType
(
out_names
[
i
],
framework
::
proto
::
VarType
::
LOD_TENSOR
);
SetType
(
ctx
,
out_names
[
i
],
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
->
SetDataType
(
out_names
[
i
],
dtypes
[
i
]);
SetDataType
(
ctx
,
out_names
[
i
],
dtypes
[
i
]);
}
}
}
}
}
}
...
...
paddle/fluid/operators/reader/reader_op_registry.cc
浏览文件 @
91ae7848
...
@@ -100,8 +100,7 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
...
@@ -100,8 +100,7 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
void
FileReaderInferVarType
::
operator
()(
void
FileReaderInferVarType
::
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
{
framework
::
InferVarTypeContext
*
ctx
)
const
{
std
::
string
reader_name
=
ctx
->
Output
(
"Out"
)[
0
];
ctx
->
SetOutputType
(
"Out"
,
framework
::
proto
::
VarType
::
READER
);
ctx
->
SetType
(
reader_name
,
framework
::
proto
::
VarType
::
READER
);
}
}
void
DecoratedReaderInferShape
::
operator
()(
void
DecoratedReaderInferShape
::
operator
()(
...
@@ -125,10 +124,8 @@ void DecoratedReaderInferShape::operator()(
...
@@ -125,10 +124,8 @@ void DecoratedReaderInferShape::operator()(
void
DecoratedReaderInferVarType
::
operator
()(
void
DecoratedReaderInferVarType
::
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
{
framework
::
InferVarTypeContext
*
ctx
)
const
{
const
std
::
string
&
in_reader_name
=
ctx
->
Input
(
"UnderlyingReader"
)[
0
];
ctx
->
SetOutputType
(
"Out"
,
framework
::
proto
::
VarType
::
READER
);
const
std
::
string
&
out_reader_name
=
ctx
->
Output
(
"Out"
)[
0
];
ctx
->
SetOutputDataTypes
(
"Out"
,
ctx
->
GetInputDataTypes
(
"UnderlyingReader"
));
ctx
->
SetType
(
out_reader_name
,
framework
::
proto
::
VarType
::
READER
);
ctx
->
SetDataTypes
(
out_reader_name
,
ctx
->
GetDataTypes
(
in_reader_name
));
}
}
void
DecoratedReaderMakerBase
::
Make
()
{
void
DecoratedReaderMakerBase
::
Make
()
{
...
...
paddle/fluid/operators/reduce_ops/reduce_sum_op.cc
浏览文件 @
91ae7848
...
@@ -58,8 +58,7 @@ class ReduceSumVarTypeInference : public paddle::framework::VarTypeInference {
...
@@ -58,8 +58,7 @@ class ReduceSumVarTypeInference : public paddle::framework::VarTypeInference {
auto
data_type
=
static_cast
<
paddle
::
framework
::
proto
::
VarType
::
Type
>
(
auto
data_type
=
static_cast
<
paddle
::
framework
::
proto
::
VarType
::
Type
>
(
boost
::
get
<
int
>
(
ctx
->
GetAttr
(
"out_dtype"
)));
boost
::
get
<
int
>
(
ctx
->
GetAttr
(
"out_dtype"
)));
if
(
data_type
>=
0
)
{
if
(
data_type
>=
0
)
{
auto
&
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
ctx
->
SetOutputDataType
(
"Out"
,
data_type
);
ctx
->
SetDataType
(
out_var_name
,
data_type
);
}
}
}
}
};
};
...
...
paddle/fluid/operators/save_combine_op.cc
浏览文件 @
91ae7848
...
@@ -85,9 +85,8 @@ to a file on disk.
...
@@ -85,9 +85,8 @@ to a file on disk.
class
SaveCombineOpInferVarType
:
public
framework
::
VarTypeInference
{
class
SaveCombineOpInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
for
(
auto
&
o
:
ctx
->
Output
(
"Y"
))
{
ctx
->
SetOutputType
(
"Y"
,
framework
::
proto
::
VarType
::
RAW
,
ctx
->
SetType
(
o
,
framework
::
proto
::
VarType
::
RAW
);
framework
::
ALL_ELEMENTS
);
}
}
}
};
};
...
...
paddle/fluid/operators/save_op.cc
浏览文件 @
91ae7848
...
@@ -73,7 +73,7 @@ class SaveOpVarTypeInference : public framework::VarTypeInference {
...
@@ -73,7 +73,7 @@ class SaveOpVarTypeInference : public framework::VarTypeInference {
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
var_type
=
framework
::
proto
::
VarType
::
RAW
;
auto
var_type
=
framework
::
proto
::
VarType
::
RAW
;
ctx
->
SetType
(
LOOKUP_TABLE_PATH
,
var_type
);
ctx
->
InsertVar
(
LOOKUP_TABLE_PATH
,
var_type
);
}
}
};
};
...
...
paddle/fluid/operators/scale_op.cc
浏览文件 @
91ae7848
...
@@ -82,13 +82,7 @@ $$Out = scale*(X + bias)$$
...
@@ -82,13 +82,7 @@ $$Out = scale*(X + bias)$$
class
ScaleOpVarTypeInference
:
public
framework
::
VarTypeInference
{
class
ScaleOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
&
in_var_name
=
ctx
->
Input
(
"X"
).
front
();
ctx
->
SyncTypeAndDataType
(
"X"
,
"Out"
);
auto
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
if
(
in_var_name
!=
out_var_name
)
{
ctx
->
SetType
(
out_var_name
,
ctx
->
GetType
(
in_var_name
));
ctx
->
SetDataType
(
out_var_name
,
ctx
->
GetDataType
(
in_var_name
));
}
}
}
};
};
...
...
paddle/fluid/operators/selu_op.cc
浏览文件 @
91ae7848
...
@@ -45,9 +45,10 @@ class SeluOp : public framework::OperatorWithKernel {
...
@@ -45,9 +45,10 @@ class SeluOp : public framework::OperatorWithKernel {
class
SeluOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
class
SeluOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
/*->*/
"Out"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Out"
}};
return
m
;
}
}
};
};
...
...
paddle/fluid/operators/softmax_op.cc
浏览文件 @
91ae7848
...
@@ -145,9 +145,10 @@ For each row $i$ and each column $j$ in the matrix, we have:
...
@@ -145,9 +145,10 @@ For each row $i$ and each column $j$ in the matrix, we have:
class
SoftmaxOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
class
SoftmaxOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
GetInputOutputWithSameType
()
const
override
{
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
/*->*/
"Out"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Out"
}};
return
m
;
}
}
};
};
...
...
paddle/fluid/operators/split_selected_rows_op.cc
浏览文件 @
91ae7848
...
@@ -64,9 +64,8 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel {
...
@@ -64,9 +64,8 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel {
class
SplitSelectedRowsOpInferVarType
:
public
framework
::
VarTypeInference
{
class
SplitSelectedRowsOpInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
for
(
auto
&
out_var
:
ctx
->
Output
(
"Out"
))
{
ctx
->
SetOutputType
(
"Out"
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
,
ctx
->
SetType
(
out_var
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
framework
::
ALL_ELEMENTS
);
}
}
}
};
};
...
...
paddle/fluid/operators/sum_op.cc
浏览文件 @
91ae7848
...
@@ -210,43 +210,36 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -210,43 +210,36 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
class
SumOpVarTypeInference
:
public
framework
::
VarTypeInference
{
class
SumOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
&
inputs
=
ctx
->
Input
(
"X"
);
if
(
!
ctx
->
IsDygraph
())
{
auto
var_type
=
framework
::
proto
::
VarType
::
SELECTED_ROWS
;
auto
var_type
=
framework
::
proto
::
VarType
::
SELECTED_ROWS
;
for
(
auto
&
name
:
ctx
->
Input
(
"X"
))
{
if
(
VLOG_IS_ON
(
10
))
{
VLOG
(
10
)
<<
name
<<
" "
<<
ctx
->
GetType
(
name
);
for
(
size_t
ind
=
0
;
ind
<
ctx
->
InputSize
(
"X"
);
++
ind
)
{
VLOG
(
10
)
<<
ctx
->
InputVarName
(
"X"
,
ind
)
<<
" "
<<
ctx
->
GetInputType
(
"X"
,
ind
);
}
}
}
bool
any_input_is_lod_tensor
=
std
::
any_of
(
if
(
ctx
->
InputTypeAnyOf
(
"X"
,
inputs
.
begin
(),
inputs
.
end
(),
[
ctx
](
const
std
::
string
&
name
)
{
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
))
{
return
ctx
->
GetType
(
name
)
==
framework
::
proto
::
VarType
::
LOD_TENSOR
;
if
(
!
ctx
->
InputTypeAllOf
(
"X"
,
});
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
))
{
auto
is_tensor_array
=
[
ctx
](
const
std
::
string
&
name
)
{
return
ctx
->
GetType
(
name
)
==
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
;
};
bool
any_input_is_tensor_array
=
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
is_tensor_array
);
bool
all_inputs_are_tensor_array
=
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
is_tensor_array
);
if
(
any_input_is_tensor_array
)
{
if
(
!
all_inputs_are_tensor_array
)
{
std
::
ostringstream
os
;
std
::
ostringstream
os
;
for
(
auto
&
each
:
inputs
)
{
for
(
size_t
ind
=
0
;
ind
<
ctx
->
InputSize
(
"X"
);
++
ind
)
{
os
<<
" "
<<
each
<<
" type is "
<<
ctx
->
GetType
(
each
)
<<
"
\n
"
;
os
<<
" "
<<
ctx
->
InputVarName
(
"X"
,
ind
)
<<
" type is "
<<
ctx
->
GetInputType
(
"X"
,
ind
)
<<
"
\n
"
;
}
}
PADDLE_ENFORCE_EQ
(
all_inputs_are_tensor_array
,
true
,
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Not all inputs are tensor array:
\n
%s"
,
os
.
str
(
));
"Not all inputs are tensor array:
\n
%s"
,
os
.
str
()
));
}
}
var_type
=
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
;
var_type
=
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
;
}
else
if
(
any_input_is_lod_tensor
)
{
}
else
if
(
ctx
->
InputTypeAnyOf
(
"X"
,
framework
::
proto
::
VarType
::
LOD_TENSOR
))
{
var_type
=
framework
::
proto
::
VarType
::
LOD_TENSOR
;
var_type
=
framework
::
proto
::
VarType
::
LOD_TENSOR
;
}
}
auto
out_var_name
=
ctx
->
Output
(
"Out"
).
front
(
);
ctx
->
SetOutputType
(
"Out"
,
var_type
);
ctx
->
SetType
(
out_var_name
,
var_type
);
ctx
->
SetOutputDataType
(
"Out"
,
ctx
->
GetInputDataType
(
"X"
)
);
ctx
->
SetDataType
(
out_var_name
,
ctx
->
GetDataType
(
inputs
.
front
()));
}
}
}
};
};
...
...
paddle/fluid/operators/tensor_array_to_tensor_op.cc
浏览文件 @
91ae7848
...
@@ -213,9 +213,9 @@ class LoDTensorArray2TensorGradInferVarType
...
@@ -213,9 +213,9 @@ class LoDTensorArray2TensorGradInferVarType
:
public
framework
::
VarTypeInference
{
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
for
(
auto
&
out_var
:
ctx
->
Output
(
framework
::
GradVarName
(
"X"
)))
{
ctx
->
SetOutputType
(
framework
::
GradVarName
(
"X"
),
ctx
->
SetType
(
out_var
,
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
);
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
,
}
framework
::
ALL_ELEMENTS
);
}
}
};
};
...
...
paddle/fluid/operators/uniform_random_op.cc
浏览文件 @
91ae7848
...
@@ -232,15 +232,13 @@ uniform distribution. The random result is in set [min, max).
...
@@ -232,15 +232,13 @@ uniform distribution. The random result is in set [min, max).
class
UniformRandomOpVarTypeInference
:
public
framework
::
VarTypeInference
{
class
UniformRandomOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
auto
var_data_type
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
auto
var_data_type
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
boost
::
get
<
int
>
(
ctx
->
GetAttr
(
"dtype"
)));
boost
::
get
<
int
>
(
ctx
->
GetAttr
(
"dtype"
)));
if
(
ctx
->
GetType
(
out_var_name
)
!=
if
(
ctx
->
GetOutputType
(
"Out"
)
!=
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
ctx
->
SetOutputType
(
"Out"
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
->
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
}
ctx
->
Set
DataType
(
out_var_name
,
var_data_type
);
ctx
->
Set
OutputDataType
(
"Out"
,
var_data_type
);
}
}
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录