Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a8890110
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a8890110
编写于
8月 24, 2017
作者:
Z
zchen0211
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into develop
上级
fabfe17a
3663bd88
变更
28
隐藏空白更改
内联
并排
Showing
28 changed file
with
812 addition
and
93 deletion
+812
-93
go/master/client.go
go/master/client.go
+16
-8
paddle/cuda/include/hl_cuda_cudnn.h
paddle/cuda/include/hl_cuda_cudnn.h
+8
-3
paddle/cuda/include/stub/hl_cuda_cudnn_stub.h
paddle/cuda/include/stub/hl_cuda_cudnn_stub.h
+8
-3
paddle/cuda/src/hl_cuda_cudnn.cc
paddle/cuda/src/hl_cuda_cudnn.cc
+70
-54
paddle/gserver/layers/ConvBaseLayer.cpp
paddle/gserver/layers/ConvBaseLayer.cpp
+11
-5
paddle/gserver/layers/ConvBaseLayer.h
paddle/gserver/layers/ConvBaseLayer.h
+4
-0
paddle/gserver/layers/ConvBaseOperator.cpp
paddle/gserver/layers/ConvBaseOperator.cpp
+2
-1
paddle/gserver/layers/ConvBaseProjection.cpp
paddle/gserver/layers/ConvBaseProjection.cpp
+17
-3
paddle/gserver/layers/ConvBaseProjection.h
paddle/gserver/layers/ConvBaseProjection.h
+1
-0
paddle/gserver/layers/ConvProjection.cpp
paddle/gserver/layers/ConvProjection.cpp
+2
-2
paddle/gserver/tests/test_LayerGrad.cpp
paddle/gserver/tests/test_LayerGrad.cpp
+30
-8
paddle/operators/CMakeLists.txt
paddle/operators/CMakeLists.txt
+3
-2
paddle/operators/fill_zeros_like_op.h
paddle/operators/fill_zeros_like_op.h
+1
-1
paddle/operators/lookup_table_op.cc
paddle/operators/lookup_table_op.cc
+72
-0
paddle/operators/lookup_table_op.cu
paddle/operators/lookup_table_op.cu
+116
-0
paddle/operators/lookup_table_op.h
paddle/operators/lookup_table_op.h
+75
-0
paddle/operators/rowwise_add_op.cu
paddle/operators/rowwise_add_op.cu
+3
-0
paddle/platform/cuda_helper.h
paddle/platform/cuda_helper.h
+51
-0
paddle/pybind/CMakeLists.txt
paddle/pybind/CMakeLists.txt
+1
-0
paddle/pybind/pybind.cc
paddle/pybind/pybind.cc
+1
-0
proto/ModelConfig.proto
proto/ModelConfig.proto
+3
-0
python/paddle/trainer/config_parser.py
python/paddle/trainer/config_parser.py
+5
-1
python/paddle/trainer_config_helpers/layers.py
python/paddle/trainer_config_helpers/layers.py
+18
-0
python/paddle/trainer_config_helpers/tests/configs/img_layers.py
...paddle/trainer_config_helpers/tests/configs/img_layers.py
+1
-0
python/paddle/v2/framework/tests/CMakeLists.txt
python/paddle/v2/framework/tests/CMakeLists.txt
+2
-0
python/paddle/v2/framework/tests/gradient_checker.py
python/paddle/v2/framework/tests/gradient_checker.py
+11
-2
python/paddle/v2/framework/tests/mnist.py
python/paddle/v2/framework/tests/mnist.py
+249
-0
python/paddle/v2/framework/tests/test_lookup_table.py
python/paddle/v2/framework/tests/test_lookup_table.py
+31
-0
未找到文件。
go/master/client.go
浏览文件 @
a8890110
...
...
@@ -63,13 +63,24 @@ func WithAddr(addr string) func(c *Client) error {
// WithEtcd sets the client to use etcd for master discovery.
func
WithEtcd
(
endpoints
[]
string
,
timeout
time
.
Duration
)
func
(
*
Client
)
error
{
return
func
(
c
*
Client
)
error
{
cli
,
err
:=
clientv3
.
New
(
clientv3
.
Config
{
Endpoints
:
endpoints
,
DialTimeout
:
timeout
,
})
if
err
!=
nil
{
var
cli
*
clientv3
.
Client
f
:=
func
()
error
{
var
err
error
cli
,
err
=
clientv3
.
New
(
clientv3
.
Config
{
Endpoints
:
endpoints
,
DialTimeout
:
timeout
,
})
return
err
}
for
{
err
:=
f
()
if
err
!=
nil
{
log
.
Warningln
(
err
)
}
else
{
break
}
time
.
Sleep
(
time
.
Second
)
}
ch
:=
make
(
chan
string
,
1
)
a
,
err
:=
GetKey
(
cli
,
DefaultAddrPath
,
timeout
)
...
...
@@ -101,9 +112,6 @@ func NewClient(opts ...func(*Client) error) (*Client, error) {
}
}
c
.
ch
=
make
(
chan
record
,
c
.
bufSize
)
// FIXME: connection is created asyncrosly in monitorMaster go routine,
// ensure the connection is ready for use before calling c.addClient.
time
.
Sleep
(
time
.
Second
)
return
c
,
nil
}
...
...
paddle/cuda/include/hl_cuda_cudnn.h
浏览文件 @
a8890110
...
...
@@ -214,7 +214,8 @@ extern void hl_conv_workspace(hl_tensor_descriptor input,
int
*
convBwdDataAlgo
,
size_t
*
bwdDataLimitBytes
,
int
*
convBwdFilterAlgo
,
size_t
*
bwdFilterLimitBytes
);
size_t
*
bwdFilterLimitBytes
,
bool
useDilation
);
/**
* @brief destroy filter descriptor.
...
...
@@ -242,7 +243,9 @@ extern void hl_create_convolution_descriptor(hl_convolution_descriptor* conv,
int
padding_height
,
int
padding_width
,
int
stride_height
,
int
stride_width
);
int
stride_width
,
int
dilation_h
=
1
,
int
dilation_w
=
1
);
/**
* @brief reset convolution descriptor.
...
...
@@ -262,7 +265,9 @@ extern void hl_reset_convolution_descriptor(hl_convolution_descriptor conv,
int
padding_height
,
int
padding_width
,
int
stride_height
,
int
stride_width
);
int
stride_width
,
int
dilation_h
=
1
,
int
dilation_w
=
1
);
/**
* @brief destroy convolution descriptor.
...
...
paddle/cuda/include/stub/hl_cuda_cudnn_stub.h
浏览文件 @
a8890110
...
...
@@ -78,7 +78,9 @@ inline void hl_create_convolution_descriptor(hl_convolution_descriptor* conv,
int
padding_height
,
int
padding_width
,
int
stride_height
,
int
stride_width
)
{}
int
stride_width
,
int
dilation_h
,
int
dilation_w
)
{}
inline
void
hl_reset_convolution_descriptor
(
hl_convolution_descriptor
conv
,
hl_tensor_descriptor
image
,
...
...
@@ -86,7 +88,9 @@ inline void hl_reset_convolution_descriptor(hl_convolution_descriptor conv,
int
padding_height
,
int
padding_width
,
int
stride_height
,
int
stride_width
)
{}
int
stride_width
,
int
dilation_h
,
int
dilation_w
)
{}
inline
void
hl_destroy_convolution_descriptor
(
hl_convolution_descriptor
conv
)
{}
...
...
@@ -99,7 +103,8 @@ inline void hl_conv_workspace(hl_tensor_descriptor input,
int
*
convBwdDataAlgo
,
size_t
*
bwdDataLimitBytes
,
int
*
convBwdFilterAlgo
,
size_t
*
bwdFilterLimitBytes
)
{}
size_t
*
bwdFilterLimitBytes
,
bool
useDilation
)
{}
inline
void
hl_convolution_forward
(
hl_tensor_descriptor
input
,
real
*
input_data
,
...
...
paddle/cuda/src/hl_cuda_cudnn.cc
浏览文件 @
a8890110
...
...
@@ -201,7 +201,8 @@ void hl_conv_workspace(hl_tensor_descriptor input,
int
*
convBwdDataAlgo
,
size_t
*
bwdDataLimitBytes
,
int
*
convBwdFilterAlgo
,
size_t
*
bwdFilterLimitBytes
)
{
size_t
*
bwdFilterLimitBytes
,
bool
useDilation
)
{
#if CUDNN_VERSION >= 4000
CHECK_NOTNULL
(
input
);
...
...
@@ -213,21 +214,60 @@ void hl_conv_workspace(hl_tensor_descriptor input,
size_t
memoryLimitBytes
=
(
1LL
<<
20
)
*
FLAGS_cudnn_conv_workspace_limit_in_mb
;
// For dilation
int
algo
=
0
;
// cudnn convolution forward configuration
cudnnTensorDescriptor_t
fwd_src_desc
=
GET_TENSOR_DESCRIPTOR
(
input
);
cudnnTensorDescriptor_t
fwd_dest_desc
=
GET_TENSOR_DESCRIPTOR
(
output
);
cudnnFilterDescriptor_t
fwd_filter_desc
=
GET_FILTER_DESCRIPTOR
(
filter
);
cudnnConvolutionDescriptor_t
fwd_conv_desc
=
GET_CONVOLUTION_DESCRIPTOR
(
conv
);
// cudnn convolution backward data configuration
cudnnFilterDescriptor_t
bwd_data_filter_desc
=
GET_FILTER_DESCRIPTOR
(
filter
);
cudnnTensorDescriptor_t
bwd_data_diff_desc
=
GET_TENSOR_DESCRIPTOR
(
output
);
cudnnTensorDescriptor_t
bwd_data_grad_desc
=
GET_TENSOR_DESCRIPTOR
(
input
);
cudnnConvolutionDescriptor_t
bwd_data_conv_desc
=
GET_CONVOLUTION_DESCRIPTOR
(
conv
);
// cudnn convolution backward filter configuration
cudnnTensorDescriptor_t
bwd_filter_src_desc
=
GET_TENSOR_DESCRIPTOR
(
input
);
cudnnTensorDescriptor_t
bwd_filter_diff_desc
=
GET_TENSOR_DESCRIPTOR
(
output
);
cudnnConvolutionDescriptor_t
bwd_filter_conv_desc
=
GET_CONVOLUTION_DESCRIPTOR
(
conv
);
cudnnFilterDescriptor_t
bwd_filter_grad_desc
=
GET_FILTER_DESCRIPTOR
(
filter
);
CHECK_CUDNN
(
dynload
::
cudnnGetConvolutionForwardAlgorithm
(
t_resource
.
cudnn_handle
,
fwd_src_desc
,
fwd_filter_desc
,
fwd_conv_desc
,
fwd_dest_desc
,
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
,
memoryLimitBytes
,
reinterpret_cast
<
cudnnConvolutionFwdAlgo_t
*>
(
convFwdAlgo
)));
if
(
useDilation
)
{
convFwdAlgo
=
&
algo
;
convBwdDataAlgo
=
&
algo
;
convBwdFilterAlgo
=
&
algo
;
}
else
{
CHECK_CUDNN
(
dynload
::
cudnnGetConvolutionForwardAlgorithm
(
t_resource
.
cudnn_handle
,
fwd_src_desc
,
fwd_filter_desc
,
fwd_conv_desc
,
fwd_dest_desc
,
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
,
memoryLimitBytes
,
reinterpret_cast
<
cudnnConvolutionFwdAlgo_t
*>
(
convFwdAlgo
)));
CHECK_CUDNN
(
dynload
::
cudnnGetConvolutionBackwardDataAlgorithm
(
t_resource
.
cudnn_handle
,
bwd_data_filter_desc
,
bwd_data_diff_desc
,
bwd_data_conv_desc
,
bwd_data_grad_desc
,
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT
,
memoryLimitBytes
,
reinterpret_cast
<
cudnnConvolutionBwdDataAlgo_t
*>
(
convBwdDataAlgo
)));
CHECK_CUDNN
(
dynload
::
cudnnGetConvolutionBackwardFilterAlgorithm
(
t_resource
.
cudnn_handle
,
bwd_filter_src_desc
,
bwd_filter_diff_desc
,
bwd_filter_conv_desc
,
bwd_filter_grad_desc
,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
,
memoryLimitBytes
,
reinterpret_cast
<
cudnnConvolutionBwdFilterAlgo_t
*>
(
convBwdFilterAlgo
)));
}
CHECK_CUDNN
(
dynload
::
cudnnGetConvolutionForwardWorkspaceSize
(
t_resource
.
cudnn_handle
,
...
...
@@ -238,23 +278,6 @@ void hl_conv_workspace(hl_tensor_descriptor input,
static_cast
<
cudnnConvolutionFwdAlgo_t
>
(
*
convFwdAlgo
),
fwdLimitBytes
));
// cudnn convolution backward data configuration
cudnnFilterDescriptor_t
bwd_data_filter_desc
=
GET_FILTER_DESCRIPTOR
(
filter
);
cudnnTensorDescriptor_t
bwd_data_diff_desc
=
GET_TENSOR_DESCRIPTOR
(
output
);
cudnnTensorDescriptor_t
bwd_data_grad_desc
=
GET_TENSOR_DESCRIPTOR
(
input
);
cudnnConvolutionDescriptor_t
bwd_data_conv_desc
=
GET_CONVOLUTION_DESCRIPTOR
(
conv
);
CHECK_CUDNN
(
dynload
::
cudnnGetConvolutionBackwardDataAlgorithm
(
t_resource
.
cudnn_handle
,
bwd_data_filter_desc
,
bwd_data_diff_desc
,
bwd_data_conv_desc
,
bwd_data_grad_desc
,
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT
,
memoryLimitBytes
,
reinterpret_cast
<
cudnnConvolutionBwdDataAlgo_t
*>
(
convBwdDataAlgo
)));
CHECK_CUDNN
(
dynload
::
cudnnGetConvolutionBackwardDataWorkspaceSize
(
t_resource
.
cudnn_handle
,
bwd_data_filter_desc
,
...
...
@@ -264,23 +287,6 @@ void hl_conv_workspace(hl_tensor_descriptor input,
static_cast
<
cudnnConvolutionBwdDataAlgo_t
>
(
*
convBwdDataAlgo
),
bwdDataLimitBytes
));
// cudnn convolution backward filter configuration
cudnnTensorDescriptor_t
bwd_filter_src_desc
=
GET_TENSOR_DESCRIPTOR
(
input
);
cudnnTensorDescriptor_t
bwd_filter_diff_desc
=
GET_TENSOR_DESCRIPTOR
(
output
);
cudnnConvolutionDescriptor_t
bwd_filter_conv_desc
=
GET_CONVOLUTION_DESCRIPTOR
(
conv
);
cudnnFilterDescriptor_t
bwd_filter_grad_desc
=
GET_FILTER_DESCRIPTOR
(
filter
);
CHECK_CUDNN
(
dynload
::
cudnnGetConvolutionBackwardFilterAlgorithm
(
t_resource
.
cudnn_handle
,
bwd_filter_src_desc
,
bwd_filter_diff_desc
,
bwd_filter_conv_desc
,
bwd_filter_grad_desc
,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
,
memoryLimitBytes
,
reinterpret_cast
<
cudnnConvolutionBwdFilterAlgo_t
*>
(
convBwdFilterAlgo
)));
CHECK_CUDNN
(
dynload
::
cudnnGetConvolutionBackwardFilterWorkspaceSize
(
t_resource
.
cudnn_handle
,
bwd_filter_src_desc
,
...
...
@@ -603,7 +609,9 @@ void hl_create_convolution_descriptor(hl_convolution_descriptor* conv,
int
padding_height
,
int
padding_width
,
int
stride_height
,
int
stride_width
)
{
int
stride_width
,
int
dilation_h
,
int
dilation_w
)
{
CHECK_NOTNULL
(
conv
);
cudnn_convolution_descriptor
hl_conv
=
(
cudnn_convolution_descriptor
)
malloc
(
...
...
@@ -625,18 +633,24 @@ void hl_create_convolution_descriptor(hl_convolution_descriptor* conv,
padding_width
,
stride_height
,
stride_width
,
1
,
1
,
dilation_h
,
dilation_w
,
mode
,
data_type
));
#else
if
(
dilation_h
>
1
||
dilation_w
>
1
)
{
LOG
(
FATAL
)
<<
"Current cuDNN version does't support for dilation convolution. "
<<
"The dilation convolution requires cuDNN >= v6.0."
;
}
CHECK_CUDNN
(
dynload
::
cudnnSetConvolution2dDescriptor
(
hl_conv
->
desc
,
padding_height
,
padding_width
,
stride_height
,
stride_width
,
1
,
1
,
dilation_h
,
dilation_w
,
mode
));
#endif
...
...
@@ -659,7 +673,9 @@ void hl_reset_convolution_descriptor(hl_convolution_descriptor conv,
int
padding_height
,
int
padding_width
,
int
stride_height
,
int
stride_width
)
{
int
stride_width
,
int
dilation_h
,
int
dilation_w
)
{
CHECK_NOTNULL
(
conv
);
CHECK_NOTNULL
(
image
);
CHECK_NOTNULL
(
filter
);
...
...
@@ -678,8 +694,8 @@ void hl_reset_convolution_descriptor(hl_convolution_descriptor conv,
padding_width
,
stride_height
,
stride_width
,
1
,
1
,
dilation_h
,
dilation_w
,
mode
,
data_type
));
#else
...
...
@@ -688,8 +704,8 @@ void hl_reset_convolution_descriptor(hl_convolution_descriptor conv,
padding_width
,
stride_height
,
stride_width
,
1
,
1
,
dilation_h
,
dilation_w
,
mode
));
#endif
...
...
paddle/gserver/layers/ConvBaseLayer.cpp
浏览文件 @
a8890110
...
...
@@ -32,9 +32,11 @@ bool ConvBaseLayer::init(const LayerMap& layerMap,
const
ConvConfig
&
conf
=
inputConfig
.
conv_conf
();
padding_
.
push_back
(
conf
.
padding
());
stride_
.
push_back
(
conf
.
stride
());
dilation_
.
push_back
(
conf
.
dilation
());
filterSize_
.
push_back
(
conf
.
filter_size
());
paddingY_
.
push_back
(
conf
.
padding_y
());
strideY_
.
push_back
(
conf
.
stride_y
());
dilationY_
.
push_back
(
conf
.
dilation_y
());
filterSizeY_
.
push_back
(
conf
.
filter_size_y
());
filterPixels_
.
push_back
(
filterSize_
.
back
()
*
filterSizeY_
.
back
());
channels_
.
push_back
(
conf
.
channels
());
...
...
@@ -89,7 +91,11 @@ size_t ConvBaseLayer::calOutputSize() {
size_t
layerSize
=
0
;
auto
setLayerSize
=
[
&
](
IntV
&
inH
,
IntV
&
inW
,
IntV
&
outH
,
IntV
&
outW
)
{
size_t
filterSizeY
;
size_t
filterSize
;
for
(
size_t
i
=
0
;
i
<
inputLayers_
.
size
();
i
++
)
{
filterSizeY
=
(
filterSizeY_
[
i
]
-
1
)
*
dilationY_
[
i
]
+
1
;
filterSize
=
(
filterSize_
[
i
]
-
1
)
*
dilation_
[
i
]
+
1
;
inH
.
push_back
(
inputLayers_
[
i
]
->
getOutput
().
getFrameHeight
());
inW
.
push_back
(
inputLayers_
[
i
]
->
getOutput
().
getFrameWidth
());
const
ConvConfig
&
conf
=
config_
.
inputs
(
i
).
conv_conf
();
...
...
@@ -98,17 +104,17 @@ size_t ConvBaseLayer::calOutputSize() {
inH
[
i
]
=
conf
.
has_output_y
()
?
conf
.
output_y
()
:
conf
.
output_x
();
if
(
inW
[
i
]
==
0
)
inW
[
i
]
=
conf
.
output_x
();
outH
.
push_back
(
imageSize
(
inH
[
i
],
filterSizeY
_
[
i
]
,
paddingY_
[
i
],
strideY_
[
i
],
caffeMode_
));
outW
.
push_back
(
imageSize
(
i
nW
[
i
],
filterSize_
[
i
]
,
padding_
[
i
],
stride_
[
i
],
caffeMode_
));
inH
[
i
],
filterSizeY
,
paddingY_
[
i
],
strideY_
[
i
],
caffeMode_
));
outW
.
push_back
(
i
mageSize
(
inW
[
i
],
filterSize
,
padding_
[
i
],
stride_
[
i
],
caffeMode_
));
}
else
{
if
(
inH
[
i
]
==
0
)
inH
[
i
]
=
conf
.
has_img_size_y
()
?
conf
.
img_size_y
()
:
conf
.
img_size
();
if
(
inW
[
i
]
==
0
)
inW
[
i
]
=
conf
.
img_size
();
outH
.
push_back
(
outputSize
(
inH
[
i
],
filterSizeY
_
[
i
]
,
paddingY_
[
i
],
strideY_
[
i
],
caffeMode_
));
inH
[
i
],
filterSizeY
,
paddingY_
[
i
],
strideY_
[
i
],
caffeMode_
));
outW
.
push_back
(
outputSize
(
inW
[
i
],
filterSize
_
[
i
]
,
padding_
[
i
],
stride_
[
i
],
caffeMode_
));
inW
[
i
],
filterSize
,
padding_
[
i
],
stride_
[
i
],
caffeMode_
));
}
CHECK_EQ
(
outH
[
i
],
outH
[
0
]);
CHECK_EQ
(
outW
[
i
],
outW
[
0
]);
...
...
paddle/gserver/layers/ConvBaseLayer.h
浏览文件 @
a8890110
...
...
@@ -40,6 +40,10 @@ protected:
IntV
stride_
;
/// The y dimension of the stride.
IntV
strideY_
;
/// The x dimension of the dilation.
IntV
dilation_
;
/// The y dimension of the dilation.
IntV
dilationY_
;
/// The x dimension of a filter kernel.
IntV
filterSize_
;
/// The y dimension of a filter kernel.
...
...
paddle/gserver/layers/ConvBaseOperator.cpp
浏览文件 @
a8890110
...
...
@@ -59,7 +59,8 @@ void ConvBaseOperator::allocConvWorkSpace() {
&
bwdDataAlgo_
,
&
bwdDataLimitBytes_
,
&
bwdFilterAlgo_
,
&
bwdFilterLimitBytes_
);
&
bwdFilterLimitBytes_
,
/*useDilation*/
false
);
size_t
maxWorkSpace
=
0
;
maxWorkSpace
=
std
::
max
(
fwdLimitBytes_
,
bwdDataLimitBytes_
);
...
...
paddle/gserver/layers/ConvBaseProjection.cpp
浏览文件 @
a8890110
...
...
@@ -41,6 +41,11 @@ void ConvBaseProjection::getConvParams() {
strideH_
=
conf
.
stride_y
();
strideW_
=
conf
.
stride
();
dilationH_
=
conf
.
dilation_y
();
dilationW_
=
conf
.
dilation
();
CHECK_GT
(
dilationH_
,
0
);
CHECK_GT
(
dilationW_
,
0
);
filterH_
=
conf
.
filter_size_y
();
filterW_
=
conf
.
filter_size
();
...
...
@@ -77,7 +82,9 @@ void ConvBaseProjection::initCudnn() {
paddingH_
,
paddingW_
,
strideH_
,
strideW_
);
strideW_
,
dilationH_
,
dilationW_
);
// initialize all to default algorithms
fwdAlgo_
=
0
;
...
...
@@ -131,7 +138,9 @@ void ConvBaseProjection::reshapeTensorDesc(int batchSize) {
paddingH_
,
paddingW_
,
strideH_
,
strideW_
);
strideW_
,
dilationH_
,
dilationW_
);
}
void
ConvBaseProjection
::
reshape
(
int
batchSize
)
{
...
...
@@ -140,6 +149,10 @@ void ConvBaseProjection::reshape(int batchSize) {
CHECK_EQ
(
calInputSize
(),
in_
->
value
->
getWidth
());
reshapeTensorDesc
(
batchSize
);
bool
useDilation
=
false
;
if
(
dilationH_
>
1
||
dilationW_
>
1
)
{
useDilation
=
true
;
}
hl_conv_workspace
(
imageDesc_
,
outputDesc_
,
filterDesc_
,
...
...
@@ -149,7 +162,8 @@ void ConvBaseProjection::reshape(int batchSize) {
&
bwdDataAlgo_
,
&
bwdDataLimitBytes_
,
&
bwdFilterAlgo_
,
&
bwdFilterLimitBytes_
);
&
bwdFilterLimitBytes_
,
useDilation
);
size_t
maxWorkSpace
=
0
;
maxWorkSpace
=
std
::
max
(
fwdLimitBytes_
,
bwdDataLimitBytes_
);
...
...
paddle/gserver/layers/ConvBaseProjection.h
浏览文件 @
a8890110
...
...
@@ -63,6 +63,7 @@ protected:
int
configChannels_
,
configNumFilters_
;
int
paddingH_
,
paddingW_
;
int
strideH_
,
strideW_
;
int
dilationH_
,
dilationW_
;
int
filterH_
,
filterW_
;
/// One group offset of input data.
int
inputOffset_
;
...
...
paddle/gserver/layers/ConvProjection.cpp
浏览文件 @
a8890110
...
...
@@ -25,12 +25,12 @@ size_t ConvProjection::calOutputSize() {
if
(
imageH_
==
0
)
imageH_
=
configImgH_
;
if
(
imageW_
==
0
)
imageW_
=
configImgW_
;
outputH_
=
outputSize
(
imageH_
,
filterH_
,
(
filterH_
-
1
)
*
dilationH_
+
1
,
paddingH_
,
strideH_
,
/* caffeMode */
true
);
outputW_
=
outputSize
(
imageW_
,
filterW_
,
(
filterW_
-
1
)
*
dilationW_
+
1
,
paddingW_
,
strideW_
,
/* caffeMode */
true
);
...
...
paddle/gserver/tests/test_LayerGrad.cpp
浏览文件 @
a8890110
...
...
@@ -12,6 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_ONLY_CPU
#include <cudnn.h>
#endif
#include <gtest/gtest.h>
#include <string>
#include <vector>
...
...
@@ -189,10 +192,16 @@ TEST(Projection, scaling) {
void
testProjectionConv
(
size_t
groups
,
bool
isDeconv
)
{
const
int
NUM_FILTERS
=
18
;
const
int
FILTER_SIZE
=
2
;
const
int
FILTER_SIZE_Y
=
4
;
const
int
FILTER_SIZE_Y
=
2
;
const
int
CHANNELS
=
3
;
const
int
IMAGE_SIZE
=
16
;
#if CUDNN_VERSION >= 6000
const
int
DILATION
=
2
;
#else
const
int
DILATION
=
1
;
#endif
ProjectionConfig
conf
;
if
(
isDeconv
)
{
conf
.
set_type
(
"convt"
);
...
...
@@ -209,6 +218,8 @@ void testProjectionConv(size_t groups, bool isDeconv) {
conv
->
set_padding_y
(
1
);
conv
->
set_stride
(
2
);
conv
->
set_stride_y
(
2
);
conv
->
set_dilation
(
DILATION
);
conv
->
set_dilation_y
(
DILATION
);
conv
->
set_groups
(
groups
);
if
(
isDeconv
)
{
conv
->
set_filter_channels
(
NUM_FILTERS
/
conv
->
groups
());
...
...
@@ -217,12 +228,12 @@ void testProjectionConv(size_t groups, bool isDeconv) {
}
conv
->
set_img_size
(
IMAGE_SIZE
);
int
output_x
=
outputSize
(
conv
->
img_size
(),
conv
->
filter_size
()
,
(
conv
->
filter_size
()
-
1
)
*
DILATION
+
1
,
conv
->
padding
(),
conv
->
stride
(),
/* caffeMode */
true
);
int
output_y
=
outputSize
(
conv
->
img_size
(),
conv
->
filter_size_y
()
,
(
conv
->
filter_size_y
()
-
1
)
*
DILATION
+
1
,
conv
->
padding_y
(),
conv
->
stride_y
(),
/* caffeMode */
true
);
...
...
@@ -424,27 +435,38 @@ void testConvLayer(const string& type, bool trans, bool useGpu) {
config
.
layerConfig
.
set_partial_sum
(
1
);
config
.
layerConfig
.
set_shared_biases
(
true
);
config
.
inputDefs
.
push_back
({
INPUT_DATA
,
"layer_0"
,
384
,
288
});
int
dilation
=
1
;
if
(
type
==
"cudnn_conv"
)
{
#if CUDNN_VERSION >= 6000
dilation
=
2
;
#else
dilation
=
1
;
#endif
}
config
.
inputDefs
.
push_back
({
INPUT_DATA
,
"layer_0"
,
768
,
192
});
LayerInputConfig
*
input
=
config
.
layerConfig
.
add_inputs
();
ConvConfig
*
conv
=
input
->
mutable_conv_conf
();
conv
->
set_filter_size
(
2
);
conv
->
set_filter_size_y
(
3
);
conv
->
set_filter_size_y
(
2
);
conv
->
set_channels
(
3
);
conv
->
set_padding
(
0
);
conv
->
set_padding_y
(
1
);
conv
->
set_stride
(
2
);
conv
->
set_stride_y
(
2
);
conv
->
set_dilation
(
dilation
);
conv
->
set_dilation_y
(
dilation
);
conv
->
set_groups
(
1
);
conv
->
set_filter_channels
(
conv
->
channels
()
/
conv
->
groups
());
conv
->
set_img_size
(
16
);
conv
->
set_img_size_y
(
8
);
conv
->
set_img_size_y
(
16
);
conv
->
set_output_x
(
outputSize
(
conv
->
img_size
(),
conv
->
filter_size
()
,
(
conv
->
filter_size
()
-
1
)
*
dilation
+
1
,
conv
->
padding
(),
conv
->
stride
(),
/* caffeMode */
true
));
conv
->
set_output_y
(
outputSize
(
conv
->
img_size_y
(),
conv
->
filter_size_y
()
,
(
conv
->
filter_size_y
()
-
1
)
*
dilation
+
1
,
conv
->
padding_y
(),
conv
->
stride_y
(),
/* caffeMode */
true
));
...
...
paddle/operators/CMakeLists.txt
浏览文件 @
a8890110
...
...
@@ -42,6 +42,7 @@ function(op_library TARGET)
endfunction
()
add_subdirectory
(
math
)
cc_test
(
gather_test SRCS gather_test.cc DEPS tensor
)
op_library
(
gather_op SRCS gather_op.cc gather_op.cu
)
...
...
@@ -67,7 +68,7 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
op_library
(
recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor op_registry operator net_op
)
op_library
(
uniform_random_op
SRCS uniform_random_op.cc uniform_random
_op.cu
)
op_library
(
uniform_random_op
SRCS uniform_random_op.cc uniform_random_op.cu
)
op_library
(
lookup_table_op SRCS lookup_table_op.cc lookup_table
_op.cu
)
op_library
(
scale_op SRCS scale_op.cc scale_op.cu DEPS net_op
)
op_library
(
minus_op SRCS minus_op.cc minus_op.cu DEPS scale_op
)
paddle/operators/fill_zeros_like_op.h
浏览文件 @
a8890110
...
...
@@ -26,7 +26,7 @@ class FillZerosLikeKernel : public framework::OpKernel {
auto
*
output
=
context
.
Output
<
framework
::
Tensor
>
(
"Dst"
);
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
output
);
t
.
device
(
context
.
GetEigenDevice
<
Place
>
())
=
t
.
constant
(
T
(
0
));
t
.
device
(
context
.
GetEigenDevice
<
Place
>
())
=
t
.
constant
(
static_cast
<
T
>
(
0
));
}
};
...
...
paddle/operators/lookup_table_op.cc
0 → 100644
浏览文件 @
a8890110
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/lookup_table_op.h"
namespace
paddle
{
namespace
operators
{
class
LookupTableOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
context
)
const
override
{
auto
table_t
=
context
.
Input
<
Tensor
>
(
"W"
);
auto
ids_t
=
context
.
Input
<
Tensor
>
(
"Ids"
);
auto
output_t
=
context
.
Output
<
Tensor
>
(
"Out"
);
output_t
->
Resize
({
ids_t
->
dims
()[
0
],
table_t
->
dims
()[
1
]});
}
};
class
LookupTableOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
LookupTableOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"W"
,
"An input represents embedding tensors,"
" which is a learnable parameter."
);
AddInput
(
"Ids"
,
"An input with type int32 or int64"
"contains the ids to be looked up in W."
);
AddOutput
(
"Out"
,
"The lookup results, which have the same type with W."
);
AddComment
(
"This operator is used to perform lookups on the parameter W,"
"then concatenated into a dense tensor."
);
}
};
class
LookupTableOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
context
)
const
override
{
auto
table
=
context
.
Input
<
Tensor
>
(
"W"
);
auto
d_table
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"W"
));
d_table
->
Resize
(
table
->
dims
());
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
lookup_table
,
ops
::
LookupTableOp
,
ops
::
LookupTableOpMaker
,
lookup_table_grad
,
ops
::
LookupTableOpGrad
);
REGISTER_OP_CPU_KERNEL
(
lookup_table
,
ops
::
LookupTableKernel
<
float
>
);
REGISTER_OP_CPU_KERNEL
(
lookup_table_grad
,
ops
::
LookupTableGradKernel
<
float
>
);
paddle/operators/lookup_table_op.cu
0 → 100644
浏览文件 @
a8890110
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/platform/assert.h"
#include "paddle/platform/cuda_helper.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
,
int
BlockDimX
,
int
BlockDimY
,
int
GridDimX
>
__global__
void
LookupTable
(
T
*
output
,
const
T
*
table
,
const
int32_t
*
ids
,
const
int
N
,
const
int
K
,
const
int
D
)
{
int
idx
=
threadIdx
.
x
;
int
idy
=
blockIdx
.
x
+
threadIdx
.
y
*
GridDimX
;
while
(
idy
<
K
)
{
int
id
=
ids
[
idy
];
PADDLE_ASSERT
(
id
>=
0
);
PADDLE_ASSERT
(
id
<
N
);
T
*
out
=
output
+
idy
*
D
;
const
T
*
tab
=
table
+
id
*
D
;
for
(
int
i
=
idx
;
i
<
D
;
i
+=
BlockDimX
)
{
out
[
i
]
=
tab
[
i
];
}
idy
+=
BlockDimY
*
GridDimX
;
}
}
template
<
typename
T
,
int
BlockDimX
,
int
BlockDimY
,
int
GridDimX
>
__global__
void
LookupTableGrad
(
T
*
table
,
const
T
*
output
,
const
int32_t
*
ids
,
const
int
N
,
const
int
K
,
const
int
D
)
{
int
idx
=
threadIdx
.
x
;
int
idy
=
blockIdx
.
x
+
threadIdx
.
y
*
GridDimX
;
while
(
idy
<
K
)
{
int
id
=
ids
[
idy
];
PADDLE_ASSERT
(
id
>=
0
);
PADDLE_ASSERT
(
id
<
N
);
const
T
*
out
=
output
+
idy
*
D
;
T
*
tab
=
table
+
id
*
D
;
for
(
int
i
=
idx
;
i
<
D
;
i
+=
BlockDimX
)
{
paddle
::
platform
::
CudaAtomicAdd
(
&
tab
[
i
],
out
[
i
]);
}
idy
+=
BlockDimY
*
GridDimX
;
}
}
template
<
typename
T
>
class
LookupTableCUDAKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
table_t
=
context
.
Input
<
Tensor
>
(
"W"
);
auto
ids_t
=
context
.
Input
<
Tensor
>
(
"Ids"
);
auto
output_t
=
context
.
Output
<
Tensor
>
(
"Out"
);
size_t
N
=
table_t
->
dims
()[
0
];
size_t
D
=
table_t
->
dims
()[
1
];
size_t
K
=
product
(
ids_t
->
dims
());
auto
ids
=
ids_t
->
data
<
int32_t
>
();
auto
table
=
table_t
->
data
<
T
>
();
auto
output
=
output_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
dim3
threads
(
128
,
8
);
dim3
grids
(
8
,
1
);
LookupTable
<
T
,
128
,
8
,
8
><<<
grids
,
threads
>>>
(
output
,
table
,
ids
,
N
,
K
,
D
);
}
};
template
<
typename
T
>
class
LookupTableGradCUDAKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
ids_t
=
context
.
Input
<
Tensor
>
(
"Ids"
);
auto
d_output_t
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
d_table_t
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"W"
));
int
N
=
d_table_t
->
dims
()[
0
];
int
D
=
d_table_t
->
dims
()[
1
];
int
K
=
product
(
ids_t
->
dims
());
const
int32_t
*
ids
=
ids_t
->
data
<
int32_t
>
();
const
T
*
d_output
=
d_output_t
->
data
<
T
>
();
T
*
d_table
=
d_table_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
d_table_t
);
t
.
device
(
context
.
GetEigenDevice
<
platform
::
GPUPlace
>
())
=
t
.
constant
(
static_cast
<
T
>
(
0
));
dim3
threads
(
128
,
8
);
dim3
grids
(
8
,
1
);
LookupTableGrad
<
T
,
128
,
8
,
8
><<<
grids
,
threads
>>>
(
d_table
,
d_output
,
ids
,
N
,
K
,
D
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
lookup_table
,
ops
::
LookupTableCUDAKernel
<
float
>
);
REGISTER_OP_GPU_KERNEL
(
lookup_table_grad
,
ops
::
LookupTableGradCUDAKernel
<
float
>
);
paddle/operators/lookup_table_op.h
0 → 100644
浏览文件 @
a8890110
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
class
LookupTableKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
table_t
=
context
.
Input
<
Tensor
>
(
"W"
);
// float tensor
auto
ids_t
=
context
.
Input
<
Tensor
>
(
"Ids"
);
// int tensor
auto
output_t
=
context
.
Output
<
Tensor
>
(
"Out"
);
// float tensor
size_t
N
=
table_t
->
dims
()[
0
];
size_t
D
=
table_t
->
dims
()[
1
];
auto
ids
=
ids_t
->
data
<
int32_t
>
();
auto
table
=
table_t
->
data
<
T
>
();
auto
output
=
output_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
size_t
i
=
0
;
i
<
product
(
ids_t
->
dims
());
++
i
)
{
PADDLE_ENFORCE_LT
(
ids
[
i
],
N
);
PADDLE_ENFORCE_GE
(
ids
[
i
],
0
);
memcpy
(
output
+
i
*
D
,
table
+
ids
[
i
]
*
D
,
D
*
sizeof
(
T
));
}
}
};
template
<
typename
T
>
class
LookupTableGradKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
ids_t
=
context
.
Input
<
Tensor
>
(
"Ids"
);
auto
d_output_t
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
d_table_t
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"W"
));
size_t
N
=
d_table_t
->
dims
()[
0
];
size_t
D
=
d_table_t
->
dims
()[
1
];
auto
ids
=
ids_t
->
data
<
int32_t
>
();
const
T
*
d_output
=
d_output_t
->
data
<
T
>
();
T
*
d_table
=
d_table_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
d_table_t
);
t
.
device
(
context
.
GetEigenDevice
<
platform
::
CPUPlace
>
())
=
t
.
constant
(
static_cast
<
T
>
(
0
));
for
(
size_t
i
=
0
;
i
<
product
(
ids_t
->
dims
());
++
i
)
{
PADDLE_ENFORCE_LT
(
ids
[
i
],
N
);
PADDLE_ENFORCE_GE
(
ids
[
i
],
0
);
for
(
size_t
j
=
0
;
j
<
D
;
++
j
)
{
d_table
[
ids
[
i
]
*
D
+
j
]
+=
d_output
[
i
*
D
+
j
];
}
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/operators/rowwise_add_op.cu
浏览文件 @
a8890110
...
...
@@ -18,3 +18,6 @@
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
rowwise_add
,
ops
::
RowwiseAddKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
REGISTER_OP_GPU_KERNEL
(
rowwise_add_grad
,
ops
::
RowwiseAddGradKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
paddle/platform/cuda_helper.h
0 → 100644
浏览文件 @
a8890110
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cuda.h>
namespace
paddle
{
namespace
platform
{
#define CUDA_ATOMIC_WRAPPER(op, T) \
__device__ __forceinline__ T CudaAtomic##op(T* address, const T val)
#define USE_CUDA_ATOMIC(op, T) \
CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); }
// For atomicAdd.
USE_CUDA_ATOMIC
(
Add
,
float
);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
USE_CUDA_ATOMIC
(
Add
,
double
);
#else
CUDA_ATOMIC_WRAPPER
(
Add
,
double
)
{
unsigned
long
long
int
*
address_as_ull
=
reinterpret_cast
<
unsigned
long
long
int
*>
(
address
);
unsigned
long
long
int
old
=
*
address_as_ull
,
assumed
;
do
{
assumed
=
old
;
old
=
atomicCAS
(
address_as_ull
,
assumed
,
__double_as_longlong
(
val
+
__longlong_as_double
(
assumed
)));
// Note: uses integer comparison to avoid hang in case of NaN
}
while
(
assumed
!=
old
);
return
__longlong_as_double
(
old
);
}
#endif
}
// namespace platform
}
// namespace paddle
paddle/pybind/CMakeLists.txt
浏览文件 @
a8890110
...
...
@@ -15,6 +15,7 @@ cc_library(paddle_pybind SHARED
uniform_random_op
gaussian_random_op
fill_zeros_like_op
lookup_table_op
scale_op
minus_op
)
endif
(
WITH_PYTHON
)
paddle/pybind/pybind.cc
浏览文件 @
a8890110
...
...
@@ -42,6 +42,7 @@ USE_OP(fill_zeros_like);
USE_OP_ITSELF
(
recurrent_op
);
USE_OP
(
gaussian_random
);
USE_OP
(
uniform_random
);
USE_OP
(
lookup_table
);
USE_OP
(
scale
);
USE_OP_ITSELF
(
identity
);
USE_OP
(
minus
);
...
...
proto/ModelConfig.proto
浏览文件 @
a8890110
...
...
@@ -82,6 +82,9 @@ message ConvConfig {
// if not set, use img_size
optional
uint32
img_size_y
=
14
;
optional
uint32
dilation
=
15
[
default
=
1
];
optional
uint32
dilation_y
=
16
[
default
=
1
];
}
message
PoolConfig
{
...
...
python/paddle/trainer/config_parser.py
浏览文件 @
a8890110
...
...
@@ -870,12 +870,16 @@ class Conv(Cfg):
caffe_mode
=
True
,
filter_size_y
=
None
,
padding_y
=
None
,
stride_y
=
None
):
stride_y
=
None
,
dilation
=
None
,
dilation_y
=
None
):
self
.
add_keys
(
locals
())
if
filter_size_y
is
None
:
self
.
filter_size_y
=
filter_size
if
padding_y
is
None
:
self
.
padding_y
=
padding
if
dilation_y
is
None
:
self
.
dilation_y
=
dilation
if
stride_y
is
None
:
self
.
stride_y
=
stride
if
output_x
is
not
None
:
...
...
python/paddle/trainer_config_helpers/layers.py
浏览文件 @
a8890110
...
...
@@ -2342,6 +2342,7 @@ def img_conv_layer(input,
groups
=
1
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
bias_attr
=
None
,
param_attr
=
None
,
shared_biases
=
True
,
...
...
@@ -2349,6 +2350,7 @@ def img_conv_layer(input,
filter_size_y
=
None
,
stride_y
=
None
,
padding_y
=
None
,
dilation_y
=
None
,
trans
=
False
,
layer_type
=
None
):
"""
...
...
@@ -2413,6 +2415,11 @@ def img_conv_layer(input,
:type padding: int|tuple|list
:param padding_y: The y dimension of the padding.
:type padding_y: int
:param dilation: The x dimension of the dilation. Or input a tuple for two
image dimension
:type dilation: int|tuple|list
:param dilation_y: The y dimension of the dilation.
:type dilation_y: int
:param bias_attr: Convolution bias attribute. None means default bias.
False means no bias.
:type bias_attr: ParameterAttribute|False
...
...
@@ -2460,6 +2467,13 @@ def img_conv_layer(input,
else
:
padding_y
=
padding
if
dilation_y
is
None
:
if
isinstance
(
dilation
,
collections
.
Sequence
):
assert
len
(
dilation
)
==
2
dilation
,
dilation_y
=
dilation
else
:
dilation_y
=
dilation
if
param_attr
.
attr
.
get
(
'initial_smart'
):
# special initial for conv layers.
init_w
=
(
2.0
/
(
filter_size
**
2
*
num_channels
))
**
0.5
...
...
@@ -2469,6 +2483,8 @@ def img_conv_layer(input,
param_attr
.
attr
[
"initial_smart"
]
=
False
if
layer_type
:
if
dilation
>
1
or
dilation_y
>
1
:
assert
layer_type
in
[
"cudnn_conv"
,
"cudnn_convt"
]
if
trans
:
assert
layer_type
in
[
"exconvt"
,
"cudnn_convt"
]
else
:
...
...
@@ -2484,11 +2500,13 @@ def img_conv_layer(input,
conv
=
Conv
(
filter_size
=
filter_size
,
padding
=
padding
,
dilation
=
dilation
,
stride
=
stride
,
channels
=
num_channels
,
groups
=
groups
,
filter_size_y
=
filter_size_y
,
padding_y
=
padding_y
,
dilation_y
=
dilation_y
,
stride_y
=
stride_y
),
**
param_attr
.
attr
),
active_type
=
act
.
name
,
...
...
python/paddle/trainer_config_helpers/tests/configs/img_layers.py
浏览文件 @
a8890110
...
...
@@ -12,6 +12,7 @@ img_conv = img_conv_layer(
num_filters
=
64
,
filter_size
=
(
32
,
32
),
padding
=
(
1
,
1
),
dilation
=
(
1
,
1
),
stride
=
(
1
,
1
),
act
=
LinearActivation
())
img_bn
=
batch_norm_layer
(
input
=
img_conv
,
act
=
ReluActivation
())
...
...
python/paddle/v2/framework/tests/CMakeLists.txt
浏览文件 @
a8890110
...
...
@@ -28,4 +28,6 @@ py_test(test_uniform_random_op SRCS test_uniform_random_op.py)
py_test
(
test_recurrent_op SRCS test_recurrent_op.py
)
py_test
(
test_sgd_op SRCS test_sgd_op.py
)
py_test
(
test_gradient_checker SRCS test_gradient_checker.py
)
py_test
(
test_lookup_table SRCS test_lookup_table.py
)
py_test
(
test_scale_and_identity_op SRCS test_scale_and_identity_op.py
)
py_test
(
mnist SRCS mnist.py
)
python/paddle/v2/framework/tests/gradient_checker.py
浏览文件 @
a8890110
...
...
@@ -23,6 +23,10 @@ def grad_var_name(var_name):
return
var_name
+
"@GRAD"
def
empty_var_name
():
return
"@EMPTY@"
def
get_numeric_gradient
(
op
,
input_values
,
output_name
,
...
...
@@ -182,7 +186,7 @@ class GradientChecker(unittest.TestCase):
]
return
outs
def
compare_grad
(
self
,
forward_op
,
input_value
):
def
compare_grad
(
self
,
forward_op
,
input_value
,
no_grad_set
=
None
):
""" Compare the input gradients between CPU and GPU for the given forward
operator.
...
...
@@ -190,15 +194,20 @@ class GradientChecker(unittest.TestCase):
:type forward_op: Operator
:param input_value: input values.
:type input_value: dict{string:numpy.array}
:param no_grad_set: the set of variables names without gradients.
:type no_grad_set: a set of string
:raises: AssertionError, there is different gradient value.
"""
backward_op
=
core
.
Operator
.
backward
(
forward_op
,
set
())
if
no_grad_set
is
None
:
no_grad_set
=
set
()
backward_op
=
core
.
Operator
.
backward
(
forward_op
,
no_grad_set
)
# return if not compile with GPU or not implementing GPU kernel
if
not
(
core
.
is_compile_gpu
()
and
backward_op
.
support_gpu
()):
return
outputs
=
backward_op
.
outputs
()
out_names
=
[
item
for
k
in
outputs
for
item
in
outputs
[
k
]]
out_names
=
filter
(
lambda
x
:
x
!=
empty_var_name
(),
out_names
)
cpu_grads
=
self
.
__get_gradient
(
forward_op
,
backward_op
,
input_value
,
out_names
,
core
.
CPUPlace
())
gpu_grads
=
self
.
__get_gradient
(
forward_op
,
backward_op
,
input_value
,
...
...
python/paddle/v2/framework/tests/mnist.py
0 → 100644
浏览文件 @
a8890110
import
paddle.v2.framework.core
as
core
from
paddle.v2.framework.op
import
Operator
import
numpy
import
paddle.v2
as
paddle
BATCH_SIZE
=
100
scope
=
core
.
Scope
()
place
=
core
.
CPUPlace
()
# if you want to test GPU training, you can use gpu place
# place = core.GPUPlace(0)
dev_ctx
=
core
.
DeviceContext
.
create
(
place
)
init_net
=
core
.
Net
.
create
()
forward_net
=
core
.
Net
.
create
()
backward_net
=
None
optimize_net
=
core
.
Net
.
create
()
def
atomic_id
():
id
=
0
while
True
:
yield
id
id
+=
1
uniq_id
=
atomic_id
().
next
def
data_layer
(
name
,
dims
):
var
=
scope
.
new_var
(
name
)
tensor
=
var
.
get_tensor
()
tensor
.
set_dims
(
dims
)
# 1 is batch size holder.
return
name
def
feed_data
(
name
,
data
):
assert
isinstance
(
data
,
numpy
.
ndarray
)
tensor
=
scope
.
find_var
(
name
).
get_tensor
()
tensor
.
set_dims
(
data
.
shape
)
if
data
.
dtype
==
numpy
.
dtype
(
'int32'
):
tensor
.
alloc_int
(
place
)
elif
data
.
dtype
==
numpy
.
dtype
(
'float32'
):
tensor
.
alloc_float
(
place
)
else
:
raise
ValueError
(
"data type not supported"
)
tensor
.
set
(
data
,
place
)
def
grad_var_name
(
var_name
):
return
var_name
+
"@GRAD"
def
sgd_optimizer
(
net
,
param_name
,
learning_rate
=
0.005
):
grad_name
=
grad_var_name
(
param_name
)
optimize_op
=
Operator
(
"sgd"
,
param
=
param_name
,
grad
=
grad_name
,
param_out
=
param_name
,
learning_rate
=
learning_rate
)
net
.
append_op
(
optimize_op
)
# should use operator and add these to the init_network
def
init_param
(
net
,
param_name
,
dims
):
scope
.
new_var
(
param_name
)
op
=
Operator
(
"uniform_random"
,
Out
=
param_name
,
dims
=
dims
,
min
=-
0.5
,
max
=
0.5
,
seed
=
10
)
op
.
infer_shape
(
scope
)
net
.
append_op
(
op
)
# fc_layer
def
fc_layer
(
net
,
input
,
size
,
act
=
"softmax"
,
bias
=
True
,
param
=
None
,
name
=
None
):
"""
Add a fc layer to net
:param input: input variable name.
:type input: str
:param size: fully connected layer size.
:param act: activation name
:param param: parameter attribute, used for initialize parameters.
:param bias: bias attribute. False will not have a bias.
:param name: the name of fc layer. If not set, model will generate a
readable name
:return: output variable name.
"""
if
name
is
None
:
name
=
'fc_%d'
%
uniq_id
()
if
not
isinstance
(
name
,
str
):
raise
ValueError
(
"name should be string"
)
input_dims
=
scope
.
find_var
(
input
).
get_tensor
().
get_dims
()
w_name
=
param
or
name
+
".w"
init_param
(
net
=
init_net
,
param_name
=
w_name
,
dims
=
[
input_dims
[
1
],
size
])
sgd_optimizer
(
net
=
optimize_net
,
param_name
=
w_name
,
learning_rate
=
0.01
)
pre_activation
=
name
+
".mul.out"
scope
.
new_var
(
pre_activation
)
mul_op
=
Operator
(
"mul"
,
X
=
input
,
Y
=
w_name
,
Out
=
pre_activation
)
net
.
append_op
(
mul_op
)
# create bias variable if needed
if
bias
:
bias_name
=
name
+
".b"
init_param
(
net
=
init_net
,
param_name
=
bias_name
,
dims
=
[
size
])
sgd_optimizer
(
net
=
optimize_net
,
param_name
=
bias_name
,
learning_rate
=
0.001
)
bias_out
=
name
+
".rowwise_add.out"
scope
.
new_var
(
bias_out
)
rowwise_append_op
=
Operator
(
"rowwise_add"
,
X
=
pre_activation
,
b
=
bias_name
,
Out
=
bias_out
)
net
.
append_op
(
rowwise_append_op
)
pre_activation
=
bias_out
activation_op
=
Operator
(
act
,
X
=
pre_activation
,
Y
=
name
)
net
.
append_op
(
activation_op
)
scope
.
new_var
(
name
)
net
.
infer_shape
(
scope
)
return
name
def
cross_entropy_layer
(
net
,
input
,
label
):
cost_name
=
'cross_entropy_%d'
%
uniq_id
()
cross_entropy_op
=
Operator
(
"onehot_cross_entropy"
,
X
=
input
,
label
=
label
,
Y
=
cost_name
)
net
.
append_op
(
cross_entropy_op
)
scope
.
new_var
(
cost_name
)
net
.
infer_shape
(
scope
)
return
cost_name
def
create_backward_net
(
forward_net
):
net
=
core
.
Operator
.
backward
(
forward_net
,
set
())
for
input
in
net
.
inputs
()[
"all"
]:
var
=
scope
.
new_var
(
input
)
var
.
get_tensor
()
for
output
in
net
.
outputs
()[
"all"
]:
var
=
scope
.
new_var
(
output
)
var
.
get_tensor
()
return
net
def
debug_print_op
(
op
):
print
(
"==============="
+
op
.
type
()
+
"=============="
)
print
(
"***inputs:***"
)
for
input
in
op
.
inputs
()[
"all"
]:
print
input
,
scope
.
find_var
(
input
).
get_tensor
().
get_dims
()
print
(
"
\n
***outputs:***"
)
for
output
in
op
.
outputs
()[
"all"
]:
print
output
,
scope
.
find_var
(
output
).
get_tensor
().
get_dims
()
print
(
""
)
print
(
""
)
def
set_cost
(
cost
):
cost_shape
=
numpy
.
array
(
scope
.
find_var
(
cost
).
get_tensor
()).
shape
cost_grad
=
\
scope
.
find_var
(
grad_var_name
(
cost
)).
get_tensor
()
cost_grad
.
set_dims
(
cost_shape
)
cost_grad
.
alloc_float
(
place
)
cost_grad
.
set
(
numpy
.
ones
(
cost_shape
).
astype
(
"float32"
),
place
)
def
get_cost_mean
(
cost
):
cost_data
=
numpy
.
array
(
scope
.
find_var
(
cost
).
get_tensor
())
return
cost_data
.
sum
()
/
len
(
cost_data
)
def
error_rate
(
predict
,
label
):
predict_var
=
numpy
.
array
(
scope
.
find_var
(
predict
).
get_tensor
()).
argmax
(
axis
=
1
)
label
=
numpy
.
array
(
scope
.
find_var
(
label
).
get_tensor
())
error_num
=
numpy
.
sum
(
predict_var
!=
label
)
return
error_num
/
float
(
len
(
label
))
images
=
data_layer
(
name
=
'pixel'
,
dims
=
[
BATCH_SIZE
,
784
])
labels
=
data_layer
(
name
=
'label'
,
dims
=
[
BATCH_SIZE
])
fc1
=
fc_layer
(
net
=
forward_net
,
input
=
images
,
size
=
100
,
act
=
"sigmoid"
)
fc2
=
fc_layer
(
net
=
forward_net
,
input
=
fc1
,
size
=
100
,
act
=
"sigmoid"
)
predict
=
fc_layer
(
net
=
forward_net
,
input
=
fc2
,
size
=
100
,
act
=
"softmax"
)
cost
=
cross_entropy_layer
(
net
=
forward_net
,
input
=
predict
,
label
=
labels
)
init_net
.
complete_add_op
(
True
)
forward_net
.
complete_add_op
(
True
)
backward_net
=
create_backward_net
(
forward_net
)
optimize_net
.
complete_add_op
(
True
)
print
(
init_net
)
print
(
forward_net
)
print
(
backward_net
)
print
(
optimize_net
)
debug_print_op
(
forward_net
)
debug_print_op
(
backward_net
)
debug_print_op
(
optimize_net
)
train_reader
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
paddle
.
dataset
.
mnist
.
train
(),
buf_size
=
8192
),
batch_size
=
BATCH_SIZE
)
def
test
(
cost_name
):
test_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
test
(),
batch_size
=
BATCH_SIZE
)
cost
=
[]
error
=
[]
for
data
in
test_reader
():
image_data
=
numpy
.
array
(
map
(
lambda
x
:
x
[
0
],
data
)).
astype
(
"float32"
)
label_data
=
numpy
.
array
(
map
(
lambda
x
:
x
[
1
],
data
)).
astype
(
"int32"
)
feed_data
(
images
,
image_data
)
feed_data
(
labels
,
label_data
)
forward_net
.
infer_shape
(
scope
)
forward_net
.
run
(
scope
,
dev_ctx
)
cost
.
append
(
get_cost_mean
(
cost_name
))
error
.
append
(
error_rate
(
predict
,
"label"
))
print
(
"cost="
+
str
(
sum
(
cost
)
/
float
(
len
(
cost
)))
+
" error_rate="
+
str
(
sum
(
error
)
/
float
(
len
(
error
))))
PASS_NUM
=
1
init_net
.
run
(
scope
,
dev_ctx
)
for
pass_id
in
range
(
PASS_NUM
):
batch_id
=
0
for
data
in
train_reader
():
image_data
=
numpy
.
array
(
map
(
lambda
x
:
x
[
0
],
data
)).
astype
(
"float32"
)
label_data
=
numpy
.
array
(
map
(
lambda
x
:
x
[
1
],
data
)).
astype
(
"int32"
)
feed_data
(
images
,
image_data
)
feed_data
(
labels
,
label_data
)
forward_net
.
infer_shape
(
scope
)
forward_net
.
run
(
scope
,
dev_ctx
)
set_cost
(
cost
)
backward_net
.
infer_shape
(
scope
)
backward_net
.
run
(
scope
,
dev_ctx
)
optimize_net
.
run
(
scope
,
dev_ctx
)
if
batch_id
%
100
==
0
:
print
(
"pass["
+
str
(
pass_id
)
+
"] batch_id["
+
str
(
batch_id
)
+
"]"
)
test
(
cost
)
batch_id
=
batch_id
+
1
python/paddle/v2/framework/tests/test_lookup_table.py
0 → 100644
浏览文件 @
a8890110
import
unittest
import
numpy
as
np
from
op_test_util
import
OpTestMeta
from
gradient_checker
import
GradientChecker
,
create_op
class
TestSigmoidOp
(
unittest
.
TestCase
):
__metaclass__
=
OpTestMeta
def
setUp
(
self
):
self
.
type
=
'lookup_table'
table
=
np
.
random
.
random
((
17
,
31
)).
astype
(
'float32'
)
ids
=
np
.
random
.
randint
(
0
,
17
,
4
).
astype
(
'int32'
)
self
.
inputs
=
{
'W'
:
table
,
'Ids'
:
ids
}
self
.
outputs
=
{
'Out'
:
table
[
ids
]}
class
TestSigmoidGradOp
(
GradientChecker
):
def
test_grad
(
self
):
op
=
create_op
(
'lookup_table'
)
table
=
np
.
random
.
random
((
17
,
31
)).
astype
(
'float32'
)
ids
=
np
.
random
.
randint
(
0
,
17
,
4
).
astype
(
'int32'
)
inputs
=
{
'W'
:
table
,
'Ids'
:
ids
}
# comapre gradients
self
.
compare_grad
(
op
,
inputs
,
set
([
'Ids'
]))
# check gradients
self
.
check_grad
(
op
,
inputs
,
set
(
'W'
),
'Out'
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录