Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
328169a9
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
328169a9
编写于
10月 27, 2017
作者:
X
xzl
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
im2col cpu gpu dilation support
上级
7a5b3846
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
85 addition
and
32 deletion
+85
-32
paddle/function/Im2Col.h
paddle/function/Im2Col.h
+6
-2
paddle/function/Im2ColOp.cpp
paddle/function/Im2ColOp.cpp
+24
-14
paddle/function/Im2ColOpGpu.cu
paddle/function/Im2ColOpGpu.cu
+55
-16
未找到文件。
paddle/function/Im2Col.h
浏览文件 @
328169a9
...
@@ -78,7 +78,9 @@ public:
...
@@ -78,7 +78,9 @@ public:
int
strideHeight
,
int
strideHeight
,
int
strideWidth
,
int
strideWidth
,
int
paddingHeight
,
int
paddingHeight
,
int
paddingWidth
);
int
paddingWidth
,
int
dilationHeight
=
1
,
int
dilationWidth
=
1
);
};
};
template
<
ColFormat
Format
,
DeviceType
Device
,
class
T
>
template
<
ColFormat
Format
,
DeviceType
Device
,
class
T
>
...
@@ -91,7 +93,9 @@ public:
...
@@ -91,7 +93,9 @@ public:
int
strideHeight
,
int
strideHeight
,
int
strideWidth
,
int
strideWidth
,
int
paddingHeight
,
int
paddingHeight
,
int
paddingWidth
);
int
paddingWidth
,
int
dilationHeight
=
1
,
int
dilationWidth
=
1
);
};
};
}
// namespace paddle
}
// namespace paddle
paddle/function/Im2ColOp.cpp
浏览文件 @
328169a9
...
@@ -31,7 +31,9 @@ public:
...
@@ -31,7 +31,9 @@ public:
int
strideHeight
,
int
strideHeight
,
int
strideWidth
,
int
strideWidth
,
int
paddingHeight
,
int
paddingHeight
,
int
paddingWidth
)
{
int
paddingWidth
,
int
dilationHeight
,
int
dilationWidth
)
{
int
inputChannels
=
imShape
[
0
];
int
inputChannels
=
imShape
[
0
];
int
inputHeight
=
imShape
[
1
];
int
inputHeight
=
imShape
[
1
];
int
inputWidth
=
imShape
[
2
];
int
inputWidth
=
imShape
[
2
];
...
@@ -47,8 +49,8 @@ public:
...
@@ -47,8 +49,8 @@ public:
int
c_im
=
c
/
filterWidth
/
filterHeight
;
int
c_im
=
c
/
filterWidth
/
filterHeight
;
for
(
int
h
=
0
;
h
<
outputHeight
;
++
h
)
{
for
(
int
h
=
0
;
h
<
outputHeight
;
++
h
)
{
for
(
int
w
=
0
;
w
<
outputWidth
;
++
w
)
{
for
(
int
w
=
0
;
w
<
outputWidth
;
++
w
)
{
int
imRowIdx
=
h
*
strideHeight
+
hOffset
;
int
imRowIdx
=
h
*
strideHeight
+
hOffset
*
dilationHeight
;
int
imColIdx
=
w
*
strideWidth
+
wOffset
;
int
imColIdx
=
w
*
strideWidth
+
wOffset
*
dilationWidth
;
if
((
imRowIdx
-
paddingHeight
)
<
0
||
if
((
imRowIdx
-
paddingHeight
)
<
0
||
(
imRowIdx
-
paddingHeight
)
>=
inputHeight
||
(
imRowIdx
-
paddingHeight
)
>=
inputHeight
||
(
imColIdx
-
paddingWidth
)
<
0
||
(
imColIdx
-
paddingWidth
)
<
0
||
...
@@ -81,7 +83,9 @@ public:
...
@@ -81,7 +83,9 @@ public:
int
strideHeight
,
int
strideHeight
,
int
strideWidth
,
int
strideWidth
,
int
paddingHeight
,
int
paddingHeight
,
int
paddingWidth
)
{
int
paddingWidth
,
int
dilationHeight
,
int
dilationWidth
)
{
int
inputChannels
=
imShape
[
0
];
int
inputChannels
=
imShape
[
0
];
int
inputHeight
=
imShape
[
1
];
int
inputHeight
=
imShape
[
1
];
int
inputWidth
=
imShape
[
2
];
int
inputWidth
=
imShape
[
2
];
...
@@ -97,8 +101,8 @@ public:
...
@@ -97,8 +101,8 @@ public:
int
c_im
=
c
/
filterWidth
/
filterHeight
;
int
c_im
=
c
/
filterWidth
/
filterHeight
;
for
(
int
h
=
0
;
h
<
outputHeight
;
++
h
)
{
for
(
int
h
=
0
;
h
<
outputHeight
;
++
h
)
{
for
(
int
w
=
0
;
w
<
outputWidth
;
++
w
)
{
for
(
int
w
=
0
;
w
<
outputWidth
;
++
w
)
{
int
imRowIdx
=
h
*
strideHeight
+
hOffset
;
int
imRowIdx
=
h
*
strideHeight
+
hOffset
*
dilationHeight
;
int
imColIdx
=
w
*
strideWidth
+
wOffset
;
int
imColIdx
=
w
*
strideWidth
+
wOffset
*
dilationWidth
;
if
((
imRowIdx
-
paddingHeight
)
>=
0
&&
if
((
imRowIdx
-
paddingHeight
)
>=
0
&&
(
imRowIdx
-
paddingHeight
)
<
inputHeight
&&
(
imRowIdx
-
paddingHeight
)
<
inputHeight
&&
(
imColIdx
-
paddingWidth
)
>=
0
&&
(
imColIdx
-
paddingWidth
)
>=
0
&&
...
@@ -134,7 +138,9 @@ public:
...
@@ -134,7 +138,9 @@ public:
int
strideHeight
,
int
strideHeight
,
int
strideWidth
,
int
strideWidth
,
int
paddingHeight
,
int
paddingHeight
,
int
paddingWidth
)
{
int
paddingWidth
,
int
dilationHeight
=
1
,
int
dilationWidth
=
1
)
{
int
inputChannels
=
imShape
[
0
];
int
inputChannels
=
imShape
[
0
];
int
inputHeight
=
imShape
[
1
];
int
inputHeight
=
imShape
[
1
];
int
inputWidth
=
imShape
[
2
];
int
inputWidth
=
imShape
[
2
];
...
@@ -147,9 +153,10 @@ public:
...
@@ -147,9 +153,10 @@ public:
for
(
int
channel
=
0
;
channel
<
inputChannels
;
++
channel
)
{
for
(
int
channel
=
0
;
channel
<
inputChannels
;
++
channel
)
{
for
(
int
filterH
=
0
;
filterH
<
filterHeight
;
++
filterH
)
{
for
(
int
filterH
=
0
;
filterH
<
filterHeight
;
++
filterH
)
{
for
(
int
filterW
=
0
;
filterW
<
filterWidth
;
++
filterW
)
{
for
(
int
filterW
=
0
;
filterW
<
filterWidth
;
++
filterW
)
{
int
imRowOffset
=
int
imRowOffset
=
outputH
*
strideHeight
+
outputH
*
strideHeight
+
filterH
-
paddingHeight
;
filterH
*
dilationHeight
-
paddingHeight
;
int
imColOffset
=
outputW
*
strideWidth
+
filterW
-
paddingWidth
;
int
imColOffset
=
outputW
*
strideWidth
+
filterW
*
dilationWidth
-
paddingWidth
;
int
colDataOffset
=
int
colDataOffset
=
(((
outputH
*
outputWidth
+
outputW
)
*
inputChannels
+
(((
outputH
*
outputWidth
+
outputW
)
*
inputChannels
+
channel
)
*
channel
)
*
...
@@ -189,7 +196,9 @@ public:
...
@@ -189,7 +196,9 @@ public:
int
strideHeight
,
int
strideHeight
,
int
strideWidth
,
int
strideWidth
,
int
paddingHeight
,
int
paddingHeight
,
int
paddingWidth
)
{
int
paddingWidth
,
int
dilationHeight
=
1
,
int
dilationWidth
=
1
)
{
int
inputChannels
=
imShape
[
0
];
int
inputChannels
=
imShape
[
0
];
int
inputHeight
=
imShape
[
1
];
int
inputHeight
=
imShape
[
1
];
int
inputWidth
=
imShape
[
2
];
int
inputWidth
=
imShape
[
2
];
...
@@ -202,9 +211,10 @@ public:
...
@@ -202,9 +211,10 @@ public:
for
(
int
channel
=
0
;
channel
<
inputChannels
;
++
channel
)
{
for
(
int
channel
=
0
;
channel
<
inputChannels
;
++
channel
)
{
for
(
int
filterH
=
0
;
filterH
<
filterHeight
;
++
filterH
)
{
for
(
int
filterH
=
0
;
filterH
<
filterHeight
;
++
filterH
)
{
for
(
int
filterW
=
0
;
filterW
<
filterWidth
;
++
filterW
)
{
for
(
int
filterW
=
0
;
filterW
<
filterWidth
;
++
filterW
)
{
int
imRowOffset
=
int
imRowOffset
=
outputH
*
strideHeight
+
outputH
*
strideHeight
+
filterH
-
paddingHeight
;
filterH
*
dilationHeight
-
paddingHeight
;
int
imColOffset
=
outputW
*
strideWidth
+
filterW
-
paddingWidth
;
int
imColOffset
=
outputW
*
strideWidth
+
filterW
*
dilationWidth
-
paddingWidth
;
int
colDataOffset
=
int
colDataOffset
=
(((
outputH
*
outputWidth
+
outputW
)
*
inputChannels
+
(((
outputH
*
outputWidth
+
outputW
)
*
inputChannels
+
channel
)
*
channel
)
*
...
...
paddle/function/Im2ColOpGpu.cu
浏览文件 @
328169a9
...
@@ -28,6 +28,8 @@ __global__ void im2col(const T* data_im,
...
@@ -28,6 +28,8 @@ __global__ void im2col(const T* data_im,
int
strideW
,
int
strideW
,
int
paddingH
,
int
paddingH
,
int
paddingW
,
int
paddingW
,
int
dilationH
,
int
dilationW
,
int
height_col
,
int
height_col
,
int
width_col
,
int
width_col
,
T
*
data_col
)
{
T
*
data_col
)
{
...
@@ -44,8 +46,8 @@ __global__ void im2col(const T* data_im,
...
@@ -44,8 +46,8 @@ __global__ void im2col(const T* data_im,
data_col
+=
(
channel_out
*
height_col
+
h_out
)
*
width_col
+
w_out
;
data_col
+=
(
channel_out
*
height_col
+
h_out
)
*
width_col
+
w_out
;
for
(
int
i
=
0
;
i
<
blockH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
blockH
;
++
i
)
{
for
(
int
j
=
0
;
j
<
blockW
;
++
j
)
{
for
(
int
j
=
0
;
j
<
blockW
;
++
j
)
{
int
rIdx
=
int
(
h_in
+
i
);
int
rIdx
=
int
(
h_in
+
i
*
dilationH
);
int
cIdx
=
int
(
w_in
+
j
);
int
cIdx
=
int
(
w_in
+
j
*
dilationW
);
if
((
rIdx
-
(
int
)
paddingH
)
>=
(
int
)
height
||
if
((
rIdx
-
(
int
)
paddingH
)
>=
(
int
)
height
||
(
rIdx
-
(
int
)
paddingH
)
<
0
||
(
rIdx
-
(
int
)
paddingH
)
<
0
||
(
cIdx
-
(
int
)
paddingW
)
>=
(
int
)
width
||
(
cIdx
-
(
int
)
paddingW
)
>=
(
int
)
width
||
...
@@ -77,7 +79,9 @@ public:
...
@@ -77,7 +79,9 @@ public:
int
strideHeight
,
int
strideHeight
,
int
strideWidth
,
int
strideWidth
,
int
paddingHeight
,
int
paddingHeight
,
int
paddingWidth
)
{
int
paddingWidth
,
int
dilationHeight
,
int
dilationWidth
)
{
int
inputChannels
=
imShape
[
0
];
int
inputChannels
=
imShape
[
0
];
int
inputHeight
=
imShape
[
1
];
int
inputHeight
=
imShape
[
1
];
int
inputWidth
=
imShape
[
2
];
int
inputWidth
=
imShape
[
2
];
...
@@ -102,6 +106,8 @@ public:
...
@@ -102,6 +106,8 @@ public:
strideWidth
,
strideWidth
,
paddingHeight
,
paddingHeight
,
paddingWidth
,
paddingWidth
,
dilationHeight
,
dilationWidth
,
outputHeight
,
outputHeight
,
outputWidth
,
outputWidth
,
colData
);
colData
);
...
@@ -121,6 +127,8 @@ __global__ void col2im(size_t n,
...
@@ -121,6 +127,8 @@ __global__ void col2im(size_t n,
size_t
strideW
,
size_t
strideW
,
size_t
paddingH
,
size_t
paddingH
,
size_t
paddingW
,
size_t
paddingW
,
size_t
dilationH
,
size_t
dilationW
,
size_t
height_col
,
size_t
height_col
,
size_t
width_col
,
size_t
width_col
,
T
*
data_im
)
{
T
*
data_im
)
{
...
@@ -131,23 +139,34 @@ __global__ void col2im(size_t n,
...
@@ -131,23 +139,34 @@ __global__ void col2im(size_t n,
int
w
=
int
(
index
%
width
);
int
w
=
int
(
index
%
width
);
int
h
=
int
((
index
/
width
)
%
height
);
int
h
=
int
((
index
/
width
)
%
height
);
int
c
=
int
(
index
/
(
width
*
height
));
int
c
=
int
(
index
/
(
width
*
height
));
int
filterH
=
(
blockH
-
1
)
*
dilationH
+
1
;
int
filterW
=
(
blockW
-
1
)
*
dilationW
+
1
;
if
((
w
-
(
int
)
paddingW
)
>=
0
&&
if
((
w
-
(
int
)
paddingW
)
>=
0
&&
(
w
-
(
int
)
paddingW
)
<
(
width
-
2
*
paddingW
)
&&
(
w
-
(
int
)
paddingW
)
<
(
width
-
2
*
paddingW
)
&&
(
h
-
(
int
)
paddingH
)
>=
0
&&
(
h
-
paddingH
)
<
(
height
-
2
*
paddingH
))
{
(
h
-
(
int
)
paddingH
)
>=
0
&&
(
h
-
paddingH
)
<
(
height
-
2
*
paddingH
))
{
// compute the start and end of the output
// compute the start and end of the output
int
w_col_start
=
int
w_col_start
=
(
w
<
(
int
)
blockW
)
?
0
:
(
w
-
int
(
block
W
))
/
(
int
)
strideW
+
1
;
(
w
<
(
int
)
filterW
)
?
0
:
(
w
-
int
(
filter
W
))
/
(
int
)
strideW
+
1
;
int
w_col_end
=
min
((
int
)(
w
/
(
int
)
strideW
+
1
),
(
int
)(
width_col
));
int
w_col_end
=
min
((
int
)(
w
/
(
int
)
strideW
+
1
),
(
int
)(
width_col
));
int
h_col_start
=
int
h_col_start
=
(
h
<
(
int
)
blockH
)
?
0
:
(
h
-
(
int
)
block
H
)
/
(
int
)
strideH
+
1
;
(
h
<
(
int
)
filterH
)
?
0
:
(
h
-
(
int
)
filter
H
)
/
(
int
)
strideH
+
1
;
int
h_col_end
=
min
(
int
(
h
/
strideH
+
1
),
int
(
height_col
));
int
h_col_end
=
min
(
int
(
h
/
strideH
+
1
),
int
(
height_col
));
for
(
int
h_col
=
h_col_start
;
h_col
<
h_col_end
;
++
h_col
)
{
for
(
int
h_col
=
h_col_start
;
h_col
<
h_col_end
;
++
h_col
)
{
for
(
int
w_col
=
w_col_start
;
w_col
<
w_col_end
;
++
w_col
)
{
for
(
int
w_col
=
w_col_start
;
w_col
<
w_col_end
;
++
w_col
)
{
// the col location: [c * width * height + h_out, w_out]
// the col location: [c * width * height + h_out, w_out]
int
c_col
=
int
(
c
*
blockH
*
blockW
)
+
int
h_k
=
(
h
-
h_col
*
strideH
);
(
h
-
h_col
*
(
int
)
strideH
)
*
(
int
)
blockW
+
int
w_k
=
(
w
-
w_col
*
strideW
);
(
w
-
w_col
*
(
int
)
strideW
);
if
(
h_k
%
dilationH
==
0
&&
w_k
%
dilationW
==
0
)
{
val
+=
data_col
[(
c_col
*
height_col
+
h_col
)
*
width_col
+
w_col
];
h_k
/=
dilationH
;
w_k
/=
dilationW
;
int
c_col
=
(((
c
*
blockH
+
h_k
)
*
blockW
+
w_k
)
*
height_col
+
h_col
)
*
width_col
+
w_col
;
val
+=
data_col
[
c_col
];
}
}
}
}
}
h
-=
paddingH
;
h
-=
paddingH
;
...
@@ -173,7 +192,9 @@ public:
...
@@ -173,7 +192,9 @@ public:
int
strideHeight
,
int
strideHeight
,
int
strideWidth
,
int
strideWidth
,
int
paddingHeight
,
int
paddingHeight
,
int
paddingWidth
)
{
int
paddingWidth
,
int
dilationHeight
,
int
dilationWidth
)
{
int
inputChannels
=
imShape
[
0
];
int
inputChannels
=
imShape
[
0
];
int
inputHeight
=
imShape
[
1
];
int
inputHeight
=
imShape
[
1
];
int
inputWidth
=
imShape
[
2
];
int
inputWidth
=
imShape
[
2
];
...
@@ -205,6 +226,8 @@ public:
...
@@ -205,6 +226,8 @@ public:
strideWidth
,
strideWidth
,
paddingHeight
,
paddingHeight
,
paddingWidth
,
paddingWidth
,
dilationHeight
,
dilationWidth
,
outputHeight
,
outputHeight
,
outputWidth
,
outputWidth
,
imData
);
imData
);
...
@@ -229,6 +252,8 @@ __global__ void im2colOCF(const T* imData,
...
@@ -229,6 +252,8 @@ __global__ void im2colOCF(const T* imData,
int
strideWidth
,
int
strideWidth
,
int
paddingHeight
,
int
paddingHeight
,
int
paddingWidth
,
int
paddingWidth
,
int
dilationHeight
,
int
dilationWidth
,
int
outputHeight
,
int
outputHeight
,
int
outputWidth
)
{
int
outputWidth
)
{
int
swId
=
blockIdx
.
x
;
int
swId
=
blockIdx
.
x
;
...
@@ -237,8 +262,10 @@ __global__ void im2colOCF(const T* imData,
...
@@ -237,8 +262,10 @@ __global__ void im2colOCF(const T* imData,
channelId
+=
blockDim
.
z
)
{
channelId
+=
blockDim
.
z
)
{
for
(
int
idy
=
threadIdx
.
y
;
idy
<
filterHeight
;
idy
+=
blockDim
.
y
)
{
for
(
int
idy
=
threadIdx
.
y
;
idy
<
filterHeight
;
idy
+=
blockDim
.
y
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
filterWidth
;
idx
+=
blockDim
.
x
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
filterWidth
;
idx
+=
blockDim
.
x
)
{
int
widthOffset
=
idx
+
swId
*
strideWidth
-
paddingWidth
;
int
widthOffset
=
int
heightOffset
=
idy
+
shId
*
strideHeight
-
paddingHeight
;
idx
*
dilationHeight
+
swId
*
strideWidth
-
paddingWidth
;
int
heightOffset
=
idy
*
dilationWidth
+
shId
*
strideHeight
-
paddingHeight
;
int
imOffset
=
widthOffset
+
heightOffset
*
inputWidth
+
int
imOffset
=
widthOffset
+
heightOffset
*
inputWidth
+
channelId
*
inputHeight
*
inputWidth
;
channelId
*
inputHeight
*
inputWidth
;
...
@@ -273,7 +300,9 @@ public:
...
@@ -273,7 +300,9 @@ public:
int
strideHeight
,
int
strideHeight
,
int
strideWidth
,
int
strideWidth
,
int
paddingHeight
,
int
paddingHeight
,
int
paddingWidth
)
{
int
paddingWidth
,
int
dilationHeight
,
int
dilationWidth
)
{
int
inputChannels
=
imShape
[
0
];
int
inputChannels
=
imShape
[
0
];
int
inputHeight
=
imShape
[
1
];
int
inputHeight
=
imShape
[
1
];
int
inputWidth
=
imShape
[
2
];
int
inputWidth
=
imShape
[
2
];
...
@@ -312,6 +341,8 @@ public:
...
@@ -312,6 +341,8 @@ public:
strideWidth
,
strideWidth
,
paddingHeight
,
paddingHeight
,
paddingWidth
,
paddingWidth
,
dilationHeight
,
dilationWidth
,
outputHeight
,
outputHeight
,
outputWidth
);
outputWidth
);
CHECK_SYNC
(
"Im2ColFunctor GPU failed"
);
CHECK_SYNC
(
"Im2ColFunctor GPU failed"
);
...
@@ -330,6 +361,8 @@ __global__ void col2imOCF(T* imData,
...
@@ -330,6 +361,8 @@ __global__ void col2imOCF(T* imData,
int
strideWidth
,
int
strideWidth
,
int
paddingHeight
,
int
paddingHeight
,
int
paddingWidth
,
int
paddingWidth
,
int
dilationHeight
,
int
dilationWidth
,
int
outputHeight
,
int
outputHeight
,
int
outputWidth
)
{
int
outputWidth
)
{
int
swId
=
blockIdx
.
x
;
int
swId
=
blockIdx
.
x
;
...
@@ -338,8 +371,10 @@ __global__ void col2imOCF(T* imData,
...
@@ -338,8 +371,10 @@ __global__ void col2imOCF(T* imData,
channelId
+=
blockDim
.
z
)
{
channelId
+=
blockDim
.
z
)
{
for
(
int
idy
=
threadIdx
.
y
;
idy
<
filterHeight
;
idy
+=
blockDim
.
y
)
{
for
(
int
idy
=
threadIdx
.
y
;
idy
<
filterHeight
;
idy
+=
blockDim
.
y
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
filterWidth
;
idx
+=
blockDim
.
x
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
filterWidth
;
idx
+=
blockDim
.
x
)
{
int
widthOffset
=
idx
+
swId
*
strideWidth
-
paddingWidth
;
int
widthOffset
=
int
heightOffset
=
idy
+
shId
*
strideHeight
-
paddingHeight
;
idx
*
dilationWidth
+
swId
*
strideWidth
-
paddingWidth
;
int
heightOffset
=
idy
*
dilationHeight
+
shId
*
strideHeight
-
paddingHeight
;
int
imOffset
=
widthOffset
+
heightOffset
*
inputWidth
+
int
imOffset
=
widthOffset
+
heightOffset
*
inputWidth
+
channelId
*
inputHeight
*
inputWidth
;
channelId
*
inputHeight
*
inputWidth
;
...
@@ -372,7 +407,9 @@ public:
...
@@ -372,7 +407,9 @@ public:
int
strideHeight
,
int
strideHeight
,
int
strideWidth
,
int
strideWidth
,
int
paddingHeight
,
int
paddingHeight
,
int
paddingWidth
)
{
int
paddingWidth
,
int
dilationHeight
,
int
dilationWidth
)
{
int
inputChannels
=
imShape
[
0
];
int
inputChannels
=
imShape
[
0
];
int
inputHeight
=
imShape
[
1
];
int
inputHeight
=
imShape
[
1
];
int
inputWidth
=
imShape
[
2
];
int
inputWidth
=
imShape
[
2
];
...
@@ -411,6 +448,8 @@ public:
...
@@ -411,6 +448,8 @@ public:
strideWidth
,
strideWidth
,
paddingHeight
,
paddingHeight
,
paddingWidth
,
paddingWidth
,
dilationHeight
,
dilationWidth
,
outputHeight
,
outputHeight
,
outputWidth
);
outputWidth
);
CHECK_SYNC
(
"Col2ImFunctor GPU failed"
);
CHECK_SYNC
(
"Col2ImFunctor GPU failed"
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录