提交 5e682b3c 编写于 作者: D dengkaipeng

change split/phase to mode

上级 133458b1
......@@ -3,3 +3,4 @@ checkpoints
output*
*.py
*.swp
*_result
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import sys
from .reader_utils import DataReader
try:
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import sys
import math
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import random
import time
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import cv2
import numpy as np
import random
......@@ -45,10 +59,3 @@ def video_fast_get_frame(video_path,
cap.release()
return video_output
if __name__ == '__main__':
video_path = '~/docker/dockermount/data/k400/Kinetics_trimmed_processed_val/dancing_gangnam_style/rC7d3L8nSB4.mp4'
vout = video_fast_get_frame(video_path)
vout2 = video_fast_get_frame(video_path, \
sampling_rate = 2, length = 8, \
start_frm = 3, sample_times = 10)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import pickle
import cv2
import numpy as np
......@@ -5,7 +19,7 @@ import random
class ReaderNotFoundError(Exception):
"Error: model not found"
"Error: reader not found"
def __init__(self, reader_name, avail_readers):
super(ReaderNotFoundError, self).__init__()
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import unicode_literals
from __future__ import print_function
......@@ -11,13 +25,13 @@ logger = logging.getLogger(__name__)
class MetricsCalculator():
def __init__(self, name, split):
def __init__(self, name, mode):
self.name = name
self.split = split # 'train', 'val', 'test'
self.mode = mode # 'train', 'val', 'test'
self.reset()
def reset(self):
logger.info('Resetting {} metrics...'.format(self.split))
logger.info('Resetting {} metrics...'.format(self.mode))
self.aggr_acc1 = 0.0
self.aggr_acc5 = 0.0
self.aggr_loss = 0.0
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import unicode_literals
from __future__ import print_function
......@@ -14,7 +28,7 @@ logger = logging.getLogger(__name__)
class Metrics(object):
def __init__(self, name, phase, **metrics_args):
def __init__(self, name, mode, **metrics_args):
"""Not implemented"""
pass
......@@ -36,9 +50,9 @@ class Metrics(object):
class Youtube8mMetrics(Metrics):
def __init__(self, name, phase, **metrics_args):
def __init__(self, name, mode, **metrics_args):
self.name = name
self.phase = phase
self.mode = mode
self.metrics_args = metrics_args
self.num_classes = metrics_args['num_classes']
self.topk = metrics_args['topk']
......@@ -68,12 +82,12 @@ class Youtube8mMetrics(Metrics):
class Kinetics400Metrics(Metrics):
def __init__(self, name, phase, **metrics_args):
def __init__(self, name, mode, **metrics_args):
self.name = name
self.phase = phase
self.mode = mode
self.metrics_args = metrics_args
self.calculator = kinetics_metrics.MetricsCalculator(name,
phase.lower())
mode.lower())
def calculate_and_log_out(self, loss, pred, label, info=''):
if loss is not None:
......@@ -101,19 +115,19 @@ class Kinetics400Metrics(Metrics):
class NonlocalMetrics(Metrics):
def __init__(self, name, phase, **metrics_args):
def __init__(self, name, mode, **metrics_args):
self.name = name
self.phase = phase
self.mode = mode
self.metrics_args = metrics_args
if phase == 'test':
if mode == 'test':
self.calculator = nonlocal_test_metrics.MetricsCalculator(
name, phase.lower(), **metrics_args)
name, mode.lower(), **metrics_args)
else:
self.calculator = kinetics_metrics.MetricsCalculator(name,
phase.lower())
mode.lower())
def calculate_and_log_out(self, loss, pred, label, info=''):
if self.phase == 'test':
if self.mode == 'test':
pass
else:
if loss is not None:
......@@ -128,7 +142,7 @@ class NonlocalMetrics(Metrics):
self.calculator.accumulate(loss, pred, label)
def finalize_and_log_out(self, info=''):
if self.phase == 'test':
if self.mode == 'test':
self.calculator.finalize_metrics()
else:
self.calculator.finalize_metrics()
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import unicode_literals
from __future__ import print_function
......@@ -15,7 +29,7 @@ logger = logging.getLogger(__name__)
class MetricsCalculator():
def __init__(self, name, split, **metrics_args):
def __init__(self, name, mode, **metrics_args):
"""
dataset args:
num_test_clips
......@@ -25,7 +39,7 @@ class MetricsCalculator():
num_classes
"""
self.name = name
self.split = split # 'train', 'val', 'test'
self.mode = mode # 'train', 'val', 'test'
self.metrics_args = metrics_args
self.num_test_clips = metrics_args['num_test_clips']
......@@ -36,7 +50,7 @@ class MetricsCalculator():
self.reset()
def reset(self):
logger.info('Resetting {} metrics...'.format(self.split))
logger.info('Resetting {} metrics...'.format(self.mode))
self.aggr_acc1 = 0.0
self.aggr_acc5 = 0.0
self.aggr_loss = 0.0
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import paddle.fluid as fluid
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
import numpy as np
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
import numpy as np
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import paddle.fluid as fluid
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import numpy as np
import paddle
import paddle.fluid as fluid
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import time
import sys
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import time
import sys
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
......@@ -15,7 +15,6 @@
import os
import sys
import time
import shutil
import argparse
import logging
import numpy as np
......@@ -24,6 +23,7 @@ import paddle.fluid as fluid
from tools.train_utils import train_with_pyreader, train_without_pyreader
import models
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
......@@ -59,11 +59,6 @@ def parse_args():
)
parser.add_argument(
'--use-gpu', type=bool, default=True, help='default use gpu.')
parser.add_argument(
'--no-parallel',
action='store_true',
default=False,
help='whether to use parallel executor')
parser.add_argument(
'--no-use-pyreader',
action='store_true',
......@@ -104,9 +99,9 @@ def parse_args():
def train(train_model, valid_model, args):
startup = fluid.Program()
train_prog = fluid.Program()
train_startup = fluid.Program()
with fluid.program_guard(train_prog, train_startup):
with fluid.program_guard(train_prog, startup):
with fluid.unique_name.guard():
train_model.build_input(not args.no_use_pyreader)
train_model.build_model()
......@@ -130,8 +125,7 @@ def train(train_model, valid_model, args):
fluid.memory_optimize(train_prog)
valid_prog = fluid.Program()
valid_startup = fluid.Program()
with fluid.program_guard(valid_prog, valid_startup):
with fluid.program_guard(valid_prog, startup):
with fluid.unique_name.guard():
valid_model.build_input(not args.no_use_pyreader)
valid_model.build_model()
......@@ -144,8 +138,7 @@ def train(train_model, valid_model, args):
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(train_startup)
exe.run(valid_startup)
exe.run(startup)
if args.pretrain:
assert os.path.exists(args.pretrain), \
......@@ -154,18 +147,14 @@ def train(train_model, valid_model, args):
if pretrain:
train_model.load_pretrain_params(exe, pretrain, train_prog)
if args.no_parallel:
train_exe = exe
valid_exe = exe
else:
train_exe = fluid.ParallelExecutor(
use_cuda=args.use_gpu,
loss_name=train_loss.name,
main_program=train_prog)
valid_exe = fluid.ParallelExecutor(
use_cuda=args.use_gpu,
share_vars_from=train_exe,
main_program=valid_prog)
train_exe = fluid.ParallelExecutor(
use_cuda=args.use_gpu,
loss_name=train_loss.name,
main_program=train_prog)
valid_exe = fluid.ParallelExecutor(
use_cuda=args.use_gpu,
share_vars_from=train_exe,
main_program=valid_prog)
train_fetch_list = [train_loss.name] + [x.name for x in train_outputs
] + [train_feeds[-1].name]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册