Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
4d1187d5
M
models
项目概览
PaddlePaddle
/
models
大约 2 年 前同步成功
通知
232
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4d1187d5
编写于
9月 16, 2020
作者:
H
huangjun12
提交者:
GitHub
9月 16, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update tsn Reader using dataloader and pipline (#4856)
上级
bde994e1
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
544 addition
and
52 deletion
+544
-52
dygraph/tsn/augmentations.py
dygraph/tsn/augmentations.py
+209
-0
dygraph/tsn/compose.py
dygraph/tsn/compose.py
+125
-0
dygraph/tsn/data/dataset/ucf101/build_ucf101_file_list.py
dygraph/tsn/data/dataset/ucf101/build_ucf101_file_list.py
+4
-3
dygraph/tsn/loader.py
dygraph/tsn/loader.py
+149
-0
dygraph/tsn/multi_tsn_frame.yaml
dygraph/tsn/multi_tsn_frame.yaml
+4
-6
dygraph/tsn/multi_tsn_video.yaml
dygraph/tsn/multi_tsn_video.yaml
+6
-8
dygraph/tsn/single_tsn_frame.yaml
dygraph/tsn/single_tsn_frame.yaml
+5
-7
dygraph/tsn/single_tsn_video.yaml
dygraph/tsn/single_tsn_video.yaml
+5
-7
dygraph/tsn/train.py
dygraph/tsn/train.py
+37
-21
未找到文件。
dygraph/tsn/augmentations.py
0 → 100644
浏览文件 @
4d1187d5
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
random
import
numpy
as
np
from
PIL
import
Image
class
Scale
(
object
):
"""
Scale images.
Args:
short_size(float | int): Short size of an image will be scaled to the short_size.
"""
def
__init__
(
self
,
short_size
):
self
.
short_size
=
short_size
def
__call__
(
self
,
imgs
):
"""
Performs resize operations.
Args:
imgs: List where each item is a PIL.Image.
For example, [PIL.Image0, PIL.Image1, PIL.Image2, ...]
return:
resized_imgs: List where each item is a PIL.Image after scaling.
"""
resized_imgs
=
[]
for
i
in
range
(
len
(
imgs
)):
img
=
imgs
[
i
]
w
,
h
=
img
.
size
if
(
w
<=
h
and
w
==
self
.
short_size
)
or
(
h
<=
w
and
h
==
self
.
short_size
):
resized_imgs
.
append
(
img
)
continue
if
w
<
h
:
ow
=
self
.
short_size
oh
=
int
(
self
.
short_size
*
4.0
/
3.0
)
resized_imgs
.
append
(
img
.
resize
((
ow
,
oh
),
Image
.
BILINEAR
))
else
:
oh
=
self
.
short_size
ow
=
int
(
self
.
short_size
*
4.0
/
3.0
)
resized_imgs
.
append
(
img
.
resize
((
ow
,
oh
),
Image
.
BILINEAR
))
return
resized_imgs
class
RandomCrop
(
object
):
"""
Random crop images.
Args:
target_size(int): Random crop a square with the target_size from an image.
"""
def
__init__
(
self
,
target_size
):
self
.
target_size
=
target_size
def
__call__
(
self
,
imgs
):
"""
Performs random crop operations.
Args:
imgs: List where each item is a PIL.Image.
For example, [PIL.Image0, PIL.Image1, PIL.Image2, ...]
return:
crop_imgs: List where each item is a PIL.Image after random crop.
"""
w
,
h
=
imgs
[
0
].
size
th
,
tw
=
self
.
target_size
,
self
.
target_size
assert
(
w
>=
self
.
target_size
)
and
(
h
>=
self
.
target_size
),
\
"image width({}) and height({}) should be larger than crop size"
.
format
(
w
,
h
,
self
.
target_size
)
crop_images
=
[]
x1
=
random
.
randint
(
0
,
w
-
tw
)
y1
=
random
.
randint
(
0
,
h
-
th
)
for
img
in
imgs
:
if
w
==
tw
and
h
==
th
:
crop_images
.
append
(
img
)
else
:
crop_images
.
append
(
img
.
crop
((
x1
,
y1
,
x1
+
tw
,
y1
+
th
)))
return
crop_images
class
CenterCrop
(
object
):
"""
Center crop images.
Args:
target_size(int): Center crop a square with the target_size from an image.
"""
def
__init__
(
self
,
target_size
):
self
.
target_size
=
target_size
def
__call__
(
self
,
imgs
):
"""
Performs Center crop operations.
Args:
imgs: List where each item is a PIL.Image.
For example, [PIL.Image0, PIL.Image1, PIL.Image2, ...]
return:
ccrop_imgs: List where each item is a PIL.Image after Center crop.
"""
ccrop_imgs
=
[]
for
img
in
imgs
:
w
,
h
=
img
.
size
th
,
tw
=
self
.
target_size
,
self
.
target_size
assert
(
w
>=
self
.
target_size
)
and
(
h
>=
self
.
target_size
),
\
"image width({}) and height({}) should be larger than crop size"
.
format
(
w
,
h
,
self
.
target_size
)
x1
=
int
(
round
((
w
-
tw
)
/
2.
))
y1
=
int
(
round
((
h
-
th
)
/
2.
))
ccrop_imgs
.
append
(
img
.
crop
((
x1
,
y1
,
x1
+
tw
,
y1
+
th
)))
return
ccrop_imgs
class
RandomFlip
(
object
):
"""
Random Flip images.
Args:
p(float): Random flip images with the probability p.
"""
def
__init__
(
self
,
p
=
0.5
):
self
.
p
=
p
def
__call__
(
self
,
imgs
):
"""
Performs random flip operations.
Args:
imgs: List where each item is a PIL.Image.
For example, [PIL.Image0, PIL.Image1, PIL.Image2, ...]
return:
flip_imgs: List where each item is a PIL.Image after random flip.
"""
v
=
random
.
random
()
if
v
<
self
.
p
:
flip_imgs
=
[
img
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
for
img
in
imgs
]
return
flip_imgs
else
:
return
imgs
class
Image2Array
(
object
):
"""
transfer PIL.Image to Numpy array and transpose dimensions from 'dhwc' to 'dchw'.
"""
def
__init__
(
self
):
self
.
format
=
"dhwc"
def
__call__
(
self
,
imgs
):
"""
Performs Image to NumpyArray operations.
Args:
imgs: List where each item is a PIL.Image.
For example, [PIL.Image0, PIL.Image1, PIL.Image2, ...]
return:
np_imgs: Numpy array.
"""
np_imgs
=
np
.
array
(
[
np
.
array
(
img
).
astype
(
'float32'
)
for
img
in
imgs
])
#dhwc
np_imgs
=
np_imgs
.
transpose
(
0
,
3
,
1
,
2
)
#dchw
return
np_imgs
class
Normalization
(
object
):
"""
Normalization.
Args:
mean(list[float]): mean values of different channels.
std(list[float]): std values of differetn channels.
"""
def
__init__
(
self
,
mean
,
std
):
self
.
mean
=
mean
self
.
std
=
std
def
__call__
(
self
,
imgs
):
"""
Performs normalization operations.
Args:
imgs: Numpy array.
return:
np_imgs: Numpy array after normalization.
"""
norm_imgs
=
imgs
/
255
norm_imgs
-=
self
.
mean
norm_imgs
/=
self
.
std
return
norm_imgs
dygraph/tsn/compose.py
0 → 100644
浏览文件 @
4d1187d5
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
random
import
numpy
as
np
import
logging
from
paddle.io
import
Dataset
from
augmentations
import
*
from
loader
import
*
logger
=
logging
.
getLogger
(
__name__
)
class
TSN_UCF101_Dataset
(
Dataset
):
def
__init__
(
self
,
cfg
,
mode
):
self
.
mode
=
mode
self
.
format
=
cfg
.
MODEL
.
format
#'videos' or 'frames'
self
.
seg_num
=
cfg
.
MODEL
.
seg_num
self
.
seglen
=
cfg
.
MODEL
.
seglen
self
.
short_size
=
cfg
.
TRAIN
.
short_size
self
.
target_size
=
cfg
.
TRAIN
.
target_size
self
.
img_mean
=
np
.
array
(
cfg
.
MODEL
.
image_mean
).
reshape
(
[
3
,
1
,
1
]).
astype
(
np
.
float32
)
self
.
img_std
=
np
.
array
(
cfg
.
MODEL
.
image_std
).
reshape
(
[
3
,
1
,
1
]).
astype
(
np
.
float32
)
self
.
filelist
=
cfg
[
mode
.
upper
()][
'filelist'
]
self
.
_construct_loader
()
def
_construct_loader
(
self
):
"""
Construct the video loader.
"""
self
.
_num_retries
=
5
self
.
_path_to_videos
=
[]
self
.
_labels
=
[]
self
.
_num_frames
=
[]
with
open
(
self
.
filelist
,
"r"
)
as
f
:
for
clip_idx
,
path_label
in
enumerate
(
f
.
read
().
splitlines
()):
if
self
.
format
==
"videos"
:
path
,
label
=
path_label
.
split
()
self
.
_path_to_videos
.
append
(
path
+
'.avi'
)
self
.
_num_frames
.
append
(
0
)
# unused
self
.
_labels
.
append
(
int
(
label
))
elif
self
.
format
==
"frames"
:
path
,
num_frames
,
label
=
path_label
.
split
()
self
.
_path_to_videos
.
append
(
path
)
self
.
_num_frames
.
append
(
int
(
num_frames
))
self
.
_labels
.
append
(
int
(
label
))
def
__len__
(
self
):
return
len
(
self
.
_path_to_videos
)
def
__getitem__
(
self
,
idx
):
for
ir
in
range
(
self
.
_num_retries
):
path
=
self
.
_path_to_videos
[
idx
]
num_frames
=
self
.
_num_frames
[
idx
]
try
:
frames
=
self
.
pipline
(
path
,
num_frames
,
format
=
self
.
format
,
seg_num
=
self
.
seg_num
,
seglen
=
self
.
seglen
,
short_size
=
self
.
short_size
,
target_size
=
self
.
target_size
,
img_mean
=
self
.
img_mean
,
img_std
=
self
.
img_std
,
mode
=
self
.
mode
)
except
:
if
ir
<
self
.
_num_retries
-
1
:
logger
.
error
(
'Error when loading {}, have {} trys, will try again'
.
format
(
path
,
ir
))
idx
=
random
.
randint
(
0
,
len
(
self
.
_path_to_videos
)
-
1
)
continue
else
:
logger
.
error
(
'Error when loading {}, have {} trys, will not try again'
.
format
(
path
,
ir
))
return
None
,
None
label
=
self
.
_labels
[
idx
]
return
frames
,
np
.
array
([
label
])
#, np.array([idx])
def
pipline
(
self
,
filepath
,
num_frames
,
format
,
seg_num
,
seglen
,
short_size
,
target_size
,
img_mean
,
img_std
,
mode
):
#Loader
if
format
==
'videos'
:
Loader_ops
=
[
VideoDecoder
(
filepath
),
VideoSampler
(
seg_num
,
seglen
,
mode
)
]
elif
format
==
'frames'
:
Loader_ops
=
[
FrameLoader
(
filepath
,
num_frames
,
seg_num
,
seglen
,
mode
)
]
#Augmentation
if
mode
==
'train'
:
Aug_ops
=
[
Scale
(
short_size
),
RandomCrop
(
target_size
),
RandomFlip
(),
Image2Array
(),
Normalization
(
img_mean
,
img_std
)
]
else
:
Aug_ops
=
[
Scale
(
short_size
),
CenterCrop
(
target_size
),
Image2Array
(),
Normalization
(
img_mean
,
img_std
)
]
ops
=
Loader_ops
+
Aug_ops
data
=
ops
[
0
]()
for
op
in
ops
[
1
:]:
data
=
op
(
data
)
return
data
dygraph/tsn/data/dataset/ucf101/build_ucf101_file_list.py
浏览文件 @
4d1187d5
...
...
@@ -103,7 +103,7 @@ def parse_args():
default
=
'rawframes'
,
choices
=
[
'rawframes'
,
'videos'
])
parser
.
add_argument
(
'--out_list_path'
,
type
=
str
,
default
=
'./'
)
parser
.
add_argument
(
'--shuffle'
,
action
=
'store_true'
,
default
=
Tru
e
)
parser
.
add_argument
(
'--shuffle'
,
action
=
'store_true'
,
default
=
Fals
e
)
args
=
parser
.
parse_args
()
return
args
...
...
@@ -146,11 +146,12 @@ def main():
lists
=
build_split_list
(
split_tp
[
i
],
frame_info
,
shuffle
=
args
.
shuffle
)
filename
=
'ucf101_train_split_{}_{}.txt'
.
format
(
i
+
1
,
args
.
format
)
PATH
=
os
.
path
.
abspath
(
args
.
frame_path
)
with
open
(
os
.
path
.
join
(
out_path
,
filename
),
'w'
)
as
f
:
f
.
writelines
(
lists
[
0
])
f
.
writelines
(
[
os
.
path
.
join
(
PATH
,
item
)
for
item
in
lists
[
0
]
])
filename
=
'ucf101_val_split_{}_{}.txt'
.
format
(
i
+
1
,
args
.
format
)
with
open
(
os
.
path
.
join
(
out_path
,
filename
),
'w'
)
as
f
:
f
.
writelines
(
lists
[
1
])
f
.
writelines
(
[
os
.
path
.
join
(
PATH
,
item
)
for
item
in
lists
[
1
]
])
if
__name__
==
"__main__"
:
...
...
dygraph/tsn/loader.py
0 → 100644
浏览文件 @
4d1187d5
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
cv2
import
random
from
PIL
import
Image
class
VideoDecoder
(
object
):
"""
Decode mp4 file to frames.
Args:
filepath: the file path of mp4 file
"""
def
__init__
(
self
,
filepath
):
self
.
filepath
=
filepath
def
__call__
(
self
):
"""
Perform mp4 decode operations.
return:
List where each item is a numpy array after decoder.
"""
cap
=
cv2
.
VideoCapture
(
self
.
filepath
)
videolen
=
int
(
cap
.
get
(
cv2
.
CAP_PROP_FRAME_COUNT
))
sampledFrames
=
[]
for
i
in
range
(
videolen
):
ret
,
frame
=
cap
.
read
()
# maybe first frame is empty
if
ret
==
False
:
continue
img
=
frame
[:,
:,
::
-
1
]
sampledFrames
.
append
(
img
)
return
sampledFrames
class
VideoSampler
(
object
):
"""
Sample frames.
Args:
num_seg(int): number of segments.
seg_len(int): number of sampled frames in each segment.
mode(str): 'train', 'test' or 'infer'
"""
def
__init__
(
self
,
num_seg
,
seg_len
,
mode
):
self
.
num_seg
=
num_seg
self
.
seg_len
=
seg_len
self
.
mode
=
mode
def
__call__
(
self
,
frames
):
"""
Args:
frames: List where each item is a numpy array decoding from video.
return:
List where each item is a PIL.Image after sampling.
"""
average_dur
=
int
(
len
(
frames
)
/
self
.
num_seg
)
imgs
=
[]
for
i
in
range
(
self
.
num_seg
):
idx
=
0
if
self
.
mode
==
'train'
:
if
average_dur
>=
self
.
seg_len
:
idx
=
random
.
randint
(
0
,
average_dur
-
self
.
seg_len
)
idx
+=
i
*
average_dur
elif
average_dur
>=
1
:
idx
+=
i
*
average_dur
else
:
idx
=
i
else
:
if
average_dur
>=
self
.
seg_len
:
idx
=
(
average_dur
-
1
)
//
2
idx
+=
i
*
average_dur
elif
average_dur
>=
1
:
idx
+=
i
*
average_dur
else
:
idx
=
i
for
jj
in
range
(
idx
,
idx
+
self
.
seg_len
):
imgbuf
=
frames
[
int
(
jj
%
len
(
frames
))]
img
=
Image
.
fromarray
(
imgbuf
,
mode
=
'RGB'
)
imgs
.
append
(
img
)
return
imgs
class
FrameLoader
(
object
):
"""
Load frames.
Args:
filepath(str): the file path of frames file.
num_frames(int): number of frames in a video file.
num_seg(int): number of segments.
seg_len(int): number of sampled frames in each segment.
mode(str): 'train', 'test' or 'infer'.
"""
def
__init__
(
self
,
filepath
,
num_frames
,
num_seg
,
seg_len
,
mode
):
self
.
filepath
=
filepath
self
.
num_frames
=
num_frames
self
.
num_seg
=
num_seg
self
.
seg_len
=
seg_len
self
.
mode
=
mode
def
__call__
(
self
):
"""
return:
imgs: List where each item is a PIL.Image.
"""
average_dur
=
int
(
self
.
num_frames
/
self
.
num_seg
)
imgs
=
[]
for
i
in
range
(
self
.
num_seg
):
idx
=
0
if
self
.
mode
==
'train'
:
if
average_dur
>=
self
.
seg_len
:
idx
=
random
.
randint
(
0
,
average_dur
-
self
.
seg_len
)
idx
+=
i
*
average_dur
elif
average_dur
>=
1
:
idx
+=
i
*
average_dur
else
:
idx
=
i
else
:
if
average_dur
>=
self
.
seg_len
:
idx
=
(
average_dur
-
1
)
//
2
idx
+=
i
*
average_dur
elif
average_dur
>=
1
:
idx
+=
i
*
average_dur
else
:
idx
=
i
for
jj
in
range
(
idx
,
idx
+
self
.
seg_len
):
img
=
Image
.
open
(
os
.
path
.
join
(
self
.
filepath
,
'img_{:05d}.jpg'
.
format
(
jj
+
1
))).
convert
(
'RGB'
)
imgs
.
append
(
img
)
return
imgs
dygraph/tsn/multi_tsn_frame.yaml
浏览文件 @
4d1187d5
...
...
@@ -13,8 +13,6 @@ TRAIN:
epoch
:
80
short_size
:
256
target_size
:
224
num_reader_threads
:
16
buf_size
:
256
batch_size
:
128
use_gpu
:
True
filelist
:
"
./data/dataset/ucf101/ucf101_train_split_1_rawframes.txt"
...
...
@@ -24,19 +22,19 @@ TRAIN:
l2_weight_decay
:
1e-4
momentum
:
0.9
total_videos
:
9738
num_workers
:
4
use_shuffle
:
True
VALID
:
short_size
:
256
target_size
:
224
num_reader_threads
:
16
buf_size
:
256
batch_size
:
128
filelist
:
"
./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt"
num_workers
:
4
TEST
:
short_size
:
256
target_size
:
224
num_reader_threads
:
16
buf_size
:
256
batch_size
:
128
filelist
:
"
./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt"
num_workers
:
4
dygraph/tsn/multi_tsn_video.yaml
浏览文件 @
4d1187d5
...
...
@@ -13,8 +13,6 @@ TRAIN:
epoch
:
80
short_size
:
256
target_size
:
224
num_reader_threads
:
16
buf_size
:
256
batch_size
:
128
use_gpu
:
True
filelist
:
"
./data/dataset/ucf101/ucf101_train_split_1_videos.txt"
...
...
@@ -24,19 +22,19 @@ TRAIN:
l2_weight_decay
:
1e-4
momentum
:
0.9
total_videos
:
9738
num_workers
:
4
use_shuffle
:
True
VALID
:
short_size
:
256
target_size
:
224
num_reader_threads
:
16
buf_size
:
256
batch_size
:
128
filelist
:
"
./data/dataset/ucf101/ucf101_val_split_1_videos.txt"
num_workers
:
4
TEST
:
short_size
:
256
target_size
:
224
num_reader_threads
:
16
buf_size
:
256
batch_size
:
128
filelist
:
"
./data/dataset/ucf101/ucf101_val_split_1_videos.txt"
num_workers
:
4
dygraph/tsn/single_tsn_frame.yaml
浏览文件 @
4d1187d5
...
...
@@ -13,8 +13,6 @@ TRAIN:
epoch
:
80
short_size
:
256
target_size
:
224
num_reader_threads
:
8
buf_size
:
64
batch_size
:
32
use_gpu
:
True
filelist
:
"
./data/dataset/ucf101/ucf101_train_split_1_rawframes.txt"
...
...
@@ -24,19 +22,19 @@ TRAIN:
l2_weight_decay
:
1e-4
momentum
:
0.9
total_videos
:
9738
num_workers
:
4
use_shuffle
:
True
VALID
:
short_size
:
256
target_size
:
224
num_reader_threads
:
8
buf_size
:
64
batch_size
:
32
filelist
:
"
./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt"
num_workers
:
4
TEST
:
short_size
:
256
target_size
:
224
num_reader_threads
:
8
buf_size
:
64
batch_size
:
32
filelist
:
"
./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt"
num_workers
:
4
dygraph/tsn/single_tsn_video.yaml
浏览文件 @
4d1187d5
...
...
@@ -13,8 +13,6 @@ TRAIN:
epoch
:
80
short_size
:
256
target_size
:
224
num_reader_threads
:
8
buf_size
:
64
batch_size
:
32
use_gpu
:
True
filelist
:
"
./data/dataset/ucf101/ucf101_train_split_1_videos.txt"
...
...
@@ -24,19 +22,19 @@ TRAIN:
l2_weight_decay
:
1e-4
momentum
:
0.9
total_videos
:
9738
num_workers
:
4
use_shuffle
:
True
VALID
:
short_size
:
256
target_size
:
224
num_reader_threads
:
8
buf_size
:
64
batch_size
:
32
filelist
:
"
./data/dataset/ucf101/ucf101_val_split_1_videos.txt"
num_workers
:
4
TEST
:
short_size
:
256
target_size
:
224
num_reader_threads
:
8
buf_size
:
64
batch_size
:
32
filelist
:
"
./data/dataset/ucf101/ucf101_val_split_1_videos.txt"
num_workers
:
4
dygraph/tsn/train.py
浏览文件 @
4d1187d5
...
...
@@ -27,6 +27,9 @@ from paddle.fluid.dygraph.base import to_variable
from
model
import
TSN_ResNet
from
utils.config_utils
import
*
from
reader.ucf101_reader
import
UCF101Reader
import
paddle
from
paddle.io
import
DataLoader
,
DistributedBatchSampler
from
compose
import
TSN_UCF101_Dataset
logging
.
root
.
handlers
=
[]
FORMAT
=
'[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
...
...
@@ -111,19 +114,15 @@ def init_model(model, pre_state_dict):
return
model
def
val
(
epoch
,
model
,
cfg
,
args
):
reader
=
UCF101Reader
(
name
=
"TSN"
,
mode
=
"valid"
,
cfg
=
cfg
)
reader
=
reader
.
create_reader
()
def
val
(
epoch
,
model
,
val_loader
,
cfg
,
args
):
total_loss
=
0.0
total_acc1
=
0.0
total_acc5
=
0.0
total_sample
=
0
for
batch_id
,
data
in
enumerate
(
reader
()):
x_data
=
np
.
array
([
item
[
0
]
for
item
in
data
])
y_data
=
np
.
array
([
item
[
1
]
for
item
in
data
]).
reshape
([
-
1
,
1
])
imgs
=
to_variable
(
x_data
)
labels
=
to_variable
(
y_data
)
for
batch_id
,
data
in
enumerate
(
val_loader
):
imgs
=
paddle
.
to_tensor
(
data
[
0
])
labels
=
paddle
.
to_tensor
(
data
[
1
])
labels
.
stop_gradient
=
True
outputs
=
model
(
imgs
)
...
...
@@ -210,11 +209,30 @@ def train(args):
gpus
=
gpus
.
split
(
","
)
num_gpus
=
len
(
gpus
)
bs_denominator
=
num_gpus
train_config
.
TRAIN
.
batch_size
=
int
(
train_config
.
TRAIN
.
batch_size
/
bs_denominator
)
train_reader
=
UCF101Reader
(
name
=
"TSN"
,
mode
=
"train"
,
cfg
=
train_config
)
train_reader
=
train_reader
.
create_reader
()
bs_train_single
=
int
(
train_config
.
TRAIN
.
batch_size
/
bs_denominator
)
bs_val_single
=
int
(
valid_config
.
VALID
.
batch_size
/
bs_denominator
)
train_dataset
=
TSN_UCF101_Dataset
(
train_config
,
'train'
)
val_dataset
=
TSN_UCF101_Dataset
(
valid_config
,
'valid'
)
train_sampler
=
DistributedBatchSampler
(
train_dataset
,
batch_size
=
bs_train_single
,
shuffle
=
train_config
.
TRAIN
.
use_shuffle
,
drop_last
=
True
)
train_loader
=
DataLoader
(
train_dataset
,
batch_sampler
=
train_sampler
,
places
=
place
,
num_workers
=
train_config
.
TRAIN
.
num_workers
,
return_list
=
True
)
val_sampler
=
DistributedBatchSampler
(
val_dataset
,
batch_size
=
bs_val_single
)
val_loader
=
DataLoader
(
val_dataset
,
batch_sampler
=
val_sampler
,
places
=
place
,
num_workers
=
valid_config
.
VALID
.
num_workers
,
return_list
=
True
)
if
use_data_parallel
:
# (data_parallel step4/6)
...
...
@@ -234,12 +252,10 @@ def train(args):
total_acc5
=
0.0
total_sample
=
0
batch_start
=
time
.
time
()
for
batch_id
,
data
in
enumerate
(
train_
reader
()
):
for
batch_id
,
data
in
enumerate
(
train_
loader
):
train_reader_cost
=
time
.
time
()
-
batch_start
x_data
=
np
.
array
([
item
[
0
]
for
item
in
data
]).
astype
(
"float32"
)
y_data
=
np
.
array
([
item
[
1
]
for
item
in
data
]).
reshape
([
-
1
,
1
])
imgs
=
to_variable
(
x_data
)
labels
=
to_variable
(
y_data
)
imgs
=
paddle
.
to_tensor
(
data
[
0
])
labels
=
paddle
.
to_tensor
(
data
[
1
])
labels
.
stop_gradient
=
True
outputs
=
video_model
(
imgs
)
...
...
@@ -292,13 +308,13 @@ def train(args):
model_path
=
os
.
path
.
join
(
args
.
checkpoint
,
"_"
+
model_path_pre
+
"_epoch{}"
.
format
(
epoch
))
fluid
.
dygraph
.
save_dygraph
(
video_model
.
state_dict
(),
model_path
)
fluid
.
dygraph
.
save_dygraph
(
video_model
.
state_dict
(),
model_path
)
fluid
.
dygraph
.
save_dygraph
(
optimizer
.
state_dict
(),
model_path
)
if
args
.
validate
:
video_model
.
eval
()
val_acc
=
val
(
epoch
,
video_model
,
valid_config
,
args
)
val_acc
=
val
(
epoch
,
video_model
,
val_loader
,
valid_config
,
args
)
# save the best parameters in trainging stage
if
epoch
==
1
:
best_acc
=
val_acc
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录