Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
70d4809f
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
70d4809f
编写于
11月 29, 2019
作者:
Y
yiicy
提交者:
Xiaoyang LI
11月 29, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[cherry-pick][ARM] conv_transpose operator support padding_algorithm
上级
9ceb67bf
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
276 addition
and
512 deletion
+276
-512
lite/backends/arm/math/col_im_transform.cc
lite/backends/arm/math/col_im_transform.cc
+11
-6
lite/backends/arm/math/col_im_transform.h
lite/backends/arm/math/col_im_transform.h
+4
-2
lite/kernels/arm/CMakeLists.txt
lite/kernels/arm/CMakeLists.txt
+0
-1
lite/kernels/arm/conv_transpose_compute.cc
lite/kernels/arm/conv_transpose_compute.cc
+4
-1
lite/kernels/arm/conv_transpose_compute_test.cc
lite/kernels/arm/conv_transpose_compute_test.cc
+0
-378
lite/operators/conv_transpose_op.cc
lite/operators/conv_transpose_op.cc
+60
-8
lite/operators/conv_transpose_op.h
lite/operators/conv_transpose_op.h
+1
-0
lite/tests/kernels/conv2d_transpose_compute_test.cc
lite/tests/kernels/conv2d_transpose_compute_test.cc
+128
-57
lite/tests/math/conv_transpose_compute_test.cc
lite/tests/math/conv_transpose_compute_test.cc
+45
-46
lite/tests/utils/naive_math_impl.h
lite/tests/utils/naive_math_impl.h
+23
-13
未找到文件。
lite/backends/arm/math/col_im_transform.cc
浏览文件 @
70d4809f
...
...
@@ -32,8 +32,10 @@ void col2im<float>(const float* data_col,
const
int
width
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
pad_h0
,
const
int
pad_h1
,
const
int
pad_w0
,
const
int
pad_w1
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
...
...
@@ -41,19 +43,22 @@ void col2im<float>(const float* data_col,
float
*
data_im
)
{
memset
(
data_im
,
0
,
height
*
width
*
channels
*
sizeof
(
float
));
const
int
output_h
=
(
height
+
2
*
pad_h
-
(
dilation_h
*
(
kernel_h
-
1
)
+
1
))
/
stride_h
+
1
;
(
height
+
pad_h0
+
pad_h1
-
(
dilation_h
*
(
kernel_h
-
1
)
+
1
))
/
stride_h
+
1
;
const
int
output_w
=
(
width
+
2
*
pad_w
-
(
dilation_w
*
(
kernel_w
-
1
)
+
1
))
/
stride_w
+
1
;
(
width
+
pad_w0
+
pad_w1
-
(
dilation_w
*
(
kernel_w
-
1
)
+
1
))
/
stride_w
+
1
;
const
int
channel_size
=
height
*
width
;
for
(
int
channel
=
channels
;
channel
--
;
data_im
+=
channel_size
)
{
for
(
int
kernel_row
=
0
;
kernel_row
<
kernel_h
;
kernel_row
++
)
{
for
(
int
kernel_col
=
0
;
kernel_col
<
kernel_w
;
kernel_col
++
)
{
int
input_row
=
-
pad_h
+
kernel_row
*
dilation_h
;
int
input_row
=
-
pad_h
0
+
kernel_row
*
dilation_h
;
for
(
int
output_rows
=
output_h
;
output_rows
;
output_rows
--
)
{
if
(
!
is_a_ge_zero_and_a_lt_b
(
input_row
,
height
))
{
data_col
+=
output_w
;
}
else
{
int
input_col
=
-
pad_w
+
kernel_col
*
dilation_w
;
int
input_col
=
-
pad_w
0
+
kernel_col
*
dilation_w
;
for
(
int
output_col
=
output_w
;
output_col
;
output_col
--
)
{
if
(
is_a_ge_zero_and_a_lt_b
(
input_col
,
width
))
{
data_im
[
input_row
*
width
+
input_col
]
+=
*
data_col
;
...
...
lite/backends/arm/math/col_im_transform.h
浏览文件 @
70d4809f
...
...
@@ -26,8 +26,10 @@ void col2im(const Dtype* data_col,
const
int
width
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
pad_h0
,
const
int
pad_h1
,
const
int
pad_w0
,
const
int
pad_w1
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
...
...
lite/kernels/arm/CMakeLists.txt
浏览文件 @
70d4809f
...
...
@@ -101,7 +101,6 @@ lite_cc_test(test_dropout_compute_arm SRCS dropout_compute_test.cc DEPS dropout_
lite_cc_test
(
test_transpose_compute_arm SRCS transpose_compute_test.cc DEPS transpose_compute_arm COMPILE_LEVEL extra
)
lite_cc_test
(
test_argmax_compute_arm SRCS argmax_compute_test.cc DEPS argmax_compute_arm
)
lite_cc_test
(
test_axpy_compute_arm SRCS axpy_compute_test.cc DEPS axpy_compute_arm
)
lite_cc_test
(
test_conv_transpose_compute_arm SRCS conv_transpose_compute_test.cc DEPS conv_transpose_compute_arm
)
if
(
LITE_BUILD_EXTRA
)
lite_cc_test
(
test_layer_norm_compute_arm SRCS layer_norm_compute_test.cc DEPS layer_norm_compute_arm
)
lite_cc_test
(
test_lookup_table_compute_arm SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_arm
)
...
...
lite/kernels/arm/conv_transpose_compute.cc
浏览文件 @
70d4809f
...
...
@@ -96,7 +96,8 @@ void Conv2DTransposeCompute::Run() {
int
group_size_weights
=
((
m_roundup
*
k
+
15
)
/
16
)
*
16
;
bool
flag_1x1s1p1
=
(
kw
==
1
)
&&
(
kh
==
1
)
&&
(
param
.
strides
[
0
]
==
1
)
&&
(
param
.
strides
[
1
]
==
1
)
&&
pads_all_qual
&&
(
dilations
[
0
]
==
1
)
&&
(
dilations
[
1
]
==
1
);
(
paddings
[
0
]
==
0
)
&&
(
dilations
[
0
]
==
1
)
&&
(
dilations
[
1
]
==
1
);
ctx
.
ExtendWorkspace
(
sizeof
(
float
)
*
group
*
m
*
n
);
auto
din
=
param
.
x
->
data
<
float
>
();
...
...
@@ -138,7 +139,9 @@ void Conv2DTransposeCompute::Run() {
kh
,
kw
,
paddings
[
0
],
paddings
[
1
],
paddings
[
2
],
paddings
[
3
],
param
.
strides
[
0
],
param
.
strides
[
1
],
dilations
[
0
],
...
...
lite/kernels/arm/conv_transpose_compute_test.cc
已删除
100644 → 0
浏览文件 @
9ceb67bf
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/conv_transpose_compute.h"
#include <gtest/gtest.h>
#include <cmath>
#include <cstdlib>
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
arm
{
template
<
typename
type
,
typename
type2
>
static
void
basic_gemm
(
int
m
,
int
n
,
int
k
,
const
type
*
a
,
const
type
*
b
,
const
type2
*
bias
,
type2
*
c
,
type2
alpha
,
type2
beta
,
bool
trans_a
=
false
,
bool
trans_b
=
false
,
bool
flag_bias
=
false
,
bool
flag_relu
=
false
)
{
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
m
;
++
i
)
{
type2
bias_data
=
(
type2
)
0
;
if
(
flag_bias
)
{
bias_data
=
bias
[
i
];
}
for
(
int
j
=
0
;
j
<
n
;
++
j
)
{
type2
sum
=
static_cast
<
type2
>
(
0
);
for
(
int
l
=
0
;
l
<
k
;
++
l
)
{
type
av
;
type
bv
;
if
(
trans_a
)
{
av
=
a
[
l
*
m
+
i
];
}
else
{
av
=
a
[
i
*
k
+
l
];
}
if
(
trans_b
)
{
bv
=
b
[
j
*
k
+
l
];
}
else
{
bv
=
b
[
l
*
n
+
j
];
}
sum
+=
av
*
bv
;
}
type2
tmp
=
alpha
*
sum
+
beta
*
c
[
i
*
n
+
j
]
+
bias_data
;
if
(
flag_relu
)
{
c
[
i
*
n
+
j
]
=
tmp
>
(
type2
)
0
?
tmp
:
(
type2
)
0
;
}
else
{
c
[
i
*
n
+
j
]
=
tmp
;
}
}
}
}
//! for float, dtype1 and type2 is float
//! for int8, dytpe1 is char, dtype2 is int
template
<
typename
Dtype1
,
typename
Dtype2
>
bool
deconv_basic
(
const
Dtype1
*
din
,
Dtype2
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
const
Dtype1
*
weights
,
const
Dtype2
*
bias
,
int
group
,
int
kernel_w
,
int
kernel_h
,
int
stride_w
,
int
stride_h
,
int
dila_w
,
int
dila_h
,
int
pad_w
,
int
pad_h
,
bool
flag_bias
,
bool
flag_relu
)
{
int
m
=
chout
*
kernel_w
*
kernel_h
/
group
;
int
n
=
hin
*
win
;
int
k
=
chin
/
group
;
if
(
chin
!=
chout
||
group
!=
chin
)
{
CHECK_OR_FALSE
(
chin
%
group
==
0
);
CHECK_OR_FALSE
(
chout
%
group
==
0
);
}
lite
::
Tensor
workspace_tensor
;
std
::
vector
<
int64_t
>
wt_shape
=
{
1
,
1
,
1
,
group
*
m
*
n
};
workspace_tensor
.
Resize
(
wt_shape
);
auto
*
workspace_ptr
=
workspace_tensor
.
mutable_data
<
Dtype2
>
();
int
group_size_in
=
win
*
hin
*
chin
/
group
;
int
group_size_out
=
wout
*
hout
*
chout
/
group
;
int
group_size_coldata
=
m
*
n
;
int
group_size_weights
=
chin
*
chout
*
kernel_w
*
kernel_h
/
(
group
*
group
);
bool
flag_1x1s1p1
=
(
kernel_w
==
1
)
&&
(
kernel_h
==
1
)
&&
(
stride_h
==
1
)
&&
(
stride_w
==
1
)
&&
(
pad_w
==
1
)
&&
(
pad_h
==
1
)
&&
(
dila_w
==
1
)
&&
(
dila_h
==
1
);
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
const
Dtype1
*
din_batch
=
din
+
i
*
chin
*
hin
*
win
;
Dtype2
*
dout_batch
=
dout
+
i
*
chout
*
hout
*
wout
;
Dtype2
*
col_data
=
workspace_ptr
;
if
(
flag_1x1s1p1
)
{
col_data
=
dout_batch
;
}
memset
(
col_data
,
0
,
sizeof
(
Dtype2
)
*
group_size_coldata
);
for
(
int
g
=
0
;
g
<
group
;
++
g
)
{
const
Dtype1
*
din_group
=
din_batch
+
g
*
group_size_in
;
const
Dtype1
*
weights_group
=
weights
+
g
*
group_size_weights
;
Dtype2
*
coldata_group
=
col_data
+
g
*
group_size_coldata
;
basic_gemm
<
Dtype1
,
Dtype2
>
(
m
,
n
,
k
,
weights_group
,
din_group
,
nullptr
,
coldata_group
,
(
Dtype2
)
1
,
(
Dtype2
)
0
,
true
,
false
,
false
,
(
!
flag_bias
&&
flag_relu
));
}
if
(
!
flag_1x1s1p1
)
{
lite
::
arm
::
math
::
col2im
(
col_data
,
chout
,
hout
,
wout
,
kernel_h
,
kernel_w
,
pad_h
,
pad_w
,
stride_h
,
stride_w
,
dila_h
,
dila_w
,
dout_batch
);
}
if
(
flag_bias
)
{
lite
::
arm
::
math
::
fill_bias_relu
(
dout_batch
,
bias
,
chout
,
wout
*
hout
,
flag_bias
,
flag_relu
);
}
}
return
true
;
}
template
<
typename
Dtype1
,
typename
Dtype2
>
void
conv2d_transpose_compute_ref
(
const
operators
::
ConvParam
&
param
)
{
const
Dtype1
*
din
=
param
.
x
->
data
<
Dtype1
>
();
Dtype2
*
dout
=
param
.
output
->
mutable_data
<
Dtype2
>
();
int
num
=
param
.
x
->
dims
()[
0
];
int
chout
=
param
.
output
->
dims
()[
1
];
int
hout
=
param
.
output
->
dims
()[
2
];
int
wout
=
param
.
output
->
dims
()[
3
];
int
chin
=
param
.
x
->
dims
()[
1
];
int
hin
=
param
.
x
->
dims
()[
2
];
int
win
=
param
.
x
->
dims
()[
3
];
const
Dtype1
*
weights
=
param
.
filter
->
mutable_data
<
Dtype1
>
();
Dtype2
*
bias
=
nullptr
;
if
(
param
.
bias
!=
nullptr
)
{
bias
=
param
.
bias
->
mutable_data
<
Dtype2
>
();
}
int
group
=
param
.
groups
;
auto
paddings
=
*
param
.
paddings
;
auto
dilations
=
*
param
.
dilations
;
int
kernel_h
=
param
.
filter
->
dims
()[
2
];
int
kernel_w
=
param
.
filter
->
dims
()[
3
];
int
stride_h
=
param
.
strides
[
0
];
int
stride_w
=
param
.
strides
[
1
];
int
dila_h
=
dilations
[
0
];
int
dila_w
=
dilations
[
1
];
int
pad_h
=
paddings
[
0
];
int
pad_w
=
paddings
[
2
];
bool
flag_bias
=
(
param
.
bias
!=
nullptr
);
bool
flag_relu
=
param
.
fuse_relu
;
deconv_basic
<
float
,
float
>
(
din
,
dout
,
num
,
chout
,
hout
,
wout
,
chin
,
hin
,
win
,
weights
,
bias
,
group
,
kernel_w
,
kernel_h
,
stride_w
,
stride_h
,
dila_w
,
dila_h
,
pad_w
,
pad_h
,
flag_bias
,
flag_relu
);
}
TEST
(
conv2d_transpose_arm
,
retrive_op
)
{
auto
op
=
KernelRegistry
::
Global
().
Create
<
TARGET
(
kARM
),
PRECISION
(
kFloat
)
>
(
"conv2d_transpose"
);
ASSERT_FALSE
(
op
.
empty
());
ASSERT_TRUE
(
op
.
front
());
}
TEST
(
conv2d_transpose_arm
,
init
)
{
Conv2DTransposeCompute
compute
;
ASSERT_EQ
(
compute
.
precision
(),
PRECISION
(
kFloat
));
ASSERT_EQ
(
compute
.
target
(),
TARGET
(
kARM
));
}
TEST
(
conv2d_transpose_arm
,
compute
)
{
DeviceInfo
::
Init
();
for
(
auto
n
:
{
1
,
2
})
{
for
(
auto
ic
:
{
1
,
3
/*, 128*/
})
{
for
(
auto
oc
:
{
1
,
3
/*, 128*/
})
{
for
(
auto
ih
:
{
2
,
8
/*, 56 , 112, 224, 512*/
})
{
for
(
auto
iw
:
{
2
,
8
/*, 56, 112, 224, 512*/
})
{
for
(
auto
flag_bias
:
{
false
,
true
})
{
for
(
auto
flag_relu
:
{
false
,
true
})
{
for
(
auto
dilation
:
{
1
,
2
})
{
for
(
auto
stride
:
{
1
,
2
})
{
for
(
auto
padding
:
{
0
,
1
,
2
})
{
for
(
auto
ks
:
{
2
,
3
,
5
})
{
for
(
auto
group
:
{
1
,
2
})
{
// obtain shape
if
(
ic
%
group
!=
0
||
oc
%
group
!=
0
)
{
group
=
1
;
}
std
::
vector
<
int64_t
>
input_shape
=
{
n
,
ic
,
ih
,
iw
};
std
::
vector
<
int64_t
>
filter_shape
=
{
oc
/
group
,
ic
,
ks
,
ks
};
int
oh
=
(
ih
-
1
)
*
stride
-
2
*
padding
+
dilation
*
(
ks
-
1
)
+
1
;
int
ow
=
(
iw
-
1
)
*
stride
-
2
*
padding
+
dilation
*
(
ks
-
1
)
+
1
;
if
(
oh
<
1
||
ow
<
1
)
{
break
;
}
std
::
vector
<
int64_t
>
output_shape
=
{
n
,
oc
,
oh
,
ow
};
std
::
vector
<
int64_t
>
bias_shape
=
{
1
,
oc
,
1
,
1
};
// define and resize tensor
Tensor
input
;
Tensor
filter
;
Tensor
filter_copy
;
Tensor
bias
;
Tensor
output
;
Tensor
output_ref
;
input
.
Resize
(
input_shape
);
filter
.
Resize
(
filter_shape
);
filter_copy
.
Resize
(
filter_shape
);
output
.
Resize
(
output_shape
);
output_ref
.
Resize
(
output_shape
);
auto
*
input_data
=
input
.
mutable_data
<
float
>
();
auto
*
filter_data
=
filter
.
mutable_data
<
float
>
();
auto
*
filter_copy_data
=
filter_copy
.
mutable_data
<
float
>
();
auto
*
output_data
=
output
.
mutable_data
<
float
>
();
// initialize tensor
for
(
int
i
=
0
;
i
<
input
.
dims
().
production
();
i
++
)
{
float
sign
=
i
%
3
==
0
?
-
1.0
f
:
1.0
f
;
input_data
[
i
]
=
sign
*
static_cast
<
float
>
(
i
%
128
);
}
for
(
int
i
=
0
;
i
<
filter
.
dims
().
production
();
i
++
)
{
filter_data
[
i
]
=
i
/
static_cast
<
float
>
(
filter
.
dims
().
production
());
filter_copy_data
[
i
]
=
i
/
static_cast
<
float
>
(
filter_copy
.
dims
().
production
());
}
if
(
flag_bias
)
{
bias
.
Resize
(
bias_shape
);
auto
*
bias_data
=
bias
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
bias
.
dims
().
production
();
i
++
)
{
bias_data
[
i
]
=
static_cast
<
float
>
(
i
);
}
}
// prepare kernel params and run
std
::
unique_ptr
<
KernelContext
>
ctx
(
new
KernelContext
);
ctx
->
As
<
ARMContext
>
();
Conv2DTransposeCompute
conv2d_transpose
;
conv2d_transpose
.
SetContext
(
std
::
move
(
ctx
));
operators
::
ConvParam
param
;
param
.
x
=
&
input
;
param
.
filter
=
&
filter
;
param
.
output
=
&
output
;
param
.
bias
=
nullptr
;
if
(
flag_bias
)
{
bias
.
Resize
(
bias_shape
);
auto
*
bias_data
=
bias
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
bias
.
dims
().
production
();
i
++
)
{
bias_data
[
i
]
=
static_cast
<
float
>
(
i
);
}
param
.
bias
=
&
bias
;
}
param
.
fuse_relu
=
flag_relu
;
std
::
vector
<
int
>
paddings
=
{
padding
,
padding
,
padding
,
padding
};
param
.
strides
=
std
::
vector
<
int
>
({
stride
,
stride
});
std
::
vector
<
int
>
dilations
=
{
dilation
,
dilation
};
param
.
paddings
=
std
::
make_shared
<
std
::
vector
<
int
>>
(
paddings
);
param
.
dilations
=
std
::
make_shared
<
std
::
vector
<
int
>>
(
dilations
);
param
.
groups
=
group
;
conv2d_transpose
.
SetParam
(
param
);
conv2d_transpose
.
Launch
();
// invoking ref implementation and compare results
param
.
filter
=
&
filter_copy
;
param
.
output
=
&
output_ref
;
conv2d_transpose_compute_ref
<
float
,
float
>
(
param
);
auto
*
output_ref_data
=
output_ref
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
output
.
dims
().
production
();
i
++
)
{
EXPECT_NEAR
(
output_data
[
i
],
output_ref_data
[
i
],
1e-3
);
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
// namespace arm
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
USE_LITE_KERNEL
(
conv2d_transpose
,
kARM
,
kFloat
,
kNCHW
,
def
);
lite/operators/conv_transpose_op.cc
浏览文件 @
70d4809f
...
...
@@ -34,24 +34,73 @@ bool ConvTransposeOpLite::CheckShape() const {
CHECK_OR_FALSE
(
in_dims
.
size
()
-
param_
.
strides
.
size
()
==
2U
);
CHECK_OR_FALSE
(
in_dims
[
1
]
%
param_
.
groups
==
0
);
CHECK_EQ_OR_FALSE
(
filter_dims
.
size
(),
4UL
);
return
true
;
}
inline
int
ConvTransposeOutputSize
(
int
input_size
,
int
filter_size
,
int
dilation
,
int
pad_left
,
int
pad_right
,
int
stride
)
{
const
int
dkernel
=
dilation
*
(
filter_size
-
1
)
+
1
;
int
output_size
=
(
input_size
-
1
)
*
stride
-
pad_left
-
pad_right
+
dkernel
;
return
output_size
;
}
inline
void
UpdatePaddingAndDilation
(
std
::
vector
<
int
>*
paddings
,
std
::
vector
<
int
>*
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
string
padding_algorithm
,
const
lite
::
DDim
data_dims
,
const
lite
::
DDim
&
ksize
)
{
// when padding_desc is "VALID" or "SAME"
if
(
padding_algorithm
==
"SAME"
)
{
for
(
size_t
i
=
0
;
i
<
strides
.
size
();
++
i
)
{
int
out_size
=
(
data_dims
[
i
+
2
]
+
strides
[
i
]
-
1
)
/
strides
[
i
];
int
pad_sum
=
std
::
max
(
(
out_size
-
1
)
*
strides
[
i
]
+
ksize
[
i
+
2
]
-
data_dims
[
i
+
2
],
(
int64_t
)
0
);
int
pad_0
=
pad_sum
/
2
;
int
pad_1
=
pad_sum
-
pad_0
;
// pad
*
(
paddings
->
begin
()
+
i
*
2
)
=
pad_0
;
*
(
paddings
->
begin
()
+
i
*
2
+
1
)
=
pad_1
;
// dilation
*
(
dilations
->
begin
()
+
i
)
=
1
;
}
}
else
if
(
padding_algorithm
==
"VALID"
)
{
for
(
auto
&
it
:
*
paddings
)
{
it
=
0
;
}
}
}
bool
ConvTransposeOpLite
::
InferShape
()
const
{
const
auto
in_dims
=
param_
.
x
->
dims
();
const
auto
filter_dims
=
param_
.
filter
->
dims
();
UpdatePaddingAndDilation
(
param_
.
paddings
.
get
(),
param_
.
dilations
.
get
(),
param_
.
strides
,
padding_algorithm_
,
in_dims
,
filter_dims
);
auto
paddings
=
*
param_
.
paddings
;
auto
dilations
=
*
param_
.
dilations
;
std
::
vector
<
int64_t
>
output_shape
;
output_shape
.
push_back
(
in_dims
[
0
]);
output_shape
.
push_back
(
filter_dims
[
1
]
*
param_
.
groups
);
for
(
int
i
=
0
;
i
<
param_
.
strides
.
size
();
i
++
)
{
int
kernel_extent
=
dilations
[
i
]
*
(
filter_dims
[
i
+
2
]
-
1
)
+
1
;
int
output_len
=
(
in_dims
[
i
+
2
]
-
1
)
*
param_
.
strides
[
i
]
+
kernel_extent
-
(
paddings
[
2
*
i
]
+
paddings
[
2
*
i
+
1
]);
output_shape
.
push_back
(
output_len
);
for
(
size_t
i
=
0
;
i
<
param_
.
strides
.
size
();
++
i
)
{
output_shape
.
push_back
(
ConvTransposeOutputSize
(
in_dims
[
i
+
2
],
filter_dims
[
i
+
2
],
dilations
[
i
],
paddings
[
i
*
2
],
paddings
[
i
*
2
+
1
],
param_
.
strides
[
i
]));
}
// Set output dims
...
...
@@ -60,8 +109,8 @@ bool ConvTransposeOpLite::InferShape() const {
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
ConvTransposeOpLite
::
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
bool
ConvTransposeOpLite
::
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
auto
X
=
op_desc
.
Input
(
"Input"
).
front
();
auto
Filter
=
op_desc
.
Input
(
"Filter"
).
front
();
auto
Out
=
op_desc
.
Output
(
"Output"
).
front
();
...
...
@@ -74,6 +123,9 @@ bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc &op_desc,
param_
.
groups
=
op_desc
.
GetAttr
<
int
>
(
"groups"
);
auto
dilations
=
op_desc
.
GetAttr
<
std
::
vector
<
int
>>
(
"dilations"
);
if
(
op_desc
.
HasAttr
(
"padding_algorithm"
))
{
padding_algorithm_
=
op_desc
.
GetAttr
<
std
::
string
>
(
"padding_algorithm"
);
}
// 2-pad to 4-pad
if
(
paddings
.
size
()
==
2L
)
{
for
(
size_t
i
=
0
;
i
<
2L
;
++
i
)
{
...
...
@@ -98,7 +150,7 @@ bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc &op_desc,
auto
bias_var
=
scope
->
FindVar
(
bias_arguments
.
front
());
if
(
bias_var
!=
nullptr
)
{
param_
.
bias
=
const_cast
<
lite
::
Tensor
*>
(
&
(
bias_var
->
Get
<
lite
::
Tensor
>
()));
const_cast
<
lite
::
Tensor
*>
(
&
(
bias_var
->
Get
<
lite
::
Tensor
>
()));
}
}
}
...
...
lite/operators/conv_transpose_op.h
浏览文件 @
70d4809f
...
...
@@ -44,6 +44,7 @@ class ConvTransposeOpLite : public OpLite {
private:
mutable
ConvParam
param_
;
std
::
string
padding_algorithm_
{
""
};
};
}
// namespace operators
...
...
lite/tests/kernels/conv2d_transpose_compute_test.cc
浏览文件 @
70d4809f
...
...
@@ -31,8 +31,10 @@ void col2im(const Dtype* data_col,
const
int
width
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
pad_h0
,
const
int
pad_h1
,
const
int
pad_w0
,
const
int
pad_w1
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
...
...
@@ -40,19 +42,22 @@ void col2im(const Dtype* data_col,
Dtype
*
data_im
)
{
memset
(
data_im
,
0
,
height
*
width
*
channels
*
sizeof
(
float
));
const
int
output_h
=
(
height
+
2
*
pad_h
-
(
dilation_h
*
(
kernel_h
-
1
)
+
1
))
/
stride_h
+
1
;
(
height
+
pad_h0
+
pad_h1
-
(
dilation_h
*
(
kernel_h
-
1
)
+
1
))
/
stride_h
+
1
;
const
int
output_w
=
(
width
+
2
*
pad_w
-
(
dilation_w
*
(
kernel_w
-
1
)
+
1
))
/
stride_w
+
1
;
(
width
+
pad_w0
+
pad_w1
-
(
dilation_w
*
(
kernel_w
-
1
)
+
1
))
/
stride_w
+
1
;
const
int
channel_size
=
height
*
width
;
for
(
int
channel
=
channels
;
channel
--
;
data_im
+=
channel_size
)
{
for
(
int
kernel_row
=
0
;
kernel_row
<
kernel_h
;
kernel_row
++
)
{
for
(
int
kernel_col
=
0
;
kernel_col
<
kernel_w
;
kernel_col
++
)
{
int
input_row
=
-
pad_h
+
kernel_row
*
dilation_h
;
int
input_row
=
-
pad_h
0
+
kernel_row
*
dilation_h
;
for
(
int
output_rows
=
output_h
;
output_rows
;
output_rows
--
)
{
if
(
!
is_a_ge_zero_and_a_lt_b
(
input_row
,
height
))
{
data_col
+=
output_w
;
}
else
{
int
input_col
=
-
pad_w
+
kernel_col
*
dilation_w
;
int
input_col
=
-
pad_w
0
+
kernel_col
*
dilation_w
;
for
(
int
output_col
=
output_w
;
output_col
;
output_col
--
)
{
if
(
is_a_ge_zero_and_a_lt_b
(
input_col
,
width
))
{
data_im
[
input_row
*
width
+
input_col
]
+=
*
data_col
;
...
...
@@ -104,6 +109,34 @@ void fill_bias_relu<float>(float* tensor,
}
}
inline
void
UpdatePaddingAndDilation
(
std
::
vector
<
int
>*
paddings
,
std
::
vector
<
int
>*
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
string
padding_algorithm
,
const
DDim
data_dims
,
const
std
::
vector
<
int
>&
ksize
)
{
// when padding_desc is "VALID" or "SAME"
if
(
padding_algorithm
==
"SAME"
)
{
for
(
size_t
i
=
0
;
i
<
strides
.
size
();
++
i
)
{
int
out_size
=
(
data_dims
[
i
+
2
]
+
strides
[
i
]
-
1
)
/
strides
[
i
];
int
pad_sum
=
std
::
max
(
(
out_size
-
1
)
*
strides
[
i
]
+
ksize
[
i
+
2
]
-
data_dims
[
i
+
2
],
(
int64_t
)
0
);
int
pad_0
=
pad_sum
/
2
;
int
pad_1
=
pad_sum
-
pad_0
;
// pad
*
(
paddings
->
begin
()
+
i
*
2
)
=
pad_0
;
*
(
paddings
->
begin
()
+
i
*
2
+
1
)
=
pad_1
;
// dilation
*
(
dilations
->
begin
()
+
i
)
=
1
;
}
}
else
if
(
padding_algorithm
==
"VALID"
)
{
for
(
auto
&
it
:
*
paddings
)
{
it
=
0
;
}
}
}
template
<
typename
type
,
typename
type2
>
static
void
basic_gemm
(
int
m
,
int
n
,
...
...
@@ -172,8 +205,10 @@ bool deconv_basic(const Dtype1* din,
int
stride_h
,
int
dila_w
,
int
dila_h
,
int
pad_w
,
int
pad_h
,
int
pad_w0
,
int
pad_w1
,
int
pad_h0
,
int
pad_h1
,
bool
flag_bias
,
bool
flag_relu
)
{
int
m
=
chout
*
kernel_w
*
kernel_h
/
group
;
...
...
@@ -193,8 +228,9 @@ bool deconv_basic(const Dtype1* din,
int
group_size_coldata
=
m
*
n
;
int
group_size_weights
=
chin
*
chout
*
kernel_w
*
kernel_h
/
(
group
*
group
);
bool
flag_1x1s1p1
=
(
kernel_w
==
1
)
&&
(
kernel_h
==
1
)
&&
(
stride_h
==
1
)
&&
(
stride_w
==
1
)
&&
(
pad_w
==
1
)
&&
(
pad_h
==
1
)
&&
(
dila_w
==
1
)
&&
(
dila_h
==
1
);
(
stride_w
==
1
)
&&
(
pad_w0
==
0
)
&&
(
pad_h0
==
0
)
&&
(
pad_w1
==
0
)
&&
(
pad_h1
==
0
)
&&
(
dila_w
==
1
)
&&
(
dila_h
==
1
);
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
const
Dtype1
*
din_batch
=
din
+
i
*
chin
*
hin
*
win
;
...
...
@@ -204,7 +240,7 @@ bool deconv_basic(const Dtype1* din,
if
(
flag_1x1s1p1
)
{
col_data
=
dout_batch
;
}
memset
(
col_data
,
0
,
sizeof
(
Dtype2
)
*
group_size_coldata
);
memset
(
col_data
,
0
,
sizeof
(
Dtype2
)
*
group_size_coldata
*
group
);
for
(
int
g
=
0
;
g
<
group
;
++
g
)
{
const
Dtype1
*
din_group
=
din_batch
+
g
*
group_size_in
;
const
Dtype1
*
weights_group
=
weights
+
g
*
group_size_weights
;
...
...
@@ -230,8 +266,10 @@ bool deconv_basic(const Dtype1* din,
wout
,
kernel_h
,
kernel_w
,
pad_h
,
pad_w
,
pad_h0
,
pad_h1
,
pad_w0
,
pad_w1
,
stride_h
,
stride_w
,
dila_h
,
...
...
@@ -253,9 +291,10 @@ class Conv2DTransposeComputeTester : public arena::TestCase {
std
::
string
output_
=
"out"
;
std
::
string
filter_
=
"filter"
;
std
::
string
bias_
=
"bias"
;
std
::
string
padding_algorithm_
=
""
;
std
::
vector
<
int
>
strides_
{
1
,
1
};
std
::
vector
<
int
>
paddings_
{
0
,
0
};
std
::
vector
<
int
>
paddings_
{
0
,
0
,
0
,
0
};
int
groups_
{
1
};
std
::
vector
<
int
>
dilations_
{
1
,
1
};
bool
flag_relu_
{
false
};
...
...
@@ -280,9 +319,13 @@ class Conv2DTransposeComputeTester : public arena::TestCase {
bool
flag_relu
,
int
dilation
,
int
stride
,
int
padding
,
int
pad_h0
,
int
pad_h1
,
int
pad_w0
,
int
pad_w1
,
int
ks
,
int
groups
)
int
groups
,
std
::
string
padding_algorithm
)
:
TestCase
(
place
,
alias
)
{
n_
=
n
;
ic_
=
ic
;
...
...
@@ -291,20 +334,29 @@ class Conv2DTransposeComputeTester : public arena::TestCase {
iw_
=
iw
;
ks_
=
ks
;
flag_bias_
=
flag_bias
;
padding_algorithm_
=
padding_algorithm
;
strides_
=
std
::
vector
<
int
>
({
stride
,
stride
});
paddings_
=
std
::
vector
<
int
>
({
padding
,
padding
});
groups_
=
groups
;
paddings_
=
std
::
vector
<
int
>
({
pad_h0
,
pad_h1
,
pad_w0
,
pad_w1
});
dilations_
=
std
::
vector
<
int
>
({
dilation
,
dilation
});
groups_
=
groups
;
flag_relu_
=
flag_relu
;
}
void
RunBaseline
(
Scope
*
scope
)
override
{
auto
*
out
=
scope
->
NewTensor
(
output_
);
CHECK
(
out
);
int
oh
=
(
ih_
-
1
)
*
strides_
[
0
]
-
2
*
paddings_
[
0
]
+
auto
*
x
=
scope
->
FindTensor
(
x_
);
auto
input_dim
=
x
->
dims
();
std
::
vector
<
int
>
ksize
({
1
,
1
,
ks_
,
ks_
});
UpdatePaddingAndDilation
(
&
paddings_
,
&
dilations_
,
strides_
,
padding_algorithm_
,
input_dim
,
ksize
);
int
oh
=
(
ih_
-
1
)
*
strides_
[
0
]
-
paddings_
[
0
]
-
paddings_
[
1
]
+
dilations_
[
0
]
*
(
ks_
-
1
)
+
1
;
int
ow
=
(
iw_
-
1
)
*
strides_
[
1
]
-
2
*
paddings_
[
1
]
+
int
ow
=
(
iw_
-
1
)
*
strides_
[
1
]
-
paddings_
[
2
]
-
paddings_
[
3
]
+
dilations_
[
1
]
*
(
ks_
-
1
)
+
1
;
CHECK
(
oh
>
0
||
ow
>
0
);
...
...
@@ -313,7 +365,6 @@ class Conv2DTransposeComputeTester : public arena::TestCase {
out
->
Resize
(
output_dims
);
auto
*
output_data
=
out
->
mutable_data
<
float
>
();
auto
*
x
=
scope
->
FindTensor
(
x_
);
const
auto
*
x_data
=
x
->
data
<
float
>
();
auto
*
filter
=
scope
->
FindTensor
(
filter_
);
const
auto
*
filter_data
=
filter
->
data
<
float
>
();
...
...
@@ -341,8 +392,10 @@ class Conv2DTransposeComputeTester : public arena::TestCase {
strides_
[
0
],
dilations_
[
1
],
dilations_
[
0
],
paddings_
[
1
],
paddings_
[
2
],
paddings_
[
3
],
paddings_
[
0
],
paddings_
[
1
],
flag_bias_
,
flag_relu_
);
}
...
...
@@ -360,6 +413,7 @@ class Conv2DTransposeComputeTester : public arena::TestCase {
op_desc
->
SetInput
(
"Bias"
,
{
bias_
});
}
op_desc
->
SetAttr
(
"fuse_relu"
,
flag_relu_
);
op_desc
->
SetAttr
(
"padding_algorithm"
,
padding_algorithm_
);
}
void
PrepareData
()
override
{
...
...
@@ -402,34 +456,42 @@ TEST(conv2d_transpose, precision) {
LOG
(
INFO
)
<<
"test conv2d_transpose op"
;
#ifdef LITE_WITH_ARM
Place
place
(
TARGET
(
kARM
));
for
(
auto
n
:
{
1
,
2
})
{
for
(
auto
n
:
{
2
})
{
for
(
auto
ic
:
{
1
,
4
/*, 128*/
})
{
for
(
auto
oc
:
{
1
,
4
/*, 128*/
})
{
LOG
(
INFO
)
<<
"n:"
<<
n
<<
",ic:"
<<
ic
<<
",oc:"
<<
oc
;
for
(
auto
ih
:
{
8
,
16
/*, 56 , 112, 224, 512*/
})
{
for
(
auto
ih
:
{
8
,
8
/*, 56 , 112, 224, 512*/
})
{
for
(
auto
iw
:
{
8
,
16
/*, 56, 112, 224, 512*/
})
{
for
(
auto
flag_bias
:
{
false
,
true
})
{
for
(
auto
flag_relu
:
{
false
,
true
})
{
for
(
auto
dilation
:
{
1
,
2
})
{
for
(
auto
stride
:
{
1
,
2
})
{
for
(
auto
padding
:
{
0
,
2
})
{
for
(
auto
ks
:
{
2
,
5
})
{
for
(
auto
pad_h0
:
{
0
,
1
})
{
for
(
auto
pad_h1
:
{
0
,
1
})
{
for
(
auto
pad_w0
:
{
0
,
1
})
{
for
(
auto
pad_w1
:
{
0
,
1
})
{
for
(
auto
ks
:
{
1
,
4
})
{
for
(
auto
group
:
{
1
,
2
})
{
for
(
auto
padding_algorithm
:
{
""
,
"SAME"
,
"VALID"
})
{
// obtain shape
// LOG(INFO) << "n:" << n << ",ic:" << ic << ",oc:" <<
// LOG(INFO) << "n:" << n << ",ic:" << ic <<
// ",oc:" <<
// oc
// << ",ih:" << ih << ",iw:" << iw
// << ",flag_bias:" << flag_bias
// << ",flag_relu:" << flag_relu
// << ",dila:" << dilation
// << ",stride:" << stride
// << ",padding:" << padding << ",ks:" << ks
// << ",padding:" << padding <<
// ",ks:" << ks
// << ",group:" << group;
if
(
ic
%
group
!=
0
||
oc
%
group
!=
0
)
{
group
=
1
;
}
std
::
unique_ptr
<
arena
::
TestCase
>
tester
(
new
Conv2DTransposeComputeTester
(
place
,
new
Conv2DTransposeComputeTester
(
place
,
"def"
,
n
,
ic
,
...
...
@@ -440,10 +502,15 @@ TEST(conv2d_transpose, precision) {
flag_relu
,
dilation
,
stride
,
padding
,
pad_h0
,
pad_h1
,
pad_w0
,
pad_w1
,
ks
,
group
));
arena
::
Arena
arena
(
std
::
move
(
tester
),
place
,
2e-5
);
group
,
padding_algorithm
));
arena
::
Arena
arena
(
std
::
move
(
tester
),
place
,
2e-5
);
arena
.
TestPrecision
();
}
}
...
...
@@ -457,6 +524,10 @@ TEST(conv2d_transpose, precision) {
}
}
}
}
}
}
}
#endif
}
...
...
lite/tests/math/conv_transpose_compute_test.cc
浏览文件 @
70d4809f
...
...
@@ -111,11 +111,11 @@ void test_conv_transpose_fp32(const std::vector<DDim>& input_dims,
param
.
output
=
new
Tensor
;
param
.
output
->
set_precision
(
PRECISION
(
kFloat
));
//
paddle::lite::fill_tensor_rand(*param.filter, -1.f, 1.f);
paddle
::
lite
::
fill_tensor_const
(
*
param
.
filter
,
1.
f
);
paddle
::
lite
::
fill_tensor_rand
(
*
param
.
filter
,
-
1.
f
,
1.
f
);
//
paddle::lite::fill_tensor_const(*param.filter, 1.f);
if
(
flag_bias
)
{
//
paddle::lite::fill_tensor_rand(*param.bias, -1.f, 1.f);
paddle
::
lite
::
fill_tensor_const
(
*
param
.
bias
,
1.
f
);
paddle
::
lite
::
fill_tensor_rand
(
*
param
.
bias
,
-
1.
f
,
1.
f
);
//
paddle::lite::fill_tensor_const(*param.bias, 1.f);
}
Tensor
tmp_weights
;
tmp_weights
.
Resize
(
weight_dim
);
...
...
@@ -130,21 +130,8 @@ void test_conv_transpose_fp32(const std::vector<DDim>& input_dims,
new
paddle
::
lite
::
KernelContext
);
auto
&
ctx
=
ctx1
->
As
<
paddle
::
lite
::
ARMContext
>
();
ctx
.
SetRunMode
(
static_cast
<
paddle
::
lite_api
::
PowerMode
>
(
cls
),
th
);
/// set param and context
for
(
auto
&
dim_in
:
input_dims
)
{
param
.
x
->
Resize
(
dim_in
);
DDim
out_tmp_dims
=
compute_out_dim
(
dim_in
,
param
);
if
(
out_tmp_dims
[
2
]
<
1
||
out_tmp_dims
[
3
]
<
1
)
{
continue
;
}
param
.
output
->
Resize
(
out_tmp_dims
);
break
;
}
conv_t
.
SetParam
(
param
);
conv_t
.
SetContext
(
std
::
move
(
ctx1
));
/// prepare for run
conv_t
.
PrepareForRun
();
for
(
auto
&
dim_in
:
input_dims
)
{
CHECK_EQ
(
weight_dim
[
0
],
dim_in
[
1
])
<<
"input channel must equal to weights channel"
;
...
...
@@ -154,9 +141,11 @@ void test_conv_transpose_fp32(const std::vector<DDim>& input_dims,
}
param
.
x
->
Resize
(
dim_in
);
param
.
output
->
Resize
(
dim_out
);
// paddle::lite::fill_tensor_rand(*param.x, -1.f, 1.f);
paddle
::
lite
::
fill_tensor_const
(
*
param
.
x
,
1.
f
);
param
.
filter
->
CopyDataFrom
(
tmp_weights
);
// prepare for run
conv_t
.
PrepareForRun
();
paddle
::
lite
::
fill_tensor_rand
(
*
param
.
x
,
-
1.
f
,
1.
f
);
// paddle::lite::fill_tensor_const(*param.x, 1.f);
auto
din
=
param
.
x
->
data
<
float
>
();
Tensor
tout_basic
;
...
...
@@ -185,7 +174,9 @@ void test_conv_transpose_fp32(const std::vector<DDim>& input_dims,
dilas
[
1
],
dilas
[
0
],
pads
[
2
],
pads
[
3
],
pads
[
0
],
pads
[
1
],
flag_bias
,
flag_relu
);
}
...
...
@@ -230,7 +221,8 @@ void test_conv_transpose_fp32(const std::vector<DDim>& input_dims,
LOG
(
FATAL
)
<<
"test fp32 conv: input: "
<<
dim_in
<<
", output: "
<<
dim_out
<<
", weight dim: "
<<
weight_dim
<<
", pad: "
<<
pads
[
0
]
<<
", "
<<
pads
[
1
]
<<
", pad: "
<<
pads
[
0
]
<<
", "
<<
pads
[
1
]
<<
", "
<<
pads
[
2
]
<<
", "
<<
pads
[
3
]
<<
", stride: "
<<
strides
[
0
]
<<
", "
<<
strides
[
1
]
<<
", dila_: "
<<
dilas
[
0
]
<<
", "
<<
dilas
[
1
]
<<
", bias: "
<<
(
flag_bias
?
"true"
:
"false"
)
...
...
@@ -242,9 +234,9 @@ void test_conv_transpose_fp32(const std::vector<DDim>& input_dims,
}
LOG
(
INFO
)
<<
"test fp32 conv: input: "
<<
dim_in
<<
", output: "
<<
dim_out
<<
", weight dim: "
<<
weight_dim
<<
", pad: "
<<
pads
[
0
]
<<
", "
<<
pads
[
1
]
<<
",
stride: "
<<
strides
[
0
]
<<
", "
<<
strides
[
1
]
<<
", dila_: "
<<
dilas
[
0
]
<<
", "
<<
dilas
[
1
]
<<
", pad: "
<<
pads
[
0
]
<<
", "
<<
pads
[
1
]
<<
", "
<<
pads
[
2
]
<<
",
"
<<
pads
[
3
]
<<
", stride: "
<<
strides
[
0
]
<<
", "
<<
strides
[
1
]
<<
", dila_: "
<<
dilas
[
0
]
<<
", "
<<
dilas
[
1
]
<<
", bias: "
<<
(
flag_bias
?
"true"
:
"false"
)
<<
", relu: "
<<
(
flag_relu
?
"true"
:
"false"
)
<<
", threads: "
<<
th
<<
", power_mode: "
<<
cls
...
...
@@ -280,7 +272,10 @@ TEST(TestConvRand, test_conv_transpose_rand) {
for
(
auto
&
kw
:
{
1
,
2
,
3
})
{
for
(
auto
&
kh
:
{
1
,
2
,
3
})
{
for
(
auto
&
stride
:
{
1
,
2
})
{
for
(
auto
&
pad
:
{
0
,
1
,
2
})
{
for
(
auto
&
pad_h0
:
{
0
,
1
,
2
})
{
for
(
auto
&
pad_h1
:
{
0
,
1
,
2
})
{
for
(
auto
&
pad_w0
:
{
0
,
1
,
2
})
{
for
(
auto
&
pad_w1
:
{
0
,
1
,
2
})
{
for
(
auto
&
dila
:
{
1
,
2
})
{
for
(
auto
&
flag_bias
:
{
false
,
true
})
{
for
(
auto
&
flag_relu
:
{
false
,
true
})
{
...
...
@@ -294,15 +289,16 @@ TEST(TestConvRand, test_conv_transpose_rand) {
dims
.
push_back
(
DDim
({
batch
,
cin
,
h
,
h
}));
}
}
test_conv_transpose_fp32
(
dims
,
test_conv_transpose_fp32
(
dims
,
weights_dim
,
g
,
{
stride
,
stride
},
{
pad
,
pad
,
pad
,
pad
},
{
pad_h0
,
pad_h1
,
pad_w0
,
pad_w1
},
{
dila
,
dila
},
flag_bias
,
flag_relu
,
{
1
,
2
,
4
},
{
1
,
4
},
{
FLAGS_power_mode
});
}
}
...
...
@@ -315,6 +311,9 @@ TEST(TestConvRand, test_conv_transpose_rand) {
}
}
}
}
}
}
}
#endif /// random param conv
...
...
lite/tests/utils/naive_math_impl.h
浏览文件 @
70d4809f
...
...
@@ -330,8 +330,10 @@ static void col2im(const Dtype* data_col,
const
int
width
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
pad_h0
,
const
int
pad_h1
,
const
int
pad_w0
,
const
int
pad_w1
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
...
...
@@ -339,21 +341,24 @@ static void col2im(const Dtype* data_col,
Dtype
*
data_im
)
{
memset
(
data_im
,
0
,
height
*
width
*
channels
*
sizeof
(
Dtype
));
const
int
output_h
=
(
height
+
2
*
pad_h
-
(
dilation_h
*
(
kernel_h
-
1
)
+
1
))
/
stride_h
+
1
;
(
height
+
pad_h0
+
pad_h1
-
(
dilation_h
*
(
kernel_h
-
1
)
+
1
))
/
stride_h
+
1
;
const
int
output_w
=
(
width
+
2
*
pad_w
-
(
dilation_w
*
(
kernel_w
-
1
)
+
1
))
/
stride_w
+
1
;
(
width
+
pad_w0
+
pad_w1
-
(
dilation_w
*
(
kernel_w
-
1
)
+
1
))
/
stride_w
+
1
;
const
int
channel_size
=
height
*
width
;
for
(
int
channel
=
channels
;
channel
--
;
data_im
+=
channel_size
)
{
for
(
int
kernel_row
=
0
;
kernel_row
<
kernel_h
;
kernel_row
++
)
{
for
(
int
kernel_col
=
0
;
kernel_col
<
kernel_w
;
kernel_col
++
)
{
int
input_row
=
-
pad_h
+
kernel_row
*
dilation_h
;
int
input_row
=
-
pad_h
0
+
kernel_row
*
dilation_h
;
for
(
int
output_rows
=
output_h
;
output_rows
;
output_rows
--
)
{
if
(
!
is_a_ge_zero_and_a_lt_b
(
input_row
,
height
))
{
data_col
+=
output_w
;
}
else
{
int
input_col
=
-
pad_w
+
kernel_col
*
dilation_w
;
int
input_col
=
-
pad_w
0
+
kernel_col
*
dilation_w
;
for
(
int
output_col
=
output_w
;
output_col
;
output_col
--
)
{
if
(
is_a_ge_zero_and_a_lt_b
(
input_col
,
width
))
{
...
...
@@ -391,8 +396,10 @@ void deconv_basic(const Dtype1* din,
int
stride_h
,
int
dila_w
,
int
dila_h
,
int
pad_w
,
int
pad_h
,
int
pad_w0
,
int
pad_w1
,
int
pad_h0
,
int
pad_h1
,
bool
flag_bias
,
bool
flag_relu
)
{
int
m
=
chout
*
kernel_w
*
kernel_h
/
group
;
...
...
@@ -404,8 +411,9 @@ void deconv_basic(const Dtype1* din,
int
group_size_coldata
=
m
*
n
;
int
group_size_weights
=
chin
*
chout
*
kernel_w
*
kernel_h
/
(
group
*
group
);
bool
flag_1x1s1p1
=
(
kernel_w
==
1
)
&&
(
kernel_h
==
1
)
&&
(
stride_h
==
1
)
&&
(
stride_w
==
1
)
&&
(
pad_w
==
1
)
&&
(
pad_h
==
1
)
&&
(
dila_w
==
1
)
&&
(
dila_h
==
1
);
(
stride_w
==
1
)
&&
(
pad_w0
==
0
)
&&
(
pad_h0
==
0
)
&&
(
pad_w1
==
0
)
&&
(
pad_h1
==
0
)
&&
(
dila_w
==
1
)
&&
(
dila_h
==
1
);
Dtype2
*
workspace_ptr
=
static_cast
<
Dtype2
*>
(
malloc
(
sizeof
(
float
)
*
m
*
n
*
group
));
...
...
@@ -418,7 +426,7 @@ void deconv_basic(const Dtype1* din,
if
(
flag_1x1s1p1
)
{
col_data
=
dout_batch
;
}
memset
(
col_data
,
0
,
sizeof
(
Dtype2
)
*
group_size_coldata
);
memset
(
col_data
,
0
,
sizeof
(
Dtype2
)
*
group_size_coldata
*
group
);
for
(
int
g
=
0
;
g
<
group
;
++
g
)
{
const
Dtype1
*
din_group
=
din_batch
+
g
*
group_size_in
;
const
Dtype1
*
weights_group
=
weights
+
g
*
group_size_weights
;
...
...
@@ -448,8 +456,10 @@ void deconv_basic(const Dtype1* din,
wout
,
kernel_h
,
kernel_w
,
pad_h
,
pad_w
,
pad_h0
,
pad_h1
,
pad_w0
,
pad_w1
,
stride_h
,
stride_w
,
dila_h
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录