Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
b40e41fb
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b40e41fb
编写于
3月 18, 2019
作者:
M
minqiyang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Polish code style
test=develop
上级
36dce65b
变更
40
隐藏空白更改
内联
并排
Showing
40 changed file
with
192 addition
and
192 deletion
+192
-192
paddle/fluid/framework/details/graph_test_base.h
paddle/fluid/framework/details/graph_test_base.h
+5
-5
paddle/fluid/framework/details/op_registry.h
paddle/fluid/framework/details/op_registry.h
+1
-1
paddle/fluid/framework/ir/graph_test.cc
paddle/fluid/framework/ir/graph_test.cc
+6
-6
paddle/fluid/framework/var_type_inference.h
paddle/fluid/framework/var_type_inference.h
+6
-6
paddle/fluid/framework/var_type_inference_test.cc
paddle/fluid/framework/var_type_inference_test.cc
+5
-5
paddle/fluid/imperative/tracer.cc
paddle/fluid/imperative/tracer.cc
+4
-5
paddle/fluid/imperative/tracer.h
paddle/fluid/imperative/tracer.h
+1
-1
paddle/fluid/operators/beam_search_decode_op.cc
paddle/fluid/operators/beam_search_decode_op.cc
+5
-5
paddle/fluid/operators/beam_search_op.cc
paddle/fluid/operators/beam_search_op.cc
+5
-5
paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc
...fluid/operators/controlflow/tensor_array_read_write_op.cc
+6
-6
paddle/fluid/operators/controlflow/while_op.cc
paddle/fluid/operators/controlflow/while_op.cc
+7
-7
paddle/fluid/operators/distributed_ops/fake_init_op.cc
paddle/fluid/operators/distributed_ops/fake_init_op.cc
+1
-1
paddle/fluid/operators/distributed_ops/merge_ids_op.cc
paddle/fluid/operators/distributed_ops/merge_ids_op.cc
+4
-4
paddle/fluid/operators/distributed_ops/split_ids_op.cc
paddle/fluid/operators/distributed_ops/split_ids_op.cc
+4
-4
paddle/fluid/operators/fill_constant_op.cc
paddle/fluid/operators/fill_constant_op.cc
+4
-4
paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc
paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc
+6
-6
paddle/fluid/operators/get_tensor_from_selected_rows_op.cc
paddle/fluid/operators/get_tensor_from_selected_rows_op.cc
+5
-5
paddle/fluid/operators/hierarchical_sigmoid_op.cc
paddle/fluid/operators/hierarchical_sigmoid_op.cc
+9
-9
paddle/fluid/operators/lod_rank_table_op.cc
paddle/fluid/operators/lod_rank_table_op.cc
+3
-3
paddle/fluid/operators/lod_tensor_to_array_op.cc
paddle/fluid/operators/lod_tensor_to_array_op.cc
+3
-3
paddle/fluid/operators/lookup_table_op.cc
paddle/fluid/operators/lookup_table_op.cc
+6
-6
paddle/fluid/operators/nccl/nccl_op.cc
paddle/fluid/operators/nccl/nccl_op.cc
+3
-3
paddle/fluid/operators/nce_op.cc
paddle/fluid/operators/nce_op.cc
+6
-6
paddle/fluid/operators/ngraph/ngraph_engine_op.cc
paddle/fluid/operators/ngraph/ngraph_engine_op.cc
+1
-1
paddle/fluid/operators/optimizers/lars_momentum_op.cc
paddle/fluid/operators/optimizers/lars_momentum_op.cc
+1
-1
paddle/fluid/operators/optimizers/momentum_op.cc
paddle/fluid/operators/optimizers/momentum_op.cc
+7
-7
paddle/fluid/operators/optimizers/sgd_op.cc
paddle/fluid/operators/optimizers/sgd_op.cc
+6
-6
paddle/fluid/operators/py_func_op.cc
paddle/fluid/operators/py_func_op.cc
+12
-12
paddle/fluid/operators/reader/create_custom_reader_op.cc
paddle/fluid/operators/reader/create_custom_reader_op.cc
+7
-7
paddle/fluid/operators/reader/read_op.cc
paddle/fluid/operators/reader/read_op.cc
+7
-7
paddle/fluid/operators/reader/reader_op_registry.cc
paddle/fluid/operators/reader/reader_op_registry.cc
+8
-8
paddle/fluid/operators/reader/reader_op_registry.h
paddle/fluid/operators/reader/reader_op_registry.h
+2
-2
paddle/fluid/operators/save_op.cc
paddle/fluid/operators/save_op.cc
+3
-3
paddle/fluid/operators/scale_op.cc
paddle/fluid/operators/scale_op.cc
+5
-5
paddle/fluid/operators/split_selected_rows_op.cc
paddle/fluid/operators/split_selected_rows_op.cc
+3
-3
paddle/fluid/operators/sum_op.cc
paddle/fluid/operators/sum_op.cc
+12
-12
paddle/fluid/operators/tensor_array_to_tensor_op.cc
paddle/fluid/operators/tensor_array_to_tensor_op.cc
+3
-3
paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc
paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc
+1
-1
paddle/fluid/operators/uniform_random_op.cc
paddle/fluid/operators/uniform_random_op.cc
+7
-6
paddle/fluid/pybind/imperative.cc
paddle/fluid/pybind/imperative.cc
+2
-2
未找到文件。
paddle/fluid/framework/details/graph_test_base.h
浏览文件 @
b40e41fb
...
@@ -68,11 +68,11 @@ class SplitOpMaker : public OpProtoAndCheckerMaker {
...
@@ -68,11 +68,11 @@ class SplitOpMaker : public OpProtoAndCheckerMaker {
class
DummyVarTypeInference
:
public
VarTypeInference
{
class
DummyVarTypeInference
:
public
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
&
inputs
=
ctx
.
Input
(
"X"
);
auto
&
inputs
=
ctx
->
Input
(
"X"
);
auto
type
=
ctx
.
GetType
(
inputs
.
front
());
auto
type
=
ctx
->
GetType
(
inputs
.
front
());
auto
out_var_name
=
ctx
.
Output
(
"Out"
).
front
();
auto
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
ctx
.
SetType
(
out_var_name
,
type
);
ctx
->
SetType
(
out_var_name
,
type
);
}
}
};
};
...
...
paddle/fluid/framework/details/op_registry.h
浏览文件 @
b40e41fb
...
@@ -131,7 +131,7 @@ struct OpInfoFiller<T, kVarTypeInference> {
...
@@ -131,7 +131,7 @@ struct OpInfoFiller<T, kVarTypeInference> {
void
operator
()(
const
char
*
op_type
,
OpInfo
*
info
)
const
{
void
operator
()(
const
char
*
op_type
,
OpInfo
*
info
)
const
{
info
->
infer_var_type_
=
[](
InferVarTypeContext
*
context
)
{
info
->
infer_var_type_
=
[](
InferVarTypeContext
*
context
)
{
T
inference
;
T
inference
;
inference
(
*
context
);
inference
(
context
);
};
};
}
}
};
};
...
...
paddle/fluid/framework/ir/graph_test.cc
浏览文件 @
b40e41fb
...
@@ -43,20 +43,20 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
...
@@ -43,20 +43,20 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
class
SumOpVarTypeInference
:
public
VarTypeInference
{
class
SumOpVarTypeInference
:
public
VarTypeInference
{
public:
public:
void
operator
()(
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
InferVarTypeContext
*
ctx
)
const
override
{
auto
&
inputs
=
ctx
.
Input
(
"X"
);
auto
&
inputs
=
ctx
->
Input
(
"X"
);
auto
default_var_type
=
proto
::
VarType
::
SELECTED_ROWS
;
auto
default_var_type
=
proto
::
VarType
::
SELECTED_ROWS
;
bool
any_input_is_lod_tensor
=
std
::
any_of
(
bool
any_input_is_lod_tensor
=
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
ctx
](
const
std
::
string
&
name
)
{
inputs
.
begin
(),
inputs
.
end
(),
[
&
ctx
](
const
std
::
string
&
name
)
{
return
ctx
.
GetType
(
name
)
==
proto
::
VarType
::
LOD_TENSOR
;
return
ctx
->
GetType
(
name
)
==
proto
::
VarType
::
LOD_TENSOR
;
});
});
if
(
any_input_is_lod_tensor
)
{
if
(
any_input_is_lod_tensor
)
{
default_var_type
=
proto
::
VarType
::
LOD_TENSOR
;
default_var_type
=
proto
::
VarType
::
LOD_TENSOR
;
}
}
auto
out_var_name
=
ctx
.
Output
(
"Out"
).
front
();
auto
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
ctx
.
SetType
(
out_var_name
,
default_var_type
);
ctx
->
SetType
(
out_var_name
,
default_var_type
);
}
}
};
};
...
@@ -71,7 +71,7 @@ class DummyOpMaker : public OpProtoAndCheckerMaker {
...
@@ -71,7 +71,7 @@ class DummyOpMaker : public OpProtoAndCheckerMaker {
class
DummyOpVarTypeInference
:
public
VarTypeInference
{
class
DummyOpVarTypeInference
:
public
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{}
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{}
};
};
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
...
...
paddle/fluid/framework/var_type_inference.h
浏览文件 @
b40e41fb
...
@@ -126,20 +126,20 @@ class InferVarTypeContext {
...
@@ -126,20 +126,20 @@ class InferVarTypeContext {
class
VarTypeInference
{
class
VarTypeInference
{
public:
public:
virtual
~
VarTypeInference
()
{}
virtual
~
VarTypeInference
()
{}
virtual
void
operator
()(
InferVarTypeContext
&
context
)
const
=
0
;
// NOLINT
virtual
void
operator
()(
InferVarTypeContext
*
context
)
const
=
0
;
// NOLINT
};
};
class
PassInDtypeAndVarTypeToOutput
:
public
framework
::
VarTypeInference
{
class
PassInDtypeAndVarTypeToOutput
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
final
{
// NOLINT
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
final
{
// NOLINT
auto
in_out_var_names
=
this
->
GetInputOutputWithSameType
();
auto
in_out_var_names
=
this
->
GetInputOutputWithSameType
();
for
(
auto
&
i_o_n
:
in_out_var_names
)
{
for
(
auto
&
i_o_n
:
in_out_var_names
)
{
auto
&
x_name
=
ctx
.
Input
(
i_o_n
.
first
).
at
(
0
);
auto
&
x_name
=
ctx
->
Input
(
i_o_n
.
first
).
at
(
0
);
auto
&
out_name
=
ctx
.
Output
(
i_o_n
.
second
).
at
(
0
);
auto
&
out_name
=
ctx
->
Output
(
i_o_n
.
second
).
at
(
0
);
ctx
.
SetType
(
out_name
,
ctx
.
GetType
(
x_name
));
ctx
->
SetType
(
out_name
,
ctx
->
GetType
(
x_name
));
ctx
.
SetDataType
(
out_name
,
ctx
.
GetDataType
(
x_name
));
ctx
->
SetDataType
(
out_name
,
ctx
->
GetDataType
(
x_name
));
}
}
}
}
...
...
paddle/fluid/framework/var_type_inference_test.cc
浏览文件 @
b40e41fb
...
@@ -44,20 +44,20 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
...
@@ -44,20 +44,20 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
class
SumOpVarTypeInference
:
public
VarTypeInference
{
class
SumOpVarTypeInference
:
public
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
&
inputs
=
ctx
.
Input
(
"X"
);
auto
&
inputs
=
ctx
->
Input
(
"X"
);
auto
default_var_type
=
proto
::
VarType
::
SELECTED_ROWS
;
auto
default_var_type
=
proto
::
VarType
::
SELECTED_ROWS
;
bool
any_input_is_lod_tensor
=
std
::
any_of
(
bool
any_input_is_lod_tensor
=
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
ctx
](
const
std
::
string
&
name
)
{
inputs
.
begin
(),
inputs
.
end
(),
[
&
ctx
](
const
std
::
string
&
name
)
{
return
ctx
.
GetType
(
name
)
==
proto
::
VarType
::
LOD_TENSOR
;
return
ctx
->
GetType
(
name
)
==
proto
::
VarType
::
LOD_TENSOR
;
});
});
if
(
any_input_is_lod_tensor
)
{
if
(
any_input_is_lod_tensor
)
{
default_var_type
=
proto
::
VarType
::
LOD_TENSOR
;
default_var_type
=
proto
::
VarType
::
LOD_TENSOR
;
}
}
auto
out_var_name
=
ctx
.
Output
(
"Out"
).
front
();
auto
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
ctx
.
SetType
(
out_var_name
,
default_var_type
);
ctx
->
SetType
(
out_var_name
,
default_var_type
);
}
}
};
};
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/imperative/tracer.cc
浏览文件 @
b40e41fb
...
@@ -161,7 +161,7 @@ Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {
...
@@ -161,7 +161,7 @@ Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {
}
}
std
::
set
<
std
::
string
>
Tracer
::
Trace
(
OpBase
*
op
,
const
VarBasePtrMap
&
inputs
,
std
::
set
<
std
::
string
>
Tracer
::
Trace
(
OpBase
*
op
,
const
VarBasePtrMap
&
inputs
,
VarBasePtrMap
&
outputs
,
VarBasePtrMap
*
outputs
,
framework
::
AttributeMap
attrs_map
,
framework
::
AttributeMap
attrs_map
,
const
platform
::
Place
expected_place
,
const
platform
::
Place
expected_place
,
const
bool
stop_gradient
)
{
const
bool
stop_gradient
)
{
...
@@ -195,7 +195,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
...
@@ -195,7 +195,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
}
}
}
}
op
->
output_vars_
=
outputs
;
op
->
output_vars_
=
*
outputs
;
for
(
auto
it
:
op
->
output_vars_
)
{
for
(
auto
it
:
op
->
output_vars_
)
{
auto
&
outvars
=
outvars_map
[
it
.
first
];
auto
&
outvars
=
outvars_map
[
it
.
first
];
const
std
::
vector
<
VarBase
*>&
outputs
=
it
.
second
;
const
std
::
vector
<
VarBase
*>&
outputs
=
it
.
second
;
...
@@ -218,7 +218,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
...
@@ -218,7 +218,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
framework
::
VariableNameMap
invars_name_map
=
framework
::
VariableNameMap
invars_name_map
=
CreateInputVarNameMap
(
op
,
inputs
);
CreateInputVarNameMap
(
op
,
inputs
);
framework
::
VariableNameMap
outvars_name_map
=
framework
::
VariableNameMap
outvars_name_map
=
CreateOutputVarNameMap
(
op
,
outputs
);
CreateOutputVarNameMap
(
op
,
*
outputs
);
auto
&
info
=
framework
::
OpInfoMap
::
Instance
().
Get
(
op
->
Type
());
auto
&
info
=
framework
::
OpInfoMap
::
Instance
().
Get
(
op
->
Type
());
if
(
info
.
Checker
()
!=
nullptr
)
{
if
(
info
.
Checker
()
!=
nullptr
)
{
...
@@ -230,8 +230,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
...
@@ -230,8 +230,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
outvars_name_map
,
attrs_map
);
outvars_name_map
,
attrs_map
);
if
(
info
.
infer_var_type_
)
{
if
(
info
.
infer_var_type_
)
{
RuntimeInferVarTypeContext
infer_var_type_ctx
(
&
inputs
,
&
outputs
,
RuntimeInferVarTypeContext
infer_var_type_ctx
(
&
inputs
,
outputs
,
&
attrs_map
);
&
attrs_map
);
info
.
infer_var_type_
(
&
infer_var_type_ctx
);
info
.
infer_var_type_
(
&
infer_var_type_ctx
);
}
}
...
...
paddle/fluid/imperative/tracer.h
浏览文件 @
b40e41fb
...
@@ -48,7 +48,7 @@ class Tracer {
...
@@ -48,7 +48,7 @@ class Tracer {
virtual
~
Tracer
()
{}
virtual
~
Tracer
()
{}
std
::
set
<
std
::
string
>
Trace
(
OpBase
*
op
,
const
VarBasePtrMap
&
inputs
,
std
::
set
<
std
::
string
>
Trace
(
OpBase
*
op
,
const
VarBasePtrMap
&
inputs
,
VarBasePtrMap
&
outputs
,
// NOLINT
VarBasePtrMap
*
outputs
,
// NOLINT
framework
::
AttributeMap
attrs_map
,
framework
::
AttributeMap
attrs_map
,
const
platform
::
Place
expected_place
,
const
platform
::
Place
expected_place
,
const
bool
stop_gradient
=
false
);
const
bool
stop_gradient
=
false
);
...
...
paddle/fluid/operators/beam_search_decode_op.cc
浏览文件 @
b40e41fb
...
@@ -203,12 +203,12 @@ class BeamSearchDecodeInferShape : public framework::InferShapeBase {
...
@@ -203,12 +203,12 @@ class BeamSearchDecodeInferShape : public framework::InferShapeBase {
class
BeamSearchDecodeInferVarType
:
public
framework
::
VarTypeInference
{
class
BeamSearchDecodeInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
for
(
auto
&
o
:
ctx
.
Output
(
"SentenceIds"
))
{
for
(
auto
&
o
:
ctx
->
Output
(
"SentenceIds"
))
{
ctx
.
SetType
(
o
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
->
SetType
(
o
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
}
for
(
auto
&
o
:
ctx
.
Output
(
"SentenceScores"
))
{
for
(
auto
&
o
:
ctx
->
Output
(
"SentenceScores"
))
{
ctx
.
SetType
(
o
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
->
SetType
(
o
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
}
}
}
};
};
...
...
paddle/fluid/operators/beam_search_op.cc
浏览文件 @
b40e41fb
...
@@ -120,12 +120,12 @@ class BeamSearchOp : public framework::OperatorWithKernel {
...
@@ -120,12 +120,12 @@ class BeamSearchOp : public framework::OperatorWithKernel {
class
BeamSearchInferVarType
:
public
framework
::
VarTypeInference
{
class
BeamSearchInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
for
(
auto
&
o
:
ctx
.
Output
(
"selected_ids"
))
{
for
(
auto
&
o
:
ctx
->
Output
(
"selected_ids"
))
{
ctx
.
SetType
(
o
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
->
SetType
(
o
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
}
for
(
auto
&
o
:
ctx
.
Output
(
"selected_scores"
))
{
for
(
auto
&
o
:
ctx
->
Output
(
"selected_scores"
))
{
ctx
.
SetType
(
o
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
->
SetType
(
o
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
}
}
}
};
};
...
...
paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc
浏览文件 @
b40e41fb
...
@@ -100,13 +100,13 @@ class WriteToArrayInferShape : public framework::InferShapeBase {
...
@@ -100,13 +100,13 @@ class WriteToArrayInferShape : public framework::InferShapeBase {
class
WriteToArrayInferVarType
:
public
framework
::
VarTypeInference
{
class
WriteToArrayInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
x_name
=
ctx
.
Input
(
"X"
)[
0
];
auto
x_name
=
ctx
->
Input
(
"X"
)[
0
];
auto
out_name
=
ctx
.
Output
(
"Out"
)[
0
];
auto
out_name
=
ctx
->
Output
(
"Out"
)[
0
];
VLOG
(
10
)
<<
"Set Variable "
<<
out_name
<<
" as LOD_TENSOR_ARRAY"
;
VLOG
(
10
)
<<
"Set Variable "
<<
out_name
<<
" as LOD_TENSOR_ARRAY"
;
ctx
.
SetType
(
out_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
);
ctx
->
SetType
(
out_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
);
if
(
ctx
.
HasVar
(
x_name
))
{
if
(
ctx
->
HasVar
(
x_name
))
{
ctx
.
SetDataType
(
out_name
,
ctx
.
GetDataType
(
x_name
));
ctx
->
SetDataType
(
out_name
,
ctx
->
GetDataType
(
x_name
));
}
}
}
}
};
};
...
...
paddle/fluid/operators/controlflow/while_op.cc
浏览文件 @
b40e41fb
...
@@ -365,16 +365,16 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
...
@@ -365,16 +365,16 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
class
WhileGradOpVarTypeInference
:
public
framework
::
VarTypeInference
{
class
WhileGradOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
p_names
=
ctx
.
Input
(
kX
);
auto
p_names
=
ctx
->
Input
(
kX
);
auto
pg_ig_names
=
ctx
.
Output
(
framework
::
GradVarName
(
kX
));
auto
pg_ig_names
=
ctx
->
Output
(
framework
::
GradVarName
(
kX
));
for
(
size_t
i
=
0
;
i
<
p_names
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
p_names
.
size
();
++
i
)
{
if
(
ctx
.
HasVar
(
pg_ig_names
[
i
]))
{
if
(
ctx
->
HasVar
(
pg_ig_names
[
i
]))
{
VLOG
(
5
)
<<
"Setting "
<<
pg_ig_names
[
i
]
<<
" following "
<<
p_names
[
i
]
VLOG
(
5
)
<<
"Setting "
<<
pg_ig_names
[
i
]
<<
" following "
<<
p_names
[
i
]
<<
" type: "
<<
ctx
.
GetType
(
p_names
[
i
]);
<<
" type: "
<<
ctx
->
GetType
(
p_names
[
i
]);
ctx
.
SetType
(
pg_ig_names
[
i
],
ctx
.
GetType
(
p_names
[
i
]));
ctx
->
SetType
(
pg_ig_names
[
i
],
ctx
->
GetType
(
p_names
[
i
]));
ctx
.
SetDataType
(
pg_ig_names
[
i
],
ctx
.
GetDataType
(
p_names
[
i
]));
ctx
->
SetDataType
(
pg_ig_names
[
i
],
ctx
->
GetDataType
(
p_names
[
i
]));
}
}
}
}
}
}
...
...
paddle/fluid/operators/distributed_ops/fake_init_op.cc
浏览文件 @
b40e41fb
...
@@ -56,7 +56,7 @@ class FakeInitOp : public framework::OperatorBase {
...
@@ -56,7 +56,7 @@ class FakeInitOp : public framework::OperatorBase {
class
FakeInitOpVarTypeInference
:
public
framework
::
VarTypeInference
{
class
FakeInitOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{}
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{}
};
};
class
FakeInitOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
FakeInitOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
...
paddle/fluid/operators/distributed_ops/merge_ids_op.cc
浏览文件 @
b40e41fb
...
@@ -114,10 +114,10 @@ class MergeIdsOp : public framework::OperatorWithKernel {
...
@@ -114,10 +114,10 @@ class MergeIdsOp : public framework::OperatorWithKernel {
class
MergeIdsOpInferVarType
:
public
framework
::
VarTypeInference
{
class
MergeIdsOpInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
input_type
=
ctx
.
GetType
(
ctx
.
Input
(
"Ids"
)[
0
]);
auto
input_type
=
ctx
->
GetType
(
ctx
->
Input
(
"Ids"
)[
0
]);
for
(
auto
&
out_var
:
ctx
.
Output
(
"Out"
))
{
for
(
auto
&
out_var
:
ctx
->
Output
(
"Out"
))
{
ctx
.
SetType
(
out_var
,
input_type
);
ctx
->
SetType
(
out_var
,
input_type
);
}
}
}
}
};
};
...
...
paddle/fluid/operators/distributed_ops/split_ids_op.cc
浏览文件 @
b40e41fb
...
@@ -73,10 +73,10 @@ class SplitIdsOp : public framework::OperatorWithKernel {
...
@@ -73,10 +73,10 @@ class SplitIdsOp : public framework::OperatorWithKernel {
class
SplitIdsOpInferVarType
:
public
framework
::
VarTypeInference
{
class
SplitIdsOpInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
input_type
=
ctx
.
GetType
(
ctx
.
Input
(
"Ids"
)[
0
]);
auto
input_type
=
ctx
->
GetType
(
ctx
->
Input
(
"Ids"
)[
0
]);
for
(
auto
&
out_var
:
ctx
.
Output
(
"Out"
))
{
for
(
auto
&
out_var
:
ctx
->
Output
(
"Out"
))
{
ctx
.
SetType
(
out_var
,
input_type
);
ctx
->
SetType
(
out_var
,
input_type
);
}
}
}
}
};
};
...
...
paddle/fluid/operators/fill_constant_op.cc
浏览文件 @
b40e41fb
...
@@ -39,11 +39,11 @@ class FillConstantOp : public framework::OperatorWithKernel {
...
@@ -39,11 +39,11 @@ class FillConstantOp : public framework::OperatorWithKernel {
class
FillConstantOpVarTypeInference
:
public
framework
::
VarTypeInference
{
class
FillConstantOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
data_type
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
auto
data_type
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
boost
::
get
<
int
>
(
ctx
.
GetAttr
(
"dtype"
)));
boost
::
get
<
int
>
(
ctx
->
GetAttr
(
"dtype"
)));
auto
&
out_var_name
=
ctx
.
Output
(
"Out"
).
front
();
auto
&
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
ctx
.
SetDataType
(
out_var_name
,
data_type
);
ctx
->
SetDataType
(
out_var_name
,
data_type
);
}
}
};
};
...
...
paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc
浏览文件 @
b40e41fb
...
@@ -137,20 +137,20 @@ class FusedEmbeddingSeqPoolOpGrad : public framework::OperatorWithKernel {
...
@@ -137,20 +137,20 @@ class FusedEmbeddingSeqPoolOpGrad : public framework::OperatorWithKernel {
class
FusedEmbeddingSeqPoolOpGradVarTypeInference
class
FusedEmbeddingSeqPoolOpGradVarTypeInference
:
public
framework
::
VarTypeInference
{
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
out_var_name
=
ctx
.
Output
(
framework
::
GradVarName
(
"W"
)).
front
();
auto
out_var_name
=
ctx
->
Output
(
framework
::
GradVarName
(
"W"
)).
front
();
auto
attr
=
ctx
.
GetAttr
(
"is_sparse"
);
auto
attr
=
ctx
->
GetAttr
(
"is_sparse"
);
bool
is_sparse
=
boost
::
get
<
bool
>
(
attr
);
bool
is_sparse
=
boost
::
get
<
bool
>
(
attr
);
if
(
is_sparse
)
{
if
(
is_sparse
)
{
VLOG
(
3
)
<<
"fused_embedding_seq_pool_grad op "
VLOG
(
3
)
<<
"fused_embedding_seq_pool_grad op "
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to SelectedRows"
;
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to SelectedRows"
;
ctx
.
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
ctx
->
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
else
{
}
else
{
VLOG
(
3
)
<<
"fused_embedding_seq_pool_grad op "
VLOG
(
3
)
<<
"fused_embedding_seq_pool_grad op "
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to LoDTensor"
;
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to LoDTensor"
;
ctx
.
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
->
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
}
ctx
.
SetDataType
(
out_var_name
,
ctx
.
GetDataType
(
ctx
.
Input
(
"W"
)[
0
]));
ctx
->
SetDataType
(
out_var_name
,
ctx
->
GetDataType
(
ctx
->
Input
(
"W"
)[
0
]));
}
}
};
};
...
...
paddle/fluid/operators/get_tensor_from_selected_rows_op.cc
浏览文件 @
b40e41fb
...
@@ -81,12 +81,12 @@ GetTensorFromSelectedRows is used to get the tensor from SelectedRows.
...
@@ -81,12 +81,12 @@ GetTensorFromSelectedRows is used to get the tensor from SelectedRows.
class
GetTensorFromSelectedRowsOpVarTypeInference
class
GetTensorFromSelectedRowsOpVarTypeInference
:
public
framework
::
VarTypeInference
{
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
{
// NOLINT
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
{
// NOLINT
auto
out_var_name
=
ctx
.
Output
(
"Out"
).
front
();
auto
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
auto
in_var_name
=
ctx
.
Input
(
"X"
).
front
();
auto
in_var_name
=
ctx
->
Input
(
"X"
).
front
();
ctx
.
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
->
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
.
SetDataType
(
out_var_name
,
ctx
.
GetDataType
(
in_var_name
));
ctx
->
SetDataType
(
out_var_name
,
ctx
->
GetDataType
(
in_var_name
));
}
}
};
};
...
...
paddle/fluid/operators/hierarchical_sigmoid_op.cc
浏览文件 @
b40e41fb
...
@@ -197,32 +197,32 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
...
@@ -197,32 +197,32 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
class
HierarchicalSigmoidGradOpGradVarTypeInference
class
HierarchicalSigmoidGradOpGradVarTypeInference
:
public
framework
::
VarTypeInference
{
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
w_grad_var_name
=
ctx
.
Output
(
framework
::
GradVarName
(
"W"
)).
front
();
auto
w_grad_var_name
=
ctx
->
Output
(
framework
::
GradVarName
(
"W"
)).
front
();
auto
bias_grad_var_name_vec
=
ctx
.
Output
(
framework
::
GradVarName
(
"Bias"
));
auto
bias_grad_var_name_vec
=
ctx
->
Output
(
framework
::
GradVarName
(
"Bias"
));
std
::
string
bias_grad_var_name
;
std
::
string
bias_grad_var_name
;
bool
hasBias
=
false
;
bool
hasBias
=
false
;
if
(
bias_grad_var_name_vec
.
size
())
{
if
(
bias_grad_var_name_vec
.
size
())
{
hasBias
=
true
;
hasBias
=
true
;
bias_grad_var_name
=
ctx
.
Output
(
framework
::
GradVarName
(
"Bias"
)).
front
();
bias_grad_var_name
=
ctx
->
Output
(
framework
::
GradVarName
(
"Bias"
)).
front
();
}
}
auto
attr
=
ctx
.
GetAttr
(
"is_sparse"
);
auto
attr
=
ctx
->
GetAttr
(
"is_sparse"
);
bool
is_sparse
=
boost
::
get
<
bool
>
(
attr
);
bool
is_sparse
=
boost
::
get
<
bool
>
(
attr
);
if
(
is_sparse
)
{
if
(
is_sparse
)
{
VLOG
(
30
)
<<
"hierarchical_sigmoid_grad op "
<<
framework
::
GradVarName
(
"W"
)
VLOG
(
30
)
<<
"hierarchical_sigmoid_grad op "
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to SelectedRows"
;
<<
" is set to SelectedRows"
;
ctx
.
SetType
(
w_grad_var_name
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
ctx
->
SetType
(
w_grad_var_name
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
else
{
}
else
{
VLOG
(
30
)
<<
"hierarchical_sigmoid_grad op "
<<
framework
::
GradVarName
(
"W"
)
VLOG
(
30
)
<<
"hierarchical_sigmoid_grad op "
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to LoDTensor"
;
<<
" is set to LoDTensor"
;
ctx
.
SetType
(
w_grad_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
->
SetType
(
w_grad_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
}
if
(
hasBias
)
{
if
(
hasBias
)
{
VLOG
(
30
)
<<
"hierarchical_sigmoid_grad op "
VLOG
(
30
)
<<
"hierarchical_sigmoid_grad op "
<<
framework
::
GradVarName
(
"Bias"
)
<<
" is set to LoDTensor"
;
<<
framework
::
GradVarName
(
"Bias"
)
<<
" is set to LoDTensor"
;
ctx
.
SetType
(
bias_grad_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
->
SetType
(
bias_grad_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
}
ctx
.
SetDataType
(
w_grad_var_name
,
ctx
.
GetDataType
(
ctx
.
Input
(
"W"
)[
0
]));
ctx
->
SetDataType
(
w_grad_var_name
,
ctx
->
GetDataType
(
ctx
->
Input
(
"W"
)[
0
]));
}
}
};
};
...
...
paddle/fluid/operators/lod_rank_table_op.cc
浏览文件 @
b40e41fb
...
@@ -64,9 +64,9 @@ class LoDRankTableInferShape : public framework::InferShapeBase {
...
@@ -64,9 +64,9 @@ class LoDRankTableInferShape : public framework::InferShapeBase {
class
LoDRankTableInferVarType
:
public
framework
::
VarTypeInference
{
class
LoDRankTableInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
for
(
auto
&
o
:
ctx
.
Output
(
"Out"
))
{
for
(
auto
&
o
:
ctx
->
Output
(
"Out"
))
{
ctx
.
SetType
(
o
,
framework
::
proto
::
VarType
::
LOD_RANK_TABLE
);
ctx
->
SetType
(
o
,
framework
::
proto
::
VarType
::
LOD_RANK_TABLE
);
}
}
}
}
};
};
...
...
paddle/fluid/operators/lod_tensor_to_array_op.cc
浏览文件 @
b40e41fb
...
@@ -201,9 +201,9 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase {
...
@@ -201,9 +201,9 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase {
class
LoDTensorToArrayInferVarType
:
public
framework
::
VarTypeInference
{
class
LoDTensorToArrayInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
for
(
auto
&
out_var
:
ctx
.
Output
(
"Out"
))
{
for
(
auto
&
out_var
:
ctx
->
Output
(
"Out"
))
{
ctx
.
SetType
(
out_var
,
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
);
ctx
->
SetType
(
out_var
,
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
);
}
}
}
}
};
};
...
...
paddle/fluid/operators/lookup_table_op.cc
浏览文件 @
b40e41fb
...
@@ -147,20 +147,20 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
...
@@ -147,20 +147,20 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
class
LookupTableOpGradVarTypeInference
:
public
framework
::
VarTypeInference
{
class
LookupTableOpGradVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
out_var_name
=
ctx
.
Output
(
framework
::
GradVarName
(
"W"
)).
front
();
auto
out_var_name
=
ctx
->
Output
(
framework
::
GradVarName
(
"W"
)).
front
();
auto
attr
=
ctx
.
GetAttr
(
"is_sparse"
);
auto
attr
=
ctx
->
GetAttr
(
"is_sparse"
);
bool
is_sparse
=
boost
::
get
<
bool
>
(
attr
);
bool
is_sparse
=
boost
::
get
<
bool
>
(
attr
);
if
(
is_sparse
)
{
if
(
is_sparse
)
{
VLOG
(
3
)
<<
"lookup_table_grad op "
<<
framework
::
GradVarName
(
"W"
)
VLOG
(
3
)
<<
"lookup_table_grad op "
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to SelectedRows"
;
<<
" is set to SelectedRows"
;
ctx
.
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
ctx
->
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
else
{
}
else
{
VLOG
(
3
)
<<
"lookup_table_grad op "
<<
framework
::
GradVarName
(
"W"
)
VLOG
(
3
)
<<
"lookup_table_grad op "
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to LoDTensor"
;
<<
" is set to LoDTensor"
;
ctx
.
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
->
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
}
ctx
.
SetDataType
(
out_var_name
,
ctx
.
GetDataType
(
ctx
.
Input
(
"W"
)[
0
]));
ctx
->
SetDataType
(
out_var_name
,
ctx
->
GetDataType
(
ctx
->
Input
(
"W"
)[
0
]));
}
}
};
};
...
...
paddle/fluid/operators/nccl/nccl_op.cc
浏览文件 @
b40e41fb
...
@@ -60,9 +60,9 @@ class NCCLInitOp : public framework::OperatorBase {
...
@@ -60,9 +60,9 @@ class NCCLInitOp : public framework::OperatorBase {
class
NCCLInitOpVarTypeInference
:
public
framework
::
VarTypeInference
{
class
NCCLInitOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
out_var_name
=
ctx
.
Output
(
"Communicator"
).
front
();
auto
out_var_name
=
ctx
->
Output
(
"Communicator"
).
front
();
ctx
.
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
RAW
);
ctx
->
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
RAW
);
}
}
};
};
...
...
paddle/fluid/operators/nce_op.cc
浏览文件 @
b40e41fb
...
@@ -237,21 +237,21 @@ class NCEOpGrad : public framework::OperatorWithKernel {
...
@@ -237,21 +237,21 @@ class NCEOpGrad : public framework::OperatorWithKernel {
class
NCEOpGradVarTypeInference
:
public
framework
::
VarTypeInference
{
class
NCEOpGradVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
weight_grad
=
ctx
.
Output
(
framework
::
GradVarName
(
"Weight"
)).
front
();
auto
weight_grad
=
ctx
->
Output
(
framework
::
GradVarName
(
"Weight"
)).
front
();
auto
attr
=
ctx
.
GetAttr
(
"is_sparse"
);
auto
attr
=
ctx
->
GetAttr
(
"is_sparse"
);
bool
is_sparse
=
boost
::
get
<
bool
>
(
attr
);
bool
is_sparse
=
boost
::
get
<
bool
>
(
attr
);
if
(
is_sparse
)
{
if
(
is_sparse
)
{
VLOG
(
3
)
<<
"nce_op_grad op "
<<
weight_grad
<<
" and "
VLOG
(
3
)
<<
"nce_op_grad op "
<<
weight_grad
<<
" and "
<<
" is set to SelectedRows"
;
<<
" is set to SelectedRows"
;
ctx
.
SetType
(
weight_grad
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
ctx
->
SetType
(
weight_grad
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
else
{
}
else
{
VLOG
(
3
)
<<
"nce_op_grad op "
<<
weight_grad
<<
" and "
VLOG
(
3
)
<<
"nce_op_grad op "
<<
weight_grad
<<
" and "
<<
" is set to LoDTensor"
;
<<
" is set to LoDTensor"
;
ctx
.
SetType
(
weight_grad
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
->
SetType
(
weight_grad
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
}
ctx
.
SetDataType
(
weight_grad
,
ctx
.
GetDataType
(
ctx
.
Input
(
"Input"
)[
0
]));
ctx
->
SetDataType
(
weight_grad
,
ctx
->
GetDataType
(
ctx
->
Input
(
"Input"
)[
0
]));
}
}
};
};
...
...
paddle/fluid/operators/ngraph/ngraph_engine_op.cc
浏览文件 @
b40e41fb
...
@@ -37,7 +37,7 @@ class NgraphEngineOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -37,7 +37,7 @@ class NgraphEngineOpMaker : public framework::OpProtoAndCheckerMaker {
class
NgraphEngineInferVarType
:
public
framework
::
VarTypeInference
{
class
NgraphEngineInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{}
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{}
};
};
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/optimizers/lars_momentum_op.cc
浏览文件 @
b40e41fb
...
@@ -72,7 +72,7 @@ use L2 regularizers in case of using LARS.
...
@@ -72,7 +72,7 @@ use L2 regularizers in case of using LARS.
class
LarsMomentumOpVarTypeInference
:
public
framework
::
VarTypeInference
{
class
LarsMomentumOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{}
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{}
};
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
...
...
paddle/fluid/operators/optimizers/momentum_op.cc
浏览文件 @
b40e41fb
...
@@ -21,14 +21,14 @@ using Tensor = framework::Tensor;
...
@@ -21,14 +21,14 @@ using Tensor = framework::Tensor;
class
MomentumOpInferVarType
:
public
framework
::
VarTypeInference
{
class
MomentumOpInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
&
input_var
=
ctx
.
Input
(
"Param"
)[
0
];
auto
&
input_var
=
ctx
->
Input
(
"Param"
)[
0
];
for
(
auto
&
out_var
:
ctx
.
Output
(
"ParamOut"
))
{
for
(
auto
&
out_var
:
ctx
->
Output
(
"ParamOut"
))
{
if
(
ctx
.
GetType
(
input_var
)
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
if
(
ctx
->
GetType
(
input_var
)
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
ctx
.
SetType
(
out_var
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
ctx
->
SetType
(
out_var
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
else
if
(
ctx
.
GetType
(
input_var
)
==
}
else
if
(
ctx
->
GetType
(
input_var
)
==
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
ctx
.
SetType
(
out_var
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
->
SetType
(
out_var
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
else
{
}
else
{
PADDLE_THROW
(
PADDLE_THROW
(
"Only support LodTensor and SelectedRows, Unexpected Input Type."
);
"Only support LodTensor and SelectedRows, Unexpected Input Type."
);
...
...
paddle/fluid/operators/optimizers/sgd_op.cc
浏览文件 @
b40e41fb
...
@@ -50,18 +50,18 @@ class SGDOp : public framework::OperatorWithKernel {
...
@@ -50,18 +50,18 @@ class SGDOp : public framework::OperatorWithKernel {
class
SGDOpInferVarType
:
public
framework
::
VarTypeInference
{
class
SGDOpInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
&
input_var_n
=
ctx
.
Input
(
"Param"
)[
0
];
auto
&
input_var_n
=
ctx
->
Input
(
"Param"
)[
0
];
auto
in_var_type
=
ctx
.
GetType
(
input_var_n
);
auto
in_var_type
=
ctx
->
GetType
(
input_var_n
);
PADDLE_ENFORCE
(
in_var_type
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
||
PADDLE_ENFORCE
(
in_var_type
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
||
in_var_type
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
in_var_type
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
"The input Var's type should be LoDtensor or SelectedRows,"
"The input Var's type should be LoDtensor or SelectedRows,"
" but the received var(%s)'s type is %s"
,
" but the received var(%s)'s type is %s"
,
input_var_n
,
in_var_type
);
input_var_n
,
in_var_type
);
for
(
auto
&
out_var_n
:
ctx
.
Output
(
"ParamOut"
))
{
for
(
auto
&
out_var_n
:
ctx
->
Output
(
"ParamOut"
))
{
if
(
ctx
.
GetType
(
out_var_n
)
!=
in_var_type
)
{
if
(
ctx
->
GetType
(
out_var_n
)
!=
in_var_type
)
{
ctx
.
SetType
(
out_var_n
,
in_var_type
);
ctx
->
SetType
(
out_var_n
,
in_var_type
);
}
}
}
}
}
}
...
...
paddle/fluid/operators/py_func_op.cc
浏览文件 @
b40e41fb
...
@@ -96,10 +96,10 @@ static void CallPythonFunc(py::object *callable,
...
@@ -96,10 +96,10 @@ static void CallPythonFunc(py::object *callable,
class
PyFuncOpVarTypeInference
:
public
framework
::
VarTypeInference
{
class
PyFuncOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
bool
has_out
=
(
ctx
.
HasOutput
(
"Out"
)
&&
!
ctx
.
Output
(
"Out"
).
empty
());
bool
has_out
=
(
ctx
->
HasOutput
(
"Out"
)
&&
!
ctx
->
Output
(
"Out"
).
empty
());
bool
has_in
=
(
ctx
.
HasInput
(
"X"
)
&&
!
ctx
.
Input
(
"Out"
).
empty
());
bool
has_in
=
(
ctx
->
HasInput
(
"X"
)
&&
!
ctx
->
Input
(
"Out"
).
empty
());
/**
/**
* X or Out can be empty, so that py_func can be more flexible
* X or Out can be empty, so that py_func can be more flexible
...
@@ -107,8 +107,8 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference {
...
@@ -107,8 +107,8 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference {
*/
*/
PADDLE_ENFORCE
(
has_in
||
has_out
,
"Input(X) or Output(Out) must exist"
);
PADDLE_ENFORCE
(
has_in
||
has_out
,
"Input(X) or Output(Out) must exist"
);
PADDLE_ENFORCE_GE
(
boost
::
get
<
int
>
(
ctx
.
GetAttr
(
kForwardPythonCallableId
)),
0
,
PADDLE_ENFORCE_GE
(
boost
::
get
<
int
>
(
ctx
->
GetAttr
(
kForwardPythonCallableId
))
,
"Function id cannot be less than 0"
);
0
,
"Function id cannot be less than 0"
);
if
(
!
has_out
)
return
;
if
(
!
has_out
)
return
;
...
@@ -118,7 +118,7 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference {
...
@@ -118,7 +118,7 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference {
* the corresponding forward variable
* the corresponding forward variable
*/
*/
const
std
::
string
kGradVarSuffix
=
framework
::
kGradVarSuffix
;
const
std
::
string
kGradVarSuffix
=
framework
::
kGradVarSuffix
;
auto
&
out_var_names
=
ctx
.
Output
(
"Out"
);
auto
&
out_var_names
=
ctx
->
Output
(
"Out"
);
for
(
auto
&
out_var_name
:
out_var_names
)
{
for
(
auto
&
out_var_name
:
out_var_names
)
{
if
(
out_var_name
==
framework
::
kEmptyVarName
||
if
(
out_var_name
==
framework
::
kEmptyVarName
||
out_var_name
.
size
()
<
kGradVarSuffix
.
size
())
{
out_var_name
.
size
()
<
kGradVarSuffix
.
size
())
{
...
@@ -128,17 +128,17 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference {
...
@@ -128,17 +128,17 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference {
size_t
len
=
out_var_name
.
size
()
-
kGradVarSuffix
.
size
();
size_t
len
=
out_var_name
.
size
()
-
kGradVarSuffix
.
size
();
if
(
out_var_name
.
substr
(
len
)
==
kGradVarSuffix
)
{
if
(
out_var_name
.
substr
(
len
)
==
kGradVarSuffix
)
{
auto
fwd_var_name
=
out_var_name
.
substr
(
0
,
len
);
auto
fwd_var_name
=
out_var_name
.
substr
(
0
,
len
);
PADDLE_ENFORCE
(
ctx
.
HasVar
(
out_var_name
),
PADDLE_ENFORCE
(
ctx
->
HasVar
(
out_var_name
),
"Backward variable %s not found"
,
out_var_name
);
"Backward variable %s not found"
,
out_var_name
);
PADDLE_ENFORCE
(
ctx
.
HasVar
(
fwd_var_name
),
PADDLE_ENFORCE
(
ctx
->
HasVar
(
fwd_var_name
),
"Backward variable %s not found"
,
fwd_var_name
);
"Backward variable %s not found"
,
fwd_var_name
);
VLOG
(
10
)
<<
"Infer var_desc of Output("
<<
out_var_name
<<
") as Input("
VLOG
(
10
)
<<
"Infer var_desc of Output("
<<
out_var_name
<<
") as Input("
<<
fwd_var_name
<<
")"
;
<<
fwd_var_name
<<
")"
;
ctx
.
SetShape
(
out_var_name
,
ctx
.
GetShape
(
fwd_var_name
));
ctx
->
SetShape
(
out_var_name
,
ctx
->
GetShape
(
fwd_var_name
));
ctx
.
SetDataType
(
out_var_name
,
ctx
.
GetDataType
(
fwd_var_name
));
ctx
->
SetDataType
(
out_var_name
,
ctx
->
GetDataType
(
fwd_var_name
));
ctx
.
SetLoDLevel
(
out_var_name
,
ctx
.
GetLoDLevel
(
fwd_var_name
));
ctx
->
SetLoDLevel
(
out_var_name
,
ctx
->
GetLoDLevel
(
fwd_var_name
));
ctx
.
SetType
(
out_var_name
,
ctx
.
GetType
(
fwd_var_name
));
ctx
->
SetType
(
out_var_name
,
ctx
->
GetType
(
fwd_var_name
));
}
}
}
}
}
}
...
...
paddle/fluid/operators/reader/create_custom_reader_op.cc
浏览文件 @
b40e41fb
...
@@ -123,22 +123,22 @@ class CustomReaderInferShape : public framework::InferShapeBase {
...
@@ -123,22 +123,22 @@ class CustomReaderInferShape : public framework::InferShapeBase {
class
CustomReaderInferVarType
:
public
framework
::
VarTypeInference
{
class
CustomReaderInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
&
out_var_name
=
ctx
.
Output
(
"Out"
)[
0
];
auto
&
out_var_name
=
ctx
->
Output
(
"Out"
)[
0
];
PADDLE_ENFORCE
(
ctx
.
HasVar
(
out_var_name
));
PADDLE_ENFORCE
(
ctx
->
HasVar
(
out_var_name
));
ctx
.
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
READER
);
ctx
->
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
READER
);
auto
sink_var_names
=
auto
sink_var_names
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
ctx
.
GetAttr
(
"sink_var_names"
));
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
ctx
->
GetAttr
(
"sink_var_names"
));
const
auto
*
sub_block
=
const
auto
*
sub_block
=
boost
::
get
<
framework
::
BlockDesc
*>
(
ctx
.
GetAttr
(
"sub_block"
));
boost
::
get
<
framework
::
BlockDesc
*>
(
ctx
->
GetAttr
(
"sub_block"
));
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
res_data_types
;
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
res_data_types
;
for
(
const
std
::
string
&
var_name
:
sink_var_names
)
{
for
(
const
std
::
string
&
var_name
:
sink_var_names
)
{
framework
::
VarDesc
*
var
=
sub_block
->
FindVar
(
var_name
);
framework
::
VarDesc
*
var
=
sub_block
->
FindVar
(
var_name
);
PADDLE_ENFORCE_NOT_NULL
(
var
);
PADDLE_ENFORCE_NOT_NULL
(
var
);
res_data_types
.
emplace_back
(
var
->
GetDataType
());
res_data_types
.
emplace_back
(
var
->
GetDataType
());
}
}
ctx
.
SetDataTypes
(
out_var_name
,
res_data_types
);
ctx
->
SetDataTypes
(
out_var_name
,
res_data_types
);
}
}
};
};
...
...
paddle/fluid/operators/reader/read_op.cc
浏览文件 @
b40e41fb
...
@@ -51,16 +51,16 @@ class ReadInferShape : public framework::InferShapeBase {
...
@@ -51,16 +51,16 @@ class ReadInferShape : public framework::InferShapeBase {
class
ReadInferVarType
:
public
framework
::
VarTypeInference
{
class
ReadInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
bool
infer_out
=
boost
::
get
<
bool
>
(
ctx
.
GetAttr
(
"infer_out"
));
bool
infer_out
=
boost
::
get
<
bool
>
(
ctx
->
GetAttr
(
"infer_out"
));
if
(
infer_out
)
{
if
(
infer_out
)
{
std
::
string
reader_name
=
ctx
.
Input
(
"Reader"
)[
0
];
std
::
string
reader_name
=
ctx
->
Input
(
"Reader"
)[
0
];
std
::
vector
<
std
::
string
>
out_names
=
ctx
.
Output
(
"Out"
);
std
::
vector
<
std
::
string
>
out_names
=
ctx
->
Output
(
"Out"
);
auto
dtypes
=
ctx
.
GetDataTypes
(
reader_name
);
auto
dtypes
=
ctx
->
GetDataTypes
(
reader_name
);
PADDLE_ENFORCE_EQ
(
dtypes
.
size
(),
out_names
.
size
());
PADDLE_ENFORCE_EQ
(
dtypes
.
size
(),
out_names
.
size
());
for
(
size_t
i
=
0
;
i
<
dtypes
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
dtypes
.
size
();
++
i
)
{
ctx
.
SetType
(
out_names
[
i
],
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
->
SetType
(
out_names
[
i
],
framework
::
proto
::
VarType
::
LOD_TENSOR
);
ctx
.
SetDataType
(
out_names
[
i
],
dtypes
[
i
]);
ctx
->
SetDataType
(
out_names
[
i
],
dtypes
[
i
]);
}
}
}
}
}
}
...
...
paddle/fluid/operators/reader/reader_op_registry.cc
浏览文件 @
b40e41fb
...
@@ -99,9 +99,9 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
...
@@ -99,9 +99,9 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
}
}
void
FileReaderInferVarType
::
operator
()(
void
FileReaderInferVarType
::
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
{
framework
::
InferVarTypeContext
*
ctx
)
const
{
std
::
string
reader_name
=
ctx
.
Output
(
"Out"
)[
0
];
std
::
string
reader_name
=
ctx
->
Output
(
"Out"
)[
0
];
ctx
.
SetType
(
reader_name
,
framework
::
proto
::
VarType
::
READER
);
ctx
->
SetType
(
reader_name
,
framework
::
proto
::
VarType
::
READER
);
}
}
void
DecoratedReaderInferShape
::
operator
()(
void
DecoratedReaderInferShape
::
operator
()(
...
@@ -124,11 +124,11 @@ void DecoratedReaderInferShape::operator()(
...
@@ -124,11 +124,11 @@ void DecoratedReaderInferShape::operator()(
}
}
void
DecoratedReaderInferVarType
::
operator
()(
void
DecoratedReaderInferVarType
::
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
{
framework
::
InferVarTypeContext
*
ctx
)
const
{
const
std
::
string
&
in_reader_name
=
ctx
.
Input
(
"UnderlyingReader"
)[
0
];
const
std
::
string
&
in_reader_name
=
ctx
->
Input
(
"UnderlyingReader"
)[
0
];
const
std
::
string
&
out_reader_name
=
ctx
.
Output
(
"Out"
)[
0
];
const
std
::
string
&
out_reader_name
=
ctx
->
Output
(
"Out"
)[
0
];
ctx
.
SetType
(
out_reader_name
,
framework
::
proto
::
VarType
::
READER
);
ctx
->
SetType
(
out_reader_name
,
framework
::
proto
::
VarType
::
READER
);
ctx
.
SetDataTypes
(
out_reader_name
,
ctx
.
GetDataTypes
(
in_reader_name
));
ctx
->
SetDataTypes
(
out_reader_name
,
ctx
->
GetDataTypes
(
in_reader_name
));
}
}
void
DecoratedReaderMakerBase
::
Make
()
{
void
DecoratedReaderMakerBase
::
Make
()
{
...
...
paddle/fluid/operators/reader/reader_op_registry.h
浏览文件 @
b40e41fb
...
@@ -61,7 +61,7 @@ class FileReaderInferShape : public framework::InferShapeBase {
...
@@ -61,7 +61,7 @@ class FileReaderInferShape : public framework::InferShapeBase {
class
FileReaderInferVarType
:
public
framework
::
VarTypeInference
{
class
FileReaderInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
;
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
;
};
};
// general infershape for decorated reader
// general infershape for decorated reader
...
@@ -73,7 +73,7 @@ class DecoratedReaderInferShape : public framework::InferShapeBase {
...
@@ -73,7 +73,7 @@ class DecoratedReaderInferShape : public framework::InferShapeBase {
// general var type inference for decorated reader
// general var type inference for decorated reader
class
DecoratedReaderInferVarType
:
public
framework
::
VarTypeInference
{
class
DecoratedReaderInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
;
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
;
};
};
class
DecoratedReaderMakerBase
:
public
framework
::
OpProtoAndCheckerMaker
{
class
DecoratedReaderMakerBase
:
public
framework
::
OpProtoAndCheckerMaker
{
...
...
paddle/fluid/operators/save_op.cc
浏览文件 @
b40e41fb
...
@@ -159,9 +159,9 @@ This operator will serialize and write LoDTensor / SelectedRows variable to file
...
@@ -159,9 +159,9 @@ This operator will serialize and write LoDTensor / SelectedRows variable to file
class
SaveOpVarTypeInference
:
public
framework
::
VarTypeInference
{
class
SaveOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
out_var_name
=
ctx
.
Output
(
LOOKUP_TABLE_PATH
).
front
();
auto
out_var_name
=
ctx
->
Output
(
LOOKUP_TABLE_PATH
).
front
();
ctx
.
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
RAW
);
ctx
->
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
RAW
);
}
}
};
};
...
...
paddle/fluid/operators/scale_op.cc
浏览文件 @
b40e41fb
...
@@ -70,13 +70,13 @@ $$Out = scale*(X + bias)$$
...
@@ -70,13 +70,13 @@ $$Out = scale*(X + bias)$$
class
ScaleOpVarTypeInference
:
public
framework
::
VarTypeInference
{
class
ScaleOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
&
in_var_name
=
ctx
.
Input
(
"X"
).
front
();
auto
&
in_var_name
=
ctx
->
Input
(
"X"
).
front
();
auto
out_var_name
=
ctx
.
Output
(
"Out"
).
front
();
auto
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
if
(
in_var_name
!=
out_var_name
)
{
if
(
in_var_name
!=
out_var_name
)
{
ctx
.
SetType
(
out_var_name
,
ctx
.
GetType
(
in_var_name
));
ctx
->
SetType
(
out_var_name
,
ctx
->
GetType
(
in_var_name
));
ctx
.
SetDataType
(
out_var_name
,
ctx
.
GetDataType
(
in_var_name
));
ctx
->
SetDataType
(
out_var_name
,
ctx
->
GetDataType
(
in_var_name
));
}
}
}
}
};
};
...
...
paddle/fluid/operators/split_selected_rows_op.cc
浏览文件 @
b40e41fb
...
@@ -62,9 +62,9 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel {
...
@@ -62,9 +62,9 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel {
class
SplitSelectedRowsOpInferVarType
:
public
framework
::
VarTypeInference
{
class
SplitSelectedRowsOpInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
for
(
auto
&
out_var
:
ctx
.
Output
(
"Out"
))
{
for
(
auto
&
out_var
:
ctx
->
Output
(
"Out"
))
{
ctx
.
SetType
(
out_var
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
ctx
->
SetType
(
out_var
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
}
}
}
};
};
...
...
paddle/fluid/operators/sum_op.cc
浏览文件 @
b40e41fb
...
@@ -160,20 +160,20 @@ the LoD information with the first input.
...
@@ -160,20 +160,20 @@ the LoD information with the first input.
class
SumOpVarTypeInference
:
public
framework
::
VarTypeInference
{
class
SumOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
&
inputs
=
ctx
.
Input
(
"X"
);
auto
&
inputs
=
ctx
->
Input
(
"X"
);
auto
var_type
=
framework
::
proto
::
VarType
::
SELECTED_ROWS
;
auto
var_type
=
framework
::
proto
::
VarType
::
SELECTED_ROWS
;
for
(
auto
&
name
:
ctx
.
Input
(
"X"
))
{
for
(
auto
&
name
:
ctx
->
Input
(
"X"
))
{
VLOG
(
10
)
<<
name
<<
" "
<<
ctx
.
GetType
(
name
);
VLOG
(
10
)
<<
name
<<
" "
<<
ctx
->
GetType
(
name
);
}
}
bool
any_input_is_lod_tensor
=
std
::
any_of
(
bool
any_input_is_lod_tensor
=
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
ctx
](
const
std
::
string
&
name
)
{
inputs
.
begin
(),
inputs
.
end
(),
[
ctx
](
const
std
::
string
&
name
)
{
return
ctx
.
GetType
(
name
)
==
framework
::
proto
::
VarType
::
LOD_TENSOR
;
return
ctx
->
GetType
(
name
)
==
framework
::
proto
::
VarType
::
LOD_TENSOR
;
});
});
auto
is_tensor_array
=
[
&
ctx
](
const
std
::
string
&
name
)
{
auto
is_tensor_array
=
[
ctx
](
const
std
::
string
&
name
)
{
return
ctx
.
GetType
(
name
)
==
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
;
return
ctx
->
GetType
(
name
)
==
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
;
};
};
bool
any_input_is_tensor_array
=
bool
any_input_is_tensor_array
=
...
@@ -185,7 +185,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
...
@@ -185,7 +185,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
if
(
!
all_inputs_are_tensor_array
)
{
if
(
!
all_inputs_are_tensor_array
)
{
std
::
ostringstream
os
;
std
::
ostringstream
os
;
for
(
auto
&
each
:
inputs
)
{
for
(
auto
&
each
:
inputs
)
{
os
<<
" "
<<
each
<<
" type is "
<<
ctx
.
GetType
(
each
)
<<
"
\n
"
;
os
<<
" "
<<
each
<<
" type is "
<<
ctx
->
GetType
(
each
)
<<
"
\n
"
;
}
}
PADDLE_ENFORCE
(
all_inputs_are_tensor_array
,
PADDLE_ENFORCE
(
all_inputs_are_tensor_array
,
"Not all inputs are tensor array:
\n
%s"
,
os
.
str
());
"Not all inputs are tensor array:
\n
%s"
,
os
.
str
());
...
@@ -195,9 +195,9 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
...
@@ -195,9 +195,9 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
var_type
=
framework
::
proto
::
VarType
::
LOD_TENSOR
;
var_type
=
framework
::
proto
::
VarType
::
LOD_TENSOR
;
}
}
auto
out_var_name
=
ctx
.
Output
(
"Out"
).
front
();
auto
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
ctx
.
SetType
(
out_var_name
,
var_type
);
ctx
->
SetType
(
out_var_name
,
var_type
);
ctx
.
SetDataType
(
out_var_name
,
ctx
.
GetDataType
(
inputs
.
front
()));
ctx
->
SetDataType
(
out_var_name
,
ctx
->
GetDataType
(
inputs
.
front
()));
}
}
};
};
...
...
paddle/fluid/operators/tensor_array_to_tensor_op.cc
浏览文件 @
b40e41fb
...
@@ -177,9 +177,9 @@ class LoDTensorArray2TensorGradInferShape : public framework::InferShapeBase {
...
@@ -177,9 +177,9 @@ class LoDTensorArray2TensorGradInferShape : public framework::InferShapeBase {
class
LoDTensorArray2TensorGradInferVarType
class
LoDTensorArray2TensorGradInferVarType
:
public
framework
::
VarTypeInference
{
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
for
(
auto
&
out_var
:
ctx
.
Output
(
framework
::
GradVarName
(
"X"
)))
{
for
(
auto
&
out_var
:
ctx
->
Output
(
framework
::
GradVarName
(
"X"
)))
{
ctx
.
SetType
(
out_var
,
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
);
ctx
->
SetType
(
out_var
,
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
);
}
}
}
}
};
};
...
...
paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc
浏览文件 @
b40e41fb
...
@@ -46,7 +46,7 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -46,7 +46,7 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
class
TensorRTEngineInferVarType
:
public
framework
::
VarTypeInference
{
class
TensorRTEngineInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{}
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{}
};
};
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/uniform_random_op.cc
浏览文件 @
b40e41fb
...
@@ -112,15 +112,16 @@ uniform distribution. The random result is in set [min, max].
...
@@ -112,15 +112,16 @@ uniform distribution. The random result is in set [min, max].
class
UniformRandomOpVarTypeInference
:
public
framework
::
VarTypeInference
{
class
UniformRandomOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
out_var_name
=
ctx
.
Output
(
"Out"
).
front
();
auto
out_var_name
=
ctx
->
Output
(
"Out"
).
front
();
auto
var_data_type
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
auto
var_data_type
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
boost
::
get
<
int
>
(
ctx
.
GetAttr
(
"dtype"
)));
boost
::
get
<
int
>
(
ctx
->
GetAttr
(
"dtype"
)));
if
(
ctx
.
GetType
(
out_var_name
)
!=
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
if
(
ctx
->
GetType
(
out_var_name
)
!=
ctx
.
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
ctx
->
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
}
ctx
.
SetDataType
(
out_var_name
,
var_data_type
);
ctx
->
SetDataType
(
out_var_name
,
var_data_type
);
}
}
};
};
...
...
paddle/fluid/pybind/imperative.cc
浏览文件 @
b40e41fb
...
@@ -38,7 +38,7 @@ void BindTracer(pybind11::module* m) {
...
@@ -38,7 +38,7 @@ void BindTracer(pybind11::module* m) {
.
def
(
"trace"
,
.
def
(
"trace"
,
[](
imperative
::
Tracer
&
self
,
imperative
::
OpBase
*
op
,
[](
imperative
::
Tracer
&
self
,
imperative
::
OpBase
*
op
,
const
imperative
::
VarBasePtrMap
&
inputs
,
const
imperative
::
VarBasePtrMap
&
inputs
,
imperative
::
VarBasePtrMap
&
outputs
,
imperative
::
VarBasePtrMap
*
outputs
,
framework
::
AttributeMap
attrs_map
,
framework
::
AttributeMap
attrs_map
,
const
platform
::
CPUPlace
expected_place
,
const
platform
::
CPUPlace
expected_place
,
const
bool
stop_gradient
=
false
)
{
const
bool
stop_gradient
=
false
)
{
...
@@ -48,7 +48,7 @@ void BindTracer(pybind11::module* m) {
...
@@ -48,7 +48,7 @@ void BindTracer(pybind11::module* m) {
.
def
(
"trace"
,
.
def
(
"trace"
,
[](
imperative
::
Tracer
&
self
,
imperative
::
OpBase
*
op
,
[](
imperative
::
Tracer
&
self
,
imperative
::
OpBase
*
op
,
const
imperative
::
VarBasePtrMap
&
inputs
,
const
imperative
::
VarBasePtrMap
&
inputs
,
imperative
::
VarBasePtrMap
&
outputs
,
imperative
::
VarBasePtrMap
*
outputs
,
framework
::
AttributeMap
attrs_map
,
framework
::
AttributeMap
attrs_map
,
const
platform
::
CUDAPlace
expected_place
,
const
platform
::
CUDAPlace
expected_place
,
const
bool
stop_gradient
=
false
)
{
const
bool
stop_gradient
=
false
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录