Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
efd37269
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
efd37269
编写于
12月 25, 2017
作者:
Q
QI JUN
提交者:
GitHub
12月 25, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
remove unused place (#6972)
* remove unused place * fix ci
上级
ebf0d9e7
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
9 addition
and
73 deletion
+9
-73
paddle/operators/math/math_function.cc
paddle/operators/math/math_function.cc
+0
-8
paddle/operators/math/math_function.cu
paddle/operators/math/math_function.cu
+0
-7
paddle/platform/device_context.cc
paddle/platform/device_context.cc
+3
-5
paddle/platform/device_context.h
paddle/platform/device_context.h
+1
-5
paddle/platform/device_context_test.cu
paddle/platform/device_context_test.cu
+2
-3
paddle/platform/place.cc
paddle/platform/place.cc
+1
-8
paddle/platform/place.h
paddle/platform/place.h
+1
-25
paddle/platform/place_test.cc
paddle/platform/place_test.cc
+1
-12
未找到文件。
paddle/operators/math/math_function.cc
浏览文件 @
efd37269
...
@@ -277,14 +277,6 @@ void set_constant_with_place<platform::CPUPlace>(
...
@@ -277,14 +277,6 @@ void set_constant_with_place<platform::CPUPlace>(
TensorSetConstantCPU
(
tensor
,
value
));
TensorSetConstantCPU
(
tensor
,
value
));
}
}
template
<
>
void
set_constant_with_place
<
platform
::
MKLDNNPlace
>
(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
*
tensor
,
float
value
)
{
framework
::
VisitDataType
(
framework
::
ToDataType
(
tensor
->
type
()),
TensorSetConstantCPU
(
tensor
,
value
));
}
struct
TensorSetConstantWithPlace
:
public
boost
::
static_visitor
<
void
>
{
struct
TensorSetConstantWithPlace
:
public
boost
::
static_visitor
<
void
>
{
TensorSetConstantWithPlace
(
const
platform
::
DeviceContext
&
context
,
TensorSetConstantWithPlace
(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
*
tensor
,
float
value
)
framework
::
Tensor
*
tensor
,
float
value
)
...
...
paddle/operators/math/math_function.cu
浏览文件 @
efd37269
...
@@ -273,13 +273,6 @@ void set_constant_with_place<platform::CUDAPlace>(
...
@@ -273,13 +273,6 @@ void set_constant_with_place<platform::CUDAPlace>(
TensorSetConstantGPU
(
context
,
tensor
,
value
));
TensorSetConstantGPU
(
context
,
tensor
,
value
));
}
}
template
<
>
void
set_constant_with_place
<
platform
::
CUDNNPlace
>
(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
*
tensor
,
float
value
)
{
set_constant_with_place
<
platform
::
CUDAPlace
>
(
context
,
tensor
,
value
);
}
template
struct
RowwiseAdd
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
RowwiseAdd
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
RowwiseAdd
<
platform
::
CUDADeviceContext
,
double
>;
template
struct
RowwiseAdd
<
platform
::
CUDADeviceContext
,
double
>;
template
struct
ColwiseSum
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
ColwiseSum
<
platform
::
CUDADeviceContext
,
float
>;
...
...
paddle/platform/device_context.cc
浏览文件 @
efd37269
...
@@ -178,20 +178,18 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
...
@@ -178,20 +178,18 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
cudaStream_t
CUDADeviceContext
::
stream
()
const
{
return
stream_
;
}
cudaStream_t
CUDADeviceContext
::
stream
()
const
{
return
stream_
;
}
CUDNNDeviceContext
::
CUDNNDeviceContext
(
CUD
NN
Place
place
)
CUDNNDeviceContext
::
CUDNNDeviceContext
(
CUD
A
Place
place
)
:
CUDADeviceContext
(
place
)
,
place_
(
place
)
{
:
CUDADeviceContext
(
place
)
{
PADDLE_ENFORCE
(
dynload
::
cudnnCreate
(
&
cudnn_handle_
));
PADDLE_ENFORCE
(
dynload
::
cudnnCreate
(
&
cudnn_handle_
));
PADDLE_ENFORCE
(
dynload
::
cudnnSetStream
(
cudnn_handle_
,
stream
()));
PADDLE_ENFORCE
(
dynload
::
cudnnSetStream
(
cudnn_handle_
,
stream
()));
}
}
CUDNNDeviceContext
::~
CUDNNDeviceContext
()
{
CUDNNDeviceContext
::~
CUDNNDeviceContext
()
{
SetDeviceId
(
place_
.
device
);
SetDeviceId
(
boost
::
get
<
CUDAPlace
>
(
GetPlace
())
.
device
);
Wait
();
Wait
();
PADDLE_ENFORCE
(
dynload
::
cudnnDestroy
(
cudnn_handle_
));
PADDLE_ENFORCE
(
dynload
::
cudnnDestroy
(
cudnn_handle_
));
}
}
Place
CUDNNDeviceContext
::
GetPlace
()
const
{
return
CUDNNPlace
();
}
cudnnHandle_t
CUDNNDeviceContext
::
cudnn_handle
()
const
{
return
cudnn_handle_
;
}
cudnnHandle_t
CUDNNDeviceContext
::
cudnn_handle
()
const
{
return
cudnn_handle_
;
}
#endif
#endif
...
...
paddle/platform/device_context.h
浏览文件 @
efd37269
...
@@ -92,18 +92,14 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -92,18 +92,14 @@ class CUDADeviceContext : public DeviceContext {
class
CUDNNDeviceContext
:
public
CUDADeviceContext
{
class
CUDNNDeviceContext
:
public
CUDADeviceContext
{
public:
public:
explicit
CUDNNDeviceContext
(
CUD
NN
Place
place
);
explicit
CUDNNDeviceContext
(
CUD
A
Place
place
);
virtual
~
CUDNNDeviceContext
();
virtual
~
CUDNNDeviceContext
();
/*! \brief Return place in the device context. */
Place
GetPlace
()
const
final
;
/*! \brief Return cudnn handle in the device context. */
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t
cudnn_handle
()
const
;
cudnnHandle_t
cudnn_handle
()
const
;
private:
private:
cudnnHandle_t
cudnn_handle_
;
cudnnHandle_t
cudnn_handle_
;
CUDNNPlace
place_
;
};
};
#endif
#endif
...
...
paddle/platform/device_context_test.cu
浏览文件 @
efd37269
...
@@ -51,12 +51,11 @@ TEST(Device, CUDADeviceContext) {
...
@@ -51,12 +51,11 @@ TEST(Device, CUDADeviceContext) {
TEST
(
Device
,
CUDNNDeviceContext
)
{
TEST
(
Device
,
CUDNNDeviceContext
)
{
using
paddle
::
platform
::
CUDNNDeviceContext
;
using
paddle
::
platform
::
CUDNNDeviceContext
;
using
paddle
::
platform
::
CUD
NN
Place
;
using
paddle
::
platform
::
CUD
A
Place
;
if
(
paddle
::
platform
::
dynload
::
HasCUDNN
())
{
if
(
paddle
::
platform
::
dynload
::
HasCUDNN
())
{
int
count
=
paddle
::
platform
::
GetCUDADeviceCount
();
int
count
=
paddle
::
platform
::
GetCUDADeviceCount
();
for
(
int
i
=
0
;
i
<
count
;
++
i
)
{
for
(
int
i
=
0
;
i
<
count
;
++
i
)
{
CUDNNDeviceContext
*
device_context
=
CUDNNDeviceContext
*
device_context
=
new
CUDNNDeviceContext
(
CUDAPlace
(
i
));
new
CUDNNDeviceContext
(
CUDNNPlace
(
i
));
cudnnHandle_t
cudnn_handle
=
device_context
->
cudnn_handle
();
cudnnHandle_t
cudnn_handle
=
device_context
->
cudnn_handle
();
ASSERT_NE
(
nullptr
,
cudnn_handle
);
ASSERT_NE
(
nullptr
,
cudnn_handle
);
ASSERT_NE
(
nullptr
,
device_context
->
stream
());
ASSERT_NE
(
nullptr
,
device_context
->
stream
());
...
...
paddle/platform/place.cc
浏览文件 @
efd37269
...
@@ -23,7 +23,6 @@ class PlacePrinter : public boost::static_visitor<> {
...
@@ -23,7 +23,6 @@ class PlacePrinter : public boost::static_visitor<> {
public:
public:
explicit
PlacePrinter
(
std
::
ostream
&
os
)
:
os_
(
os
)
{}
explicit
PlacePrinter
(
std
::
ostream
&
os
)
:
os_
(
os
)
{}
void
operator
()(
const
CPUPlace
&
)
{
os_
<<
"CPUPlace"
;
}
void
operator
()(
const
CPUPlace
&
)
{
os_
<<
"CPUPlace"
;
}
void
operator
()(
const
MKLDNNPlace
&
)
{
os_
<<
"MKLDNNPlace"
;
}
void
operator
()(
const
CUDAPlace
&
p
)
{
void
operator
()(
const
CUDAPlace
&
p
)
{
os_
<<
"CUDAPlace("
<<
p
.
device
<<
")"
;
os_
<<
"CUDAPlace("
<<
p
.
device
<<
")"
;
}
}
...
@@ -41,18 +40,12 @@ const Place &get_place() { return the_default_place; }
...
@@ -41,18 +40,12 @@ const Place &get_place() { return the_default_place; }
const
CUDAPlace
default_gpu
()
{
return
CUDAPlace
(
0
);
}
const
CUDAPlace
default_gpu
()
{
return
CUDAPlace
(
0
);
}
const
CPUPlace
default_cpu
()
{
return
CPUPlace
();
}
const
CPUPlace
default_cpu
()
{
return
CPUPlace
();
}
const
MKLDNNPlace
default_mkldnn
()
{
return
MKLDNNPlace
();
}
bool
is_gpu_place
(
const
Place
&
p
)
{
bool
is_gpu_place
(
const
Place
&
p
)
{
return
boost
::
apply_visitor
(
IsCUDAPlace
(),
p
);
return
boost
::
apply_visitor
(
IsCUDAPlace
(),
p
);
}
}
bool
is_cpu_place
(
const
Place
&
p
)
{
return
!
is_gpu_place
(
p
)
&&
!
is_mkldnn_place
(
p
);
}
bool
is_mkldnn_place
(
const
Place
&
p
)
{
bool
is_cpu_place
(
const
Place
&
p
)
{
return
!
is_gpu_place
(
p
);
}
return
boost
::
apply_visitor
(
IsMKLDNNPlace
(),
p
);
}
bool
places_are_same_class
(
const
Place
&
p1
,
const
Place
&
p2
)
{
bool
places_are_same_class
(
const
Place
&
p1
,
const
Place
&
p2
)
{
return
p1
.
which
()
==
p2
.
which
();
return
p1
.
which
()
==
p2
.
which
();
...
...
paddle/platform/place.h
浏览文件 @
efd37269
...
@@ -31,14 +31,6 @@ struct CPUPlace {
...
@@ -31,14 +31,6 @@ struct CPUPlace {
inline
bool
operator
!=
(
const
CPUPlace
&
)
const
{
return
false
;
}
inline
bool
operator
!=
(
const
CPUPlace
&
)
const
{
return
false
;
}
};
};
struct
MKLDNNPlace
{
MKLDNNPlace
()
{}
// needed for variant equality comparison
inline
bool
operator
==
(
const
MKLDNNPlace
&
)
const
{
return
true
;
}
inline
bool
operator
!=
(
const
MKLDNNPlace
&
)
const
{
return
false
;
}
};
struct
CUDAPlace
{
struct
CUDAPlace
{
CUDAPlace
()
:
CUDAPlace
(
0
)
{}
CUDAPlace
()
:
CUDAPlace
(
0
)
{}
explicit
CUDAPlace
(
int
d
)
:
device
(
d
)
{}
explicit
CUDAPlace
(
int
d
)
:
device
(
d
)
{}
...
@@ -53,37 +45,21 @@ struct CUDAPlace {
...
@@ -53,37 +45,21 @@ struct CUDAPlace {
int
device
;
int
device
;
};
};
struct
CUDNNPlace
:
public
CUDAPlace
{
CUDNNPlace
()
:
CUDAPlace
()
{}
explicit
CUDNNPlace
(
int
d
)
:
CUDAPlace
(
d
)
{}
};
struct
IsCUDAPlace
:
public
boost
::
static_visitor
<
bool
>
{
struct
IsCUDAPlace
:
public
boost
::
static_visitor
<
bool
>
{
bool
operator
()(
const
CPUPlace
&
)
const
{
return
false
;
}
bool
operator
()(
const
CPUPlace
&
)
const
{
return
false
;
}
bool
operator
()(
const
MKLDNNPlace
&
)
const
{
return
false
;
}
bool
operator
()(
const
CUDAPlace
&
gpu
)
const
{
return
true
;
}
bool
operator
()(
const
CUDAPlace
&
gpu
)
const
{
return
true
;
}
bool
operator
()(
const
CUDNNPlace
&
)
const
{
return
true
;
}
};
struct
IsMKLDNNPlace
:
public
boost
::
static_visitor
<
bool
>
{
bool
operator
()(
const
MKLDNNPlace
&
)
const
{
return
true
;
}
bool
operator
()(
const
CPUPlace
&
)
const
{
return
false
;
}
bool
operator
()(
const
CUDAPlace
&
)
const
{
return
false
;
}
bool
operator
()(
const
CUDNNPlace
&
)
const
{
return
false
;
}
};
};
typedef
boost
::
variant
<
CUD
NNPlace
,
CUDAPlace
,
CPUPlace
,
MKLDNN
Place
>
Place
;
typedef
boost
::
variant
<
CUD
APlace
,
CPU
Place
>
Place
;
void
set_place
(
const
Place
&
);
void
set_place
(
const
Place
&
);
const
Place
&
get_place
();
const
Place
&
get_place
();
const
CUDAPlace
default_gpu
();
const
CUDAPlace
default_gpu
();
const
CPUPlace
default_cpu
();
const
CPUPlace
default_cpu
();
const
MKLDNNPlace
default_mkldnn
();
bool
is_gpu_place
(
const
Place
&
);
bool
is_gpu_place
(
const
Place
&
);
bool
is_cpu_place
(
const
Place
&
);
bool
is_cpu_place
(
const
Place
&
);
bool
is_mkldnn_place
(
const
Place
&
);
bool
places_are_same_class
(
const
Place
&
,
const
Place
&
);
bool
places_are_same_class
(
const
Place
&
,
const
Place
&
);
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
Place
&
);
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
Place
&
);
...
...
paddle/platform/place_test.cc
浏览文件 @
efd37269
...
@@ -5,37 +5,26 @@
...
@@ -5,37 +5,26 @@
TEST
(
Place
,
Equality
)
{
TEST
(
Place
,
Equality
)
{
paddle
::
platform
::
CPUPlace
cpu
;
paddle
::
platform
::
CPUPlace
cpu
;
paddle
::
platform
::
CUDAPlace
g0
(
0
),
g1
(
1
),
gg0
(
0
);
paddle
::
platform
::
CUDAPlace
g0
(
0
),
g1
(
1
),
gg0
(
0
);
paddle
::
platform
::
CUDNNPlace
d0
(
0
),
d1
(
1
),
dd0
(
0
);
EXPECT_EQ
(
cpu
,
cpu
);
EXPECT_EQ
(
cpu
,
cpu
);
EXPECT_EQ
(
g0
,
g0
);
EXPECT_EQ
(
g0
,
g0
);
EXPECT_EQ
(
g1
,
g1
);
EXPECT_EQ
(
g1
,
g1
);
EXPECT_EQ
(
g0
,
gg0
);
EXPECT_EQ
(
g0
,
gg0
);
EXPECT_EQ
(
d0
,
dd0
);
EXPECT_NE
(
g0
,
g1
);
EXPECT_NE
(
g0
,
g1
);
EXPECT_NE
(
d0
,
d1
);
EXPECT_TRUE
(
paddle
::
platform
::
places_are_same_class
(
g0
,
gg0
));
EXPECT_TRUE
(
paddle
::
platform
::
places_are_same_class
(
g0
,
gg0
));
EXPECT_FALSE
(
paddle
::
platform
::
places_are_same_class
(
g0
,
cpu
));
EXPECT_FALSE
(
paddle
::
platform
::
places_are_same_class
(
g0
,
cpu
));
EXPECT_TRUE
(
paddle
::
platform
::
is_gpu_place
(
d0
));
EXPECT_FALSE
(
paddle
::
platform
::
places_are_same_class
(
g0
,
d0
));
}
}
TEST
(
Place
,
Default
)
{
TEST
(
Place
,
Default
)
{
EXPECT_TRUE
(
paddle
::
platform
::
is_gpu_place
(
paddle
::
platform
::
get_place
()));
EXPECT_TRUE
(
paddle
::
platform
::
is_gpu_place
(
paddle
::
platform
::
get_place
()));
EXPECT_TRUE
(
paddle
::
platform
::
is_gpu_place
(
paddle
::
platform
::
default_gpu
()));
EXPECT_TRUE
(
paddle
::
platform
::
is_gpu_place
(
paddle
::
platform
::
default_gpu
()));
EXPECT_TRUE
(
paddle
::
platform
::
is_cpu_place
(
paddle
::
platform
::
default_cpu
()));
EXPECT_TRUE
(
paddle
::
platform
::
is_cpu_place
(
paddle
::
platform
::
default_cpu
()));
EXPECT_TRUE
(
paddle
::
platform
::
is_mkldnn_place
(
paddle
::
platform
::
default_mkldnn
()));
EXPECT_FALSE
(
paddle
::
platform
::
is_cpu_place
(
paddle
::
platform
::
get_place
()));
paddle
::
platform
::
set_place
(
paddle
::
platform
::
CPUPlace
());
paddle
::
platform
::
set_place
(
paddle
::
platform
::
CPUPlace
());
EXPECT_TRUE
(
paddle
::
platform
::
is_cpu_place
(
paddle
::
platform
::
get_place
()));
EXPECT_TRUE
(
paddle
::
platform
::
is_cpu_place
(
paddle
::
platform
::
get_place
()));
paddle
::
platform
::
set_place
(
paddle
::
platform
::
MKLDNNPlace
());
EXPECT_FALSE
(
paddle
::
platform
::
is_cpu_place
(
paddle
::
platform
::
get_place
()));
EXPECT_TRUE
(
paddle
::
platform
::
is_mkldnn_place
(
paddle
::
platform
::
get_place
()));
}
}
TEST
(
Place
,
Print
)
{
TEST
(
Place
,
Print
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录