Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
c3976c83
M
models
项目概览
PaddlePaddle
/
models
大约 1 年 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
c3976c83
编写于
1月 30, 2019
作者:
S
SunGaofeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
move reader and metrics out of models
上级
9ba1db46
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
137 addition
and
148 deletion
+137
-148
fluid/PaddleCV/video/configs/attention_cluster.txt
fluid/PaddleCV/video/configs/attention_cluster.txt
+2
-1
fluid/PaddleCV/video/configs/attention_lstm.txt
fluid/PaddleCV/video/configs/attention_lstm.txt
+2
-1
fluid/PaddleCV/video/configs/nextvlad.txt
fluid/PaddleCV/video/configs/nextvlad.txt
+1
-0
fluid/PaddleCV/video/datareader/feature_reader.py
fluid/PaddleCV/video/datareader/feature_reader.py
+5
-7
fluid/PaddleCV/video/datareader/kinetics_reader.py
fluid/PaddleCV/video/datareader/kinetics_reader.py
+13
-13
fluid/PaddleCV/video/datareader/nonlocal_reader.py
fluid/PaddleCV/video/datareader/nonlocal_reader.py
+70
-23
fluid/PaddleCV/video/datareader/nonlocal_video_io.py
fluid/PaddleCV/video/datareader/nonlocal_video_io.py
+0
-61
fluid/PaddleCV/video/datareader/reader_utils.py
fluid/PaddleCV/video/datareader/reader_utils.py
+1
-1
fluid/PaddleCV/video/metrics/metrics_util.py
fluid/PaddleCV/video/metrics/metrics_util.py
+22
-20
fluid/PaddleCV/video/metrics/multicrop_test/__init__.py
fluid/PaddleCV/video/metrics/multicrop_test/__init__.py
+0
-0
fluid/PaddleCV/video/metrics/multicrop_test/multicrop_test_metrics.py
...CV/video/metrics/multicrop_test/multicrop_test_metrics.py
+0
-0
fluid/PaddleCV/video/models/attention_cluster/attention_cluster.py
...dleCV/video/models/attention_cluster/attention_cluster.py
+1
-1
fluid/PaddleCV/video/models/stnet/stnet.py
fluid/PaddleCV/video/models/stnet/stnet.py
+6
-6
fluid/PaddleCV/video/train.py
fluid/PaddleCV/video/train.py
+14
-14
未找到文件。
fluid/PaddleCV/video/configs/attention_cluster.txt
浏览文件 @
c3976c83
...
...
@@ -8,7 +8,8 @@ feature_names = ['rgb', 'audio']
feature_dims = [1024, 128]
seg_num = 100
cluster_nums = [32, 32]
class_num = 3862
num_classes = 3862
topk = 20
[TRAIN]
epoch = 5
...
...
fluid/PaddleCV/video/configs/attention_lstm.txt
浏览文件 @
c3976c83
...
...
@@ -8,7 +8,8 @@ feature_names = ['rgb', 'audio']
feature_dims = [1024, 128]
embedding_size = 512
lstm_size = 1024
class_num = 3862
num_classes = 3862
topk = 20
[TRAIN]
epoch = 10
...
...
fluid/PaddleCV/video/configs/nextvlad.txt
浏览文件 @
c3976c83
[MODEL]
name = "NEXTVLAD"
num_classes = 3862
topk = 20
video_feature_size = 1024
audio_feature_size = 128
cluster_size = 128
...
...
fluid/PaddleCV/video/datareader/feature_reader.py
浏览文件 @
c3976c83
...
...
@@ -40,15 +40,13 @@ class FeatureReader(DataReader):
def
__init__
(
self
,
name
,
phase
,
cfg
):
self
.
name
=
name
self
.
phase
=
phase
self
.
num_classes
=
cfg
[
'num_classes'
]
self
.
num_classes
=
cfg
.
MODEL
.
num_classes
# set batch size and file list
self
.
batch_size
=
cfg
[
'batch_size'
]
self
.
filelist
=
cfg
[
'list'
]
if
'eigen_file'
in
cfg
.
keys
():
self
.
eigen_file
=
cfg
[
'eigen_file'
]
if
'seg_num'
in
cfg
.
keys
():
self
.
seg_num
=
cfg
[
'seg_num'
]
self
.
batch_size
=
cfg
[
phase
.
upper
()][
'batch_size'
]
self
.
filelist
=
cfg
[
phase
.
upper
()][
'filelist'
]
self
.
eigen_file
=
cfg
.
MODEL
.
get
(
'eigen_file'
,
None
)
self
.
seg_num
=
cfg
.
MODEL
.
get
(
'seg_num'
,
None
)
def
create_reader
(
self
):
fl
=
open
(
self
.
filelist
).
readlines
()
...
...
fluid/PaddleCV/video/datareader/kinetics_reader.py
浏览文件 @
c3976c83
...
...
@@ -56,22 +56,22 @@ class KineticsReader(DataReader):
def
__init__
(
self
,
name
,
phase
,
cfg
):
self
.
name
=
name
self
.
phase
=
phase
self
.
format
=
cfg
[
'format'
]
self
.
num_classes
=
cfg
[
'num_classes'
]
self
.
seg_num
=
cfg
[
'seg_num'
]
self
.
seglen
=
cfg
[
'seglen'
]
self
.
short_size
=
cfg
[
'short_size'
]
self
.
target_size
=
cfg
[
'target_size'
]
self
.
num_reader_threads
=
cfg
[
'num_reader_threads'
]
self
.
buf_size
=
cfg
[
'buf_size'
]
self
.
img_mean
=
np
.
array
(
cfg
[
'image_mean'
]
).
reshape
(
self
.
format
=
cfg
.
MODEL
.
format
#cfg
['format']
self
.
num_classes
=
cfg
.
MODEL
.
num_classes
#cfg
['num_classes']
self
.
seg_num
=
cfg
.
MODEL
.
segnum
#
['seg_num']
self
.
seglen
=
cfg
.
MODEL
.
seglen
#
['seglen']
self
.
short_size
=
cfg
[
phase
.
upper
()][
'short_size'
]
# [
'short_size']
self
.
target_size
=
cfg
[
phase
.
upper
()][
'target_size'
]
#[
'target_size']
self
.
num_reader_threads
=
cfg
[
phase
.
upper
()][
'num_reader_threads'
]
self
.
buf_size
=
cfg
[
phase
.
upper
()][
'buf_size'
]
self
.
img_mean
=
np
.
array
(
cfg
.
MODEL
.
image_mean
).
reshape
(
[
3
,
1
,
1
]).
astype
(
np
.
float32
)
self
.
img_std
=
np
.
array
(
cfg
[
'image_std'
]
).
reshape
(
self
.
img_std
=
np
.
array
(
cfg
.
MODEL
.
image_std
).
reshape
(
[
3
,
1
,
1
]).
astype
(
np
.
float32
)
# set batch size and file list
self
.
batch_size
=
cfg
[
'batch_size'
]
self
.
filelist
=
cfg
[
'
list'
]
self
.
batch_size
=
cfg
[
phase
.
upper
()][
'batch_size'
]
self
.
filelist
=
cfg
[
phase
.
upper
()][
'file
list'
]
def
create_reader
(
self
):
_reader
=
_reader_creator
(
self
.
filelist
,
self
.
phase
,
seg_num
=
self
.
seg_num
,
seglen
=
self
.
seglen
,
\
...
...
fluid/PaddleCV/video/datareader/nonlocal_reader.py
浏览文件 @
c3976c83
...
...
@@ -20,7 +20,6 @@ import numpy as np
import
cv2
import
logging
from
.
import
nonlocal_video_io
from
.reader_utils
import
DataReader
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -51,43 +50,91 @@ class NonlocalReader(DataReader):
def
create_reader
(
self
):
cfg
=
self
.
cfg
assert
cfg
[
'num_reader_threads'
]
>=
1
,
\
"number of reader threads({}) should be a positive integer"
.
format
(
cfg
[
'num_reader_threads'
])
if
cfg
[
'num_reader_threads'
]
==
1
:
phase
=
self
.
phase
num_reader_threads
=
cfg
[
phase
.
upper
()][
'num_reader_threads'
]
assert
num_reader_threads
>=
1
,
\
"number of reader threads({}) should be a positive integer"
.
format
(
num_reader_threads
)
if
num_reader_threads
==
1
:
reader_func
=
make_reader
else
:
reader_func
=
make_multi_reader
dataset_args
=
{}
dataset_args
[
'image_mean'
]
=
cfg
[
'image_mean'
]
dataset_args
[
'image_std'
]
=
cfg
[
'image_std'
]
dataset_args
[
'crop_size'
]
=
cfg
[
'crop_size'
]
dataset_args
[
'sample_rate'
]
=
cfg
[
'sample_rate'
]
dataset_args
[
'video_length'
]
=
cfg
[
'video_length'
]
dataset_args
[
'min_size'
]
=
cfg
[
'jitter_scales'
][
0
]
dataset_args
[
'max_size'
]
=
cfg
[
'jitter_scales'
][
1
]
dataset_args
[
'num_reader_threads'
]
=
cfg
[
'num_reader_threads'
]
dataset_args
[
'image_mean'
]
=
cfg
.
MODEL
.
image_mean
dataset_args
[
'image_std'
]
=
cfg
.
MODEL
.
image_std
dataset_args
[
'crop_size'
]
=
cfg
[
phase
.
upper
()][
'crop_size'
]
dataset_args
[
'sample_rate'
]
=
cfg
[
phase
.
upper
()][
'sample_rate'
]
dataset_args
[
'video_length'
]
=
cfg
[
phase
.
upper
()][
'video_length'
]
dataset_args
[
'min_size'
]
=
cfg
[
phase
.
upper
()][
'jitter_scales'
][
0
]
dataset_args
[
'max_size'
]
=
cfg
[
phase
.
upper
()][
'jitter_scales'
][
1
]
dataset_args
[
'num_reader_threads'
]
=
num_reader_threads
filelist
=
cfg
[
phase
.
upper
()][
'list'
]
batch_size
=
cfg
[
phase
.
upper
()][
'batch_size'
]
if
self
.
phase
==
'train'
:
sample_times
=
1
return
reader_func
(
cfg
[
'list'
],
cfg
[
'batch_size'
],
sample_times
,
True
,
True
,
**
dataset_args
)
return
reader_func
(
filelist
,
batch_size
,
sample_times
,
True
,
True
,
**
dataset_args
)
elif
self
.
phase
==
'valid'
:
sample_times
=
1
return
reader_func
(
cfg
[
'list'
],
cfg
[
'batch_size'
],
sample_times
,
False
,
False
,
**
dataset_args
)
return
reader_func
(
filelist
,
batch_size
,
sample_times
,
False
,
False
,
**
dataset_args
)
elif
self
.
phase
==
'test'
:
sample_times
=
cfg
[
'num_test_clips'
]
if
cfg
[
'use_multi_crop'
]
==
1
:
sample_times
=
cfg
[
'
TEST'
][
'
num_test_clips'
]
if
cfg
[
'
TEST'
][
'
use_multi_crop'
]
==
1
:
sample_times
=
int
(
sample_times
/
3
)
if
cfg
[
'use_multi_crop'
]
==
2
:
if
cfg
[
'
TEST'
][
'
use_multi_crop'
]
==
2
:
sample_times
=
int
(
sample_times
/
6
)
return
reader_func
(
cfg
[
'list'
],
cfg
[
'batch_size'
],
sample_times
,
False
,
False
,
**
dataset_args
)
return
reader_func
(
filelist
,
batch_size
,
sample_times
,
False
,
False
,
**
dataset_args
)
else
:
logger
.
info
(
'Not implemented'
)
raise
def
video_fast_get_frame
(
video_path
,
sampling_rate
=
1
,
length
=
64
,
start_frm
=-
1
,
sample_times
=
1
):
cap
=
cv2
.
VideoCapture
(
video_path
)
frame_cnt
=
int
(
cap
.
get
(
cv2
.
CAP_PROP_FRAME_COUNT
))
width
=
int
(
cap
.
get
(
cv2
.
CAP_PROP_FRAME_WIDTH
))
height
=
int
(
cap
.
get
(
cv2
.
CAP_PROP_FRAME_HEIGHT
))
sampledFrames
=
[]
# n_frame < sample area
video_output
=
np
.
ndarray
(
shape
=
[
length
,
height
,
width
,
3
],
dtype
=
np
.
uint8
)
use_start_frm
=
start_frm
if
start_frm
<
0
:
if
(
frame_cnt
-
length
*
sampling_rate
>
0
):
use_start_frm
=
random
.
randint
(
0
,
frame_cnt
-
length
*
sampling_rate
)
else
:
use_start_frm
=
0
else
:
frame_gaps
=
float
(
frame_cnt
)
/
float
(
sample_times
)
use_start_frm
=
int
(
frame_gaps
*
start_frm
)
%
frame_cnt
for
i
in
range
(
frame_cnt
):
ret
,
frame
=
cap
.
read
()
# maybe first frame is empty
if
ret
==
False
:
continue
img
=
frame
[:,
:,
::
-
1
]
sampledFrames
.
append
(
img
)
for
idx
in
range
(
length
):
i
=
use_start_frm
+
idx
*
sampling_rate
i
=
i
%
len
(
sampledFrames
)
video_output
[
idx
]
=
sampledFrames
[
i
]
cap
.
release
()
return
video_output
def
apply_resize
(
rgbdata
,
min_size
,
max_size
):
length
,
height
,
width
,
channel
=
rgbdata
.
shape
ratio
=
1.0
...
...
@@ -177,7 +224,7 @@ def make_reader(filelist, batch_size, sample_times, is_training, shuffle,
label
=
np
.
array
([
label
]).
astype
(
np
.
int64
)
# 1, get rgb data for fixed length of frames
try
:
rgbdata
=
nonlocal_video_io
.
video_fast_get_frame
(
fn
,
\
rgbdata
=
video_fast_get_frame
(
fn
,
\
sampling_rate
=
dataset_args
[
'sample_rate'
],
length
=
dataset_args
[
'video_length'
],
\
start_frm
=
start_frm
,
sample_times
=
in_sample_times
)
except
:
...
...
@@ -244,7 +291,7 @@ def make_multi_reader(filelist, batch_size, sample_times, is_training, shuffle,
label
=
np
.
array
([
label
]).
astype
(
np
.
int64
)
# 1, get rgb data for fixed length of frames
try
:
rgbdata
=
nonlocal_video_io
.
video_fast_get_frame
(
fn
,
\
rgbdata
=
video_fast_get_frame
(
fn
,
\
sampling_rate
=
dataset_args
[
'sample_rate'
],
length
=
dataset_args
[
'video_length'
],
\
start_frm
=
start_frm
,
sample_times
=
in_sample_times
)
except
:
...
...
fluid/PaddleCV/video/datareader/nonlocal_video_io.py
已删除
100644 → 0
浏览文件 @
9ba1db46
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import
cv2
import
numpy
as
np
import
random
def
video_fast_get_frame
(
video_path
,
sampling_rate
=
1
,
length
=
64
,
start_frm
=-
1
,
sample_times
=
1
):
cap
=
cv2
.
VideoCapture
(
video_path
)
frame_cnt
=
int
(
cap
.
get
(
cv2
.
CAP_PROP_FRAME_COUNT
))
width
=
int
(
cap
.
get
(
cv2
.
CAP_PROP_FRAME_WIDTH
))
height
=
int
(
cap
.
get
(
cv2
.
CAP_PROP_FRAME_HEIGHT
))
sampledFrames
=
[]
# n_frame < sample area
video_output
=
np
.
ndarray
(
shape
=
[
length
,
height
,
width
,
3
],
dtype
=
np
.
uint8
)
use_start_frm
=
start_frm
if
start_frm
<
0
:
if
(
frame_cnt
-
length
*
sampling_rate
>
0
):
use_start_frm
=
random
.
randint
(
0
,
frame_cnt
-
length
*
sampling_rate
)
else
:
use_start_frm
=
0
else
:
frame_gaps
=
float
(
frame_cnt
)
/
float
(
sample_times
)
use_start_frm
=
int
(
frame_gaps
*
start_frm
)
%
frame_cnt
for
i
in
range
(
frame_cnt
):
ret
,
frame
=
cap
.
read
()
# maybe first frame is empty
if
ret
==
False
:
continue
img
=
frame
[:,
:,
::
-
1
]
sampledFrames
.
append
(
img
)
for
idx
in
range
(
length
):
i
=
use_start_frm
+
idx
*
sampling_rate
i
=
i
%
len
(
sampledFrames
)
video_output
[
idx
]
=
sampledFrames
[
i
]
cap
.
release
()
return
video_output
fluid/PaddleCV/video/datareader/reader_utils.py
浏览文件 @
c3976c83
...
...
@@ -70,6 +70,6 @@ def regist_reader(name, reader):
reader_zoo
.
regist
(
name
,
reader
)
def
get_reader
(
name
,
mode
=
'train'
,
**
cfg
):
def
get_reader
(
name
,
mode
,
cfg
):
reader_model
=
reader_zoo
.
get
(
name
,
mode
,
cfg
)
return
reader_model
.
create_reader
()
fluid/PaddleCV/video/metrics/metrics_util.py
浏览文件 @
c3976c83
...
...
@@ -22,13 +22,13 @@ import logging
import
numpy
as
np
from
metrics.youtube8m
import
eval_util
as
youtube8m_metrics
from
metrics.kinetics
import
accuracy_metrics
as
kinetics_metrics
from
metrics.
non_local
import
nonlocal_test_metrics
as
nonlocal
_test_metrics
from
metrics.
multicrop_test
import
multicrop_test_metrics
as
multicrop
_test_metrics
logger
=
logging
.
getLogger
(
__name__
)
class
Metrics
(
object
):
def
__init__
(
self
,
name
,
mode
,
**
metrics_args
):
def
__init__
(
self
,
name
,
mode
,
metrics_args
):
"""Not implemented"""
pass
...
...
@@ -50,12 +50,11 @@ class Metrics(object):
class
Youtube8mMetrics
(
Metrics
):
def
__init__
(
self
,
name
,
mode
,
**
metrics_args
):
def
__init__
(
self
,
name
,
mode
,
metrics_args
):
self
.
name
=
name
self
.
mode
=
mode
self
.
metrics_args
=
metrics_args
self
.
num_classes
=
metrics_args
[
'num_classes'
]
self
.
topk
=
metrics_args
[
'topk'
]
self
.
num_classes
=
metrics_args
[
'MODEL'
][
'num_classes'
]
self
.
topk
=
metrics_args
[
'MODEL'
][
'topk'
]
self
.
calculator
=
youtube8m_metrics
.
EvaluationMetrics
(
self
.
num_classes
,
self
.
topk
)
...
...
@@ -82,12 +81,10 @@ class Youtube8mMetrics(Metrics):
class
Kinetics400Metrics
(
Metrics
):
def
__init__
(
self
,
name
,
mode
,
**
metrics_args
):
def
__init__
(
self
,
name
,
mode
,
metrics_args
):
self
.
name
=
name
self
.
mode
=
mode
self
.
metrics_args
=
metrics_args
self
.
calculator
=
kinetics_metrics
.
MetricsCalculator
(
name
,
mode
.
lower
())
self
.
calculator
=
kinetics_metrics
.
MetricsCalculator
(
name
,
mode
.
lower
())
def
calculate_and_log_out
(
self
,
loss
,
pred
,
label
,
info
=
''
):
if
loss
is
not
None
:
...
...
@@ -114,14 +111,19 @@ class Kinetics400Metrics(Metrics):
self
.
calculator
.
reset
()
class
Nonlocal
Metrics
(
Metrics
):
def
__init__
(
self
,
name
,
mode
,
**
metrics_args
):
class
Multicrop
Metrics
(
Metrics
):
def
__init__
(
self
,
name
,
mode
,
metrics_args
):
self
.
name
=
name
self
.
mode
=
mode
self
.
metrics_args
=
metrics_args
if
mode
==
'test'
:
self
.
calculator
=
nonlocal_test_metrics
.
MetricsCalculator
(
name
,
mode
.
lower
(),
**
metrics_args
)
args
=
{}
args
[
'num_test_clips'
]
=
metrics_args
.
TEST
.
num_test_clips
args
[
'dataset_size'
]
=
metrics_args
.
TEST
.
dataset_size
args
[
'filename_gt'
]
=
metrics_args
.
TEST
.
filename_gt
args
[
'checkpoint_dir'
]
=
metrics_args
.
TEST
.
checkpoint_dir
args
[
'num_classes'
]
=
metrics_args
.
MODEL
.
num_classes
self
.
calculator
=
multicrop_test_metrics
.
MetricsCalculator
(
name
,
mode
.
lower
(),
**
args
)
else
:
self
.
calculator
=
kinetics_metrics
.
MetricsCalculator
(
name
,
mode
.
lower
())
...
...
@@ -166,10 +168,10 @@ class MetricsZoo(object):
type
(
metrics
))
self
.
metrics_zoo
[
name
]
=
metrics
def
get
(
self
,
name
,
mode
,
**
cfg
):
def
get
(
self
,
name
,
mode
,
cfg
):
for
k
,
v
in
self
.
metrics_zoo
.
items
():
if
k
==
name
:
return
v
(
name
,
mode
,
**
cfg
)
return
v
(
name
,
mode
,
cfg
)
raise
MetricsNotFoundError
(
name
,
self
.
metrics_zoo
.
keys
())
...
...
@@ -181,8 +183,8 @@ def regist_metrics(name, metrics):
metrics_zoo
.
regist
(
name
,
metrics
)
def
get_metrics
(
name
,
mode
=
'train'
,
**
cfg
):
return
metrics_zoo
.
get
(
name
,
mode
,
**
cfg
)
def
get_metrics
(
name
,
mode
,
cfg
):
return
metrics_zoo
.
get
(
name
,
mode
,
cfg
)
regist_metrics
(
"NEXTVLAD"
,
Youtube8mMetrics
)
...
...
@@ -191,4 +193,4 @@ regist_metrics("ATTENTIONCLUSTER", Youtube8mMetrics)
regist_metrics
(
"TSN"
,
Kinetics400Metrics
)
regist_metrics
(
"TSM"
,
Kinetics400Metrics
)
regist_metrics
(
"STNET"
,
Kinetics400Metrics
)
regist_metrics
(
"NONLOCAL"
,
Nonlocal
Metrics
)
regist_metrics
(
"NONLOCAL"
,
Multicrop
Metrics
)
fluid/PaddleCV/video/metrics/
non_local
/__init__.py
→
fluid/PaddleCV/video/metrics/
multicrop_test
/__init__.py
浏览文件 @
c3976c83
文件已移动
fluid/PaddleCV/video/metrics/
non_local/nonlocal
_test_metrics.py
→
fluid/PaddleCV/video/metrics/
multicrop_test/multicrop
_test_metrics.py
浏览文件 @
c3976c83
文件已移动
fluid/PaddleCV/video/models/attention_cluster/attention_cluster.py
浏览文件 @
c3976c83
...
...
@@ -34,7 +34,7 @@ class AttentionCluster(ModelBase):
self
.
feature_dims
=
self
.
cfg
.
MODEL
.
feature_dims
self
.
cluster_nums
=
self
.
cfg
.
MODEL
.
cluster_nums
self
.
seg_num
=
self
.
cfg
.
MODEL
.
seg_num
self
.
class_num
=
self
.
cfg
.
MODEL
.
class_num
self
.
class_num
=
self
.
cfg
.
MODEL
.
num_classes
#self.cfg.MODEL.
class_num
self
.
drop_rate
=
self
.
cfg
.
MODEL
.
drop_rate
# get mode configs
...
...
fluid/PaddleCV/video/models/stnet/stnet.py
浏览文件 @
c3976c83
...
...
@@ -154,13 +154,13 @@ class STNET(ModelBase):
return
{}
def
load_pretrain_params
(
self
,
exe
,
pretrain
,
prog
):
def
is_parameter
(
var
):
if
isinstance
(
var
,
fluid
.
framework
.
Parameter
):
return
isinstance
(
var
,
fluid
.
framework
.
Parameter
)
and
(
not
(
"fc_0"
in
var
.
name
))
\
and
(
not
(
"batch_norm"
in
var
.
name
))
and
(
not
(
"xception"
in
var
.
name
))
and
(
not
(
"conv3d"
in
var
.
name
))
def
is_parameter
(
var
):
if
isinstance
(
var
,
fluid
.
framework
.
Parameter
):
return
isinstance
(
var
,
fluid
.
framework
.
Parameter
)
and
(
not
(
"fc_0"
in
var
.
name
))
\
and
(
not
(
"batch_norm"
in
var
.
name
))
and
(
not
(
"xception"
in
var
.
name
))
and
(
not
(
"conv3d"
in
var
.
name
))
vars
=
filter
(
is_parameter
,
prog
.
list_vars
())
fluid
.
io
.
load_vars
(
exe
,
pretrain
,
vars
=
vars
)
vars
=
filter
(
is_parameter
,
prog
.
list_vars
())
fluid
.
io
.
load_vars
(
exe
,
pretrain
,
vars
=
vars
)
param_tensor
=
fluid
.
global_scope
().
find_var
(
"conv1_weights"
).
get_tensor
()
...
...
fluid/PaddleCV/video/train.py
浏览文件 @
c3976c83
...
...
@@ -21,8 +21,10 @@ import numpy as np
import
paddle.fluid
as
fluid
from
tools.train_utils
import
train_with_pyreader
,
train_without_pyreader
from
config
import
*
import
models
from
config
import
*
from
datareader
import
get_reader
from
metrics
import
get_metrics
logging
.
root
.
handlers
=
[]
FORMAT
=
'[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
...
...
@@ -104,10 +106,8 @@ def train(args):
config
=
parse_config
(
args
.
config
)
train_config
=
merge_configs
(
config
,
'train'
,
vars
(
args
))
valid_config
=
merge_configs
(
config
,
'valid'
,
vars
(
args
))
train_model
=
models
.
get_model
(
args
.
model_name
,
train_config
,
mode
=
'train'
)
valid_model
=
models
.
get_model
(
args
.
model_name
,
valid_config
,
mode
=
'valid'
)
train_model
=
models
.
get_model
(
args
.
model_name
,
train_config
,
mode
=
'train'
)
valid_model
=
models
.
get_model
(
args
.
model_name
,
valid_config
,
mode
=
'valid'
)
# build model
startup
=
fluid
.
Program
()
...
...
@@ -141,7 +141,7 @@ def train(args):
valid_feeds
=
valid_model
.
feeds
()
valid_outputs
=
valid_model
.
outputs
()
valid_loss
=
valid_model
.
loss
()
valid_metrics
=
valid_model
.
metrics
()
#
valid_metrics = valid_model.metrics()
valid_pyreader
=
valid_model
.
pyreader
()
place
=
fluid
.
CUDAPlace
(
0
)
if
args
.
use_gpu
else
fluid
.
CPUPlace
()
...
...
@@ -165,16 +165,16 @@ def train(args):
main_program
=
valid_prog
)
# get reader
# train_reader = get_reader(
train_config)
# valid_reader = get_reader(
valid_config)
train_reader
=
train_model
.
reader
()
valid_reader
=
valid_model
.
reader
()
train_reader
=
get_reader
(
args
.
model_name
.
upper
(),
'train'
,
train_config
)
valid_reader
=
get_reader
(
args
.
model_name
.
upper
(),
'valid'
,
valid_config
)
#
train_reader = train_model.reader()
#
valid_reader = valid_model.reader()
# get metrics
# train_metrics = get_metrics(
train_config)
# valid_metrics = get_metrics(
valid_config)
train_metrics
=
train_model
.
metrics
()
train_metrics
=
train_model
.
metrics
()
train_metrics
=
get_metrics
(
args
.
model_name
.
upper
(),
'train'
,
train_config
)
valid_metrics
=
get_metrics
(
args
.
model_name
.
upper
(),
'valid'
,
valid_config
)
#
train_metrics = train_model.metrics()
#
train_metrics = train_model.metrics()
train_fetch_list
=
[
train_loss
.
name
]
+
[
x
.
name
for
x
in
train_outputs
]
+
[
train_feeds
[
-
1
].
name
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录