Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
60f70b17
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
60f70b17
编写于
11月 05, 2018
作者:
D
dzhwinter
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
test=develop
上级
cc02353d
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
120 addition
and
183 deletion
+120
-183
CMakeLists.txt
CMakeLists.txt
+1
-3
paddle/fluid/inference/api/demo_ci/simple_on_word2vec.cc
paddle/fluid/inference/api/demo_ci/simple_on_word2vec.cc
+0
-1
paddle/fluid/inference/api/demo_ci/test.cc
paddle/fluid/inference/api/demo_ci/test.cc
+0
-99
paddle/fluid/inference/api/helper.h
paddle/fluid/inference/api/helper.h
+1
-1
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+1
-1
paddle/fluid/operators/batch_norm_op.cu.cc
paddle/fluid/operators/batch_norm_op.cu.cc
+0
-21
paddle/fluid/operators/conv_cudnn_op.cu.cc
paddle/fluid/operators/conv_cudnn_op.cu.cc
+7
-7
paddle/fluid/operators/fetch_op.cc
paddle/fluid/operators/fetch_op.cc
+0
-2
paddle/fluid/operators/label_smooth_op.cc
paddle/fluid/operators/label_smooth_op.cc
+1
-1
paddle/fluid/operators/load_combine_op.cc
paddle/fluid/operators/load_combine_op.cc
+27
-24
paddle/fluid/operators/load_op.cc
paddle/fluid/operators/load_op.cc
+24
-4
paddle/fluid/operators/save_combine_op.cc
paddle/fluid/operators/save_combine_op.cc
+24
-5
paddle/fluid/operators/save_op.cc
paddle/fluid/operators/save_op.cc
+34
-9
paddle/fluid/platform/init.cc
paddle/fluid/platform/init.cc
+0
-5
未找到文件。
CMakeLists.txt
浏览文件 @
60f70b17
...
@@ -212,6 +212,7 @@ endif()
...
@@ -212,6 +212,7 @@ endif()
include
(
external/threadpool
)
include
(
external/threadpool
)
include
(
flags
)
# set paddle compile flags
include
(
cudnn
)
# set cudnn libraries, must before configure
include
(
cudnn
)
# set cudnn libraries, must before configure
include
(
configure
)
# add paddle env configuration
include
(
configure
)
# add paddle env configuration
...
@@ -225,9 +226,6 @@ elseif()
...
@@ -225,9 +226,6 @@ elseif()
set
(
WITH_ANAKIN OFF CACHE STRING
"Anakin is used in MKL only now."
FORCE
)
set
(
WITH_ANAKIN OFF CACHE STRING
"Anakin is used in MKL only now."
FORCE
)
endif
()
endif
()
include
(
flags
)
# set paddle compile flags
include
(
cudnn
)
# set cudnn libraries, must before configure
include
(
configure
)
# add paddle env configuration
include
(
generic
)
# simplify cmake module
include
(
generic
)
# simplify cmake module
include
(
package
)
# set paddle packages
include
(
package
)
# set paddle packages
include
(
ccache
)
# set ccache for compilation
include
(
ccache
)
# set ccache for compilation
...
...
paddle/fluid/inference/api/demo_ci/simple_on_word2vec.cc
浏览文件 @
60f70b17
...
@@ -135,7 +135,6 @@ void MainThreads(int num_threads, bool use_gpu) {
...
@@ -135,7 +135,6 @@ void MainThreads(int num_threads, bool use_gpu) {
}
// namespace paddle
}
// namespace paddle
int
main
(
int
argc
,
char
**
argv
)
{
int
main
(
int
argc
,
char
**
argv
)
{
FLAGS_dirname
=
"./word2vec.inference.model"
;
google
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
true
);
google
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
true
);
paddle
::
demo
::
Main
(
false
/* use_gpu*/
);
paddle
::
demo
::
Main
(
false
/* use_gpu*/
);
paddle
::
demo
::
MainThreads
(
1
,
false
/* use_gpu*/
);
paddle
::
demo
::
MainThreads
(
1
,
false
/* use_gpu*/
);
...
...
paddle/fluid/inference/api/demo_ci/test.cc
已删除
100644 → 0
浏览文件 @
cc02353d
#include<windows.h>
#include <fstream>
#include "inference_icnet.h"
#include <thread>
#include <vector>
#include <string>
#include <iostream>
#include <sstream>
using
namespace
std
;
template
<
class
Type
>
Type
stringToNum
(
const
string
&
str
)
{
istringstream
iss
(
str
);
Type
num
;
iss
>>
num
;
return
num
;
}
void
test_imgs
()
{
void
*
h
=
init_predictor
(
"./lb/__model__"
,
"./lb/__params__"
,
0.3
f
,
true
,
0
);
std
::
ifstream
infile
(
"new_file.list"
);
std
::
ofstream
ofs
(
"./1.png.output.txt"
);
std
::
string
temp_s
;
std
::
vector
<
std
::
string
>
all_files
;
while
(
!
infile
.
eof
())
{
infile
>>
temp_s
;
all_files
.
push_back
(
temp_s
);
}
// size_t file_num = all_files.size();
infile
.
close
();
// =============read file list =============
for
(
size_t
f_k
=
0
;
f_k
<
1
;
f_k
++
)
{
// std::string path = "D:\\Paddle\\paddle\\fluid\\inference\\api\\demo_ci\\build\\Release\\";
// std::ifstream in_img(path + all_files[f_k]);
std
::
string
mypath
=
"D:
\\
Paddle
\\
paddle
\\
fluid
\\
inference
\\
api
\\
demo_ci
\\
build
\\
Release
\\
1.png.txt"
;
std
::
cout
<<
"file"
<<
mypath
<<
std
::
endl
;
std
::
ifstream
in_img
(
mypath
);
//std::cout << path + all_files[f_k] << std::endl;
double
temp_v
;
const
int
size
=
3
*
449
*
581
*
1
;
float
*
data
=
new
float
[
size
];
std
::
string
value
;
if
(
!
in_img
.
is_open
())
{
cout
<<
"open failed"
<<
endl
;
}
double
sum_input
=
.0
;
for
(
auto
i
=
0
;
i
<
size
;
i
++
)
{
getline
(
in_img
,
value
,
'\n'
);
double
v
=
stringToNum
<
double
>
(
value
);
data
[
i
]
=
static_cast
<
float
>
(
v
);
sum_input
+=
v
;
}
std
::
cout
<<
"sum_input"
<<
sum_input
<<
std
::
endl
;
in_img
.
close
();
const
int
SIZE
=
449
*
581
*
1
;
int64_t
*
p
=
new
int64_t
[
SIZE
]();
int
out_size
=
0
;
//memset(p, 0, size);
predict
(
h
,
data
,
3
,
449
,
581
,
&
p
,
&
out_size
,
1
);
std
::
cout
<<
"out_size = "
<<
out_size
<<
std
::
endl
;
double
out_sum
=
.0
;
for
(
auto
i
=
0
;
i
<
out_size
/
sizeof
(
int64_t
);
i
++
)
{
out_sum
+=
p
[
i
];
ofs
<<
p
[
i
]
<<
" "
;
}
ofs
.
close
();
std
::
cout
<<
"inferece out sum"
<<
out_sum
<<
std
::
endl
;
delete
p
;
}
destory_predictor
(
h
);
}
int
main
(
int
argc
,
char
**
argv
)
{
//if (true) {
// std::thread t1(func, init_predictor("./infer_model/__model__", "./infer_model/__params__", 0.1f, true, 0));
// std::thread t2(func, init_predictor("./infer_model/__model__", "./infer_model/__params__", 0.1f, true, 0));
// //std::thread t3(func, init_predictor("./infer_model/__model__", "./infer_model/__params__", 0.1f, true, 0));
// //std::thread t4(func, init_predictor("./infer_model/__model__", "./infer_model/__params__", 0.1f, true, 0));
// t1.join();
// t2.join();
// //t3.join();
// //t4.join();
// //Sleep(1);
//}
test_imgs
();
return
0
;
}
paddle/fluid/inference/api/helper.h
浏览文件 @
60f70b17
...
@@ -97,7 +97,7 @@ static void TensorAssignData(PaddleTensor *tensor,
...
@@ -97,7 +97,7 @@ static void TensorAssignData(PaddleTensor *tensor,
}
}
template
<
typename
T
>
template
<
typename
T
>
static
int
ZeroCopyTensorAssignData
(
paddle
::
ZeroCopyTensor
*
tensor
,
static
int
ZeroCopyTensorAssignData
(
ZeroCopyTensor
*
tensor
,
const
std
::
vector
<
std
::
vector
<
T
>>
&
data
)
{
const
std
::
vector
<
std
::
vector
<
T
>>
&
data
)
{
int
size
{
0
};
int
size
{
0
};
auto
*
ptr
=
tensor
->
mutable_data
<
T
>
(
PaddlePlace
::
kCPU
);
auto
*
ptr
=
tensor
->
mutable_data
<
T
>
(
PaddlePlace
::
kCPU
);
...
...
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
60f70b17
...
@@ -291,7 +291,7 @@ op_library(gru_op DEPS sequence2batch gru_compute)
...
@@ -291,7 +291,7 @@ op_library(gru_op DEPS sequence2batch gru_compute)
op_library
(
recurrent_op DEPS executor
)
op_library
(
recurrent_op DEPS executor
)
op_library
(
warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale
)
op_library
(
warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale
)
op_library
(
cos_sim_op DEPS cos_sim_functor
)
op_library
(
cos_sim_op DEPS cos_sim_functor
)
op_library
(
parallel_do_op DEPS executor
glog
)
op_library
(
parallel_do_op DEPS executor
)
op_library
(
unsqueeze_op DEPS reshape_op
)
op_library
(
unsqueeze_op DEPS reshape_op
)
op_library
(
squeeze_op DEPS reshape_op
)
op_library
(
squeeze_op DEPS reshape_op
)
op_library
(
extract_rows_op DEPS memory
)
op_library
(
extract_rows_op DEPS memory
)
...
...
paddle/fluid/operators/batch_norm_op.cu.cc
浏览文件 @
60f70b17
...
@@ -141,27 +141,6 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
...
@@ -141,27 +141,6 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
bias
->
template
data
<
BatchNormParamType
<
T
>
>
(),
bias
->
template
data
<
BatchNormParamType
<
T
>
>
(),
est_mean
->
template
data
<
BatchNormParamType
<
T
>
>
(),
est_mean
->
template
data
<
BatchNormParamType
<
T
>
>
(),
est_var
->
template
data
<
BatchNormParamType
<
T
>
>
(),
epsilon
));
est_var
->
template
data
<
BatchNormParamType
<
T
>
>
(),
epsilon
));
VLOG
(
3
)
<<
"before tensor copy"
;
Tensor
mean_
,
var_
,
x_
,
y_
;
framework
::
TensorCopy
(
*
est_mean
,
platform
::
CPUPlace
(),
dev_ctx
,
&
mean_
);
framework
::
TensorCopy
(
*
est_var
,
platform
::
CPUPlace
(),
dev_ctx
,
&
var_
);
framework
::
TensorCopy
(
*
x
,
platform
::
CPUPlace
(),
dev_ctx
,
&
x_
);
framework
::
TensorCopy
(
*
y
,
platform
::
CPUPlace
(),
dev_ctx
,
&
y_
);
VLOG
(
3
)
<<
"after tensor copy"
;
auto
check_tensor
=
[
&
](
const
Tensor
&
check
)
{
float
sum
=
.0
;
for
(
size_t
i
=
0
;
i
<
check
.
numel
();
++
i
)
{
sum
+=
check
.
data
<
float
>
()[
i
];
}
return
sum
;
};
VLOG
(
3
)
<<
"BatchNormKernel"
;
VLOG
(
3
)
<<
"mean"
<<
check_tensor
(
mean_
);
VLOG
(
3
)
<<
"var"
<<
check_tensor
(
var_
);
VLOG
(
3
)
<<
"x"
<<
check_tensor
(
x_
);
VLOG
(
3
)
<<
"y"
<<
check_tensor
(
y_
);
}
else
{
}
else
{
// Run training mode.
// Run training mode.
// obtain running mean and running inv var, and see if we need to
// obtain running mean and running inv var, and see if we need to
...
...
paddle/fluid/operators/conv_cudnn_op.cu.cc
浏览文件 @
60f70b17
...
@@ -43,7 +43,6 @@ template <typename T>
...
@@ -43,7 +43,6 @@ template <typename T>
class
CUDNNConvOpKernel
:
public
framework
::
OpKernel
<
T
>
{
class
CUDNNConvOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
VLOG
(
3
)
<<
"inside cudnn"
;
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"It must use CUDAPlace."
);
"It must use CUDAPlace."
);
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"Input"
);
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"Input"
);
...
@@ -60,7 +59,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
...
@@ -60,7 +59,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
const
T
*
input_data
=
input
->
data
<
T
>
();
const
T
*
input_data
=
input
->
data
<
T
>
();
const
T
*
filter_data
=
filter
->
data
<
T
>
();
const
T
*
filter_data
=
filter
->
data
<
T
>
();
T
*
output_data
=
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
output_data
=
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
VLOG
(
3
)
<<
"get all inputs"
;
// ------------------- cudnn descriptors ---------------------
// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor
input_desc
;
ScopedTensorDescriptor
input_desc
;
ScopedTensorDescriptor
output_desc
;
ScopedTensorDescriptor
output_desc
;
...
@@ -73,7 +72,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
...
@@ -73,7 +72,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
cudnnConvolutionDescriptor_t
cudnn_conv_desc
=
cudnnConvolutionDescriptor_t
cudnn_conv_desc
=
conv_desc
.
descriptor
<
T
>
(
paddings
,
strides
,
dilations
);
conv_desc
.
descriptor
<
T
>
(
paddings
,
strides
,
dilations
);
VLOG
(
3
)
<<
"create tensor descriptor"
;
#if CUDNN_VERSION_MIN(7, 0, 1)
#if CUDNN_VERSION_MIN(7, 0, 1)
// cudnn 7 can support groups, no need to do it mannually
// cudnn 7 can support groups, no need to do it mannually
// FIXME(typhoonzero): find a better way to disable groups
// FIXME(typhoonzero): find a better way to disable groups
...
@@ -82,7 +81,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
...
@@ -82,7 +81,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
cudnn_conv_desc
,
groups
));
cudnn_conv_desc
,
groups
));
groups
=
1
;
groups
=
1
;
#endif
#endif
VLOG
(
3
)
<<
"before create tensor descriptor"
;
cudnnTensorDescriptor_t
cudnn_input_desc
=
input_desc
.
descriptor
<
T
>
(
cudnnTensorDescriptor_t
cudnn_input_desc
=
input_desc
.
descriptor
<
T
>
(
layout
,
framework
::
vectorize2int
(
input
->
dims
()),
groups
);
layout
,
framework
::
vectorize2int
(
input
->
dims
()),
groups
);
cudnnTensorDescriptor_t
cudnn_output_desc
=
output_desc
.
descriptor
<
T
>
(
cudnnTensorDescriptor_t
cudnn_output_desc
=
output_desc
.
descriptor
<
T
>
(
...
@@ -112,7 +111,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
...
@@ -112,7 +111,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
output_height
=
output
->
dims
()[
2
];
output_height
=
output
->
dims
()[
2
];
output_width
=
output
->
dims
()[
3
];
output_width
=
output
->
dims
()[
3
];
}
}
VLOG
(
3
)
<<
"after create tensor descriptor"
;
int
group_offset_in
=
int
group_offset_in
=
input_channels
/
groups
*
input_height
*
input_width
*
input_depth
;
input_channels
/
groups
*
input_height
*
input_width
*
input_depth
;
int
group_offset_out
=
int
group_offset_out
=
...
@@ -129,7 +128,6 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
...
@@ -129,7 +128,6 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
handle
=
dev_ctx
.
cudnn_handle
();
auto
handle
=
dev_ctx
.
cudnn_handle
();
VLOG
(
3
)
<<
"set cudnn algorithm"
;
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionForwardAlgorithm
(
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionForwardAlgorithm
(
handle
,
cudnn_input_desc
,
cudnn_filter_desc
,
cudnn_conv_desc
,
handle
,
cudnn_input_desc
,
cudnn_filter_desc
,
cudnn_conv_desc
,
cudnn_output_desc
,
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
,
cudnn_output_desc
,
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
,
...
@@ -150,7 +148,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
...
@@ -150,7 +148,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
cudnn_conv_desc
,
CUDNN_DEFAULT_MATH
));
cudnn_conv_desc
,
CUDNN_DEFAULT_MATH
));
}
}
#endif
#endif
VLOG
(
3
)
<<
"before get workspace"
;
// get workspace size able to allocate
// get workspace size able to allocate
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionForwardWorkspaceSize
(
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionForwardWorkspaceSize
(
handle
,
cudnn_input_desc
,
cudnn_filter_desc
,
cudnn_conv_desc
,
handle
,
cudnn_input_desc
,
cudnn_filter_desc
,
cudnn_conv_desc
,
...
@@ -159,6 +157,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
...
@@ -159,6 +157,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
// the limit because the algo is overrided to use tensor core.
// the limit because the algo is overrided to use tensor core.
PADDLE_ENFORCE_LE
(
workspace_size_in_bytes
,
workspace_size_limit
,
PADDLE_ENFORCE_LE
(
workspace_size_in_bytes
,
workspace_size_limit
,
"workspace_size to be allocated exceeds the limit"
);
"workspace_size to be allocated exceeds the limit"
);
// ------------------- cudnn conv forward ---------------------
// ------------------- cudnn conv forward ---------------------
ScalingParamType
<
T
>
alpha
=
1.0
f
,
beta
=
0.0
f
;
ScalingParamType
<
T
>
alpha
=
1.0
f
,
beta
=
0.0
f
;
for
(
int
i
=
0
;
i
<
groups
;
i
++
)
{
for
(
int
i
=
0
;
i
<
groups
;
i
++
)
{
...
@@ -312,6 +311,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
...
@@ -312,6 +311,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
cudnn_filter_desc
,
filter_algo
,
&
tmp_size
));
cudnn_filter_desc
,
filter_algo
,
&
tmp_size
));
workspace_size_in_bytes
=
std
::
max
(
workspace_size_in_bytes
,
tmp_size
);
workspace_size_in_bytes
=
std
::
max
(
workspace_size_in_bytes
,
tmp_size
);
}
}
// ------------------- cudnn conv backward data ---------------------
// ------------------- cudnn conv backward data ---------------------
ScalingParamType
<
T
>
alpha
=
1.0
f
,
beta
=
0.0
f
;
ScalingParamType
<
T
>
alpha
=
1.0
f
,
beta
=
0.0
f
;
if
(
input_grad
)
{
if
(
input_grad
)
{
...
...
paddle/fluid/operators/fetch_op.cc
浏览文件 @
60f70b17
...
@@ -42,8 +42,6 @@ class FetchOp : public framework::OperatorBase {
...
@@ -42,8 +42,6 @@ class FetchOp : public framework::OperatorBase {
"Cannot find out_var in scope, out_var_name is %s"
,
"Cannot find out_var in scope, out_var_name is %s"
,
out_name
);
out_name
);
VLOG
(
3
)
<<
"fetch_var ptr "
<<
fetch_var
<<
" is "
<<
(
fetch_var
==
nullptr
);
VLOG
(
3
)
<<
"out_var ptr "
<<
out_var
<<
" is "
<<
(
out_var
==
nullptr
);
auto
col
=
static_cast
<
size_t
>
(
Attr
<
int
>
(
"col"
));
auto
col
=
static_cast
<
size_t
>
(
Attr
<
int
>
(
"col"
));
auto
*
fetch_list
=
out_var
->
GetMutable
<
framework
::
FeedFetchList
>
();
auto
*
fetch_list
=
out_var
->
GetMutable
<
framework
::
FeedFetchList
>
();
...
...
paddle/fluid/operators/label_smooth_op.cc
浏览文件 @
60f70b17
...
@@ -34,7 +34,7 @@ class LabelSmoothOp : public framework::OperatorWithKernel {
...
@@ -34,7 +34,7 @@ class LabelSmoothOp : public framework::OperatorWithKernel {
auto
in_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
in_dims
=
ctx
->
GetInputDim
(
"X"
);
if
(
ctx
->
HasInput
(
"PriorDist"
))
{
if
(
ctx
->
HasInput
(
"PriorDist"
))
{
auto
noise_dims
=
ctx
->
GetInputDim
(
"PriorDist"
);
auto
noise_dims
=
ctx
->
GetInputDim
(
"PriorDist"
);
int64_t
noise_numel
=
paddle
::
framework
::
product
(
noise_dims
);
auto
noise_numel
=
paddle
::
framework
::
product
(
noise_dims
);
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
in_dims
[
1
]
==
noise_numel
,
in_dims
[
1
]
==
noise_numel
,
"The number of elements in Input(PriorDist) must be equal to the "
"The number of elements in Input(PriorDist) must be equal to the "
...
...
paddle/fluid/operators/load_combine_op.cc
浏览文件 @
60f70b17
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include <fstream>
#include <fstream>
#include <
vector
>
#include <
memory
>
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
...
@@ -33,10 +33,15 @@ class LoadCombineOp : public framework::OperatorBase {
...
@@ -33,10 +33,15 @@ class LoadCombineOp : public framework::OperatorBase {
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
load_as_fp16
=
Attr
<
bool
>
(
"load_as_fp16"
);
auto
load_as_fp16
=
Attr
<
bool
>
(
"load_as_fp16"
);
auto
format
=
Attr
<
std
::
string
>
(
"format"
);
std
::
ifstream
fin
(
filename
,
std
::
ios_base
::
in
|
std
::
ios_base
::
binary
);
std
::
unique_ptr
<
std
::
ifstream
>
fin
;
//std::ifstream fin(filename, std::ios_base::in);
if
(
format
==
"windows"
)
{
PADDLE_ENFORCE
(
!
fin
.
bad
(),
fin
.
reset
(
new
std
::
ifstream
(
filename
,
std
::
ios_base
::
in
|
std
::
ios_base
::
binary
));
}
else
{
fin
.
reset
(
new
std
::
ifstream
(
filename
));
}
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
*
fin
),
"Cannot open file %s for load_combine op"
,
filename
);
"Cannot open file %s for load_combine op"
,
filename
);
auto
out_var_names
=
Outputs
(
"Out"
);
auto
out_var_names
=
Outputs
(
"Out"
);
...
@@ -48,32 +53,20 @@ class LoadCombineOp : public framework::OperatorBase {
...
@@ -48,32 +53,20 @@ class LoadCombineOp : public framework::OperatorBase {
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
for
(
size_t
i
=
0
;
i
<
out_var_names
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
out_var_names
.
size
();
i
++
)
{
VLOG
(
3
)
<<
"load variable "
<<
out_var_names
[
i
];
auto
*
out_var
=
scope
.
FindVar
(
out_var_names
[
i
]);
auto
*
out_var
=
scope
.
FindVar
(
out_var_names
[
i
]);
PADDLE_ENFORCE
(
out_var
!=
nullptr
,
"Output variable %s cannot be found"
,
PADDLE_ENFORCE
(
out_var
!=
nullptr
,
"Output variable %s cannot be found"
,
out_var_names
[
i
]);
out_var_names
[
i
]);
auto
*
tensor
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
tensor
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
VLOG
(
3
)
<<
"Get Tensor"
;
// Error checking
// Error checking
PADDLE_ENFORCE
(
!
fin
.
bad
(
),
"Cannot read more from file %s"
,
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
*
fin
),
"Cannot read more from file %s"
,
filename
);
filename
);
VLOG
(
3
)
<<
"before deserialization"
;
// Get data from fin to tensor
// Get data from fin to tensor
DeserializeFromStream
(
fin
,
tensor
,
dev_ctx
);
DeserializeFromStream
(
*
fin
,
tensor
,
dev_ctx
);
// VLOG(3) << "after deserialization";
// framework::Tensor check;
// framework::TensorCopy(*tensor, platform::CPUPlace(), dev_ctx, &check);
// float sum = .0;
// for(size_t i=0; i < check.numel(); ++i) {
// if(std::type_index(check.type()) == std::type_index(typeid(int64_t))) {
// sum += static_cast<float>(check.data<int64_t>()[i]);
// } else {
// sum += check.data<float>()[i];
// }
// }
// VLOG(3) << "sum result" << sum;
auto
in_dtype
=
framework
::
ToDataType
(
tensor
->
type
());
auto
in_dtype
=
framework
::
ToDataType
(
tensor
->
type
());
auto
out_dtype
=
auto
out_dtype
=
load_as_fp16
?
framework
::
proto
::
VarType
::
FP16
:
in_dtype
;
load_as_fp16
?
framework
::
proto
::
VarType
::
FP16
:
in_dtype
;
...
@@ -93,9 +86,7 @@ class LoadCombineOp : public framework::OperatorBase {
...
@@ -93,9 +86,7 @@ class LoadCombineOp : public framework::OperatorBase {
tensor
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
tensor
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
tensor
->
set_lod
(
fp16_tensor
.
lod
());
tensor
->
set_lod
(
fp16_tensor
.
lod
());
tensor
->
ShareDataWith
(
fp16_tensor
);
tensor
->
ShareDataWith
(
fp16_tensor
);
}
}
VLOG
(
3
)
<<
"load "
<<
out_var_names
[
i
]
<<
" finished"
;
}
}
}
}
};
};
...
@@ -119,6 +110,18 @@ class LoadCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -119,6 +110,18 @@ class LoadCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"LoDTensors will be loaded from
\"
file_path
\"
."
)
"LoDTensors will be loaded from
\"
file_path
\"
."
)
.
AddCustomChecker
(
.
AddCustomChecker
(
[](
const
std
::
string
&
path
)
{
return
!
path
.
empty
();
});
[](
const
std
::
string
&
path
)
{
return
!
path
.
empty
();
});
AddAttr
<
std
::
string
>
(
"format"
,
R"DOC((windows|linux)" "saved model file format
windows and linux file newline symbol is
different. windows(newline is \n\r) or linux(newline is \r)
So if you set attribute format to windows, then we saved model file in binary.
It can be used both linux and windows. If you set format to linux,
it will save file in normal file, newline symbol is \r. Need to note
that these two format is not inter-compatible.)DOC"
)
.
SetDefault
(
"linux"
)
.
AddCustomChecker
([](
const
std
::
string
&
s
)
{
return
s
==
"windows"
||
s
==
"linux"
;
});
AddComment
(
R"DOC(
AddComment
(
R"DOC(
LoadCombine Operator.
LoadCombine Operator.
...
...
paddle/fluid/operators/load_op.cc
浏览文件 @
60f70b17
...
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include <fstream>
#include <fstream>
#include <memory>
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
...
@@ -34,8 +35,15 @@ class LoadOp : public framework::OperatorBase {
...
@@ -34,8 +35,15 @@ class LoadOp : public framework::OperatorBase {
// FIXME(yuyang18): We save variable to local file now, but we should change
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
// it to save an output stream.
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
std
::
ifstream
fin
(
filename
);
auto
format
=
Attr
<
std
::
string
>
(
"format"
);
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fin
),
"Cannot open file %s for load op"
,
std
::
unique_ptr
<
std
::
ifstream
>
fin
;
if
(
format
==
"windows"
)
{
fin
.
reset
(
new
std
::
ifstream
(
filename
,
std
::
ios_base
::
in
|
std
::
ios_base
::
binary
));
}
else
{
fin
.
reset
(
new
std
::
ifstream
(
filename
));
}
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
*
fin
),
"Cannot open file %s for load op"
,
filename
);
filename
);
auto
out_var_name
=
Output
(
"Out"
);
auto
out_var_name
=
Output
(
"Out"
);
...
@@ -44,9 +52,9 @@ class LoadOp : public framework::OperatorBase {
...
@@ -44,9 +52,9 @@ class LoadOp : public framework::OperatorBase {
out_var_name
);
out_var_name
);
if
(
out_var
->
IsType
<
framework
::
LoDTensor
>
())
{
if
(
out_var
->
IsType
<
framework
::
LoDTensor
>
())
{
LoadLodTensor
(
fin
,
place
,
out_var
);
LoadLodTensor
(
*
fin
,
place
,
out_var
);
}
else
if
(
out_var
->
IsType
<
framework
::
SelectedRows
>
())
{
}
else
if
(
out_var
->
IsType
<
framework
::
SelectedRows
>
())
{
LoadSelectedRows
(
fin
,
place
,
out_var
);
LoadSelectedRows
(
*
fin
,
place
,
out_var
);
}
else
{
}
else
{
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
false
,
false
,
...
@@ -110,6 +118,18 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -110,6 +118,18 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
R"(Variable will be loaded from "file_path")"
)
R"(Variable will be loaded from "file_path")"
)
.
AddCustomChecker
(
.
AddCustomChecker
(
[](
const
std
::
string
&
path
)
{
return
!
path
.
empty
();
});
[](
const
std
::
string
&
path
)
{
return
!
path
.
empty
();
});
AddAttr
<
std
::
string
>
(
"format"
,
R"DOC((windows|linux)" "saved model file format
windows and linux file newline symbol is
different. windows(newline is \n\r) or linux(newline is \r)
So if you set attribute format to windows, then we saved model file in binary.
It can be used both linux and windows. If you set format to linux,
it will save file in normal file, newline symbol is \r. Need to note
that these two format is not inter-compatible.)DOC"
)
.
SetDefault
(
"linux"
)
.
AddCustomChecker
([](
const
std
::
string
&
s
)
{
return
s
==
"windows"
||
s
==
"linux"
;
});
AddComment
(
AddComment
(
"Load operator will load a LoDTensor / SelectedRows variable from disk "
"Load operator will load a LoDTensor / SelectedRows variable from disk "
"file."
);
"file."
);
...
...
paddle/fluid/operators/save_combine_op.cc
浏览文件 @
60f70b17
...
@@ -14,6 +14,7 @@ limitations under the License. */
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include <stdint.h>
#include <stdint.h>
#include <fstream>
#include <fstream>
#include <memory>
#include <numeric>
#include <numeric>
#include <sstream>
#include <sstream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type.h"
...
@@ -41,6 +42,7 @@ class SaveCombineOp : public framework::OperatorBase {
...
@@ -41,6 +42,7 @@ class SaveCombineOp : public framework::OperatorBase {
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
overwrite
=
Attr
<
bool
>
(
"overwrite"
);
auto
overwrite
=
Attr
<
bool
>
(
"overwrite"
);
auto
save_as_fp16
=
Attr
<
bool
>
(
"save_as_fp16"
);
auto
save_as_fp16
=
Attr
<
bool
>
(
"save_as_fp16"
);
auto
format
=
Attr
<
std
::
string
>
(
"format"
);
bool
is_present
=
FileExists
(
filename
);
bool
is_present
=
FileExists
(
filename
);
if
(
is_present
&&
!
overwrite
)
{
if
(
is_present
&&
!
overwrite
)
{
...
@@ -49,8 +51,14 @@ class SaveCombineOp : public framework::OperatorBase {
...
@@ -49,8 +51,14 @@ class SaveCombineOp : public framework::OperatorBase {
}
}
MkDirRecursively
(
DirName
(
filename
).
c_str
());
MkDirRecursively
(
DirName
(
filename
).
c_str
());
std
::
ofstream
fout
(
filename
,
std
::
ios_base
::
out
|
std
::
ios_base
::
binary
);
std
::
unique_ptr
<
std
::
ofstream
>
fout
;
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fout
),
"Cannot open %s to write"
,
if
(
format
==
"windows"
)
{
fout
.
reset
(
new
std
::
ofstream
(
filename
,
std
::
ios_base
::
out
|
std
::
ios_base
::
binary
));
}
else
{
fout
.
reset
(
new
std
::
ofstream
(
filename
));
}
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
*
fout
),
"Cannot open %s to write"
,
filename
);
filename
);
auto
inp_var_names
=
Inputs
(
"X"
);
auto
inp_var_names
=
Inputs
(
"X"
);
...
@@ -86,12 +94,11 @@ class SaveCombineOp : public framework::OperatorBase {
...
@@ -86,12 +94,11 @@ class SaveCombineOp : public framework::OperatorBase {
// copy LoD info to the new tensor
// copy LoD info to the new tensor
out
.
set_lod
(
tensor
.
lod
());
out
.
set_lod
(
tensor
.
lod
());
framework
::
TransDataType
(
in_kernel_type
,
out_kernel_type
,
tensor
,
&
out
);
framework
::
TransDataType
(
in_kernel_type
,
out_kernel_type
,
tensor
,
&
out
);
framework
::
SerializeToStream
(
fout
,
out
,
dev_ctx
);
framework
::
SerializeToStream
(
*
fout
,
out
,
dev_ctx
);
}
else
{
}
else
{
framework
::
SerializeToStream
(
fout
,
tensor
,
dev_ctx
);
framework
::
SerializeToStream
(
*
fout
,
tensor
,
dev_ctx
);
}
}
}
}
fout
.
close
();
}
}
};
};
...
@@ -124,6 +131,18 @@ to a file on disk.
...
@@ -124,6 +131,18 @@ to a file on disk.
"The
\"
file_path
\"
where the LoDTensor variables will be saved."
)
"The
\"
file_path
\"
where the LoDTensor variables will be saved."
)
.
AddCustomChecker
(
.
AddCustomChecker
(
[](
const
std
::
string
&
path
)
{
return
!
path
.
empty
();
});
[](
const
std
::
string
&
path
)
{
return
!
path
.
empty
();
});
AddAttr
<
std
::
string
>
(
"format"
,
R"DOC((windows|linux)" "saved model file format
windows and linux file newline symbol is
different. windows(newline is \n\r) or linux(newline is \r)
So if you set attribute format to windows, then we saved model file in binary.
It can be used both linux and windows. If you set format to linux,
it will save file in normal file, newline symbol is \r. Need to note
that these two format is not inter-compatible.)DOC"
)
.
SetDefault
(
"linux"
)
.
AddCustomChecker
([](
const
std
::
string
&
s
)
{
return
s
==
"windows"
||
s
==
"linux"
;
});
}
}
};
};
...
...
paddle/fluid/operators/save_op.cc
浏览文件 @
60f70b17
...
@@ -14,6 +14,7 @@ limitations under the License. */
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include <stdint.h>
#include <stdint.h>
#include <fstream>
#include <fstream>
#include <memory>
#include <numeric>
#include <numeric>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type.h"
...
@@ -64,6 +65,7 @@ class SaveOp : public framework::OperatorBase {
...
@@ -64,6 +65,7 @@ class SaveOp : public framework::OperatorBase {
framework
::
Variable
*
var
)
const
{
framework
::
Variable
*
var
)
const
{
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
overwrite
=
Attr
<
bool
>
(
"overwrite"
);
auto
overwrite
=
Attr
<
bool
>
(
"overwrite"
);
auto
format
=
Attr
<
std
::
string
>
(
"format"
);
if
(
FileExists
(
filename
)
&&
!
overwrite
)
{
if
(
FileExists
(
filename
)
&&
!
overwrite
)
{
PADDLE_THROW
(
"%s is existed, cannot save to it when overwrite=false"
,
PADDLE_THROW
(
"%s is existed, cannot save to it when overwrite=false"
,
...
@@ -80,8 +82,14 @@ class SaveOp : public framework::OperatorBase {
...
@@ -80,8 +82,14 @@ class SaveOp : public framework::OperatorBase {
// FIXME(yuyang18): We save variable to local file now, but we should change
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
// it to save an output stream.
std
::
ofstream
fout
(
filename
);
std
::
unique_ptr
<
std
::
ofstream
>
fout
;
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fout
),
"Cannot open %s to write"
,
if
(
format
==
"windows"
)
{
fout
.
reset
(
new
std
::
ofstream
(
filename
,
std
::
ios_base
::
out
|
std
::
ios_base
::
binary
));
}
else
{
fout
.
reset
(
new
std
::
ofstream
(
filename
));
}
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
*
fout
),
"Cannot open %s to write"
,
filename
);
filename
);
auto
save_as_fp16
=
Attr
<
bool
>
(
"save_as_fp16"
);
auto
save_as_fp16
=
Attr
<
bool
>
(
"save_as_fp16"
);
...
@@ -95,11 +103,10 @@ class SaveOp : public framework::OperatorBase {
...
@@ -95,11 +103,10 @@ class SaveOp : public framework::OperatorBase {
framework
::
TransDataType
(
in_kernel_type
,
out_kernel_type
,
tensor
,
&
out
);
framework
::
TransDataType
(
in_kernel_type
,
out_kernel_type
,
tensor
,
&
out
);
// copy LoD info to the new tensor
// copy LoD info to the new tensor
out
.
set_lod
(
tensor
.
lod
());
out
.
set_lod
(
tensor
.
lod
());
framework
::
SerializeToStream
(
fout
,
out
,
dev_ctx
);
framework
::
SerializeToStream
(
*
fout
,
out
,
dev_ctx
);
}
else
{
}
else
{
framework
::
SerializeToStream
(
fout
,
tensor
,
dev_ctx
);
framework
::
SerializeToStream
(
*
fout
,
tensor
,
dev_ctx
);
}
}
fout
.
close
();
}
}
void
SaveSelectedRows
(
const
framework
::
Scope
&
scope
,
void
SaveSelectedRows
(
const
framework
::
Scope
&
scope
,
...
@@ -110,6 +117,7 @@ class SaveOp : public framework::OperatorBase {
...
@@ -110,6 +117,7 @@ class SaveOp : public framework::OperatorBase {
lt_var
!=
nullptr
,
lt_var
!=
nullptr
,
"Can not find variable kLookupTablePath for SaveSelectedRows"
);
"Can not find variable kLookupTablePath for SaveSelectedRows"
);
std
::
string
filename
=
lt_var
->
data
();
std
::
string
filename
=
lt_var
->
data
();
auto
format
=
Attr
<
std
::
string
>
(
"format"
);
VLOG
(
4
)
<<
"SaveSelectedRows get File name: "
<<
filename
;
VLOG
(
4
)
<<
"SaveSelectedRows get File name: "
<<
filename
;
MkDirRecursively
(
DirName
(
filename
).
c_str
());
MkDirRecursively
(
DirName
(
filename
).
c_str
());
...
@@ -122,11 +130,16 @@ class SaveOp : public framework::OperatorBase {
...
@@ -122,11 +130,16 @@ class SaveOp : public framework::OperatorBase {
// FIXME(yuyang18): We save variable to local file now, but we should change
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
// it to save an output stream.
std
::
ofstream
fout
(
filename
);
std
::
unique_ptr
<
std
::
ofstream
>
fout
;
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fout
),
"Cannot open %s to write"
,
if
(
format
==
"windows"
)
{
fout
.
reset
(
new
std
::
ofstream
(
filename
,
std
::
ios_base
::
out
|
std
::
ios_base
::
binary
));
}
else
{
fout
.
reset
(
new
std
::
ofstream
(
filename
));
}
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
*
fout
),
"Cannot open %s to write"
,
filename
);
filename
);
framework
::
SerializeToStream
(
fout
,
selectedRows
,
dev_ctx
);
framework
::
SerializeToStream
(
*
fout
,
selectedRows
,
dev_ctx
);
fout
.
close
();
}
}
};
};
...
@@ -154,6 +167,18 @@ This operator will serialize and write LoDTensor / SelectedRows variable to file
...
@@ -154,6 +167,18 @@ This operator will serialize and write LoDTensor / SelectedRows variable to file
"The
\"
file_path
\"
where the variable will be saved."
)
"The
\"
file_path
\"
where the variable will be saved."
)
.
AddCustomChecker
(
.
AddCustomChecker
(
[](
const
std
::
string
&
path
)
{
return
!
path
.
empty
();
});
[](
const
std
::
string
&
path
)
{
return
!
path
.
empty
();
});
AddAttr
<
std
::
string
>
(
"format"
,
R"DOC((windows|linux)" "saved model file format
windows and linux file newline symbol is
different. windows(newline is \n\r) or linux(newline is \r)
So if you set attribute format to windows, then we saved model file in binary.
It can be used both linux and windows. If you set format to linux,
it will save file in normal file, newline symbol is \r. Need to note
that these two format is not inter-compatible.)DOC"
)
.
SetDefault
(
"linux"
)
.
AddCustomChecker
([](
const
std
::
string
&
s
)
{
return
s
==
"windows"
||
s
==
"linux"
;
});
}
}
};
};
...
...
paddle/fluid/platform/init.cc
浏览文件 @
60f70b17
...
@@ -94,9 +94,7 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
...
@@ -94,9 +94,7 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
int
count
=
0
;
int
count
=
0
;
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
try
{
try
{
VLOG
(
3
)
<<
"get cuda count"
;
count
=
platform
::
GetCUDADeviceCount
();
count
=
platform
::
GetCUDADeviceCount
();
VLOG
(
3
)
<<
"get cuda pass"
;
}
catch
(
const
std
::
exception
&
exp
)
{
}
catch
(
const
std
::
exception
&
exp
)
{
LOG
(
WARNING
)
<<
"Compiled with WITH_GPU, but no GPU found in runtime."
;
LOG
(
WARNING
)
<<
"Compiled with WITH_GPU, but no GPU found in runtime."
;
}
}
...
@@ -109,14 +107,11 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
...
@@ -109,14 +107,11 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
}
}
places
.
emplace_back
(
platform
::
CUDAPlace
(
devices
[
i
]));
places
.
emplace_back
(
platform
::
CUDAPlace
(
devices
[
i
]));
}
}
VLOG
(
3
)
<<
"before p2p"
;
if
(
init_p2p
)
{
if
(
init_p2p
)
{
InitP2P
(
devices
);
InitP2P
(
devices
);
}
}
VLOG
(
3
)
<<
"p2p pass"
;
places
.
emplace_back
(
platform
::
CPUPlace
());
places
.
emplace_back
(
platform
::
CPUPlace
());
platform
::
DeviceContextPool
::
Init
(
places
);
platform
::
DeviceContextPool
::
Init
(
places
);
VLOG
(
3
)
<<
"init pass"
;
#ifndef PADDLE_WITH_MKLDNN
#ifndef PADDLE_WITH_MKLDNN
platform
::
SetNumThreads
(
FLAGS_paddle_num_threads
);
platform
::
SetNumThreads
(
FLAGS_paddle_num_threads
);
#endif
#endif
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录