Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
452373de
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
452373de
编写于
2月 19, 2019
作者:
D
dengkaipeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
resize box in input image scale. test=develop
上级
3896d955
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
70 addition
and
29 deletion
+70
-29
paddle/fluid/operators/detection/yolo_box_op.cc
paddle/fluid/operators/detection/yolo_box_op.cc
+14
-0
paddle/fluid/operators/detection/yolo_box_op.h
paddle/fluid/operators/detection/yolo_box_op.h
+16
-7
python/paddle/fluid/layers/detection.py
python/paddle/fluid/layers/detection.py
+19
-8
python/paddle/fluid/tests/test_detection.py
python/paddle/fluid/tests/test_detection.py
+3
-1
python/paddle/fluid/tests/unittests/test_yolo_box_op.py
python/paddle/fluid/tests/unittests/test_yolo_box_op.py
+18
-13
未找到文件。
paddle/fluid/operators/detection/yolo_box_op.cc
浏览文件 @
452373de
...
@@ -23,12 +23,15 @@ class YoloBoxOp : public framework::OperatorWithKernel {
...
@@ -23,12 +23,15 @@ class YoloBoxOp : public framework::OperatorWithKernel {
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of YoloBoxOp should not be null."
);
"Input(X) of YoloBoxOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"ImgSize"
),
"Input(ImgSize) of YoloBoxOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Boxes"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Boxes"
),
"Output(Boxes) of YoloBoxOp should not be null."
);
"Output(Boxes) of YoloBoxOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Scores"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Scores"
),
"Output(Scores) of YoloBoxOp should not be null."
);
"Output(Scores) of YoloBoxOp should not be null."
);
auto
dim_x
=
ctx
->
GetInputDim
(
"X"
);
auto
dim_x
=
ctx
->
GetInputDim
(
"X"
);
auto
dim_imgsize
=
ctx
->
GetInputDim
(
"ImgSize"
);
auto
anchors
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"anchors"
);
auto
anchors
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"anchors"
);
int
anchor_num
=
anchors
.
size
()
/
2
;
int
anchor_num
=
anchors
.
size
()
/
2
;
auto
class_num
=
ctx
->
Attrs
().
Get
<
int
>
(
"class_num"
);
auto
class_num
=
ctx
->
Attrs
().
Get
<
int
>
(
"class_num"
);
...
@@ -39,6 +42,12 @@ class YoloBoxOp : public framework::OperatorWithKernel {
...
@@ -39,6 +42,12 @@ class YoloBoxOp : public framework::OperatorWithKernel {
dim_x
[
1
],
anchor_num
*
(
5
+
class_num
),
dim_x
[
1
],
anchor_num
*
(
5
+
class_num
),
"Input(X) dim[1] should be equal to (anchor_mask_number * (5 "
"Input(X) dim[1] should be equal to (anchor_mask_number * (5 "
"+ class_num))."
);
"+ class_num))."
);
PADDLE_ENFORCE_EQ
(
dim_imgsize
.
size
(),
2
,
"Input(ImgSize) should be a 2-D tensor."
);
PADDLE_ENFORCE_EQ
(
dim_imgsize
[
0
],
dim_x
[
0
],
"Input(ImgSize) dim[0] and Input(X) dim[0] should be same."
);
PADDLE_ENFORCE_EQ
(
dim_imgsize
[
1
],
2
,
"Input(ImgSize) dim[1] should be 2."
);
PADDLE_ENFORCE_GT
(
anchors
.
size
(),
0
,
PADDLE_ENFORCE_GT
(
anchors
.
size
(),
0
,
"Attr(anchors) length should be greater then 0."
);
"Attr(anchors) length should be greater then 0."
);
PADDLE_ENFORCE_EQ
(
anchors
.
size
()
%
2
,
0
,
PADDLE_ENFORCE_EQ
(
anchors
.
size
()
%
2
,
0
,
...
@@ -72,6 +81,11 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -72,6 +81,11 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker {
"box locations, confidence score and classification one-hot"
"box locations, confidence score and classification one-hot"
"keys of each anchor box. Generally, X should be the output"
"keys of each anchor box. Generally, X should be the output"
"of YOLOv3 network."
);
"of YOLOv3 network."
);
AddInput
(
"ImgSize"
,
"The image size tensor of YoloBox operator, "
"This is a 2-D tensor with shape of [N, 2]. This tensor holds"
"height and width of each input image using for resize output"
"box in input image scale."
);
AddOutput
(
"Boxes"
,
AddOutput
(
"Boxes"
,
"The output tensor of detection boxes of YoloBox operator, "
"The output tensor of detection boxes of YoloBox operator, "
"This is a 3-D tensor with shape of [N, M, 4], N is the"
"This is a 3-D tensor with shape of [N, M, 4], N is the"
...
...
paddle/fluid/operators/detection/yolo_box_op.h
浏览文件 @
452373de
...
@@ -32,12 +32,15 @@ static inline T sigmoid(T x) {
...
@@ -32,12 +32,15 @@ static inline T sigmoid(T x) {
template
<
typename
T
>
template
<
typename
T
>
static
inline
Box
<
T
>
GetYoloBox
(
const
T
*
x
,
std
::
vector
<
int
>
anchors
,
int
i
,
static
inline
Box
<
T
>
GetYoloBox
(
const
T
*
x
,
std
::
vector
<
int
>
anchors
,
int
i
,
int
j
,
int
an_idx
,
int
grid_size
,
int
j
,
int
an_idx
,
int
grid_size
,
int
input_size
,
int
index
,
int
stride
)
{
int
input_size
,
int
index
,
int
stride
,
int
img_height
,
int
img_width
)
{
Box
<
T
>
b
;
Box
<
T
>
b
;
b
.
x
=
(
i
+
sigmoid
<
T
>
(
x
[
index
]))
*
input_size
/
grid_size
;
b
.
x
=
(
i
+
sigmoid
<
T
>
(
x
[
index
]))
*
img_width
/
grid_size
;
b
.
y
=
(
j
+
sigmoid
<
T
>
(
x
[
index
+
stride
]))
*
input_size
/
grid_size
;
b
.
y
=
(
j
+
sigmoid
<
T
>
(
x
[
index
+
stride
]))
*
img_height
/
grid_size
;
b
.
w
=
std
::
exp
(
x
[
index
+
2
*
stride
])
*
anchors
[
2
*
an_idx
];
b
.
w
=
std
::
exp
(
x
[
index
+
2
*
stride
])
*
anchors
[
2
*
an_idx
]
*
img_width
/
b
.
h
=
std
::
exp
(
x
[
index
+
3
*
stride
])
*
anchors
[
2
*
an_idx
+
1
];
input_size
;
b
.
h
=
std
::
exp
(
x
[
index
+
3
*
stride
])
*
anchors
[
2
*
an_idx
+
1
]
*
img_height
/
input_size
;
return
b
;
return
b
;
}
}
...
@@ -69,6 +72,7 @@ class YoloBoxKernel : public framework::OpKernel<T> {
...
@@ -69,6 +72,7 @@ class YoloBoxKernel : public framework::OpKernel<T> {
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
imgsize
=
ctx
.
Input
<
Tensor
>
(
"ImgSize"
);
auto
*
boxes
=
ctx
.
Output
<
Tensor
>
(
"Boxes"
);
auto
*
boxes
=
ctx
.
Output
<
Tensor
>
(
"Boxes"
);
auto
*
scores
=
ctx
.
Output
<
Tensor
>
(
"Scores"
);
auto
*
scores
=
ctx
.
Output
<
Tensor
>
(
"Scores"
);
auto
anchors
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"anchors"
);
auto
anchors
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"anchors"
);
...
@@ -87,6 +91,7 @@ class YoloBoxKernel : public framework::OpKernel<T> {
...
@@ -87,6 +91,7 @@ class YoloBoxKernel : public framework::OpKernel<T> {
const
int
an_stride
=
(
class_num
+
5
)
*
stride
;
const
int
an_stride
=
(
class_num
+
5
)
*
stride
;
const
T
*
input_data
=
input
->
data
<
T
>
();
const
T
*
input_data
=
input
->
data
<
T
>
();
const
int
*
imgsize_data
=
imgsize
->
data
<
int
>
();
T
*
boxes_data
=
boxes
->
mutable_data
<
T
>
({
n
,
box_num
,
4
},
ctx
.
GetPlace
());
T
*
boxes_data
=
boxes
->
mutable_data
<
T
>
({
n
,
box_num
,
4
},
ctx
.
GetPlace
());
memset
(
boxes_data
,
0
,
boxes
->
numel
()
*
sizeof
(
T
));
memset
(
boxes_data
,
0
,
boxes
->
numel
()
*
sizeof
(
T
));
T
*
scores_data
=
T
*
scores_data
=
...
@@ -94,6 +99,9 @@ class YoloBoxKernel : public framework::OpKernel<T> {
...
@@ -94,6 +99,9 @@ class YoloBoxKernel : public framework::OpKernel<T> {
memset
(
scores_data
,
0
,
scores
->
numel
()
*
sizeof
(
T
));
memset
(
scores_data
,
0
,
scores
->
numel
()
*
sizeof
(
T
));
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
int
img_height
=
imgsize_data
[
2
*
i
];
int
img_width
=
imgsize_data
[
2
*
i
+
1
];
for
(
int
j
=
0
;
j
<
an_num
;
j
++
)
{
for
(
int
j
=
0
;
j
<
an_num
;
j
++
)
{
for
(
int
k
=
0
;
k
<
h
;
k
++
)
{
for
(
int
k
=
0
;
k
<
h
;
k
++
)
{
for
(
int
l
=
0
;
l
<
w
;
l
++
)
{
for
(
int
l
=
0
;
l
<
w
;
l
++
)
{
...
@@ -106,8 +114,9 @@ class YoloBoxKernel : public framework::OpKernel<T> {
...
@@ -106,8 +114,9 @@ class YoloBoxKernel : public framework::OpKernel<T> {
int
box_idx
=
int
box_idx
=
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
stride
,
0
);
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
stride
,
0
);
Box
<
T
>
pred
=
GetYoloBox
(
input_data
,
anchors
,
l
,
k
,
j
,
h
,
Box
<
T
>
pred
=
input_size
,
box_idx
,
stride
);
GetYoloBox
(
input_data
,
anchors
,
l
,
k
,
j
,
h
,
input_size
,
box_idx
,
stride
,
img_height
,
img_width
);
box_idx
=
(
i
*
box_num
+
j
*
stride
+
k
*
w
+
l
)
*
4
;
box_idx
=
(
i
*
box_num
+
j
*
stride
+
k
*
w
+
l
)
*
4
;
CalcDetectionBox
<
T
>
(
boxes_data
,
pred
,
box_idx
);
CalcDetectionBox
<
T
>
(
boxes_data
,
pred
,
box_idx
);
...
...
python/paddle/fluid/layers/detection.py
浏览文件 @
452373de
...
@@ -611,12 +611,19 @@ def yolov3_loss(x,
...
@@ -611,12 +611,19 @@ def yolov3_loss(x,
@
templatedoc
(
op_type
=
"yolo_box"
)
@
templatedoc
(
op_type
=
"yolo_box"
)
def
yolo_box
(
x
,
anchors
,
class_num
,
conf_thresh
,
downsample_ratio
,
name
=
None
):
def
yolo_box
(
x
,
img_size
,
anchors
,
class_num
,
conf_thresh
,
downsample_ratio
,
name
=
None
):
"""
"""
${comment}
${comment}
Args:
Args:
x (Variable): ${x_comment}
x (Variable): ${x_comment}
img_size (Variable): ${img_size_comment}
anchors (list|tuple): ${anchors_comment}
anchors (list|tuple): ${anchors_comment}
class_num (int): ${class_num_comment}
class_num (int): ${class_num_comment}
conf_thresh (float): ${conf_thresh_comment}
conf_thresh (float): ${conf_thresh_comment}
...
@@ -643,16 +650,17 @@ def yolo_box(x, anchors, class_num, conf_thresh, downsample_ratio, name=None):
...
@@ -643,16 +650,17 @@ def yolo_box(x, anchors, class_num, conf_thresh, downsample_ratio, name=None):
helper
=
LayerHelper
(
'yolo_box'
,
**
locals
())
helper
=
LayerHelper
(
'yolo_box'
,
**
locals
())
if
not
isinstance
(
x
,
Variable
):
if
not
isinstance
(
x
,
Variable
):
raise
TypeError
(
"Input x of yolov3_loss must be Variable"
)
raise
TypeError
(
"Input x of yolo_box must be Variable"
)
if
not
isinstance
(
img_size
,
Variable
):
raise
TypeError
(
"Input img_size of yolo_box must be Variable"
)
if
not
isinstance
(
anchors
,
list
)
and
not
isinstance
(
anchors
,
tuple
):
if
not
isinstance
(
anchors
,
list
)
and
not
isinstance
(
anchors
,
tuple
):
raise
TypeError
(
"Attr anchors of yolo
v3_loss
must be list or tuple"
)
raise
TypeError
(
"Attr anchors of yolo
_box
must be list or tuple"
)
if
not
isinstance
(
anchor_mask
,
list
)
and
not
isinstance
(
anchor_mask
,
tuple
):
if
not
isinstance
(
anchor_mask
,
list
)
and
not
isinstance
(
anchor_mask
,
tuple
):
raise
TypeError
(
"Attr anchor_mask of yolo
v3_loss
must be list or tuple"
)
raise
TypeError
(
"Attr anchor_mask of yolo
_box
must be list or tuple"
)
if
not
isinstance
(
class_num
,
int
):
if
not
isinstance
(
class_num
,
int
):
raise
TypeError
(
"Attr class_num of yolo
v3_loss
must be an integer"
)
raise
TypeError
(
"Attr class_num of yolo
_box
must be an integer"
)
if
not
isinstance
(
conf_thresh
,
float
):
if
not
isinstance
(
conf_thresh
,
float
):
raise
TypeError
(
raise
TypeError
(
"Attr ignore_thresh of yolo_box must be a float number"
)
"Attr ignore_thresh of yolov3_loss must be a float number"
)
boxes
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
boxes
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
scores
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
scores
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
...
@@ -666,7 +674,10 @@ def yolo_box(x, anchors, class_num, conf_thresh, downsample_ratio, name=None):
...
@@ -666,7 +674,10 @@ def yolo_box(x, anchors, class_num, conf_thresh, downsample_ratio, name=None):
helper
.
append_op
(
helper
.
append_op
(
type
=
'yolo_box'
,
type
=
'yolo_box'
,
inputs
=
{
"X"
:
x
,
},
inputs
=
{
"X"
:
x
,
"ImgSize"
:
img_size
,
},
outputs
=
{
outputs
=
{
'Boxes'
:
boxes
,
'Boxes'
:
boxes
,
'Scores'
:
scores
,
'Scores'
:
scores
,
...
...
python/paddle/fluid/tests/test_detection.py
浏览文件 @
452373de
...
@@ -484,7 +484,9 @@ class TestYoloDetection(unittest.TestCase):
...
@@ -484,7 +484,9 @@ class TestYoloDetection(unittest.TestCase):
program
=
Program
()
program
=
Program
()
with
program_guard
(
program
):
with
program_guard
(
program
):
x
=
layers
.
data
(
name
=
'x'
,
shape
=
[
30
,
7
,
7
],
dtype
=
'float32'
)
x
=
layers
.
data
(
name
=
'x'
,
shape
=
[
30
,
7
,
7
],
dtype
=
'float32'
)
boxes
,
scores
=
layers
.
yolo_box
(
x
,
[
10
,
13
,
30
,
13
],
10
,
0.01
,
32
)
img_size
=
layers
.
data
(
name
=
'x'
,
shape
=
[
2
],
dtype
=
'int32'
)
boxes
,
scores
=
layers
.
yolo_box
(
x
,
img_size
,
[
10
,
13
,
30
,
13
],
10
,
0.01
,
32
)
self
.
assertIsNotNone
(
boxes
)
self
.
assertIsNotNone
(
boxes
)
self
.
assertIsNotNone
(
scores
)
self
.
assertIsNotNone
(
scores
)
...
...
python/paddle/fluid/tests/unittests/test_yolo_box_op.py
浏览文件 @
452373de
...
@@ -25,7 +25,7 @@ def sigmoid(x):
...
@@ -25,7 +25,7 @@ def sigmoid(x):
return
1.0
/
(
1.0
+
np
.
exp
(
-
1.0
*
x
))
return
1.0
/
(
1.0
+
np
.
exp
(
-
1.0
*
x
))
def
YoloBox
(
x
,
attrs
):
def
YoloBox
(
x
,
img_size
,
attrs
):
n
,
c
,
h
,
w
=
x
.
shape
n
,
c
,
h
,
w
=
x
.
shape
anchors
=
attrs
[
'anchors'
]
anchors
=
attrs
[
'anchors'
]
an_num
=
int
(
len
(
anchors
)
//
2
)
an_num
=
int
(
len
(
anchors
)
//
2
)
...
@@ -56,15 +56,14 @@ def YoloBox(x, attrs):
...
@@ -56,15 +56,14 @@ def YoloBox(x, attrs):
pred_box
=
pred_box
*
(
pred_conf
>
0.
).
astype
(
'float32'
)
pred_box
=
pred_box
*
(
pred_conf
>
0.
).
astype
(
'float32'
)
pred_box
=
pred_box
.
reshape
((
n
,
-
1
,
4
))
pred_box
=
pred_box
.
reshape
((
n
,
-
1
,
4
))
pred_box
[:,
:,
:
pred_box
[:,
:,
:
2
],
pred_box
[:,
:,
2
:
4
]
=
\
2
],
pred_box
[:,
:,
2
:
pred_box
[:,
:,
:
2
]
-
pred_box
[:,
:,
2
:
4
]
/
2.
,
\
4
]
=
pred_box
[:,
:,
:
pred_box
[:,
:,
:
2
]
+
pred_box
[:,
:,
2
:
4
]
/
2.0
2
]
-
pred_box
[:,
:,
2
:
# pred_box = pred_box * input_size
4
]
/
2.
,
pred_box
[:,
:,
:
pred_box
[:,
:,
0
]
=
pred_box
[:,
:,
0
]
*
img_size
[:,
1
][:,
np
.
newaxis
]
2
]
+
pred_box
[:,
:,
pred_box
[:,
:,
1
]
=
pred_box
[:,
:,
1
]
*
img_size
[:,
0
][:,
np
.
newaxis
]
2
:
pred_box
[:,
:,
2
]
=
pred_box
[:,
:,
2
]
*
img_size
[:,
1
][:,
np
.
newaxis
]
4
]
/
2.0
pred_box
[:,
:,
3
]
=
pred_box
[:,
:,
3
]
*
img_size
[:,
0
][:,
np
.
newaxis
]
pred_box
=
pred_box
*
input_size
return
pred_box
,
pred_score
.
reshape
((
n
,
-
1
,
class_num
))
return
pred_box
,
pred_score
.
reshape
((
n
,
-
1
,
class_num
))
...
@@ -74,6 +73,7 @@ class TestYoloBoxOp(OpTest):
...
@@ -74,6 +73,7 @@ class TestYoloBoxOp(OpTest):
self
.
initTestCase
()
self
.
initTestCase
()
self
.
op_type
=
'yolo_box'
self
.
op_type
=
'yolo_box'
x
=
np
.
random
.
random
(
self
.
x_shape
).
astype
(
'float32'
)
x
=
np
.
random
.
random
(
self
.
x_shape
).
astype
(
'float32'
)
img_size
=
np
.
random
.
randint
(
10
,
20
,
self
.
imgsize_shape
).
astype
(
'int32'
)
self
.
attrs
=
{
self
.
attrs
=
{
"anchors"
:
self
.
anchors
,
"anchors"
:
self
.
anchors
,
...
@@ -82,8 +82,11 @@ class TestYoloBoxOp(OpTest):
...
@@ -82,8 +82,11 @@ class TestYoloBoxOp(OpTest):
"downsample"
:
self
.
downsample
,
"downsample"
:
self
.
downsample
,
}
}
self
.
inputs
=
{
'X'
:
x
,
}
self
.
inputs
=
{
boxes
,
scores
=
YoloBox
(
x
,
self
.
attrs
)
'X'
:
x
,
'ImgSize'
:
img_size
,
}
boxes
,
scores
=
YoloBox
(
x
,
img_size
,
self
.
attrs
)
self
.
outputs
=
{
self
.
outputs
=
{
"Boxes"
:
boxes
,
"Boxes"
:
boxes
,
"Scores"
:
scores
,
"Scores"
:
scores
,
...
@@ -95,10 +98,12 @@ class TestYoloBoxOp(OpTest):
...
@@ -95,10 +98,12 @@ class TestYoloBoxOp(OpTest):
def
initTestCase
(
self
):
def
initTestCase
(
self
):
self
.
anchors
=
[
10
,
13
,
16
,
30
,
33
,
23
]
self
.
anchors
=
[
10
,
13
,
16
,
30
,
33
,
23
]
an_num
=
int
(
len
(
self
.
anchors
)
//
2
)
an_num
=
int
(
len
(
self
.
anchors
)
//
2
)
self
.
batch_size
=
3
self
.
class_num
=
2
self
.
class_num
=
2
self
.
conf_thresh
=
0.5
self
.
conf_thresh
=
0.5
self
.
downsample
=
32
self
.
downsample
=
32
self
.
x_shape
=
(
3
,
an_num
*
(
5
+
self
.
class_num
),
5
,
5
)
self
.
x_shape
=
(
self
.
batch_size
,
an_num
*
(
5
+
self
.
class_num
),
5
,
5
)
self
.
imgsize_shape
=
(
self
.
batch_size
,
2
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录