未验证 提交 ff93284b 编写于 作者: S SunGaofeng 提交者: GitHub

modify on nextvald, tsm and nonlocal to support v1.5 (#2386)

上级 700310a7
......@@ -77,18 +77,16 @@ class FeatureReader(DataReader):
rgb = rgb[0:nframes, :]
audio = audio[0:nframes, :]
rgb = dequantize(
rgb, max_quantized_value=2., min_quantized_value=-2.)
audio = dequantize(
audio, max_quantized_value=2, min_quantized_value=-2)
if self.name == 'NEXTVLAD':
# add the effect of eigen values
eigen_file = self.eigen_file
eigen_val = np.sqrt(np.load(eigen_file)
[:1024, 0]).astype(np.float32)
eigen_val = eigen_val + 1e-4
rgb = (rgb - 4. / 512) * eigen_val
if self.name != 'NEXTVLAD':
rgb = dequantize(
rgb,
max_quantized_value=2.,
min_quantized_value=-2.)
audio = dequantize(
audio,
max_quantized_value=2,
min_quantized_value=-2)
if self.name == 'ATTENTIONCLUSTER':
sample_inds = generate_random_idx(rgb.shape[0],
self.seg_num)
......
......@@ -12,6 +12,8 @@
#See the License for the specific language governing permissions and
#limitations under the License.
import numpy as np
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
......@@ -71,7 +73,7 @@ class NEXTVLAD(ModelBase):
shapes=[[-1] + rgb_shape, [-1] + audio_shape,
[-1] + label_shape],
lod_levels=[1, 1, 0],
dtypes=['float32', 'float32', 'float32'],
dtypes=['uint8', 'uint8', 'float32'],
name='train_py_reader'
if self.is_training else 'test_py_reader',
use_double_buffer=True)
......@@ -81,12 +83,12 @@ class NEXTVLAD(ModelBase):
rgb = fluid.layers.data(
name='train_rgb' if self.is_training else 'test_rgb',
shape=rgb_shape,
dtype='float32',
dtype='uint8',
lod_level=1)
audio = fluid.layers.data(
name='train_audio' if self.is_training else 'test_audio',
shape=audio_shape,
dtype='float32',
dtype='uint8',
lod_level=1)
if self.mode == 'infer':
label = None
......@@ -115,6 +117,31 @@ class NEXTVLAD(ModelBase):
videomodel = nextvlad_model.NeXtVLADModel()
rgb = self.feature_input[0]
audio = self.feature_input[1]
# move data processing from data reader to fluid to process on gpu
rgb = fluid.layers.cast(rgb, 'float32')
audio = fluid.layers.cast(audio, 'float32')
bias = -2.
scale = 4. / 255
offset = 4. / 512
rgb = fluid.layers.scale(rgb, scale=scale, bias=bias)
audio = fluid.layers.scale(audio, scale=scale, bias=bias + offset)
eigen_value = np.sqrt(np.load(self.eigen_file)[:1024, 0])
eigen_value = (eigen_value + 1e-4).astype(np.float32)
eigen_param = fluid.layers.create_parameter(
shape=eigen_value.shape,
dtype='float32',
attr=fluid.ParamAttr(
name='eigen_param', trainable=False),
default_initializer=fluid.initializer.NumpyArrayInitializer(
eigen_value))
rgb = fluid.layers.elementwise_mul(rgb, eigen_param)
rgb.stop_gradient = True
audio.stop_gradient = True
out = videomodel.create_model(
rgb, audio, is_training=(self.mode == 'train'), **model_args)
self.logits = out['logits']
......@@ -146,10 +173,12 @@ class NEXTVLAD(ModelBase):
return self.feature_input if self.mode == 'infer' else self.feature_input + [
self.label_input
]
def weights_info(self):
return ('nextvlad_youtube8m',
'https://paddlemodels.bj.bcebos.com/video_classification/nextvlad_youtube8m.tar.gz')
return (
'nextvlad_youtube8m',
'https://paddlemodels.bj.bcebos.com/video_classification/nextvlad_youtube8m.tar.gz'
)
def get_learning_rate_decay_list(base_learning_rate, decay, max_iter,
......
......@@ -264,6 +264,13 @@ def res_stage_nonlocal(block_fn,
for idx in range(num_blocks):
block_prefix = '{}{}'.format(prefix, chr(idx + 97))
if cfg.MODEL.depth == 101:
if num_blocks == 23:
if idx == 0:
block_prefix = '{}{}'.format(prefix, chr(97))
else:
block_prefix = '{}{}{}'.format(prefix, 'b', idx)
block_stride = 2 if ((idx == 0) and (stride == 2)) else 1
blob_in = _generic_residual_block_3d(
blob_in,
......
# activate eager gc to reduce memory use
#export FLAGS_fraction_of_gpu_memory_to_use=1.0
#export FLAGS_fast_eager_deletion_mode=1
#export FLAGS_eager_delete_tensor_gb=0.0
#export FLAGS_limit_of_tmp_allocation=0
export CUDA_VISIBLE_DEVICES=0,1,2,3
python train.py --model_name="NEXTVLAD" --config=./configs/nextvlad.txt --epoch=6 \
--valid_interval=1 --log_interval=10
......@@ -4,6 +4,8 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export FLAGS_fast_eager_deletion_mode=1
export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_fraction_of_gpu_memory_to_use=0.98
#export FLAGS_limit_of_tmp_allocation=0
#export FLAGS_conv_workspace_size_limit=1024
python train.py --model_name="TSM" --config=./configs/tsm.txt --epoch=65 \
--valid_interval=1 --log_interval=10
......@@ -101,7 +101,7 @@ def train_without_pyreader(exe, train_prog, train_exe, train_reader, train_feede
info = '[TRAIN] Epoch {}, iter {} '.format(epoch, train_iter))
train_iter += 1
logger.info('[TRAIN] Epoch {} training finished, average time: {}'.
format(epoch, np.mean(epoch_periods)))
format(epoch, np.mean(epoch_periods[1:])))
save_model(exe, train_prog, save_dir, save_model_name,
"_epoch{}".format(epoch))
if test_exe and valid_interval > 0 and (epoch + 1
......@@ -144,7 +144,7 @@ def train_with_pyreader(exe, train_prog, train_exe, train_pyreader, \
except fluid.core.EOFException:
# eval here
logger.info('[TRAIN] Epoch {} training finished, average time: {}'.
format(epoch, np.mean(epoch_periods)))
format(epoch, np.mean(epoch_periods[1:])))
save_model(exe, train_prog, save_dir, save_model_name,
"_epoch{}".format(epoch))
if test_exe and valid_interval > 0 and (epoch + 1
......@@ -159,7 +159,9 @@ def train_with_pyreader(exe, train_prog, train_exe, train_pyreader, \
cards = os.environ.get('CUDA_VISIBLE_DEVICES')
gpu_num = len(cards.split(","))
print("kpis\ttrain_cost_card{}\t{}".format(gpu_num, train_loss))
print("kpis\ttrain_speed_card{}\t{}".format(gpu_num, np.mean(epoch_periods)))
print("kpis\ttrain_speed_card{}\t{}".format(gpu_num,
np.mean(epoch_periods)))
def save_model(exe, program, save_dir, model_name, postfix=None):
model_path = os.path.join(save_dir, model_name + postfix)
......
......@@ -140,9 +140,6 @@ def train(args):
optimizer.minimize(train_loss)
train_pyreader = train_model.pyreader()
if not args.no_memory_optimize:
fluid.memory_optimize(train_prog)
valid_prog = fluid.Program()
with fluid.program_guard(valid_prog, startup):
with fluid.unique_name.guard():
......@@ -176,10 +173,15 @@ def train(args):
if pretrain:
train_model.load_pretrain_params(exe, pretrain, train_prog, place)
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = True
#build_strategy.memory_optimize = True
train_exe = fluid.ParallelExecutor(
use_cuda=args.use_gpu,
loss_name=train_loss.name,
main_program=train_prog)
main_program=train_prog,
build_strategy=build_strategy)
valid_exe = fluid.ParallelExecutor(
use_cuda=args.use_gpu,
share_vars_from=train_exe,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册