Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
ea97e83e
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ea97e83e
编写于
10月 13, 2018
作者:
Q
Qiao Longfei
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into optimize-pyreader
上级
fc77b504
abf1005b
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
592 addition
and
209 deletion
+592
-209
paddle/fluid/inference/api/demo_ci/run.sh
paddle/fluid/inference/api/demo_ci/run.sh
+3
-5
paddle/fluid/inference/api/demo_ci/trt_mobilenet_demo.cc
paddle/fluid/inference/api/demo_ci/trt_mobilenet_demo.cc
+82
-0
paddle/fluid/inference/api/demo_ci/utils.h
paddle/fluid/inference/api/demo_ci/utils.h
+59
-0
paddle/fluid/inference/api/demo_ci/vis_demo.cc
paddle/fluid/inference/api/demo_ci/vis_demo.cc
+14
-91
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+2
-1
paddle/fluid/operators/clip_by_norm_op.h
paddle/fluid/operators/clip_by_norm_op.h
+37
-3
paddle/fluid/operators/math/depthwise_conv.cu
paddle/fluid/operators/math/depthwise_conv.cu
+210
-65
paddle/fluid/operators/truncated_gaussian_random_op.cc
paddle/fluid/operators/truncated_gaussian_random_op.cc
+1
-1
paddle/fluid/operators/truncated_gaussian_random_op.cu
paddle/fluid/operators/truncated_gaussian_random_op.cu
+2
-1
paddle/fluid/operators/uniform_random_op.cc
paddle/fluid/operators/uniform_random_op.cc
+16
-16
paddle/fluid/platform/profiler.cc
paddle/fluid/platform/profiler.cc
+47
-18
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+44
-1
python/paddle/fluid/layers/io.py
python/paddle/fluid/layers/io.py
+5
-1
python/paddle/fluid/layers/ops.py
python/paddle/fluid/layers/ops.py
+12
-5
python/paddle/fluid/layers/tensor.py
python/paddle/fluid/layers/tensor.py
+1
-1
python/paddle/fluid/tests/unittests/test_clip_by_norm_op.py
python/paddle/fluid/tests/unittests/test_clip_by_norm_op.py
+57
-0
未找到文件。
paddle/fluid/inference/api/demo_ci/run.sh
浏览文件 @
ea97e83e
...
...
@@ -100,19 +100,17 @@ for WITH_STATIC_LIB in ON OFF; do
rm
-rf
*
cmake ..
-DPADDLE_LIB
=
${
PADDLE_ROOT
}
/build/fluid_install_dir/
\
-DWITH_MKL
=
$TURN_ON_MKL
\
-DDEMO_NAME
=
vis
_demo
\
-DDEMO_NAME
=
trt_mobilenet
_demo
\
-DWITH_GPU
=
$TEST_GPU_CPU
\
-DWITH_STATIC_LIB
=
$WITH_STATIC_LIB
\
-DUSE_TENSORRT
=
$USE_TENSORRT
\
-DTENSORRT_INCLUDE_DIR
=
$TENSORRT_INCLUDE_DIR
\
-DTENSORRT_LIB_DIR
=
$TENSORRT_LIB_DIR
make
-j
./
vis
_demo
\
./
trt_mobilenet
_demo
\
--modeldir
=
$DATA_DIR
/mobilenet/model
\
--data
=
$DATA_DIR
/mobilenet/data.txt
\
--refer
=
$DATA_DIR
/mobilenet/result.txt
\
--use_gpu
=
true
\
--use_trt
=
true
--refer
=
$DATA_DIR
/mobilenet/result.txt
fi
done
set
+x
paddle/fluid/inference/api/demo_ci/trt_mobilenet_demo.cc
0 → 100644
浏览文件 @
ea97e83e
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
* This file contains demo of mobilenet for tensorrt.
*/
#include <gflags/gflags.h>
#include <glog/logging.h> // use glog instead of CHECK to avoid importing other paddle header files.
#include "paddle/fluid/inference/demo_ci/utils.h"
DECLARE_double
(
fraction_of_gpu_memory_to_use
);
DEFINE_string
(
modeldir
,
""
,
"Directory of the inference model."
);
DEFINE_string
(
refer
,
""
,
"path to reference result for comparison."
);
DEFINE_string
(
data
,
""
,
"path of data; each line is a record, format is "
"'<space splitted floats as data>
\t
<space splitted ints as shape'"
);
namespace
paddle
{
namespace
demo
{
/*
* Use the tensorrt fluid engine to inference the demo.
*/
void
Main
()
{
std
::
unique_ptr
<
PaddlePredictor
>
predictor
;
paddle
::
contrib
::
MixedRTConfig
config
;
config
.
param_file
=
FLAGS_modeldir
+
"/__params__"
;
config
.
prog_file
=
FLAGS_modeldir
+
"/__model__"
;
config
.
use_gpu
=
true
;
config
.
device
=
0
;
config
.
max_batch_size
=
1
;
config
.
fraction_of_gpu_memory
=
0.1
;
// set by yourself
predictor
=
CreatePaddlePredictor
<
paddle
::
contrib
::
MixedRTConfig
>
(
config
);
VLOG
(
3
)
<<
"begin to process data"
;
// Just a single batch of data.
std
::
string
line
;
std
::
ifstream
file
(
FLAGS_data
);
std
::
getline
(
file
,
line
);
auto
record
=
ProcessALine
(
line
);
file
.
close
();
// Inference.
PaddleTensor
input
;
input
.
shape
=
record
.
shape
;
input
.
data
=
PaddleBuf
(
record
.
data
.
data
(),
record
.
data
.
size
()
*
sizeof
(
float
));
input
.
dtype
=
PaddleDType
::
FLOAT32
;
VLOG
(
3
)
<<
"run executor"
;
std
::
vector
<
PaddleTensor
>
output
;
predictor
->
Run
({
input
},
&
output
,
1
);
VLOG
(
3
)
<<
"output.size "
<<
output
.
size
();
auto
&
tensor
=
output
.
front
();
VLOG
(
3
)
<<
"output: "
<<
SummaryTensor
(
tensor
);
// compare with reference result
CheckOutput
(
FLAGS_refer
,
tensor
);
}
}
// namespace demo
}
// namespace paddle
int
main
(
int
argc
,
char
**
argv
)
{
google
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
true
);
paddle
::
demo
::
Main
();
return
0
;
}
paddle/fluid/inference/api/demo_ci/utils.h
浏览文件 @
ea97e83e
...
...
@@ -14,6 +14,8 @@
#pragma once
#include <algorithm>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
#include "paddle/fluid/inference/paddle_inference_api.h"
...
...
@@ -21,6 +23,11 @@
namespace
paddle
{
namespace
demo
{
struct
Record
{
std
::
vector
<
float
>
data
;
std
::
vector
<
int32_t
>
shape
;
};
static
void
split
(
const
std
::
string
&
str
,
char
sep
,
std
::
vector
<
std
::
string
>*
pieces
)
{
pieces
->
clear
();
...
...
@@ -39,6 +46,58 @@ static void split(const std::string& str, char sep,
}
}
Record
ProcessALine
(
const
std
::
string
&
line
)
{
VLOG
(
3
)
<<
"process a line"
;
std
::
vector
<
std
::
string
>
columns
;
split
(
line
,
'\t'
,
&
columns
);
CHECK_EQ
(
columns
.
size
(),
2UL
)
<<
"data format error, should be <data>
\t
<shape>"
;
Record
record
;
std
::
vector
<
std
::
string
>
data_strs
;
split
(
columns
[
0
],
' '
,
&
data_strs
);
for
(
auto
&
d
:
data_strs
)
{
record
.
data
.
push_back
(
std
::
stof
(
d
));
}
std
::
vector
<
std
::
string
>
shape_strs
;
split
(
columns
[
1
],
' '
,
&
shape_strs
);
for
(
auto
&
s
:
shape_strs
)
{
record
.
shape
.
push_back
(
std
::
stoi
(
s
));
}
VLOG
(
3
)
<<
"data size "
<<
record
.
data
.
size
();
VLOG
(
3
)
<<
"data shape size "
<<
record
.
shape
.
size
();
return
record
;
}
void
CheckOutput
(
const
std
::
string
&
referfile
,
const
PaddleTensor
&
output
)
{
std
::
string
line
;
std
::
ifstream
file
(
referfile
);
std
::
getline
(
file
,
line
);
auto
refer
=
ProcessALine
(
line
);
file
.
close
();
size_t
numel
=
output
.
data
.
length
()
/
PaddleDtypeSize
(
output
.
dtype
);
VLOG
(
3
)
<<
"predictor output numel "
<<
numel
;
VLOG
(
3
)
<<
"reference output numel "
<<
refer
.
data
.
size
();
CHECK_EQ
(
numel
,
refer
.
data
.
size
());
switch
(
output
.
dtype
)
{
case
PaddleDType
::
INT64
:
{
for
(
size_t
i
=
0
;
i
<
numel
;
++
i
)
{
CHECK_EQ
(
static_cast
<
int64_t
*>
(
output
.
data
.
data
())[
i
],
refer
.
data
[
i
]);
}
break
;
}
case
PaddleDType
::
FLOAT32
:
for
(
size_t
i
=
0
;
i
<
numel
;
++
i
)
{
CHECK_LT
(
fabs
(
static_cast
<
float
*>
(
output
.
data
.
data
())[
i
]
-
refer
.
data
[
i
]),
1e-5
);
}
break
;
}
}
/*
* Get a summary of a PaddleTensor content.
*/
...
...
paddle/fluid/inference/api/demo_ci/vis_demo.cc
浏览文件 @
ea97e83e
...
...
@@ -18,10 +18,6 @@ limitations under the License. */
#include <gflags/gflags.h>
#include <glog/logging.h> // use glog instead of CHECK to avoid importing other paddle header files.
#include <fstream>
#include <iostream>
// #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/inference/demo_ci/utils.h"
#ifdef PADDLE_WITH_CUDA
...
...
@@ -34,99 +30,28 @@ DEFINE_string(
"path of data; each line is a record, format is "
"'<space splitted floats as data>
\t
<space splitted ints as shape'"
);
DEFINE_bool
(
use_gpu
,
false
,
"Whether use gpu."
);
DEFINE_bool
(
use_trt
,
false
,
"Whether use trt."
);
namespace
paddle
{
namespace
demo
{
struct
Record
{
std
::
vector
<
float
>
data
;
std
::
vector
<
int32_t
>
shape
;
};
void
split
(
const
std
::
string
&
str
,
char
sep
,
std
::
vector
<
std
::
string
>*
pieces
);
Record
ProcessALine
(
const
std
::
string
&
line
)
{
VLOG
(
3
)
<<
"process a line"
;
std
::
vector
<
std
::
string
>
columns
;
split
(
line
,
'\t'
,
&
columns
);
CHECK_EQ
(
columns
.
size
(),
2UL
)
<<
"data format error, should be <data>
\t
<shape>"
;
Record
record
;
std
::
vector
<
std
::
string
>
data_strs
;
split
(
columns
[
0
],
' '
,
&
data_strs
);
for
(
auto
&
d
:
data_strs
)
{
record
.
data
.
push_back
(
std
::
stof
(
d
));
}
std
::
vector
<
std
::
string
>
shape_strs
;
split
(
columns
[
1
],
' '
,
&
shape_strs
);
for
(
auto
&
s
:
shape_strs
)
{
record
.
shape
.
push_back
(
std
::
stoi
(
s
));
}
VLOG
(
3
)
<<
"data size "
<<
record
.
data
.
size
();
VLOG
(
3
)
<<
"data shape size "
<<
record
.
shape
.
size
();
return
record
;
}
void
CheckOutput
(
const
std
::
string
&
referfile
,
const
PaddleTensor
&
output
)
{
std
::
string
line
;
std
::
ifstream
file
(
referfile
);
std
::
getline
(
file
,
line
);
auto
refer
=
ProcessALine
(
line
);
file
.
close
();
size_t
numel
=
output
.
data
.
length
()
/
PaddleDtypeSize
(
output
.
dtype
);
VLOG
(
3
)
<<
"predictor output numel "
<<
numel
;
VLOG
(
3
)
<<
"reference output numel "
<<
refer
.
data
.
size
();
CHECK_EQ
(
numel
,
refer
.
data
.
size
());
switch
(
output
.
dtype
)
{
case
PaddleDType
::
INT64
:
{
for
(
size_t
i
=
0
;
i
<
numel
;
++
i
)
{
CHECK_EQ
(
static_cast
<
int64_t
*>
(
output
.
data
.
data
())[
i
],
refer
.
data
[
i
]);
}
break
;
}
case
PaddleDType
::
FLOAT32
:
for
(
size_t
i
=
0
;
i
<
numel
;
++
i
)
{
CHECK_LT
(
fabs
(
static_cast
<
float
*>
(
output
.
data
.
data
())[
i
]
-
refer
.
data
[
i
]),
1e-5
);
}
break
;
}
}
/*
* Use the native fluid engine to inference the demo.
*/
void
Main
(
bool
use_gpu
,
bool
use_trt
)
{
void
Main
(
bool
use_gpu
)
{
std
::
unique_ptr
<
PaddlePredictor
>
predictor
;
if
(
!
use_trt
)
{
NativeConfig
config
;
config
.
param_file
=
FLAGS_modeldir
+
"/__params__"
;
config
.
prog_file
=
FLAGS_modeldir
+
"/__model__"
;
config
.
use_gpu
=
use_gpu
;
config
.
device
=
0
;
if
(
FLAGS_use_gpu
)
{
config
.
fraction_of_gpu_memory
=
0.1
;
// set by yourself
}
VLOG
(
3
)
<<
"init predictor"
;
predictor
=
CreatePaddlePredictor
<
NativeConfig
,
PaddleEngineKind
::
kNative
>
(
config
);
}
else
{
paddle
::
contrib
::
MixedRTConfig
config
;
config
.
param_file
=
FLAGS_modeldir
+
"/__params__"
;
config
.
prog_file
=
FLAGS_modeldir
+
"/__model__"
;
config
.
use_gpu
=
true
;
config
.
device
=
0
;
config
.
max_batch_size
=
1
;
NativeConfig
config
;
config
.
param_file
=
FLAGS_modeldir
+
"/__params__"
;
config
.
prog_file
=
FLAGS_modeldir
+
"/__model__"
;
config
.
use_gpu
=
use_gpu
;
config
.
device
=
0
;
if
(
FLAGS_use_gpu
)
{
config
.
fraction_of_gpu_memory
=
0.1
;
// set by yourself
predictor
=
CreatePaddlePredictor
<
paddle
::
contrib
::
MixedRTConfig
>
(
config
);
}
VLOG
(
3
)
<<
"init predictor"
;
predictor
=
CreatePaddlePredictor
<
NativeConfig
,
PaddleEngineKind
::
kNative
>
(
config
);
VLOG
(
3
)
<<
"begin to process data"
;
// Just a single batch of data.
std
::
string
line
;
...
...
@@ -159,12 +84,10 @@ void Main(bool use_gpu, bool use_trt) {
int
main
(
int
argc
,
char
**
argv
)
{
google
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
true
);
if
(
FLAGS_use_gpu
&&
FLAGS_use_trt
)
{
paddle
::
demo
::
Main
(
true
/*use_gpu*/
,
true
);
}
else
if
(
FLAGS_use_gpu
)
{
paddle
::
demo
::
Main
(
true
/*use_gpu*/
,
false
);
if
(
FLAGS_use_gpu
)
{
paddle
::
demo
::
Main
(
true
/*use_gpu*/
);
}
else
{
paddle
::
demo
::
Main
(
false
/*use_gpu*/
,
false
/*use_tensorrt*/
);
paddle
::
demo
::
Main
(
false
/*use_gpu*/
);
}
return
0
;
}
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
ea97e83e
...
...
@@ -230,7 +230,7 @@ if(WITH_DISTRIBUTE)
op_library
(
${
dist_op
}
DEPS
${
DISTRIBUTE_DEPS
}
)
set_source_files_properties
(
${
dist_op
}
.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
endforeach
()
#set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
#cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op
# listen_and_serv_op sum_op executor SERIAL)
...
...
@@ -268,6 +268,7 @@ if (WITH_GPU AND TENSORRT_FOUND)
else
()
set
(
DEPS_OPS
${
DEPS_OPS
}
tensorrt_engine_op
)
endif
()
op_library
(
clip_by_norm_op DEPS selected_rows_functor selected_rows
)
op_library
(
sum_op DEPS selected_rows_functor
)
op_library
(
sgd_op DEPS selected_rows_functor
)
op_library
(
print_op DEPS lod_tensor
)
...
...
paddle/fluid/operators/clip_by_norm_op.h
浏览文件 @
ea97e83e
...
...
@@ -16,12 +16,15 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/transform.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
SelectedRows
=
framework
::
SelectedRows
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenVector
=
framework
::
EigenVector
<
T
,
MajorType
,
IndexType
>
;
...
...
@@ -31,9 +34,40 @@ class ClipByNormKernel : public framework::OpKernel<T> {
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
max_norm
=
context
.
Attr
<
T
>
(
"max_norm"
);
auto
*
input
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
output
=
context
.
Output
<
Tensor
>
(
"Out"
);
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
in_var
=
context
.
InputVar
(
"X"
);
Tensor
*
output
=
nullptr
;
const
Tensor
*
input
=
nullptr
;
if
(
in_var
->
IsType
<
framework
::
LoDTensor
>
())
{
input
=
context
.
Input
<
Tensor
>
(
"X"
);
output
=
context
.
Output
<
Tensor
>
(
"Out"
);
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
}
else
if
(
in_var
->
IsType
<
SelectedRows
>
())
{
auto
*
x
=
context
.
Input
<
SelectedRows
>
(
"X"
);
// merge ids in selected rows first
math
::
scatter
::
MergeAdd
<
DeviceContext
,
T
>
merge_func
;
SelectedRows
*
merged_input
=
const_cast
<
framework
::
Scope
&>
(
context
.
scope
())
.
Var
()
->
GetMutable
<
SelectedRows
>
();
merge_func
(
context
.
template
device_context
<
DeviceContext
>(),
*
x
,
merged_input
);
input
=
&
(
merged_input
->
value
());
SelectedRows
*
output_selected_rows
=
context
.
Output
<
SelectedRows
>
(
"Out"
);
output_selected_rows
->
set_rows
(
merged_input
->
rows
());
output_selected_rows
->
set_height
(
merged_input
->
height
());
output
=
output_selected_rows
->
mutable_value
();
output
->
Resize
(
merged_input
->
value
().
dims
());
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
}
else
{
PADDLE_THROW
(
"Unexpected branch, input variable type is %s"
,
in_var
->
Type
().
name
());
}
PADDLE_ENFORCE_NOT_NULL
(
input
);
auto
x
=
EigenVector
<
T
>::
Flatten
(
*
input
);
auto
out
=
EigenVector
<
T
>::
Flatten
(
*
output
);
...
...
paddle/fluid/operators/math/depthwise_conv.cu
浏览文件 @
ea97e83e
...
...
@@ -46,17 +46,20 @@ __forceinline__ __device__ unsigned warp_id() {
return
ret
;
}
#define ARG_DEFINE_KernelDepthwiseConv \
const T *const input_data, const T *const filter_data, const int batch_size, \
const int output_channels, const int output_height, \
const int output_width, const int input_channels, \
const int input_height, const int input_width, \
const int filter_multiplier, const int filter_height, \
const int filter_width, const int stride_height, const int stride_width, \
const int padding_height, const int padding_width, \
const int dilate_height, const int dilate_width, T *const output_data
// A Cuda kernel to compute the depthwise convolution forward pass
// in NCHW format.
template
<
typename
T
>
__device__
__inline__
void
KernelDepthwiseConv
(
const
T
*
const
input_data
,
const
T
*
const
filter_data
,
const
int
batch_size
,
const
int
output_channels
,
const
int
output_height
,
const
int
output_width
,
const
int
input_channels
,
const
int
input_height
,
const
int
input_width
,
const
int
filter_multiplier
,
const
int
filter_height
,
const
int
filter_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
,
const
int
dilate_height
,
const
int
dilate_width
,
T
*
const
output_data
)
{
__device__
__inline__
void
KernelDepthwiseConv
(
ARG_DEFINE_KernelDepthwiseConv
)
{
for
(
int
w_out
=
threadIdx
.
x
;
w_out
<
output_width
;
w_out
+=
blockDim
.
x
)
{
for
(
int
h_out
=
threadIdx
.
y
;
h_out
<
output_height
;
h_out
+=
blockDim
.
y
)
{
const
int
batch
=
blockIdx
.
y
;
...
...
@@ -97,42 +100,105 @@ __device__ __inline__ void KernelDepthwiseConv(
}
}
template
<
typename
T
,
int
c_filter_multiplier
,
int
c_stride
>
__global__
void
KernelDepthwiseConvSp
(
const
T
*
const
input_data
,
const
T
*
const
filter_data
,
const
int
batch_size
,
const
int
output_channels
,
const
int
output_height
,
const
int
output_width
,
const
int
input_channels
,
const
int
input_height
,
const
int
input_width
,
const
int
filter_multiplier
,
const
int
filter_height
,
const
int
filter_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
,
const
int
dilate_height
,
const
int
dilate_width
,
T
*
const
output_data
)
{
if
(
c_filter_multiplier
==
0
)
KernelDepthwiseConv
<
T
>
(
input_data
,
filter_data
,
batch_size
,
output_channels
,
output_height
,
output_width
,
input_channels
,
input_height
,
input_width
,
filter_multiplier
,
filter_height
,
filter_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
dilate_height
,
dilate_width
,
output_data
);
template
<
typename
T
,
int
c_filter
>
__device__
__inline__
void
KernelDepthwiseConvCFilter
(
ARG_DEFINE_KernelDepthwiseConv
)
{
const
int
kWeghtSize
=
c_filter
*
c_filter
;
T
r_weight
[
kWeghtSize
];
const
int
batch
=
blockIdx
.
y
;
const
int
c_out
=
blockIdx
.
x
;
const
T
*
weight
=
filter_data
+
c_out
*
c_filter
*
c_filter
;
for
(
int
i
=
0
;
i
<
c_filter
*
c_filter
;
i
++
)
r_weight
[
i
]
=
weight
[
i
];
else
KernelDepthwiseConv
<
T
>
(
input_data
,
filter_data
,
batch_size
,
output_channels
,
output_height
,
output_width
,
input_channels
,
input_height
,
input_width
,
c_filter_multiplier
,
filter_height
,
filter_height
,
c_stride
,
c_stride
,
padding_height
,
padding_width
,
dilate_height
,
dilate_width
,
output_data
);
for
(
int
w_out
=
threadIdx
.
x
;
w_out
<
output_width
;
w_out
+=
blockDim
.
x
)
{
for
(
int
h_out
=
threadIdx
.
y
;
h_out
<
output_height
;
h_out
+=
blockDim
.
y
)
{
const
int
batch
=
blockIdx
.
y
;
const
int
c_out
=
blockIdx
.
x
;
const
int
c_in
=
c_out
/
filter_multiplier
;
T
value
=
0
;
const
int
h_in_start
=
-
padding_height
+
h_out
*
stride_height
;
const
int
w_in_start
=
-
padding_width
+
w_out
*
stride_width
;
const
int
h_in_end
=
h_in_start
+
c_filter
*
dilate_height
;
const
int
w_in_end
=
w_in_start
+
c_filter
*
dilate_width
;
const
int
in_offset
=
((
batch
*
input_channels
+
c_in
)
*
input_height
)
*
input_width
;
const
int
h_end
=
h_in_end
<
input_height
?
h_in_end
:
input_height
;
const
int
w_end
=
w_in_end
<
input_width
?
w_in_end
:
input_width
;
const
int
h_start
=
h_in_start
>
0
?
h_in_start
:
0
;
const
int
w_start
=
w_in_start
>
0
?
w_in_start
:
0
;
for
(
int
h_in
=
h_in_start
,
h_f
=
0
;
h_f
<
c_filter
;
h_in
+=
dilate_height
,
h_f
++
)
{
for
(
int
w_in
=
w_in_start
,
w_f
=
0
;
w_f
<
c_filter
;
w_in
+=
dilate_width
,
w_f
++
)
{
if
(
h_in
>=
0
&&
h_in
<
input_height
&&
w_in
>=
0
&&
w_in
<
input_width
)
{
const
int
offset
=
in_offset
+
h_in
*
input_width
+
w_in
;
value
+=
r_weight
[
h_f
*
c_filter
+
w_f
]
*
input_data
[
offset
];
}
}
}
int
index
=
((
batch
*
gridDim
.
x
+
c_out
)
*
output_height
+
h_out
)
*
output_width
+
w_out
;
output_data
[
index
]
=
value
;
}
}
}
template
<
typename
T
,
int
c_filter_multiplier
,
int
c_stride
,
int
c_filter
>
__global__
void
KernelDepthwiseConvSp
(
ARG_DEFINE_KernelDepthwiseConv
)
{
if
(
c_filter_multiplier
==
0
)
{
if
(
c_filter
==
-
1
)
KernelDepthwiseConv
<
T
>
(
input_data
,
filter_data
,
batch_size
,
output_channels
,
output_height
,
output_width
,
input_channels
,
input_height
,
input_width
,
filter_multiplier
,
filter_height
,
filter_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
dilate_height
,
dilate_width
,
output_data
);
else
KernelDepthwiseConvCFilter
<
T
,
c_filter
>
(
input_data
,
filter_data
,
batch_size
,
output_channels
,
output_height
,
output_width
,
input_channels
,
input_height
,
input_width
,
filter_multiplier
,
filter_height
,
filter_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
dilate_height
,
dilate_width
,
output_data
);
}
else
{
if
(
c_filter
==
-
1
)
KernelDepthwiseConv
<
T
>
(
input_data
,
filter_data
,
batch_size
,
output_channels
,
output_height
,
output_width
,
input_channels
,
input_height
,
input_width
,
c_filter_multiplier
,
filter_height
,
filter_height
,
c_stride
,
c_stride
,
padding_height
,
padding_width
,
dilate_height
,
dilate_width
,
output_data
);
else
KernelDepthwiseConvCFilter
<
T
,
c_filter
>
(
input_data
,
filter_data
,
batch_size
,
output_channels
,
output_height
,
output_width
,
input_channels
,
input_height
,
input_width
,
c_filter_multiplier
,
filter_height
,
filter_height
,
c_stride
,
c_stride
,
padding_height
,
padding_width
,
dilate_height
,
dilate_width
,
output_data
);
}
}
// CUDA kernel to compute the depthwise convolution backprop w.r.t input.
#define ARG_DEFINE_KernelDepthwiseConvInputGrad \
const T *const output_grad_data, const T *const filter_data, \
const int batch_size, const int output_channels, \
const int output_height, const int output_width, \
const int input_channels, const int input_height, const int input_width, \
const int filter_multiplier, const int filter_height, \
const int filter_width, const int stride_height, const int stride_width, \
const int padding_height, const int padding_width, \
const int dilate_height, const int dilate_width, \
T *const input_grad_data
template
<
typename
T
>
__device__
__inline__
void
KernelDepthwiseConvInputGrad
(
const
T
*
const
output_grad_data
,
const
T
*
const
filter_data
,
const
int
batch_size
,
const
int
output_channels
,
const
int
output_height
,
const
int
output_width
,
const
int
input_channels
,
const
int
input_height
,
const
int
input_width
,
const
int
filter_multiplier
,
const
int
filter_height
,
const
int
filter_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
,
const
int
dilate_height
,
const
int
dilate_width
,
T
*
const
input_grad_data
)
{
ARG_DEFINE_KernelDepthwiseConvInputGrad
)
{
for
(
int
w_in
=
threadIdx
.
x
;
w_in
<
input_width
;
w_in
+=
blockDim
.
x
)
{
for
(
int
h_in
=
threadIdx
.
y
;
h_in
<
input_height
;
h_in
+=
blockDim
.
y
)
{
const
int
batch
=
blockIdx
.
y
;
...
...
@@ -184,15 +250,67 @@ __device__ __inline__ void KernelDepthwiseConvInputGrad(
}
}
template
<
typename
T
,
int
c_filter_multiplier
,
int
c_stride
>
template
<
typename
T
,
int
c_filter
,
int
c_filter_multiplier
>
__device__
__inline__
void
KernelDepthwiseConvInputGradCFilter
(
ARG_DEFINE_KernelDepthwiseConvInputGrad
)
{
const
int
kWeghtSize
=
c_filter
*
c_filter
*
c_filter_multiplier
+
1
;
T
r_weight
[
kWeghtSize
];
const
int
batch
=
blockIdx
.
y
;
const
int
c_in
=
blockIdx
.
x
;
for
(
int
c_i
=
0
;
c_i
<
filter_multiplier
;
c_i
++
)
{
int
c_out
=
c_in
*
filter_multiplier
+
c_i
;
const
T
*
weight
=
filter_data
+
c_out
*
c_filter
*
c_filter
;
for
(
int
i
=
0
;
i
<
c_filter
*
c_filter
;
i
++
)
r_weight
[
i
+
c_i
*
c_filter
*
c_filter
]
=
weight
[
c_filter
*
c_filter
-
i
-
1
];
}
for
(
int
w_in
=
threadIdx
.
x
;
w_in
<
input_width
;
w_in
+=
blockDim
.
x
)
{
for
(
int
h_in
=
threadIdx
.
y
;
h_in
<
input_height
;
h_in
+=
blockDim
.
y
)
{
const
int
batch
=
blockIdx
.
y
;
const
int
c_in
=
blockIdx
.
x
;
int
h_out_start
=
h_in
-
(
c_filter
-
1
)
*
dilate_height
+
padding_height
;
int
w_out_start
=
w_in
-
(
c_filter
-
1
)
*
dilate_width
+
padding_width
;
T
value
=
0
;
for
(
int
c_i
=
0
;
c_i
<
filter_multiplier
;
c_i
++
)
{
int
c_out
=
c_in
*
filter_multiplier
+
c_i
;
for
(
int
h_out
=
h_out_start
,
h_f
=
0
;
h_f
<
c_filter
;
h_out
+=
dilate_height
,
h_f
++
)
{
for
(
int
w_out
=
w_out_start
,
w_f
=
0
;
w_f
<
c_filter
;
w_out
+=
dilate_width
,
w_f
++
)
{
int
s_h_out
=
h_out
/
stride_height
;
int
s_w_out
=
w_out
/
stride_width
;
if
(
h_out
%
stride_height
==
0
&&
w_out
%
stride_width
==
0
&&
s_h_out
>=
0
&&
s_h_out
<
output_height
&&
s_w_out
>=
0
&&
s_w_out
<
output_width
)
{
const
int
output_grad_offset
=
((
batch
*
output_channels
+
c_out
)
*
output_height
+
s_h_out
)
*
output_width
+
s_w_out
;
value
+=
output_grad_data
[
output_grad_offset
]
*
r_weight
[
h_f
*
c_filter
+
w_f
+
c_i
*
c_filter
*
c_filter
];
}
}
}
}
int
index
=
((
batch
*
gridDim
.
x
+
c_in
)
*
input_height
+
h_in
)
*
input_width
+
w_in
;
input_grad_data
[
index
]
=
value
;
}
}
}
template
<
typename
T
,
int
c_filter_multiplier
,
int
c_stride
,
int
c_filter
>
__global__
void
KernelDepthwiseConvInputGradSp
(
const
T
*
const
output_grad_data
,
const
T
*
const
filter_data
,
const
int
batch_size
,
const
int
output_channels
,
const
int
output_height
,
const
int
output_width
,
const
int
input_channels
,
const
int
input_height
,
const
int
input_width
,
const
int
filter_multiplier
,
const
int
filter_height
,
const
int
filter_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
,
const
int
dilate_height
,
const
int
dilate_width
,
T
*
const
input_grad_data
)
{
ARG_DEFINE_KernelDepthwiseConvInputGrad
)
{
if
(
c_filter_multiplier
==
0
)
KernelDepthwiseConvInputGrad
<
T
>
(
output_grad_data
,
filter_data
,
batch_size
,
output_channels
,
...
...
@@ -200,13 +318,20 @@ __global__ void KernelDepthwiseConvInputGradSp(
filter_multiplier
,
filter_height
,
filter_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
dilate_height
,
dilate_width
,
input_grad_data
);
else
else
if
(
c_filter
==
-
1
)
KernelDepthwiseConvInputGrad
<
T
>
(
output_grad_data
,
filter_data
,
batch_size
,
output_channels
,
output_height
,
output_width
,
input_channels
,
input_height
,
input_width
,
c_filter_multiplier
,
filter_height
,
filter_width
,
c_stride
,
c_stride
,
padding_height
,
padding_width
,
dilate_height
,
dilate_width
,
input_grad_data
);
else
KernelDepthwiseConvInputGradCFilter
<
T
,
c_filter
,
c_filter_multiplier
>
(
output_grad_data
,
filter_data
,
batch_size
,
output_channels
,
output_height
,
output_width
,
input_channels
,
input_height
,
input_width
,
c_filter_multiplier
,
filter_height
,
filter_width
,
c_stride
,
c_stride
,
padding_height
,
padding_width
,
dilate_height
,
dilate_width
,
input_grad_data
);
}
// Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
...
...
@@ -325,12 +450,14 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T> {
dim3
threads
(
std
::
min
(
output_width
,
thread
),
blocks
,
1
);
dim3
grid
(
output_channels
,
batch_size
,
1
);
int
filter_multiplier
=
output_channels
/
input_channels
;
#define check_case(c_filter_multiplier, c_stride
)
\
#define check_case(c_filter_multiplier, c_stride
, c_filter)
\
if (c_filter_multiplier == 0 || \
filter_multiplier == c_filter_multiplier && \
stride_height == stride_width && stride_height == c_stride) { \
KernelDepthwiseConvSp<T, c_filter_multiplier, \
c_stride><<<grid, threads, 0, context.stream()>>>( \
stride_height == stride_width && stride_height == c_stride && \
(ksize_height == ksize_width && ksize_height == c_filter || \
c_filter == -1)) { \
KernelDepthwiseConvSp<T, c_filter_multiplier, c_stride, \
c_filter><<<grid, threads, 0, context.stream()>>>( \
input_data, filter_data, batch_size, output_channels, output_height, \
output_width, input_channels, input_height, input_width, \
filter_multiplier, ksize_height, ksize_width, stride_height, \
...
...
@@ -338,11 +465,17 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T> {
dilate_width, output_data); \
return; \
}
check_case
(
1
,
1
);
check_case
(
1
,
2
);
// NOTE(liangdun): 0,0 for other case
// add other case if needed, e.g. check_case(2^n,1)
check_case
(
0
,
0
);
check_case
(
1
,
1
,
3
);
check_case
(
1
,
1
,
5
);
check_case
(
1
,
1
,
-
1
);
check_case
(
1
,
2
,
3
);
check_case
(
1
,
2
,
5
);
check_case
(
1
,
2
,
-
1
);
check_case
(
0
,
0
,
3
);
check_case
(
0
,
0
,
5
);
check_case
(
0
,
0
,
-
1
);
// NOTE(liangdun): 0,0 for other case
// add other case if needed, e.g. check_case(2^n,1)
#undef check_case
}
};
...
...
@@ -384,13 +517,15 @@ class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T> {
dim3
grid
(
input_channels
,
batch_size
,
1
);
int
filter_multiplier
=
output_channels
/
input_channels
;
#define check_case(c_filter_multiplier, c_stride
)
\
#define check_case(c_filter_multiplier, c_stride
, c_filter)
\
if (c_filter_multiplier == 0 || \
filter_multiplier == c_filter_multiplier && \
stride_height == stride_width && stride_height == c_stride) { \
stride_height == stride_width && stride_height == c_stride && \
(ksize_height == ksize_width && ksize_height == c_filter || \
c_filter == -1)) { \
KernelDepthwiseConvInputGradSp< \
T, c_filter_multiplier,
\
c_
stride
><<<grid, threads, 0, context.stream()>>>( \
T, c_filter_multiplier,
c_stride,
\
c_
filter
><<<grid, threads, 0, context.stream()>>>( \
output_grad_data, filter_data, batch_size, output_channels, \
output_height, output_width, input_channels, input_height, \
input_width, filter_multiplier, ksize_height, ksize_width, \
...
...
@@ -398,11 +533,21 @@ class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T> {
dilate_height, dilate_width, input_grad_data); \
return; \
}
check_case
(
1
,
1
);
check_case
(
1
,
2
);
// NOTE(liangdun): 0,0 for other case
// add other case if needed, e.g. check_case(2^n,1)
check_case
(
0
,
0
);
check_case
(
1
,
1
,
3
);
check_case
(
1
,
1
,
5
);
check_case
(
1
,
1
,
-
1
);
check_case
(
1
,
2
,
3
);
check_case
(
1
,
2
,
5
);
check_case
(
1
,
2
,
-
1
);
check_case
(
2
,
1
,
3
);
check_case
(
2
,
1
,
5
);
check_case
(
2
,
1
,
-
1
);
check_case
(
2
,
2
,
3
);
check_case
(
2
,
2
,
5
);
check_case
(
2
,
2
,
-
1
);
check_case
(
0
,
0
,
-
1
);
// NOTE(liangdun): 0,0 for other case
// add other case if needed, e.g. check_case(2^n,1)
#undef check_case
}
};
...
...
paddle/fluid/operators/truncated_gaussian_random_op.cc
浏览文件 @
ea97e83e
...
...
@@ -148,7 +148,7 @@ struct TruncatedNormal {
T
operator
()(
T
value
)
const
{
auto
p
=
a_normal_cdf
+
(
b_normal_cdf
-
a_normal_cdf
)
*
value
;
return
(
std
::
sqrt
(
2.0
)
*
Erfinv
(
2
*
p
-
1
)
+
mean
)
*
std
;
return
std
::
sqrt
(
2.0
)
*
Erfinv
(
2
*
p
-
1
)
*
std
+
mean
;
}
};
...
...
paddle/fluid/operators/truncated_gaussian_random_op.cu
浏览文件 @
ea97e83e
...
...
@@ -42,7 +42,7 @@ struct TruncatedNormal {
rng
.
discard
(
n
);
T
value
=
dist
(
rng
);
auto
p
=
a_normal_cdf
+
(
b_normal_cdf
-
a_normal_cdf
)
*
value
;
return
(
std
::
sqrt
(
2.0
)
*
erfinvf
(
2
*
p
-
1
)
+
mean
)
*
std
;
return
std
::
sqrt
(
2.0
)
*
erfinvf
(
2
*
p
-
1
)
*
std
+
mean
;
}
};
...
...
@@ -52,6 +52,7 @@ class GPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
context
.
Attr
<
int
>
(
"seed"
));
if
(
seed
==
0
)
{
std
::
random_device
rd
;
...
...
paddle/fluid/operators/uniform_random_op.cc
浏览文件 @
ea97e83e
...
...
@@ -23,14 +23,14 @@ namespace operators {
template
<
typename
T
>
class
CPUUniformRandomKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
Tensor
*
tensor
=
nullptr
;
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
Tensor
*
tensor
=
nullptr
;
auto
out_var
=
ctx
.
OutputVar
(
"Out"
);
if
(
out_var
->
IsType
<
framework
::
LoDTensor
>
())
{
tensor
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
}
else
if
(
out_var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
shape
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"shape"
);
auto
*
selected_rows
=
out_var
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
*
selected_rows
=
out_var
->
GetMutable
<
framework
::
SelectedRows
>
();
tensor
=
selected_rows
->
mutable_value
();
tensor
->
Resize
(
framework
::
make_ddim
(
shape
));
selected_rows
->
mutable_rows
()
->
reserve
(
shape
[
0
]);
...
...
@@ -39,7 +39,7 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> {
"uniform_random_op's output only"
"supports SelectedRows and LoDTensor"
);
}
T
*
data
=
tensor
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
data
=
tensor
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
ctx
.
Attr
<
int
>
(
"seed"
));
std
::
minstd_rand
engine
;
if
(
seed
==
0
)
{
...
...
@@ -60,14 +60,14 @@ class UniformRandomOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of UniformRandomOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
Attrs
().
Get
<
float
>
(
"min"
)
<
ctx
->
Attrs
().
Get
<
float
>
(
"max"
),
"uniform_random's min must less then max"
);
auto
&
shape
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"shape"
);
auto
&
shape
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"shape"
);
std
::
vector
<
int64_t
>
temp
;
temp
.
reserve
(
shape
.
size
());
for
(
auto
dim
:
shape
)
{
...
...
@@ -78,7 +78,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
ctx
.
Attr
<
int
>
(
"dtype"
)),
ctx
.
GetPlace
());
...
...
@@ -112,17 +112,17 @@ uniform distribution. The random result is in set [min, max].
class
UniformRandomOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
out_var_name
=
op_desc
.
Output
(
"Out"
).
front
();
if
(
block
->
FindRecursiveOrCreateVar
(
out_var_name
).
GetType
()
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
block
->
FindRecursiveOrCreateVar
(
out_var_name
)
.
SetType
(
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
else
{
block
->
FindRecursiveOrCreateVar
(
out_var_name
)
.
SetType
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
auto
var_data_type
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"dtype"
)));
auto
out_var
=
block
->
FindRecursiveOrCreateVar
(
out_var_name
);
if
(
out_var
.
GetType
()
!=
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
out_var
.
SetType
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
out_var
.
SetDataType
(
var_data_type
);
}
};
...
...
paddle/fluid/platform/profiler.cc
浏览文件 @
ea97e83e
...
...
@@ -276,7 +276,7 @@ struct EventItem {
// Print results
void
PrintProfiler
(
const
std
::
vector
<
std
::
vector
<
EventItem
>>&
events_table
,
const
std
::
string
&
sorted_domain
,
const
size_t
name_width
,
const
size_t
data_width
,
double
total
)
{
const
size_t
data_width
,
bool
merge_thread
)
{
// Output header information
std
::
cout
<<
"
\n
------------------------->"
<<
" Profiling Report "
...
...
@@ -292,6 +292,10 @@ void PrintProfiler(const std::vector<std::vector<EventItem>>& events_table,
PADDLE_THROW
(
"Invalid profiler state"
,
g_state
);
}
if
(
merge_thread
)
{
std
::
cout
<<
"Note! This Report merge all thread info into one."
<<
std
::
endl
;
}
std
::
cout
<<
"Place: "
<<
place
<<
std
::
endl
;
std
::
cout
<<
"Time unit: ms"
<<
std
::
endl
;
std
::
cout
<<
"Sorted by "
<<
sorted_domain
...
...
@@ -312,8 +316,7 @@ void PrintProfiler(const std::vector<std::vector<EventItem>>& events_table,
<<
std
::
setw
(
data_width
)
<<
event_item
.
min_time
<<
std
::
setw
(
data_width
)
<<
event_item
.
max_time
<<
std
::
setw
(
data_width
)
<<
event_item
.
ave_time
<<
std
::
setw
(
data_width
)
<<
event_item
.
total_time
/
total
<<
std
::
endl
;
<<
std
::
setw
(
data_width
)
<<
event_item
.
ratio
<<
std
::
endl
;
}
}
std
::
cout
<<
std
::
endl
;
...
...
@@ -321,8 +324,10 @@ void PrintProfiler(const std::vector<std::vector<EventItem>>& events_table,
// Parse the event list and output the profiling report
void
ParseEvents
(
const
std
::
vector
<
std
::
vector
<
Event
>>&
events
,
bool
merge_thread
,
EventSortingKey
sorted_by
=
EventSortingKey
::
kDefault
)
{
if
(
g_state
==
ProfilerState
::
kDisabled
)
return
;
if
(
merge_thread
&&
events
.
size
()
<
2
)
return
;
std
::
string
sorted_domain
;
std
::
function
<
bool
(
const
EventItem
&
,
const
EventItem
&
)
>
sorted_func
;
...
...
@@ -361,34 +366,55 @@ void ParseEvents(const std::vector<std::vector<Event>>& events,
sorted_domain
=
"event first end time"
;
}
const
std
::
vector
<
std
::
vector
<
Event
>>*
analyze_events
;
std
::
vector
<
std
::
vector
<
Event
>>
merged_events_list
;
if
(
merge_thread
)
{
std
::
vector
<
Event
>
merged_events
;
for
(
int
i
=
0
;
i
<
events
.
size
();
++
i
)
{
for
(
int
j
=
0
;
j
<
events
[
i
].
size
();
++
j
)
{
merged_events
.
push_back
(
events
[
i
][
j
]);
}
}
merged_events_list
.
push_back
(
merged_events
);
analyze_events
=
&
merged_events_list
;
}
else
{
analyze_events
=
&
events
;
}
std
::
vector
<
std
::
vector
<
EventItem
>>
events_table
;
size_t
max_name_width
=
0
;
double
total
=
0.
;
// the total time
for
(
size_t
i
=
0
;
i
<
events
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
(
*
analyze_events
).
size
();
i
++
)
{
double
total
=
0.
;
// the total time in one thread
std
::
list
<
Event
>
pushed_events
;
std
::
vector
<
EventItem
>
event_items
;
std
::
unordered_map
<
std
::
string
,
int
>
event_idx
;
for
(
size_t
j
=
0
;
j
<
events
[
i
].
size
();
j
++
)
{
if
(
events
[
i
][
j
].
type
()
==
EventType
::
kPushRange
)
{
pushed_events
.
push_back
(
events
[
i
][
j
]);
}
else
if
(
events
[
i
][
j
].
type
()
==
EventType
::
kPopRange
)
{
for
(
size_t
j
=
0
;
j
<
(
*
analyze_events
)
[
i
].
size
();
j
++
)
{
if
(
(
*
analyze_events
)
[
i
][
j
].
type
()
==
EventType
::
kPushRange
)
{
pushed_events
.
push_back
(
(
*
analyze_events
)
[
i
][
j
]);
}
else
if
(
(
*
analyze_events
)
[
i
][
j
].
type
()
==
EventType
::
kPopRange
)
{
std
::
list
<
Event
>::
reverse_iterator
rit
=
pushed_events
.
rbegin
();
while
(
rit
!=
pushed_events
.
rend
()
&&
rit
->
name
()
!=
events
[
i
][
j
].
name
())
{
rit
->
name
()
!=
(
*
analyze_events
)
[
i
][
j
].
name
())
{
++
rit
;
}
if
(
rit
!=
pushed_events
.
rend
())
{
double
event_time
=
(
g_state
==
ProfilerState
::
kCUDA
||
g_state
==
ProfilerState
::
kAll
)
?
rit
->
CudaElapsedMs
(
events
[
i
][
j
])
:
rit
->
CpuElapsedMs
(
events
[
i
][
j
]);
?
rit
->
CudaElapsedMs
(
(
*
analyze_events
)
[
i
][
j
])
:
rit
->
CpuElapsedMs
(
(
*
analyze_events
)
[
i
][
j
]);
total
+=
event_time
;
std
::
string
event_name
=
"thread"
+
std
::
to_string
(
rit
->
thread_id
())
+
"::"
+
rit
->
name
();
max_name_width
=
std
::
max
(
max_name_width
,
event_name
.
size
());
std
::
string
event_name
;
if
(
merge_thread
)
{
event_name
=
rit
->
name
();
max_name_width
=
std
::
max
(
max_name_width
,
event_name
.
size
());
}
else
{
event_name
=
"thread"
+
std
::
to_string
(
rit
->
thread_id
())
+
"::"
+
rit
->
name
();
max_name_width
=
std
::
max
(
max_name_width
,
event_name
.
size
());
}
if
(
event_idx
.
find
(
event_name
)
==
event_idx
.
end
())
{
event_idx
[
event_name
]
=
event_items
.
size
();
...
...
@@ -413,7 +439,7 @@ void ParseEvents(const std::vector<std::vector<Event>>& events,
pushed_events
.
erase
((
++
rit
).
base
());
}
else
{
LOG
(
WARNING
)
<<
"Cannot find the push marker of event
\'
"
<<
events
[
i
][
j
].
name
()
<<
(
*
analyze_events
)
[
i
][
j
].
name
()
<<
"
\'
, which will be ignored in profiling report."
;
}
}
...
...
@@ -421,6 +447,7 @@ void ParseEvents(const std::vector<std::vector<Event>>& events,
// average time
for
(
auto
&
item
:
event_items
)
{
item
.
ave_time
=
item
.
total_time
/
item
.
calls
;
item
.
ratio
=
item
.
total_time
/
total
;
}
// sort
if
(
sorted_by
!=
EventSortingKey
::
kDefault
)
{
...
...
@@ -438,7 +465,8 @@ void ParseEvents(const std::vector<std::vector<Event>>& events,
}
// Print report
PrintProfiler
(
events_table
,
sorted_domain
,
max_name_width
+
4
,
12
,
total
);
PrintProfiler
(
events_table
,
sorted_domain
,
max_name_width
+
4
,
12
,
merge_thread
);
}
void
DisableProfiler
(
EventSortingKey
sorted_key
,
...
...
@@ -449,7 +477,8 @@ void DisableProfiler(EventSortingKey sorted_key,
Mark
(
"_stop_profiler_"
,
nullptr
);
std
::
vector
<
std
::
vector
<
Event
>>
all_events
=
GetAllEvents
();
ParseEvents
(
all_events
,
sorted_key
);
ParseEvents
(
all_events
,
true
,
sorted_key
);
ParseEvents
(
all_events
,
false
,
sorted_key
);
ResetProfiler
();
DeviceTracer
*
tracer
=
GetDeviceTracer
();
if
(
tracer
->
IsEnabled
())
{
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
ea97e83e
...
...
@@ -157,7 +157,50 @@ PYBIND11_PLUGIN(core) {
.
def
(
"_get_double_element"
,
TensorGetElement
<
double
>
)
.
def
(
"_dtype"
,
[](
Tensor
&
self
)
{
return
ToDataType
(
self
.
type
());
});
py
::
class_
<
LoDTensor
,
Tensor
>
(
m
,
"LoDTensor"
)
py
::
class_
<
LoDTensor
,
Tensor
>
(
m
,
"LoDTensor"
,
R"DOC(
LoDTensor is a Tensor with optional LoD information.
np.array(lod_tensor) can convert LoDTensor to numpy array.
lod_tensor.lod() can retrieve the LoD information.
LoD is short for Level of Details and is usually used for varied sequence
length. You can skip the following comment if you don't need optional LoD.
For example:
A LoDTensor X can look like the example below. It contains 2 sequences.
The first has length 2 and the second has length 3, as described by x.lod.
The first tensor dimension 6=2+3 is calculated from LoD if it's available.
It means the total number of sequence element. In X, each element has 2
columns, hence [6, 2].
x.lod = [[2, 3]]
x.data = [[1, 2], [3, 4],
[5, 6], [7, 8], [9, 10], [11, 12]]
x.shape = [6, 2]
LoD can have multiple levels (for example, a paragraph can have multiple
sentences and a sentence can have multiple words). In the following
LodTensor Y, the lod_level is 2. It means there are 2 sequence, the
first sequence length is 2 (has 2 sub-sequences), the second one's
length is 1. The first sequence's 2 sub-sequences have length 2 and 2,
respectively. And the second sequence's 1 sub-sequence has length 3.
y.lod = [[2 1], [2 2 3]]
y.shape = [2+2+3, ...]
Note:
In above description, LoD is length-based. In Paddle internal
implementation, lod is offset-based. Hence, internally,
y.lod is represented as [[0, 2, 3], [0, 2, 4, 7]] (length-based
equivlent would be [[2-0, 3-2], [2-0, 4-2, 7-4]]).
Sometimes LoD is called recursive_sequence_length to be more
self-explanatory. In this case, it must be length-based. Due to history
reasons. when LoD is called lod in public API, it might be offset-based.
Users should be careful about it.
)DOC"
)
.
def_buffer
(
[](
Tensor
&
self
)
->
py
::
buffer_info
{
return
CastToPyBuffer
(
self
);
})
.
def
(
"__init__"
,
...
...
python/paddle/fluid/layers/io.py
浏览文件 @
ea97e83e
...
...
@@ -56,7 +56,11 @@ def data(name,
Args:
name(str): The name/alias of the function
shape(list): Tuple declaring the shape.
append_batch_size(bool): Whether or not to append the data as a batch.
append_batch_size(bool):
1. If true, it prepends -1 to the shape.
For example if shape=[1], the resulting shape is [-1, 1].
2. If shape contains -1, such as shape=[1, -1],
append_batch_size will be enforced to be be False (ineffective).
dtype(int|float): The type of data : float32, float_16, int etc
type(VarType): The output type. By default it is LOD_TENSOR.
lod_level(int): The LoD Level. 0 means the input data is not a sequence.
...
...
python/paddle/fluid/layers/ops.py
浏览文件 @
ea97e83e
...
...
@@ -14,6 +14,8 @@
from
__future__
import
print_function
from
.layer_function_generator
import
generate_layer_fn
,
generate_layer_fn_noattr
from
..
import
core
from
..framework
import
convert_np_dtype_to_dtype_
__activations_noattr__
=
[
'sigmoid'
,
...
...
@@ -58,8 +60,11 @@ _uniform_random_ = generate_layer_fn('uniform_random')
def
uniform_random
(
shape
,
dtype
=
None
,
min
=
None
,
max
=
None
,
seed
=
None
):
locals_var
=
locals
().
keys
()
if
not
isinstance
(
dtype
,
core
.
VarDesc
.
VarType
):
dtype
=
convert_np_dtype_to_dtype_
(
dtype
)
kwargs
=
dict
()
for
name
in
locals
()
:
for
name
in
locals
_var
:
val
=
locals
()[
name
]
if
val
is
not
None
:
kwargs
[
name
]
=
val
...
...
@@ -78,8 +83,9 @@ _hard_shrink_ = generate_layer_fn('hard_shrink')
def
hard_shrink
(
x
,
threshold
=
None
):
locals_var
=
locals
().
keys
()
kwargs
=
dict
()
for
name
in
locals
()
:
for
name
in
locals
_var
:
val
=
locals
()[
name
]
if
val
is
not
None
:
kwargs
[
name
]
=
val
...
...
@@ -99,12 +105,12 @@ _cum_sum_ = generate_layer_fn('cumsum')
def
cumsum
(
x
,
axis
=
None
,
exclusive
=
None
,
reverse
=
None
):
locals_var
=
locals
().
keys
()
kwargs
=
dict
()
for
name
in
locals
()
:
for
name
in
locals
_var
:
val
=
locals
()[
name
]
if
val
is
not
None
:
kwargs
[
name
]
=
val
return
_cum_sum_
(
**
kwargs
)
...
...
@@ -121,8 +127,9 @@ _thresholded_relu_ = generate_layer_fn('thresholded_relu')
def
thresholded_relu
(
x
,
threshold
=
None
):
locals_var
=
locals
().
keys
()
kwargs
=
dict
()
for
name
in
locals
()
:
for
name
in
locals
_var
:
val
=
locals
()[
name
]
if
val
is
not
None
:
kwargs
[
name
]
=
val
...
...
python/paddle/fluid/layers/tensor.py
浏览文件 @
ea97e83e
...
...
@@ -100,7 +100,7 @@ def create_global_var(shape,
force_cpu
=
False
,
name
=
None
):
"""
Create a new
variabl
e in the global block(block 0).
Create a new
tensor variable with valu
e in the global block(block 0).
Args:
shape(list[int]): shape of the variable
...
...
python/paddle/fluid/tests/unittests/test_clip_by_norm_op.py
浏览文件 @
ea97e83e
...
...
@@ -18,6 +18,9 @@ import unittest
import
numpy
as
np
from
op_test
import
OpTest
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
class
TestClipByNormOp
(
OpTest
):
def
setUp
(
self
):
...
...
@@ -62,5 +65,59 @@ class TestCase3(TestClipByNormOp):
self
.
max_norm
=
1.0
class
TestClipByNormOpWithSelectedRows
(
OpTest
):
def
check_with_place
(
self
,
place
):
self
.
config_test_case
()
scope
=
core
.
Scope
()
# set input
x_selected_rows
=
scope
.
var
(
'X'
).
get_selected_rows
()
x_selected_rows
.
set_rows
(
self
.
grad_rows
)
x_tensor
=
x_selected_rows
.
get_tensor
()
x_np
=
np
.
random
.
random
(
self
.
grad_shape
).
astype
(
"float32"
)
x_np
[
np
.
abs
(
x_np
)
<
self
.
max_relative_error
]
=
0.5
x_tensor
.
set
(
x_np
,
place
)
# set output
out_selected_rows
=
scope
.
var
(
'Out'
).
get_selected_rows
()
# run clip_by_norm_op
clip_by_norm_op
=
fluid
.
op
.
Operator
(
"clip_by_norm"
,
max_norm
=
self
.
max_norm
,
X
=
'X'
,
Out
=
'Out'
)
clip_by_norm_op
.
run
(
scope
,
place
)
# check output
self
.
assertEqual
(
out_selected_rows
.
rows
(),
self
.
grad_clipped_rows
)
out_tensor
=
out_selected_rows
.
get_tensor
()
y_np
=
np
.
zeros
(
self
.
grad_clipped_shape
)
y_np
[
0
]
=
np
.
sum
(
x_np
[
0
:
2
])
y_np
[
1
]
=
x_np
[
2
]
y_np
[
2
]
=
x_np
[
3
]
norm
=
np
.
sqrt
(
np
.
sum
(
np
.
square
(
y_np
)))
if
norm
>
self
.
max_norm
:
output
=
self
.
max_norm
*
y_np
/
norm
else
:
output
=
y_np
self
.
assertTrue
(
np
.
allclose
(
np
.
array
(
out_tensor
),
output
,
atol
=
1e-5
,
equal_nan
=
False
))
def
test_clip_by_norm_with_selected_ros
(
self
):
places
=
[
core
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
():
places
.
append
(
core
.
CUDAPlace
(
0
))
for
place
in
places
:
self
.
check_with_place
(
place
)
def
config_test_case
(
self
):
self
.
max_norm
=
1.0
self
.
max_relative_error
=
0.006
self
.
grad_shape
=
(
4
,
1
)
self
.
grad_clipped_shape
=
(
3
,
1
)
self
.
grad_rows
=
[
0
,
0
,
1
,
2
]
self
.
grad_clipped_rows
=
[
0
,
1
,
2
]
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录