Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
36dce65b
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
36dce65b
编写于
3月 18, 2019
作者:
M
minqiyang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Take DataType and VarType apart
test=develop
上级
db0c9708
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
53 addition
and
36 deletion
+53
-36
paddle/fluid/framework/details/op_registry.h
paddle/fluid/framework/details/op_registry.h
+2
-2
paddle/fluid/framework/ir/graph_test.cc
paddle/fluid/framework/ir/graph_test.cc
+1
-1
paddle/fluid/framework/op_desc.cc
paddle/fluid/framework/op_desc.cc
+1
-1
paddle/fluid/framework/type_defs.h
paddle/fluid/framework/type_defs.h
+1
-1
paddle/fluid/framework/var_type_inference_test.cc
paddle/fluid/framework/var_type_inference_test.cc
+1
-1
paddle/fluid/imperative/layer.cc
paddle/fluid/imperative/layer.cc
+13
-13
paddle/fluid/imperative/layer.h
paddle/fluid/imperative/layer.h
+29
-13
paddle/fluid/imperative/tracer.cc
paddle/fluid/imperative/tracer.cc
+2
-1
paddle/fluid/operators/sum_op.cc
paddle/fluid/operators/sum_op.cc
+2
-2
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+1
-1
未找到文件。
paddle/fluid/framework/details/op_registry.h
浏览文件 @
36dce65b
...
@@ -129,9 +129,9 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
...
@@ -129,9 +129,9 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
template
<
typename
T
>
template
<
typename
T
>
struct
OpInfoFiller
<
T
,
kVarTypeInference
>
{
struct
OpInfoFiller
<
T
,
kVarTypeInference
>
{
void
operator
()(
const
char
*
op_type
,
OpInfo
*
info
)
const
{
void
operator
()(
const
char
*
op_type
,
OpInfo
*
info
)
const
{
info
->
infer_var_type_
=
[](
InferVarTypeContext
&
context
)
{
info
->
infer_var_type_
=
[](
InferVarTypeContext
*
context
)
{
T
inference
;
T
inference
;
inference
(
context
);
inference
(
*
context
);
};
};
}
}
};
};
...
...
paddle/fluid/framework/ir/graph_test.cc
浏览文件 @
36dce65b
...
@@ -48,7 +48,7 @@ class SumOpVarTypeInference : public VarTypeInference {
...
@@ -48,7 +48,7 @@ class SumOpVarTypeInference : public VarTypeInference {
auto
default_var_type
=
proto
::
VarType
::
SELECTED_ROWS
;
auto
default_var_type
=
proto
::
VarType
::
SELECTED_ROWS
;
bool
any_input_is_lod_tensor
=
std
::
any_of
(
bool
any_input_is_lod_tensor
=
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
ctx
](
const
std
::
string
&
name
)
{
inputs
.
begin
(),
inputs
.
end
(),
[
&
ctx
](
const
std
::
string
&
name
)
{
return
ctx
.
GetType
(
name
)
==
proto
::
VarType
::
LOD_TENSOR
;
return
ctx
.
GetType
(
name
)
==
proto
::
VarType
::
LOD_TENSOR
;
});
});
if
(
any_input_is_lod_tensor
)
{
if
(
any_input_is_lod_tensor
)
{
...
...
paddle/fluid/framework/op_desc.cc
浏览文件 @
36dce65b
...
@@ -679,7 +679,7 @@ void OpDesc::InferVarType(BlockDesc *block) const {
...
@@ -679,7 +679,7 @@ void OpDesc::InferVarType(BlockDesc *block) const {
auto
&
info
=
OpInfoMap
::
Instance
().
Get
(
this
->
Type
());
auto
&
info
=
OpInfoMap
::
Instance
().
Get
(
this
->
Type
());
if
(
info
.
infer_var_type_
)
{
if
(
info
.
infer_var_type_
)
{
InferVarTypeContext
context
(
this
,
block
);
InferVarTypeContext
context
(
this
,
block
);
info
.
infer_var_type_
(
context
);
info
.
infer_var_type_
(
&
context
);
}
}
}
}
...
...
paddle/fluid/framework/type_defs.h
浏览文件 @
36dce65b
...
@@ -54,7 +54,7 @@ using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDesc>>(
...
@@ -54,7 +54,7 @@ using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDesc>>(
const
std
::
vector
<
BlockDesc
*>&
grad_block
)
>
;
const
std
::
vector
<
BlockDesc
*>&
grad_block
)
>
;
using
InferVarTypeFN
=
using
InferVarTypeFN
=
std
::
function
<
void
(
framework
::
InferVarTypeContext
&
/*context*/
)
>
;
std
::
function
<
void
(
framework
::
InferVarTypeContext
*
/*context*/
)
>
;
using
InferShapeFN
=
std
::
function
<
void
(
InferShapeContext
*
)
>
;
using
InferShapeFN
=
std
::
function
<
void
(
InferShapeContext
*
)
>
;
...
...
paddle/fluid/framework/var_type_inference_test.cc
浏览文件 @
36dce65b
...
@@ -49,7 +49,7 @@ class SumOpVarTypeInference : public VarTypeInference {
...
@@ -49,7 +49,7 @@ class SumOpVarTypeInference : public VarTypeInference {
auto
default_var_type
=
proto
::
VarType
::
SELECTED_ROWS
;
auto
default_var_type
=
proto
::
VarType
::
SELECTED_ROWS
;
bool
any_input_is_lod_tensor
=
std
::
any_of
(
bool
any_input_is_lod_tensor
=
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
ctx
](
const
std
::
string
&
name
)
{
inputs
.
begin
(),
inputs
.
end
(),
[
&
ctx
](
const
std
::
string
&
name
)
{
return
ctx
.
GetType
(
name
)
==
proto
::
VarType
::
LOD_TENSOR
;
return
ctx
.
GetType
(
name
)
==
proto
::
VarType
::
LOD_TENSOR
;
});
});
if
(
any_input_is_lod_tensor
)
{
if
(
any_input_is_lod_tensor
)
{
...
...
paddle/fluid/imperative/layer.cc
浏览文件 @
36dce65b
...
@@ -243,12 +243,14 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
...
@@ -243,12 +243,14 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
auto
&
outputs
=
tmp_grad_outputs
[
k
][
it
.
first
];
auto
&
outputs
=
tmp_grad_outputs
[
k
][
it
.
first
];
outputs
.
reserve
(
it
.
second
.
size
());
outputs
.
reserve
(
it
.
second
.
size
());
for
(
size_t
i
=
0
;
i
<
it
.
second
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
it
.
second
.
size
();
++
i
)
{
VarBase
*
origin_grad_var_base
=
it
.
second
[
i
];
// Allocate a new variable
// Allocate a new variable
Var
iable
*
tmp_var
=
new
framework
::
Variable
();
Var
Base
*
tmp_grad_var_base
=
new
VarBase
(
tmp_var
->
GetMutable
<
framework
::
LoDTensor
>
();
string
::
Sprintf
(
"%s@IGrad"
,
origin_grad_var_base
->
Name
()),
VarBase
*
tmp_var_base
=
origin_grad_var_base
->
DataType
(),
origin_grad_var_base
->
Dims
(),
new
VarBase
(
it
.
second
[
i
]
->
Name
(),
tmp_var
,
nullptr
,
tru
e
);
place_
,
true
,
fals
e
);
outputs
.
emplace_back
(
tmp_var_base
);
outputs
.
emplace_back
(
tmp_
grad_
var_base
);
}
}
}
}
...
@@ -259,13 +261,12 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
...
@@ -259,13 +261,12 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
std
::
unique_ptr
<
framework
::
OperatorBase
>
opbase
=
std
::
unique_ptr
<
framework
::
OperatorBase
>
opbase
=
framework
::
OpRegistry
::
CreateOp
(
*
grad_op_desc
);
framework
::
OpRegistry
::
CreateOp
(
*
grad_op_desc
);
// auto& info =
auto
&
info
=
framework
::
OpInfoMap
::
Instance
().
Get
(
grad_op_desc
->
Type
());
// framework::OpInfoMap::Instance().Get(grad_op_desc->Type());
if
(
info
.
infer_var_type_
)
{
// if (info.infer_var_type_) {
RuntimeInferVarTypeContext
infer_var_type_ctx
(
// framework::RuntimeInferVarTypeContext infer_var_type_ctx(
&
grad_input_vars_
[
k
],
&
tmp_grad_outputs
[
k
],
&
attrs_
);
// this, &grad_inputs, &outputs, &attrs_map);
info
.
infer_var_type_
(
&
infer_var_type_ctx
);
// info.infer_var_type_(infer_var_type_ctx);
}
// }
framework
::
OperatorWithKernel
*
op_kernel
=
framework
::
OperatorWithKernel
*
op_kernel
=
dynamic_cast
<
framework
::
OperatorWithKernel
*>
(
opbase
.
get
());
dynamic_cast
<
framework
::
OperatorWithKernel
*>
(
opbase
.
get
());
...
@@ -298,7 +299,6 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
...
@@ -298,7 +299,6 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
}
}
framework
::
RuntimeContext
ctx
(
grad_invars_map
,
grad_outvars_map
);
framework
::
RuntimeContext
ctx
(
grad_invars_map
,
grad_outvars_map
);
framework
::
Scope
scope
;
framework
::
Scope
scope
;
PreparedOp
p
=
PreparedOp
::
Prepare
(
ctx
,
*
op_kernel
,
place_
);
PreparedOp
p
=
PreparedOp
::
Prepare
(
ctx
,
*
op_kernel
,
place_
);
p
.
op
.
RuntimeInferShape
(
scope
,
place_
,
ctx
);
p
.
op
.
RuntimeInferShape
(
scope
,
place_
,
ctx
);
...
...
paddle/fluid/imperative/layer.h
浏览文件 @
36dce65b
...
@@ -137,13 +137,13 @@ class VarBase {
...
@@ -137,13 +137,13 @@ class VarBase {
persistable
)
{}
persistable
)
{}
private:
private:
// TODO(minqiyang): need support SelectedRows
VarBase
(
const
std
::
string
&
name
,
framework
::
proto
::
VarType
::
Type
dtype
,
VarBase
(
const
std
::
string
&
name
,
framework
::
proto
::
VarType
::
Type
dtype
,
const
framework
::
DDim
&
shape
,
const
platform
::
Place
&
place
,
const
framework
::
DDim
&
shape
,
const
platform
::
Place
&
place
,
framework
::
Variable
*
var
,
VarBase
*
grad
,
bool
stop_gradient
,
framework
::
Variable
*
var
,
VarBase
*
grad
,
bool
stop_gradient
,
bool
persistable
)
bool
persistable
)
:
name_
(
name
),
:
name_
(
name
),
dtype_
(
dtype
),
type_
(
framework
::
proto
::
VarType
::
LOD_TENSOR
),
place_
(
place
),
var_
(
var
),
var_
(
var
),
grads_
(
grad
),
grads_
(
grad
),
stop_gradient_
(
stop_gradient
),
stop_gradient_
(
stop_gradient
),
...
@@ -153,10 +153,12 @@ class VarBase {
...
@@ -153,10 +153,12 @@ class VarBase {
pre_op_out_idx_
(
-
1
)
{
pre_op_out_idx_
(
-
1
)
{
if
(
!
var_
)
{
if
(
!
var_
)
{
var_
=
new
framework
::
Variable
();
var_
=
new
framework
::
Variable
();
}
auto
tensor
=
var_
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
tensor
=
var_
->
GetMutable
<
framework
::
LoDTensor
>
();
tensor
->
Resize
(
shape
);
tensor
->
Resize
(
shape
);
tensor
->
mutable_data
(
place_
,
dtype_
);
tensor
->
mutable_data
(
place
,
dtype
);
}
VLOG
(
10
)
<<
"create varbase: "
<<
name_
<<
" type: "
<<
dtype
<<
" place: "
<<
place
;
}
}
public:
public:
...
@@ -186,11 +188,23 @@ class VarBase {
...
@@ -186,11 +188,23 @@ class VarBase {
}
}
}
}
inline
void
SetDType
(
framework
::
proto
::
VarType
::
Type
type
)
{
inline
framework
::
DDim
Dims
()
const
{
return
var_
->
Get
<
framework
::
LoDTensor
>
().
dims
();
}
// data type. e.g.. FP32
inline
void
SetDataType
(
framework
::
proto
::
VarType
::
Type
type
)
{
auto
tensor
=
var_
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
tensor
=
var_
->
GetMutable
<
framework
::
LoDTensor
>
();
tensor
->
mutable_data
(
place_
,
dtype_
);
tensor
->
mutable_data
(
place_
,
type
);
}
}
inline
framework
::
proto
::
VarType
::
Type
DType
()
const
{
return
dtype_
;
}
inline
framework
::
proto
::
VarType
::
Type
DataType
()
const
{
auto
tensor
=
var_
->
Get
<
framework
::
LoDTensor
>
();
return
tensor
.
type
();
}
// tensor type. e.g.. LoDTensor
inline
void
SetType
(
framework
::
proto
::
VarType
::
Type
type
)
{
type_
=
type
;
}
inline
framework
::
proto
::
VarType
::
Type
Type
()
const
{
return
type_
;
}
inline
void
SetStopGradient
(
bool
stop_gradient
)
{
inline
void
SetStopGradient
(
bool
stop_gradient
)
{
stop_gradient_
=
stop_gradient
;
stop_gradient_
=
stop_gradient
;
...
@@ -244,7 +258,7 @@ class VarBase {
...
@@ -244,7 +258,7 @@ class VarBase {
}
}
std
::
string
name_
;
std
::
string
name_
;
framework
::
proto
::
VarType
::
Type
d
type_
;
framework
::
proto
::
VarType
::
Type
type_
;
platform
::
Place
place_
;
platform
::
Place
place_
;
framework
::
Variable
*
var_
;
framework
::
Variable
*
var_
;
...
@@ -339,6 +353,8 @@ class PYBIND11_HIDDEN OpBase {
...
@@ -339,6 +353,8 @@ class PYBIND11_HIDDEN OpBase {
std
::
vector
<
VarBasePtrMap
>
grad_output_vars_
;
std
::
vector
<
VarBasePtrMap
>
grad_output_vars_
;
std
::
vector
<
py
::
object
>
backward_hooks_
;
std
::
vector
<
py
::
object
>
backward_hooks_
;
framework
::
AttributeMap
attrs_
;
};
};
class
Layer
{
class
Layer
{
...
@@ -437,22 +453,22 @@ class PYBIND11_HIDDEN RuntimeInferVarTypeContext
...
@@ -437,22 +453,22 @@ class PYBIND11_HIDDEN RuntimeInferVarTypeContext
framework
::
proto
::
VarType
::
Type
GetType
(
framework
::
proto
::
VarType
::
Type
GetType
(
const
std
::
string
&
name
)
const
override
{
const
std
::
string
&
name
)
const
override
{
return
var_set_
.
at
(
name
)
->
D
Type
();
return
var_set_
.
at
(
name
)
->
Type
();
}
}
void
SetType
(
const
std
::
string
&
name
,
void
SetType
(
const
std
::
string
&
name
,
framework
::
proto
::
VarType
::
Type
type
)
override
{
framework
::
proto
::
VarType
::
Type
type
)
override
{
var_set_
[
name
]
->
Set
D
Type
(
type
);
var_set_
[
name
]
->
SetType
(
type
);
}
}
framework
::
proto
::
VarType
::
Type
GetDataType
(
framework
::
proto
::
VarType
::
Type
GetDataType
(
const
std
::
string
&
name
)
const
override
{
const
std
::
string
&
name
)
const
override
{
return
var_set_
.
at
(
name
)
->
DType
();
return
var_set_
.
at
(
name
)
->
D
ata
Type
();
}
}
void
SetDataType
(
const
std
::
string
&
name
,
void
SetDataType
(
const
std
::
string
&
name
,
framework
::
proto
::
VarType
::
Type
type
)
override
{
framework
::
proto
::
VarType
::
Type
type
)
override
{
var_set_
[
name
]
->
SetDType
(
type
);
var_set_
[
name
]
->
SetD
ata
Type
(
type
);
}
}
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
GetDataTypes
(
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
GetDataTypes
(
...
...
paddle/fluid/imperative/tracer.cc
浏览文件 @
36dce65b
...
@@ -232,7 +232,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
...
@@ -232,7 +232,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
if
(
info
.
infer_var_type_
)
{
if
(
info
.
infer_var_type_
)
{
RuntimeInferVarTypeContext
infer_var_type_ctx
(
&
inputs
,
&
outputs
,
RuntimeInferVarTypeContext
infer_var_type_ctx
(
&
inputs
,
&
outputs
,
&
attrs_map
);
&
attrs_map
);
info
.
infer_var_type_
(
infer_var_type_ctx
);
info
.
infer_var_type_
(
&
infer_var_type_ctx
);
}
}
// TODO(minqiyang): Support infer var type in imperative mode
// TODO(minqiyang): Support infer var type in imperative mode
...
@@ -259,6 +259,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
...
@@ -259,6 +259,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
VLOG
(
5
)
<<
"start construct backward op"
;
VLOG
(
5
)
<<
"start construct backward op"
;
// construct grad op descs
// construct grad op descs
op
->
attrs_
=
attrs_map
;
std
::
unique_ptr
<
framework
::
OpDesc
>
fwd_op_desc
(
new
framework
::
OpDesc
(
std
::
unique_ptr
<
framework
::
OpDesc
>
fwd_op_desc
(
new
framework
::
OpDesc
(
op
->
Type
(),
invars_name_map
,
outvars_name_map
,
attrs_map
));
op
->
Type
(),
invars_name_map
,
outvars_name_map
,
attrs_map
));
std
::
unique_ptr
<
std
::
unordered_map
<
std
::
string
,
std
::
string
>>
grad_to_var
(
std
::
unique_ptr
<
std
::
unordered_map
<
std
::
string
,
std
::
string
>>
grad_to_var
(
...
...
paddle/fluid/operators/sum_op.cc
浏览文件 @
36dce65b
...
@@ -168,11 +168,11 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
...
@@ -168,11 +168,11 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
}
}
bool
any_input_is_lod_tensor
=
std
::
any_of
(
bool
any_input_is_lod_tensor
=
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
ctx
](
const
std
::
string
&
name
)
{
inputs
.
begin
(),
inputs
.
end
(),
[
&
ctx
](
const
std
::
string
&
name
)
{
return
ctx
.
GetType
(
name
)
==
framework
::
proto
::
VarType
::
LOD_TENSOR
;
return
ctx
.
GetType
(
name
)
==
framework
::
proto
::
VarType
::
LOD_TENSOR
;
});
});
auto
is_tensor_array
=
[
ctx
](
const
std
::
string
&
name
)
{
auto
is_tensor_array
=
[
&
ctx
](
const
std
::
string
&
name
)
{
return
ctx
.
GetType
(
name
)
==
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
;
return
ctx
.
GetType
(
name
)
==
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
;
};
};
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
36dce65b
...
@@ -194,7 +194,7 @@ PYBIND11_MODULE(core, m) {
...
@@ -194,7 +194,7 @@ PYBIND11_MODULE(core, m) {
.
def_property
(
"name"
,
&
imperative
::
VarBase
::
Name
,
.
def_property
(
"name"
,
&
imperative
::
VarBase
::
Name
,
&
imperative
::
VarBase
::
SetName
)
&
imperative
::
VarBase
::
SetName
)
.
def_property_readonly
(
"shape"
,
&
imperative
::
VarBase
::
Shape
)
.
def_property_readonly
(
"shape"
,
&
imperative
::
VarBase
::
Shape
)
.
def_property_readonly
(
"dtype"
,
&
imperative
::
VarBase
::
DType
)
.
def_property_readonly
(
"dtype"
,
&
imperative
::
VarBase
::
D
ata
Type
)
.
def_property
(
"persistable"
,
&
imperative
::
VarBase
::
IsPersistable
,
.
def_property
(
"persistable"
,
&
imperative
::
VarBase
::
IsPersistable
,
&
imperative
::
VarBase
::
SetPersistable
)
&
imperative
::
VarBase
::
SetPersistable
)
.
def_property
(
"stop_gradient"
,
&
imperative
::
VarBase
::
IsStopGradient
,
.
def_property
(
"stop_gradient"
,
&
imperative
::
VarBase
::
IsStopGradient
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录