Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
d5768ebc
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
d5768ebc
编写于
8月 18, 2017
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix above comments
上级
38cc5dad
变更
13
显示空白变更内容
内联
并排
Showing
13 changed file
with
245 addition
and
199 deletion
+245
-199
paddle/cuda/include/hl_matrix.h
paddle/cuda/include/hl_matrix.h
+38
-20
paddle/cuda/include/stub/hl_matrix_stub.h
paddle/cuda/include/stub/hl_matrix_stub.h
+33
-14
paddle/cuda/src/hl_cuda_matrix.cu
paddle/cuda/src/hl_cuda_matrix.cu
+42
-42
paddle/gserver/layers/Conv3DLayer.cpp
paddle/gserver/layers/Conv3DLayer.cpp
+18
-8
paddle/gserver/layers/Conv3DLayer.h
paddle/gserver/layers/Conv3DLayer.h
+4
-10
paddle/gserver/layers/ConvBaseLayer.cpp
paddle/gserver/layers/ConvBaseLayer.cpp
+3
-23
paddle/gserver/layers/ConvBaseLayer.h
paddle/gserver/layers/ConvBaseLayer.h
+0
-1
paddle/gserver/layers/CudnnConvBaseLayer.cpp
paddle/gserver/layers/CudnnConvBaseLayer.cpp
+18
-0
paddle/gserver/layers/DeConv3DLayer.cpp
paddle/gserver/layers/DeConv3DLayer.cpp
+27
-19
paddle/gserver/layers/DeConv3DLayer.h
paddle/gserver/layers/DeConv3DLayer.h
+19
-25
paddle/gserver/layers/ExpandConvBaseLayer.cpp
paddle/gserver/layers/ExpandConvBaseLayer.cpp
+20
-1
paddle/gserver/tests/test_LayerGrad.cpp
paddle/gserver/tests/test_LayerGrad.cpp
+16
-15
paddle/math/tests/test_matrixCompare.cpp
paddle/math/tests/test_matrixCompare.cpp
+7
-21
未找到文件。
paddle/cuda/include/hl_matrix.h
浏览文件 @
d5768ebc
...
...
@@ -244,12 +244,21 @@ extern void hl_matrix_rotate(
* @param[out] matDst output matrix.
*
*/
extern
void
hl_matrix_vol2Col
(
real
*
matSrc
,
int
channel
,
int
depth
,
int
height
,
int
width
,
int
filterD
,
int
filterH
,
int
filterW
,
int
strideD
,
int
strideH
,
int
strideW
,
int
paddingD
,
int
paddingH
,
int
paddingW
,
real
*
matDst
);
extern
void
hl_matrix_vol2Col
(
const
real
*
dataSrc
,
int
channels
,
int
depth
,
int
height
,
int
width
,
int
filterD
,
int
filterH
,
int
filterW
,
int
strideD
,
int
strideH
,
int
strideW
,
int
paddingD
,
int
paddingH
,
int
paddingW
,
real
*
dataDst
);
/**
* @brief Matrix col2Vol: Convert col matrix into 3D volume
...
...
@@ -273,13 +282,22 @@ extern void hl_matrix_vol2Col(real* matSrc,
* @param[in] alpha input
*
*/
extern
void
hl_matrix_col2Vol
(
real
*
matDst
,
int
channels
,
int
depth
,
int
height
,
int
width
,
int
filterD
,
int
filterH
,
int
filterW
,
int
strideD
,
int
strideH
,
int
strideW
,
int
paddingD
,
int
paddingH
,
int
paddingW
,
real
*
matSrc
,
real
alpha
,
real
beta
);
extern
void
hl_matrix_col2Vol
(
real
*
dataDst
,
int
channels
,
int
depth
,
int
height
,
int
width
,
int
filterD
,
int
filterH
,
int
filterW
,
int
strideD
,
int
strideH
,
int
strideW
,
int
paddingD
,
int
paddingH
,
int
paddingW
,
const
real
*
dataSrc
,
real
alpha
,
real
beta
);
#endif
/* HL_MATRIX_H_ */
paddle/cuda/include/stub/hl_matrix_stub.h
浏览文件 @
d5768ebc
...
...
@@ -99,19 +99,38 @@ inline void hl_matrix_collect_shared_bias(real* B_d,
inline
void
hl_matrix_rotate
(
real
*
mat
,
real
*
matRot
,
int
dimM
,
int
dimN
,
bool
clockWise
)
{}
inline
void
hl_matrix_vol2Col
(
real
*
data
,
int
channels
,
int
depth
,
int
height
,
int
width
,
int
filterD
,
int
filterH
,
int
filterW
,
int
strideD
,
int
strideH
,
int
strideW
,
int
paddingD
,
int
paddingH
,
int
paddingW
,
real
*
data_col
)
{}
inline
void
hl_matrix_col2Vol
(
real
*
data
,
int
channels
,
int
depth
,
int
height
,
int
width
,
int
filterD
,
int
filterH
,
int
filterW
,
int
strideD
,
int
strideH
,
int
strideW
,
int
paddingD
,
int
paddingH
,
int
paddingW
,
real
*
data_Im
,
real
alpha
,
real
beta
)
{}
inline
void
hl_matrix_vol2Col
(
const
real
*
dataSrc
,
int
channels
,
int
depth
,
int
height
,
int
width
,
int
filterD
,
int
filterH
,
int
filterW
,
int
strideD
,
int
strideH
,
int
strideW
,
int
paddingD
,
int
paddingH
,
int
paddingW
,
real
*
dataDst
)
{}
inline
void
hl_matrix_col2Vol
(
real
*
dataDst
,
int
channels
,
int
depth
,
int
height
,
int
width
,
int
filterD
,
int
filterH
,
int
filterW
,
int
strideD
,
int
strideH
,
int
strideW
,
int
paddingD
,
int
paddingH
,
int
paddingW
,
const
real
*
dataSrc
,
real
alpha
,
real
beta
)
{}
#endif // HL_MATRIX_STUB_H_
paddle/cuda/src/hl_cuda_matrix.cu
浏览文件 @
d5768ebc
...
...
@@ -594,7 +594,7 @@ void hl_matrix_rotate(
}
__global__
void
keMatrixVol2Col
(
int
num_kernels
,
real
*
dataSrc
,
const
real
*
dataSrc
,
real
*
dataDst
,
int
depth
,
int
height
,
...
...
@@ -643,7 +643,7 @@ __global__ void keMatrixVol2Col(int num_kernels,
}
}
void
hl_matrix_vol2Col
(
real
*
dataSrc
,
void
hl_matrix_vol2Col
(
const
real
*
dataSrc
,
int
channels
,
int
depth
,
int
height
,
...
...
@@ -666,7 +666,7 @@ void hl_matrix_vol2Col(real* dataSrc,
const
int
threads
=
512
;
const
int
blocks
=
DIVUP
(
num_kernels
,
threads
);
keMatrixVol2Col
<<<
blocks
,
threads
>>>
(
num_kernels
,
keMatrixVol2Col
<<<
blocks
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
num_kernels
,
dataSrc
,
dataDst
,
depth
,
...
...
@@ -689,7 +689,7 @@ void hl_matrix_vol2Col(real* dataSrc,
__global__
void
keMatrixCol2Vol
(
int
num_kernels
,
real
*
dataDst
,
real
*
dataSrc
,
const
real
*
dataSrc
,
int
depth
,
int
height
,
int
width
,
...
...
@@ -759,7 +759,7 @@ void hl_matrix_col2Vol(real* dataDst,
int
paddingD
,
int
paddingH
,
int
paddingW
,
real
*
dataSrc
,
const
real
*
dataSrc
,
real
alpha
,
real
beta
)
{
int
depth_col
=
(
depth
+
2
*
paddingD
-
filterD
)
/
strideD
+
1
;
...
...
@@ -770,7 +770,7 @@ void hl_matrix_col2Vol(real* dataDst,
const
int
threads
=
512
;
const
int
blocks
=
DIVUP
(
num_kernels
,
threads
);
keMatrixCol2Vol
<<<
blocks
,
threads
>>>
(
num_kernels
,
keMatrixCol2Vol
<<<
blocks
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
num_kernels
,
dataDst
,
dataSrc
,
depth
,
...
...
paddle/gserver/layers/Conv3DLayer.cpp
浏览文件 @
d5768ebc
...
...
@@ -28,16 +28,26 @@ bool Conv3DLayer::init(const LayerMap &layerMap,
const
ConvConfig
&
conf
=
inputConfig
.
conv_conf
();
M_
.
push_back
(
numFilters_
/
conf
.
groups
());
K_
.
push_back
(
filterPixels_
[
index
]
*
filterChannels_
[
index
]);
if
(
nullptr
!=
weights_
[
index
]
->
getW
())
weights_
[
index
]
->
getW
()
->
reshape
(
weights_
[
index
]
->
getW
()
->
getWidth
(),
weights_
[
index
]
->
getW
()
->
getHeight
());
if
(
nullptr
!=
weights_
[
index
]
->
getWGrad
())
weights_
[
index
]
->
getWGrad
()
->
reshape
(
weights_
[
index
]
->
getWGrad
()
->
getWidth
(),
weights_
[
index
]
->
getWGrad
()
->
getHeight
());
// create a new weight
size_t
height
,
width
;
width
=
filterPixels_
[
index
]
*
filterChannels_
[
index
];
height
=
numFilters_
;
CHECK_EQ
(
parameters_
[
index
]
->
getSize
(),
width
*
height
);
Weight
*
w
=
new
Weight
(
height
,
width
,
parameters_
[
index
]);
weights_
.
emplace_back
(
w
);
++
index
;
}
CHECK
(
inputLayers_
.
size
()
==
parameters_
.
size
());
if
(
biasParameter_
.
get
())
{
if
(
sharedBiases_
)
{
CHECK_EQ
((
size_t
)
numFilters_
,
biasParameter_
->
getSize
());
biases_
=
std
::
unique_ptr
<
Weight
>
(
new
Weight
(
1
,
numFilters_
,
biasParameter_
));
}
else
{
biases_
=
std
::
unique_ptr
<
Weight
>
(
new
Weight
(
1
,
getSize
(),
biasParameter_
));
}
}
return
true
;
}
...
...
paddle/gserver/layers/Conv3DLayer.h
浏览文件 @
d5768ebc
...
...
@@ -12,13 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "ConvBaseLayer.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/MathUtils.h"
#include
<vector>
#include
"paddle/math/Matrix.h"
namespace
paddle
{
...
...
@@ -30,21 +28,17 @@ namespace paddle {
class
Conv3DLayer
:
public
ConvBaseLayer
{
public:
explicit
Conv3DLayer
(
const
LayerConfig
&
config
)
:
ConvBaseLayer
(
config
)
{}
~
Conv3DLayer
()
{}
bool
init
(
const
LayerMap
&
layerMap
,
const
ParameterMap
&
parameterMap
);
size_t
getSize
();
bool
init
(
const
LayerMap
&
layerMap
,
const
ParameterMap
&
parameterMap
);
void
forward
(
PassType
passType
);
void
addBias
();
void
backward
(
const
UpdateCallback
&
callback
);
void
bpropBiases
();
void
bpropData
(
int
i
);
void
bpropWeights
(
int
i
);
size_t
getSize
();
protected:
// Figure out the dimensions for individual gemms.
...
...
paddle/gserver/layers/ConvBaseLayer.cpp
浏览文件 @
d5768ebc
...
...
@@ -21,8 +21,7 @@ bool ConvBaseLayer::init(const LayerMap& layerMap,
const
ParameterMap
&
parameterMap
)
{
/* Initialize the basic parent class */
Layer
::
init
(
layerMap
,
parameterMap
);
isDeconv_
=
(
config_
.
type
()
==
"exconv"
||
config_
.
type
()
==
"cudnn_conv"
||
config_
.
type
()
==
"conv3d"
||
config_
.
type
()
==
"deconv3d"
)
isDeconv_
=
(
config_
.
type
()
==
"exconv"
||
config_
.
type
()
==
"cudnn_conv"
)
?
false
:
true
;
...
...
@@ -56,28 +55,9 @@ bool ConvBaseLayer::init(const LayerMap& layerMap,
}
CHECK
(
inputLayers_
.
size
()
==
parameters_
.
size
());
for
(
size_t
i
=
0
;
i
<
inputLayers_
.
size
();
i
++
)
{
size_t
height
,
width
;
height
=
filterPixels_
[
i
]
*
filterChannels_
[
i
];
width
=
(
!
isDeconv_
)
?
numFilters_
:
channels_
[
i
];
// create a new weight
CHECK_EQ
(
parameters_
[
i
]
->
getSize
(),
width
*
height
);
Weight
*
w
=
new
Weight
(
height
,
width
,
parameters_
[
i
]);
weights_
.
emplace_back
(
w
);
}
/* initialize the biases_ */
if
(
biasParameter_
.
get
())
{
if
(
sharedBiases_
)
{
CHECK_EQ
((
size_t
)
numFilters_
,
biasParameter_
->
getSize
());
biases_
=
std
::
unique_ptr
<
Weight
>
(
new
Weight
(
1
,
numFilters_
,
biasParameter_
));
}
else
{
biases_
=
std
::
unique_ptr
<
Weight
>
(
new
Weight
(
1
,
getSize
(),
biasParameter_
));
}
}
// create new weights_ in derived class
// create new biases_ in derived class
// default caffe model
caffeMode_
=
true
;
...
...
paddle/gserver/layers/ConvBaseLayer.h
浏览文件 @
d5768ebc
...
...
@@ -23,7 +23,6 @@ namespace paddle {
* with learned filters and (optionally) adds biases.
*/
class
ConvBaseLayer
:
public
Layer
{
protected:
typedef
std
::
vector
<
int
>
IntV
;
...
...
paddle/gserver/layers/CudnnConvBaseLayer.cpp
浏览文件 @
d5768ebc
...
...
@@ -46,8 +46,26 @@ bool CudnnConvBaseLayer::init(const LayerMap &layerMap,
projConf_
.
emplace_back
(
conf
);
projections_
.
emplace_back
(
Projection
::
create
(
*
projConf_
[
i
],
parameters_
[
i
],
useGpu_
));
// create a new weight
size_t
height
,
width
;
height
=
filterPixels_
[
i
]
*
filterChannels_
[
i
];
width
=
(
!
isDeconv_
)
?
numFilters_
:
channels_
[
i
];
CHECK_EQ
(
parameters_
[
i
]
->
getSize
(),
width
*
height
);
Weight
*
w
=
new
Weight
(
height
,
width
,
parameters_
[
i
]);
weights_
.
emplace_back
(
w
);
}
if
(
biasParameter_
.
get
())
{
if
(
sharedBiases_
)
{
CHECK_EQ
((
size_t
)
numFilters_
,
biasParameter_
->
getSize
());
biases_
=
std
::
unique_ptr
<
Weight
>
(
new
Weight
(
numFilters_
,
1
,
biasParameter_
));
}
else
{
biases_
=
std
::
unique_ptr
<
Weight
>
(
new
Weight
(
getSize
(),
1
,
biasParameter_
));
}
}
if
(
biases_
.
get
()
&&
sharedBiases_
)
{
hl_create_tensor_descriptor
(
&
biasDesc_
);
hl_create_tensor_descriptor
(
&
outputDesc_
);
...
...
paddle/gserver/layers/DeConv3DLayer.cpp
浏览文件 @
d5768ebc
...
...
@@ -20,9 +20,6 @@ namespace paddle {
REGISTER_LAYER
(
deconv3d
,
DeConv3DLayer
);
#define DECONV_OUTPUT_SIZE(IN_SIZE, STRID, PAD, KSIZE) \
(((IN_SIZE)-1) * (STRID)-2 * (PAD) + (KSIZE))
bool
DeConv3DLayer
::
init
(
const
LayerMap
&
layerMap
,
const
ParameterMap
&
parameterMap
)
{
if
(
!
ConvBaseLayer
::
init
(
layerMap
,
parameterMap
))
return
false
;
...
...
@@ -32,14 +29,25 @@ bool DeConv3DLayer::init(const LayerMap &layerMap,
for
(
int
index
=
0
;
index
<
config_
.
inputs
().
size
();
++
index
)
{
M_
.
push_back
(
filterChannels_
[
index
]);
K_
.
push_back
(
filterPixels_
[
index
]
*
(
numFilters_
/
groups_
[
index
]));
if
(
weights_
[
index
]
->
getW
())
weights_
[
index
]
->
getW
()
->
reshape
(
filterPixels_
[
index
]
*
numFilters_
,
filterChannels_
[
index
]);
if
(
weights_
[
index
]
->
getWGrad
())
weights_
[
index
]
->
getWGrad
()
->
reshape
(
filterPixels_
[
index
]
*
numFilters_
,
filterChannels_
[
index
]);
}
CHECK
(
inputLayers_
.
size
()
==
parameters_
.
size
());
// create a new weight
size_t
height
,
width
;
height
=
filterPixels_
[
index
]
*
numFilters_
;
width
=
filterChannels_
[
index
];
CHECK_EQ
(
parameters_
[
index
]
->
getSize
(),
width
*
height
);
Weight
*
w
=
new
Weight
(
height
,
width
,
parameters_
[
index
]);
weights_
.
emplace_back
(
w
);
}
if
(
biasParameter_
.
get
())
{
if
(
sharedBiases_
)
{
CHECK_EQ
((
size_t
)
numFilters_
,
biasParameter_
->
getSize
());
biases_
=
std
::
unique_ptr
<
Weight
>
(
new
Weight
(
1
,
numFilters_
,
biasParameter_
));
}
else
{
biases_
=
std
::
unique_ptr
<
Weight
>
(
new
Weight
(
1
,
getSize
(),
biasParameter_
));
}
}
return
true
;
}
...
...
@@ -52,22 +60,22 @@ size_t DeConv3DLayer::getSize() {
outputW_
.
clear
();
outputD_
.
clear
();
N_
.
clear
();
N
o
_
.
clear
();
N
Out
_
.
clear
();
size_t
layerSize
=
0
;
for
(
size_t
i
=
0
;
i
<
inputLayers_
.
size
();
++
i
)
{
// imgSizeH_.push_back(inputLayers_[i]->getOutput().getFrameHeight());
// imgSizeW_.push_back(inputLayers_[i]->getOutput().getFrameWidth());
// imgSizeD_.push_back(inputLayers_[i]->getOutput().getFrameDepth());
outputW_
.
push_back
(
DECONV_OUTPUT_SIZE
(
im
gSizeW_
[
i
],
stride_
[
i
],
padding_
[
i
],
filterSize_
[
i
]
));
outputH_
.
push_back
(
DECONV_OUTPUT_SIZE
(
imgSizeH_
[
i
],
strideY_
[
i
],
paddingY_
[
i
],
filterSizeY_
[
i
]
));
outputD_
.
push_back
(
DECONV_OUTPUT_SIZE
(
imgSizeD_
[
i
],
strideZ_
[
i
],
paddingZ_
[
i
],
filterSizeZ_
[
i
]
));
N
o
_
.
push_back
(
outputD_
[
i
]
*
outputH_
[
i
]
*
outputW_
[
i
]);
outputW_
.
push_back
(
im
ageSize
(
imgSizeW_
[
i
],
filterSize_
[
i
],
padding_
[
i
],
stride_
[
i
],
true
));
outputH_
.
push_back
(
imageSize
(
imgSizeH_
[
i
],
filterSizeY_
[
i
],
paddingY_
[
i
],
strideY_
[
i
],
true
));
outputD_
.
push_back
(
imageSize
(
imgSizeD_
[
i
],
filterSizeZ_
[
i
],
paddingZ_
[
i
],
strideZ_
[
i
],
true
));
N
Out
_
.
push_back
(
outputD_
[
i
]
*
outputH_
[
i
]
*
outputW_
[
i
]);
N_
.
push_back
(
imgSizeD_
[
i
]
*
imgSizeH_
[
i
]
*
imgSizeW_
[
i
]);
CHECK
(
layerSize
==
0
||
N_
[
i
]
*
size_t
(
numFilters_
)
==
layerSize
);
layerSize
+=
N
o
_
[
i
]
*
numFilters_
;
layerSize
+=
N
Out
_
[
i
]
*
numFilters_
;
}
getOutput
().
setFrameHeight
(
outputH_
[
0
]);
getOutput
().
setFrameWidth
(
outputW_
[
0
]);
...
...
paddle/gserver/layers/DeConv3DLayer.h
浏览文件 @
d5768ebc
...
...
@@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "ConvBaseLayer.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/MathUtils.h"
#include
<vector>
#include
"paddle/math/Matrix.h"
namespace
paddle
{
...
...
@@ -30,28 +29,23 @@ namespace paddle {
class
DeConv3DLayer
:
public
ConvBaseLayer
{
public:
explicit
DeConv3DLayer
(
const
LayerConfig
&
config
)
:
ConvBaseLayer
(
config
)
{}
~
DeConv3DLayer
()
{}
bool
init
(
const
LayerMap
&
layerMap
,
const
ParameterMap
&
parameterMap
);
size_t
getSize
();
bool
init
(
const
LayerMap
&
layerMap
,
const
ParameterMap
&
parameterMap
);
void
forward
(
PassType
passType
);
void
addBias
();
void
backward
(
const
UpdateCallback
&
callback
);
void
bpropBiases
();
void
bpropData
(
int
i
);
void
bpropWeights
(
int
i
);
size_t
getSize
();
protected:
// Figure out the dimensions for individual gemms.
IntV
M_
;
/// numFilters_ / filter_group_;
IntV
N_
;
/// channels_ * filterSizeZ_ * filterSize_ * filterSizeY_
IntV
K_
;
/// outputD_ * outputH_ * outputW_
IntV
No
_
;
IntV
NOut
_
;
MatrixPtr
colBuf_
;
};
...
...
paddle/gserver/layers/ExpandConvBaseLayer.cpp
浏览文件 @
d5768ebc
...
...
@@ -22,12 +22,31 @@ bool ExpandConvBaseLayer::init(const LayerMap &layerMap,
/* Initialize the basic convolutional parent class */
ConvBaseLayer
::
init
(
layerMap
,
parameterMap
);
int
index
=
0
;
for
(
auto
&
inputConfig
:
config_
.
inputs
())
{
const
ConvConfig
&
conf
=
inputConfig
.
conv_conf
();
/* Consistent caffe mode for multiple input */
caffeMode_
=
conf
.
caffe_mode
();
}
// create a new weight
size_t
height
,
width
;
height
=
filterPixels_
[
index
]
*
filterChannels_
[
index
];
width
=
(
!
isDeconv_
)
?
numFilters_
:
channels_
[
index
];
CHECK_EQ
(
parameters_
[
index
]
->
getSize
(),
width
*
height
);
Weight
*
w
=
new
Weight
(
height
,
width
,
parameters_
[
index
]);
weights_
.
emplace_back
(
w
);
index
++
;
}
if
(
biasParameter_
.
get
())
{
if
(
sharedBiases_
)
{
CHECK_EQ
((
size_t
)
numFilters_
,
biasParameter_
->
getSize
());
biases_
=
std
::
unique_ptr
<
Weight
>
(
new
Weight
(
numFilters_
,
1
,
biasParameter_
));
}
else
{
biases_
=
std
::
unique_ptr
<
Weight
>
(
new
Weight
(
getSize
(),
1
,
biasParameter_
));
}
}
getOutputSize
();
return
true
;
...
...
paddle/gserver/tests/test_LayerGrad.cpp
浏览文件 @
d5768ebc
...
...
@@ -2019,7 +2019,7 @@ void test3DConvLayer(const string& type, bool trans, bool useGpu) {
const
int
CHANNELS
=
3
;
const
int
IMAGE_SIZE
=
9
;
const
int
IMAGE_SIZE_Y
=
9
;
const
int
IMAGE_SIZE_Z
=
9
;
// 2, 3, 5, 5, 5
const
int
IMAGE_SIZE_Z
=
9
;
TestConfig
config
;
config
.
biasSize
=
NUM_FILTERS
;
...
...
@@ -2084,10 +2084,6 @@ TEST(Layer, test3DConvLayer) {
#endif
}
int
deConvOutputSize
(
int
inSize
,
int
kSize
,
int
pad
,
int
stride
)
{
return
(
inSize
-
1
)
*
stride
-
2
*
pad
+
kSize
;
}
void
test3DDeConvLayer
(
const
string
&
type
,
bool
trans
,
bool
useGpu
)
{
// filter size
const
int
NUM_FILTERS
=
6
;
...
...
@@ -2126,16 +2122,21 @@ void test3DDeConvLayer(const string& type, bool trans, bool useGpu) {
conv
->
set_img_size
(
IMAGE_SIZE
);
conv
->
set_img_size_y
(
IMAGE_SIZE_Y
);
conv
->
set_img_size_z
(
IMAGE_SIZE_Z
);
conv
->
set_output_x
(
deConvOutputSize
(
conv
->
img_size
(),
conv
->
filter_size
(),
conv
->
padding
(),
conv
->
stride
()));
conv
->
set_output_y
(
deConvOutputSize
(
conv
->
img_size_y
(),
conv
->
set_output_x
(
imageSize
(
conv
->
img_size
(),
conv
->
filter_size
(),
conv
->
padding
(),
conv
->
stride
(),
true
));
conv
->
set_output_y
(
imageSize
(
conv
->
img_size_y
(),
conv
->
filter_size_y
(),
conv
->
padding_y
(),
conv
->
stride_y
()));
conv
->
set_output_z
(
deConvOutputSize
(
conv
->
img_size_z
(),
conv
->
stride_y
(),
true
));
conv
->
set_output_z
(
imageSize
(
conv
->
img_size_z
(),
conv
->
filter_size_z
(),
conv
->
padding_z
(),
conv
->
stride_z
()));
conv
->
stride_z
(),
true
));
config
.
layerConfig
.
set_size
(
conv
->
output_x
()
*
conv
->
output_y
()
*
conv
->
output_z
()
*
NUM_FILTERS
);
conv
->
set_groups
(
1
);
...
...
paddle/math/tests/test_matrixCompare.cpp
浏览文件 @
d5768ebc
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include <gtest/gtest.h>
#include "TensorCheck.h"
#include "paddle/math/MathUtils.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/SparseMatrix.h"
#include "paddle/testing/TestUtil.h"
...
...
@@ -1203,19 +1204,6 @@ TEST(Matrix, warpCTC) {
}
}
int
outputSizeCol2Vol
(
int
imageSize
,
int
filterSize
,
int
padding
,
int
stride
,
bool
caffeMode
)
{
int
outputSize
;
if
(
!
caffeMode
)
{
outputSize
=
(
imageSize
-
filterSize
+
2
*
padding
+
stride
-
1
)
/
stride
+
1
;
}
else
{
outputSize
=
(
imageSize
-
filterSize
+
2
*
padding
)
/
stride
+
1
;
}
CHECK_GE
(
outputSize
,
1
);
return
outputSize
;
}
void
testMatrixCol2Vol
(
int
depth
,
int
height
,
int
width
)
{
int
channel
=
3
;
int
filterX
=
3
,
filterY
=
4
,
filterZ
=
5
;
...
...
@@ -1229,9 +1217,9 @@ void testMatrixCol2Vol(int depth, int height, int width) {
cpuImage
->
randomizeUniform
();
gpuImage
->
copyFrom
(
*
cpuImage
);
int
outD
=
outputSize
Col2Vol
(
depth
,
filterZ
,
padZ
,
strideZ
,
true
);
int
outH
=
outputSize
Col2Vol
(
height
,
filterY
,
padZ
,
strideY
,
true
);
int
outW
=
outputSize
Col2Vol
(
width
,
filterX
,
padZ
,
strideX
,
true
);
int
outD
=
outputSize
(
depth
,
filterZ
,
padZ
,
strideZ
,
true
);
int
outH
=
outputSize
(
height
,
filterY
,
padY
,
strideY
,
true
);
int
outW
=
outputSize
(
width
,
filterX
,
padX
,
strideX
,
true
);
int
colBufHeight
=
channel
*
filterZ
*
filterY
*
filterX
;
int
colBufWidth
=
outD
*
outH
*
outW
;
...
...
@@ -1305,11 +1293,9 @@ void testMatrixCol2Vol(int depth, int height, int width) {
}
TEST
(
Matrix
,
col2Vol
)
{
for
(
auto
depth
:
{
9
,
16
,
64
,
128
})
{
for
(
auto
height
:
{
9
,
11
,
73
,
128
,
256
})
{
for
(
auto
width
:
{
9
,
32
,
100
,
512
,
})
{
for
(
auto
depth
:
{
9
,
16
,
64
})
{
for
(
auto
height
:
{
9
,
11
,
128
})
{
for
(
auto
width
:
{
9
,
32
,
128
})
{
VLOG
(
3
)
<<
"depth="
<<
depth
<<
" height="
<<
height
<<
" width="
<<
width
;
testMatrixCol2Vol
(
depth
,
height
,
width
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录