Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
6db96ec2
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6db96ec2
编写于
4月 11, 2018
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
follow comments
上级
8eaec5dd
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
27 addition
and
27 deletion
+27
-27
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+3
-3
paddle/fluid/framework/details/broadcast_op_handle.cc
paddle/fluid/framework/details/broadcast_op_handle.cc
+6
-6
paddle/fluid/framework/details/broadcast_op_handle.h
paddle/fluid/framework/details/broadcast_op_handle.h
+5
-5
paddle/fluid/framework/details/broadcast_op_handle_test.cc
paddle/fluid/framework/details/broadcast_op_handle_test.cc
+13
-13
未找到文件。
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
6db96ec2
...
...
@@ -5,7 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod
if
(
WITH_GPU
)
nv_library
(
nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda
)
nv_library
(
broad
_cast_op_handle SRCS broad_
cast_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
)
nv_library
(
broad
cast_op_handle SRCS broad
cast_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
)
endif
()
cc_library
(
computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry
)
...
...
@@ -15,8 +15,8 @@ cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
if
(
WITH_GPU
)
set
(
multi_devices_graph_builder_deps nccl_all_reduce_op_handle
)
nv_test
(
broad
_cast_op_test SRCS broad_
cast_op_handle_test.cc DEPS var_handle op_handle_base scope lod_tensor ddim memory
device_context broad
_
cast_op_handle
)
nv_test
(
broad
cast_op_test SRCS broad
cast_op_handle_test.cc DEPS var_handle op_handle_base scope lod_tensor ddim memory
device_context broadcast_op_handle
)
else
()
set
(
multi_devices_graph_builder_deps
)
endif
()
...
...
paddle/fluid/framework/details/broad
_
cast_op_handle.cc
→
paddle/fluid/framework/details/broadcast_op_handle.cc
浏览文件 @
6db96ec2
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/broad
_
cast_op_handle.h"
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -28,16 +28,16 @@ Tensor *GetTensorFromVar(Variable *in_var) {
}
return
nullptr
;
}
B
CastOpHandle
::
BC
astOpHandle
(
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
platform
::
ContextMap
&
ctxs
)
B
roadcastOpHandle
::
Broadc
astOpHandle
(
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
platform
::
ContextMap
&
ctxs
)
:
local_scopes_
(
local_scopes
),
places_
(
places
),
ctxs_
(
ctxs
)
{
for
(
auto
&
p
:
places_
)
{
this
->
dev_ctxes_
[
p
]
=
ctxs_
.
DevCtx
(
p
);
}
}
void
B
C
astOpHandle
::
RunImpl
()
{
void
B
roadc
astOpHandle
::
RunImpl
()
{
PADDLE_ENFORCE_EQ
(
this
->
inputs_
.
size
(),
1
);
PADDLE_ENFORCE_EQ
(
this
->
outputs_
.
size
(),
places_
.
size
());
...
...
@@ -97,7 +97,7 @@ void BCastOpHandle::RunImpl() {
}
}
std
::
string
B
C
astOpHandle
::
Name
()
const
{
return
"broadcast"
;
}
std
::
string
B
roadc
astOpHandle
::
Name
()
const
{
return
"broadcast"
;
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/broad
_
cast_op_handle.h
→
paddle/fluid/framework/details/broadcast_op_handle.h
浏览文件 @
6db96ec2
...
...
@@ -29,17 +29,17 @@ namespace framework {
namespace
details
{
/*
* Broad
C
ast the input to all scope.
* Broad
c
ast the input to all scope.
*
*/
struct
B
C
astOpHandle
:
public
OpHandleBase
{
struct
B
roadc
astOpHandle
:
public
OpHandleBase
{
const
std
::
vector
<
Scope
*>
&
local_scopes_
;
const
std
::
vector
<
platform
::
Place
>
&
places_
;
const
platform
::
ContextMap
&
ctxs_
;
B
C
astOpHandle
(
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
platform
::
ContextMap
&
ctxs
);
B
roadc
astOpHandle
(
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
platform
::
ContextMap
&
ctxs
);
std
::
string
Name
()
const
override
;
...
...
paddle/fluid/framework/details/broad
_
cast_op_handle_test.cc
→
paddle/fluid/framework/details/broadcast_op_handle_test.cc
浏览文件 @
6db96ec2
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/broad
_
cast_op_handle.h"
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "gtest/gtest.h"
#include "paddle/fluid/platform/device_context.h"
...
...
@@ -23,12 +23,12 @@ namespace p = paddle::platform;
// test data amount
const
f
::
DDim
kDims
=
{
20
,
20
};
class
Broad
C
astTester
:
public
::
testing
::
Test
{
class
Broad
c
astTester
:
public
::
testing
::
Test
{
public:
void
SetUp
()
override
{
int
count
=
p
::
GetCUDADeviceCount
();
if
(
count
<=
1
)
{
LOG
(
WARNING
)
<<
"Cannot test multi-gpu Broad
C
ast, because the CUDA "
LOG
(
WARNING
)
<<
"Cannot test multi-gpu Broad
c
ast, because the CUDA "
"device count is "
<<
count
;
exit
(
0
);
...
...
@@ -40,7 +40,7 @@ class BroadCastTester : public ::testing::Test {
}
template
<
class
T
>
void
Broad
C
astInitOp
(
int
gpu_id
=
0
)
{
void
Broad
c
astInitOp
(
int
gpu_id
=
0
)
{
for
(
size_t
j
=
0
;
j
<
gpu_list_
.
size
();
++
j
)
{
local_scope_
.
push_back
(
&
g_scope_
.
NewScope
());
auto
*
out_var
=
local_scope_
[
j
]
->
Var
(
"out"
);
...
...
@@ -50,7 +50,7 @@ class BroadCastTester : public ::testing::Test {
in_var
->
GetMutable
<
T
>
();
bc_op_handle_
=
new
f
::
details
::
B
C
astOpHandle
(
local_scope_
,
gpu_list_
,
*
ctxs_
);
new
f
::
details
::
B
roadc
astOpHandle
(
local_scope_
,
gpu_list_
,
*
ctxs_
);
f
::
details
::
VarHandle
*
in_var_handle
=
new
f
::
details
::
VarHandle
();
in_var_handle
->
place_
=
gpu_list_
[
gpu_id
];
...
...
@@ -68,7 +68,7 @@ class BroadCastTester : public ::testing::Test {
bc_op_handle_
->
AddOutput
(
out_var_handle
);
}
}
void
Broad
C
astDestroy
()
{
void
Broad
c
astDestroy
()
{
delete
ctxs_
;
for
(
auto
in
:
bc_op_handle_
->
inputs_
)
{
delete
in
;
...
...
@@ -84,12 +84,12 @@ class BroadCastTester : public ::testing::Test {
p
::
ContextMap
*
ctxs_
;
std
::
vector
<
f
::
Scope
*>
local_scope_
;
std
::
vector
<
p
::
Place
>
gpu_list_
;
f
::
details
::
B
C
astOpHandle
*
bc_op_handle_
;
f
::
details
::
B
roadc
astOpHandle
*
bc_op_handle_
;
};
TEST_F
(
Broad
CastTester
,
BroadC
astTestLodTensor
)
{
TEST_F
(
Broad
castTester
,
Broadc
astTestLodTensor
)
{
int
gpu_id
=
0
;
Broad
C
astInitOp
<
f
::
LoDTensor
>
(
gpu_id
);
Broad
c
astInitOp
<
f
::
LoDTensor
>
(
gpu_id
);
auto
in_var
=
local_scope_
[
gpu_id
]
->
Var
(
"input"
);
auto
in_lod_tensor
=
in_var
->
GetMutable
<
f
::
LoDTensor
>
();
...
...
@@ -122,12 +122,12 @@ TEST_F(BroadCastTester, BroadCastTestLodTensor) {
}
}
Broad
C
astDestroy
();
Broad
c
astDestroy
();
}
TEST_F
(
Broad
CastTester
,
BroadC
astTestSelectedRows
)
{
TEST_F
(
Broad
castTester
,
Broadc
astTestSelectedRows
)
{
int
gpu_id
=
0
;
Broad
C
astInitOp
<
f
::
SelectedRows
>
(
gpu_id
);
Broad
c
astInitOp
<
f
::
SelectedRows
>
(
gpu_id
);
auto
in_var
=
local_scope_
[
gpu_id
]
->
Var
(
"input"
);
auto
in_selected_rows
=
in_var
->
GetMutable
<
f
::
SelectedRows
>
();
...
...
@@ -170,5 +170,5 @@ TEST_F(BroadCastTester, BroadCastTestSelectedRows) {
}
}
Broad
C
astDestroy
();
Broad
c
astDestroy
();
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录