Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
881e063e
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
881e063e
编写于
5月 05, 2018
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
follow comments
上级
ff599b92
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
46 addition
and
59 deletion
+46
-59
paddle/fluid/framework/details/broadcast_op_handle.cc
paddle/fluid/framework/details/broadcast_op_handle.cc
+16
-19
paddle/fluid/framework/details/gather_op_handle.cc
paddle/fluid/framework/details/gather_op_handle.cc
+14
-13
paddle/fluid/framework/details/reduce_op_handle.cc
paddle/fluid/framework/details/reduce_op_handle.cc
+13
-11
paddle/fluid/framework/details/ssa_graph_builder.h
paddle/fluid/framework/details/ssa_graph_builder.h
+0
-4
paddle/fluid/framework/details/var_handle.h
paddle/fluid/framework/details/var_handle.h
+1
-10
paddle/fluid/framework/details/variable_visitor.cc
paddle/fluid/framework/details/variable_visitor.cc
+2
-2
未找到文件。
paddle/fluid/framework/details/broadcast_op_handle.cc
浏览文件 @
881e063e
...
...
@@ -53,42 +53,39 @@ void BroadcastOpHandle::RunImpl() {
Tensor
&
in_tensor
=
VariableVisitor
::
GetMutableTensor
(
in_var
);
// NOTE(zcd): the Place of input can get from in_tensor and in_var_handle ,
// maybe they are different, because the Place that getting from in_tensor is
// determined at runtime, the other is determined at building SSA graph stage.
// If they are different, DataTransform should be applied. Currently, it has
// not been done yet.
// NOTE: The tensors' Place of input and output must be all on GPU or all on
// CPU.
for
(
auto
*
out_var_handle
:
out_var_handles
)
{
if
(
*
out_var_handle
==
*
in_var_handle
)
{
if
(
out_var_handle
->
IsTheSameVar
(
*
in_var_handle
)
)
{
continue
;
}
auto
&
out_p
=
out_var_handle
->
place_
;
auto
t_
out_p
=
out_var_handle
->
place_
;
auto
*
out_var
=
var_scopes
.
at
(
out_var_handle
->
scope_idx_
)
->
FindVar
(
out_var_handle
->
name_
);
PADDLE_ENFORCE_NOT_NULL
(
out_var
);
PADDLE_ENFORCE_EQ
(
out_p
.
which
(),
in_tensor
.
place
().
which
(),
"Currently, Places of input and output must be all on CPU "
"or all on GPU."
);
if
(
platform
::
is_gpu_place
(
in_tensor
.
place
()))
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
t_out_p
),
"Places of input and output must be all on GPU."
);
}
else
{
t_out_p
=
platform
::
CPUPlace
();
}
VariableVisitor
::
ShareDimsAndLoD
(
*
in_var
,
out_var
);
VariableVisitor
::
GetMutableTensor
(
out_var
).
mutable_data
(
out_p
,
VariableVisitor
::
GetMutableTensor
(
out_var
).
mutable_data
(
t_
out_p
,
in_tensor
.
type
());
}
if
(
platform
::
is_cpu_place
(
in_tensor
.
place
()))
{
for
(
auto
*
out_var_handle
:
out_var_handles
)
{
if
(
*
out_var_handle
==
*
in_var_handle
)
{
if
(
out_var_handle
->
IsTheSameVar
(
*
in_var_handle
)
)
{
continue
;
}
auto
&
out_p
=
out_var_handle
->
place_
;
auto
dev_ctx
=
dev_ctxes_
.
at
(
out_p
);
auto
*
out_var
=
var_scopes
.
at
(
out_var_handle
->
scope_idx_
)
->
FindVar
(
out_var_handle
->
name_
);
RunAndRecordEvent
(
out_p
,
[
in_tensor
,
out_var
,
dev_ctx
,
out_p
]
{
RunAndRecordEvent
(
out_p
,
[
in_tensor
,
out_var
]
{
paddle
::
framework
::
TensorCopy
(
in_tensor
,
out_p
,
*
dev_ctx
,
in_tensor
,
platform
::
CPUPlace
()
,
&
VariableVisitor
::
GetMutableTensor
(
out_var
));
});
}
...
...
@@ -134,8 +131,8 @@ void BroadcastOpHandle::RunImpl() {
call
();
}
}
// TODO(zcd): Maybe the unequal operator is not appropriate here.
if
(
*
out_handle
!=
*
in_var_handle
)
{
if
(
!
out_handle
->
IsTheSameVar
(
*
in_var_handle
)
)
{
auto
out_var
=
var_scopes
.
at
(
in_var_handle
->
scope_idx_
)
->
FindVar
(
out_var_handles
[
0
]
->
name_
);
paddle
::
framework
::
TensorCopy
(
...
...
paddle/fluid/framework/details/gather_op_handle.cc
浏览文件 @
881e063e
...
...
@@ -75,14 +75,15 @@ void GatherOpHandle::RunImpl() {
in_tensors
.
emplace_back
(
in_sr_value
.
value
());
}
// TODO(zcd): The Place of var_handle is determined at building SSA graph
// stage, while the Place of var is determined at runtime. If they are
// different, DataTransform should be applied. Currently, it has not been done
// yet.
auto
&
out_place
=
out_var_handle
->
place_
;
PADDLE_ENFORCE_EQ
(
out_place
.
which
(),
pre_in_value
.
place
().
which
(),
"Currently, Places of input and output must be all on CPU "
"or all on GPU."
);
// NOTE: The Places of all input tensor must be all on CPU or all on GPU.
platform
::
Place
t_out_p
=
out_var_handle
->
place_
;
if
(
platform
::
is_gpu_place
(
pre_in_value
.
place
()))
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
t_out_p
),
"Places of input and output must be all on GPU."
);
}
else
{
t_out_p
=
platform
::
CPUPlace
();
}
auto
out_var
=
var_scopes
.
at
(
out_var_handle
->
scope_idx_
)
->
FindVar
(
out_var_handle
->
name_
);
PADDLE_ENFORCE_NOT_NULL
(
out_var
);
...
...
@@ -93,18 +94,18 @@ void GatherOpHandle::RunImpl() {
DDim
out_dim
=
pre_in_value
.
GetCompleteDims
();
out_dim
[
0
]
=
static_cast
<
int64_t
>
(
rows
);
out_value
->
mutable_value
()
->
Resize
(
out_dim
).
mutable_data
(
out_place
,
pre_in_value
.
value
().
type
());
t_out_p
,
pre_in_value
.
value
().
type
());
Tensor
*
out_tensor
=
out_value
->
mutable_value
();
// copy
auto
dev_ctx
=
dev_ctxes_
[
out_place
];
RunAndRecordEvent
(
out_place
,
[
in_tensors
,
out_tensor
,
&
dev_ctx
,
out_place
]
{
auto
dev_ctx
=
dev_ctxes_
[
out_var_handle
->
place_
];
RunAndRecordEvent
(
out_var_handle
->
place_
,
[
in_tensors
,
out_tensor
,
&
dev_ctx
,
t_out_p
]
{
int
s
=
0
,
e
=
0
;
for
(
size_t
j
=
0
;
j
<
in_tensors
.
size
();
++
j
)
{
e
+=
in_tensors
[
j
].
dims
()[
0
];
auto
sub_out
=
out_tensor
->
Slice
(
s
,
e
);
paddle
::
framework
::
TensorCopy
(
in_tensors
[
j
],
out_place
,
*
dev_ctx
,
&
sub_out
);
paddle
::
framework
::
TensorCopy
(
in_tensors
[
j
],
t_out_p
,
*
dev_ctx
,
&
sub_out
);
s
=
e
;
}
});
...
...
paddle/fluid/framework/details/reduce_op_handle.cc
浏览文件 @
881e063e
...
...
@@ -53,6 +53,7 @@ void ReduceOpHandle::RunImpl() {
// Wait input done, this Wait is asynchronous operation
WaitInputVarGenerated
(
in_var_handles
);
// NOTE: The Places of all input tensor must be all on CPU or all on GPU.
std
::
vector
<
platform
::
Place
>
in_places
;
// used to get dev_ctx
for
(
auto
*
in_handle
:
in_var_handles
)
{
in_places
.
emplace_back
(
in_handle
->
place_
);
...
...
@@ -66,22 +67,23 @@ void ReduceOpHandle::RunImpl() {
var_scopes
.
at
(
out_var_handle
->
scope_idx_
)
->
FindVar
(
out_var_handle
->
name_
);
PADDLE_ENFORCE_NOT_NULL
(
out_var
);
// TODO(zcd): The Place of var_handle is determined at building SSA graph
// stage, while the Place of var is determined at runtime. If they are
// different, DataTransform should be applied. Currently, it has not been done
// yet.
PADDLE_ENFORCE_EQ
(
VariableVisitor
::
GetMutableTensor
(
pre_in_var
).
place
().
which
(),
out_var_handle
->
place_
.
which
(),
"Currently, Places of input and output must be all on CPU or all on "
"GPU."
);
// NOTE: The tensors' Place of input and output must be all on GPU or all on
// CPU.
auto
in_p
=
VariableVisitor
::
GetMutableTensor
(
pre_in_var
).
place
();
platform
::
Place
t_out_p
;
if
(
platform
::
is_gpu_place
(
in_p
))
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
out_var_handle
->
place_
),
"Places of input and output must be all on GPU."
);
t_out_p
=
out_var_handle
->
place_
;
}
else
{
t_out_p
=
platform
::
CPUPlace
();
}
if
(
pre_in_var
->
IsType
<
framework
::
SelectedRows
>
())
{
std
::
vector
<
const
SelectedRows
*>
in_selected_rows
=
GetInputValues
<
SelectedRows
>
(
in_var_handles
,
var_scopes
);
GatherSelectedRows
(
in_selected_rows
,
in_places
,
dev_ctxes_
,
out_var_handle
->
place_
,
GatherSelectedRows
(
in_selected_rows
,
in_places
,
dev_ctxes_
,
t_out_p
,
out_var
->
GetMutable
<
framework
::
SelectedRows
>
());
}
else
{
std
::
vector
<
const
LoDTensor
*>
lod_tensors
=
...
...
paddle/fluid/framework/details/ssa_graph_builder.h
浏览文件 @
881e063e
...
...
@@ -48,10 +48,6 @@ class SSAGraphBuilder {
const
platform
::
Place
&
place
,
size_t
place_offset
);
static
VarHandle
*
GetLatestVarHandle
(
SSAGraph
*
graph
,
const
std
::
string
&
each_var_name
,
size_t
place_offset
);
// Add an output variable (each_var_name, place, place_offset) to op_handle,
// which belongs to graph
static
void
CreateOpOutput
(
SSAGraph
*
graph
,
OpHandleBase
*
op_handle
,
...
...
paddle/fluid/framework/details/var_handle.h
浏览文件 @
881e063e
...
...
@@ -62,19 +62,10 @@ struct VarHandle : public VarHandleBase {
std
::
string
name_
;
platform
::
Place
place_
;
// NOTE(zcd): Strictly speaking, if the two var_handle is equal, the four
// member variables(version_, scope_id_, name_, place_) must be equal. But
// sometimes judging whether the two var_handle is equal is actually to
// determine whether the two Variables that represented by var_handle is the
// same. And the same Variable may have many different var_handles, the
// version_ of these var_handles is different. So I don't take care of
// version_ temporarily when overloading equal.
bool
operator
==
(
const
VarHandle
&
o
)
const
{
bool
IsTheSameVar
(
const
VarHandle
&
o
)
const
{
return
o
.
generated_op_
==
generated_op_
&&
o
.
name_
==
name_
&&
o
.
scope_idx_
==
scope_idx_
;
}
bool
operator
!=
(
const
VarHandle
&
o
)
const
{
return
!
this
->
operator
==
(
o
);
}
};
// Dummy Variable. It is used to represent dependencies between operators
...
...
paddle/fluid/framework/details/variable_visitor.cc
浏览文件 @
881e063e
...
...
@@ -88,7 +88,7 @@ void VariableVisitor::ShareDimsAndLoD(const Variable& src, Variable* trg) {
VisitVariable
(
src
,
&
visitor
);
}
struct
Enforce
EqualShapeAndDType
Visitor
{
struct
Enforce
ShapeAndDTypeEQ
Visitor
{
const
Variable
*
trg_
;
void
operator
()(
const
LoDTensor
&
src
)
{
...
...
@@ -130,7 +130,7 @@ struct EnforceEqualShapeAndDTypeVisitor {
void
VariableVisitor
::
EnforceShapeAndDTypeEQ
(
const
Variable
&
var1
,
const
Variable
&
var2
)
{
Enforce
EqualShapeAndDType
Visitor
visitor
{
&
var1
};
Enforce
ShapeAndDTypeEQ
Visitor
visitor
{
&
var1
};
VisitVariable
(
var2
,
&
visitor
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录