Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
d0785ce9
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d0785ce9
编写于
5月 03, 2018
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into refine_code
上级
815d8884
4c58da2c
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
58 addition
and
40 deletion
+58
-40
paddle/cuda/include/hl_base.h
paddle/cuda/include/hl_base.h
+5
-0
paddle/cuda/src/hl_top_k.cu
paddle/cuda/src/hl_top_k.cu
+5
-4
paddle/fluid/operators/row_conv_op.cu
paddle/fluid/operators/row_conv_op.cu
+10
-2
paddle/fluid/operators/top_k_op.cu
paddle/fluid/operators/top_k_op.cu
+7
-1
paddle/fluid/platform/cuda_primitives.h
paddle/fluid/platform/cuda_primitives.h
+0
-1
paddle/function/RowConvOpGpu.cu
paddle/function/RowConvOpGpu.cu
+20
-15
paddle/math/Matrix.cpp
paddle/math/Matrix.cpp
+11
-17
未找到文件。
paddle/cuda/include/hl_base.h
浏览文件 @
d0785ce9
...
...
@@ -229,6 +229,11 @@ extern __thread cudaStream_t default_stream;
// __shfl has been deprecated as of CUDA 9.0.
#if CUDA_VERSION < 9000
template
<
typename
T
>
__forceinline__
__device__
T
__shfl_down_sync
(
unsigned
,
T
val
,
int
delta
)
{
return
__shfl_down
(
val
,
delta
);
}
template
<
typename
T
>
__forceinline__
__device__
T
__shfl_sync
(
unsigned
,
T
val
,
int
src_line
,
int
width
)
{
...
...
paddle/cuda/src/hl_top_k.cu
浏览文件 @
d0785ce9
...
...
@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hl_base.h"
#include "hl_sparse.ph"
#include "hl_top_k.h"
#include "
paddle/cuda/include/
hl_base.h"
#include "
paddle/cuda/include/
hl_sparse.ph"
#include "
paddle/cuda/include/
hl_top_k.h"
#include "paddle/utils/Logging.h"
// using namespace hppl;
...
...
@@ -244,8 +244,9 @@ __device__ __forceinline__ void blockReduce(Pair* shTopK,
if
(
--
beamSize
==
0
)
break
;
__syncthreads
();
// NOTE(zcd): temporary solution
unsigned
mask
=
0u
;
// CREATE_SHFL_MASK(mask, tid < len
);
CREATE_SHFL_MASK
(
mask
,
true
);
if
(
tid
==
maxId
[
0
])
{
if
(
beam
<
maxLength
)
{
...
...
paddle/fluid/operators/row_conv_op.cu
浏览文件 @
d0785ce9
...
...
@@ -189,6 +189,10 @@ __global__ void RowConvGradFilterImproved(const T *in, const T *dout,
}
__syncthreads
();
// NOTE(zcd): temporary solution
unsigned
mask
=
0u
;
CREATE_SHFL_MASK
(
mask
,
true
);
for
(
int
i
=
0
;
i
<
num_sequence
;
i
++
)
{
int
start
=
static_cast
<
int
>
(
batch_indices
[
i
]);
int
end
=
static_cast
<
int
>
(
batch_indices
[
i
+
1
]);
...
...
@@ -220,7 +224,7 @@ __global__ void RowConvGradFilterImproved(const T *in, const T *dout,
for
(
int
offset
=
16
;
offset
>
0
;
offset
=
offset
/
2
)
{
// blockDim.x is 32.
val
+=
platform
::
__shfl_down_sync
(
0
,
val
,
offset
);
val
+=
platform
::
__shfl_down_sync
(
mask
,
val
,
offset
);
}
__syncthreads
();
...
...
@@ -251,6 +255,10 @@ __global__ void RowConvGradFilter(const T *in, const T *dout, int num_sequence,
T
*
sh_in
=
mem
;
T
*
sh_dout
=
&
mem
[
block_x
*
block_y
];
// NOTE(zcd): temporary solution
unsigned
mask
=
0u
;
CREATE_SHFL_MASK
(
mask
,
true
);
for
(
int
i
=
0
;
i
<
num_sequence
;
i
++
)
{
int
start
=
static_cast
<
int
>
(
batch_indices
[
i
]);
int
end
=
static_cast
<
int
>
(
batch_indices
[
i
+
1
]);
...
...
@@ -276,7 +284,7 @@ __global__ void RowConvGradFilter(const T *in, const T *dout, int num_sequence,
for
(
int
offset
=
16
;
offset
>
0
;
offset
=
offset
/
2
)
{
// blockDim.x is 32.
val
+=
platform
::
__shfl_down_sync
(
0
,
val
,
offset
);
val
+=
platform
::
__shfl_down_sync
(
mask
,
val
,
offset
);
}
__syncthreads
();
...
...
paddle/fluid/operators/top_k_op.cu
浏览文件 @
d0785ce9
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/top_k_op.h"
#include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cuda_device_function.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -235,8 +236,13 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
sh_topk
[
tid
]
=
topk
[
*
beam
];
}
}
// NOTE(zcd): temporary solution
unsigned
mask
=
0u
;
CREATE_SHFL_MASK
(
mask
,
true
);
if
(
maxid
[
0
]
/
32
==
warp
)
{
if
(
__shfl
(
*
beam
,
(
maxid
[
0
])
%
32
,
32
)
==
MaxLength
)
break
;
if
(
platform
::
__shfl_sync
(
mask
,
*
beam
,
(
maxid
[
0
])
%
32
,
32
)
==
MaxLength
)
break
;
}
}
}
...
...
paddle/fluid/platform/cuda_primitives.h
浏览文件 @
d0785ce9
...
...
@@ -65,6 +65,5 @@ CUDA_ATOMIC_WRAPPER(Add, double) {
return
__longlong_as_double
(
old
);
}
#endif
}
// namespace platform
}
// namespace paddle
paddle/function/RowConvOpGpu.cu
浏览文件 @
d0785ce9
...
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "
RowConvOp
.h"
#include "
hl_base
.h"
#include "
paddle/cuda/include/hl_base
.h"
#include "
paddle/function/RowConvOp
.h"
namespace
paddle
{
...
...
@@ -94,7 +94,7 @@ __global__ void KeRowConv2(real* y,
}
template
<
>
void
RowConv
<
DEVICE_TYPE_GPU
>
(
GpuMatrix
&
out
,
void
RowConv
<
DEVICE_TYPE_GPU
>
(
GpuMatrix
&
out
,
// NOLINT
const
GpuMatrix
&
in
,
const
GpuMatrix
&
filter
,
const
GpuIVector
&
seq
)
{
...
...
@@ -144,6 +144,10 @@ __global__ void KeRowConvBwWeight(real* dw,
}
__syncthreads
();
// NOTE(zcd): temporary solution
unsigned
mask
=
0u
;
CREATE_SHFL_MASK
(
mask
,
true
);
for
(
int
i
=
0
;
i
<
numSeq
;
++
i
)
{
const
int
start
=
starts
[
i
];
const
int
end
=
starts
[
i
+
1
];
...
...
@@ -170,11 +174,10 @@ __global__ void KeRowConvBwWeight(real* dw,
real
val
=
sh_x
[
tidy
][
tidx
]
*
sh_dy
[
tidy
][
tidx
+
context
-
1
-
t
];
__syncthreads
();
// warp size and blockDim.x is 32.
val
+=
__shfl_down
(
val
,
16
);
val
+=
__shfl_down
(
val
,
8
);
val
+=
__shfl_down
(
val
,
4
);
val
+=
__shfl_down
(
val
,
2
);
val
+=
__shfl_down
(
val
,
1
);
for
(
int
offset
=
16
;
offset
>
0
;
offset
/=
2
)
val
+=
__shfl_down_sync
(
mask
,
val
,
offset
);
__syncthreads
();
if
(
tidx
==
0
)
{
sh_dw
[
t
][
tidy
]
+=
val
;
...
...
@@ -205,6 +208,10 @@ __global__ void KeRowConvBwWeight2(real* dw,
__shared__
real
sh_x
[
BLOCK_H
][
BLOCK_W
];
__shared__
real
sh_dy
[
BLOCK_H
][
BLOCK_W
];
// NOTE(zcd): temporary solution
unsigned
mask
=
0u
;
CREATE_SHFL_MASK
(
mask
,
true
);
for
(
int
i
=
0
;
i
<
numSeq
;
++
i
)
{
const
int
start
=
starts
[
i
];
const
int
end
=
starts
[
i
+
1
];
...
...
@@ -230,11 +237,9 @@ __global__ void KeRowConvBwWeight2(real* dw,
real
val
=
sh_x
[
tidy
][
tidx
]
*
sh_dy
[
tidy
][
tidx
];
__syncthreads
();
// warp size and blockDim.x is 32.
val
+=
__shfl_down
(
val
,
16
);
val
+=
__shfl_down
(
val
,
8
);
val
+=
__shfl_down
(
val
,
4
);
val
+=
__shfl_down
(
val
,
2
);
val
+=
__shfl_down
(
val
,
1
);
for
(
int
offset
=
16
;
offset
>
0
;
offset
/=
2
)
val
+=
__shfl_down_sync
(
mask
,
val
,
offset
);
__syncthreads
();
if
(
tidx
==
0
&&
(
gidx
+
tidy
)
<
width
)
{
...
...
@@ -323,8 +328,8 @@ template <>
void
RowConvGrad
<
DEVICE_TYPE_GPU
>
(
const
GpuMatrix
&
outG
,
const
GpuMatrix
&
in
,
const
GpuMatrix
&
filter
,
GpuMatrix
&
inG
,
GpuMatrix
&
filterG
,
GpuMatrix
&
inG
,
// NOLINT
GpuMatrix
&
filterG
,
// NOLINT
const
GpuIVector
&
seq
)
{
const
size_t
numSeq
=
seq
.
getSize
()
-
1
;
const
size_t
contextLength
=
filter
.
getHeight
();
...
...
paddle/math/Matrix.cpp
浏览文件 @
d0785ce9
...
...
@@ -2157,26 +2157,20 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat,
int
wend
=
wstart
+
sizeX
;
wstart
=
wstart
<
0
?
0
:
wstart
;
wend
=
wend
<
(
int
)
imgSizeW
?
wend
:
(
int
)
imgSizeW
;
if
(
maskData
==
NULL
)
{
real
tmp
=
-
(
real
)
FLT_MAX
;
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
tmp
=
tmp
<
inputData
[
h
*
imgSizeW
+
w
]
?
inputData
[
h
*
imgSizeW
+
w
]
:
tmp
;
}
}
outData
[
ph
*
outputW
+
pw
]
=
tmp
;
}
else
{
real
maxval
=
-
(
real
)
FLT_MAX
;
int
max_index
=
-
1
;
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
if
(
outData
[
ph
*
outputW
+
pw
]
<
inputData
[
h
*
imgSizeW
+
w
])
{
outData
[
ph
*
outputW
+
pw
]
=
inputData
[
h
*
imgSizeW
+
w
];
maskData
[
ph
*
outputW
+
pw
]
=
h
*
imgSizeW
+
w
;
}
if
(
maxval
<
inputData
[
h
*
imgSizeW
+
w
])
{
maxval
=
inputData
[
h
*
imgSizeW
+
w
];
max_index
=
h
*
imgSizeW
+
w
;
}
}
}
outData
[
ph
*
outputW
+
pw
]
=
maxval
;
if
(
maskData
!=
NULL
)
maskData
[
ph
*
outputW
+
pw
]
=
max_index
;
}
}
// compute offset
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录