未验证 提交 92874cc0 编写于 作者: Z zhouzj 提交者: GitHub

Clear fluid api and fix tests (#1641)

* remove fluid apis.

* fix hpo.

* fix asp.
上级 b248f202
...@@ -2,10 +2,7 @@ from __future__ import absolute_import ...@@ -2,10 +2,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import paddle import paddle
import paddle.fluid as fluid
from paddle.nn.initializer import KaimingUniform from paddle.nn.initializer import KaimingUniform
import os, sys, time, math
import numpy as np
from collections import namedtuple from collections import namedtuple
BLOCK_TYPE_MCRELU = 'BLOCK_TYPE_MCRELU' BLOCK_TYPE_MCRELU = 'BLOCK_TYPE_MCRELU'
...@@ -458,15 +455,24 @@ def loss(f_score, f_geo, l_score, l_geo, l_mask, class_num=1): ...@@ -458,15 +455,24 @@ def loss(f_score, f_geo, l_score, l_geo, l_mask, class_num=1):
abs_geo_diff = paddle.abs(geo_diff) abs_geo_diff = paddle.abs(geo_diff)
l_flag = l_score >= 1 l_flag = l_score >= 1
l_flag = paddle.cast(x=l_flag, dtype="float32") l_flag = paddle.cast(x=l_flag, dtype="float32")
l_flag = fluid.layers.expand(x=l_flag, expand_times=[1, channels, 1, 1]) l_flag = paddle.expand(
x=l_flag,
shape=[
l_flag.shape[0], l_flag.shape[1] * channels, l_flag.shape[2],
l_flag.shape[3]
])
smooth_l1_sign = abs_geo_diff < l_flag smooth_l1_sign = abs_geo_diff < l_flag
smooth_l1_sign = paddle.cast(x=smooth_l1_sign, dtype="float32") smooth_l1_sign = paddle.cast(x=smooth_l1_sign, dtype="float32")
in_loss = abs_geo_diff * abs_geo_diff * smooth_l1_sign + ( in_loss = abs_geo_diff * abs_geo_diff * smooth_l1_sign + (
abs_geo_diff - 0.5) * (1.0 - smooth_l1_sign) abs_geo_diff - 0.5) * (1.0 - smooth_l1_sign)
l_short_edge = fluid.layers.expand( l_short_edge = paddle.expand(
x=l_short_edge, expand_times=[1, channels, 1, 1]) x=l_short_edge,
shape=[
l_short_edge.shape[0], l_short_edge.shape[1] * channels,
l_short_edge.shape[2], l_short_edge.shape[3]
])
out_loss = l_short_edge * in_loss * l_flag out_loss = l_short_edge * in_loss * l_flag
out_loss = out_loss * l_flag out_loss = out_loss * l_flag
smooth_l1_loss = paddle.mean(out_loss) smooth_l1_loss = paddle.mean(out_loss)
......
...@@ -18,7 +18,7 @@ from paddleslim.analysis import flops ...@@ -18,7 +18,7 @@ from paddleslim.analysis import flops
from paddleslim.quant import quant_aware, quant_post, convert from paddleslim.quant import quant_aware, quant_post, convert
import models import models
from utility import add_arguments, print_arguments from utility import add_arguments, print_arguments
from paddle.fluid.layer_helper import LayerHelper from paddle.common_ops_import import LayerHelper
quantization_model_save_dir = './quantization_models/' quantization_model_save_dir = './quantization_models/'
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
...@@ -146,8 +146,8 @@ def compress(args): ...@@ -146,8 +146,8 @@ def compress(args):
raise ValueError("{} is not supported.".format(args.data)) raise ValueError("{} is not supported.".format(args.data))
image_shape = [int(m) for m in image_shape.split(",")] image_shape = [int(m) for m in image_shape.split(",")]
assert args.model in model_list, "{} is not in lists: {}".format(args.model, assert args.model in model_list, "{} is not in lists: {}".format(
model_list) args.model, model_list)
image = paddle.static.data( image = paddle.static.data(
name='image', shape=[None] + image_shape, dtype='float32') name='image', shape=[None] + image_shape, dtype='float32')
if args.use_pact: if args.use_pact:
......
...@@ -15,15 +15,17 @@ ...@@ -15,15 +15,17 @@
neural network for word2vec neural network for word2vec
""" """
from __future__ import print_function from __future__ import print_function
import math
import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.nn.functional as F
def skip_gram_word2vec(dict_size, embedding_size, is_sparse=False, neg_num=5): def skip_gram_word2vec(dict_size,
embedding_size,
batch_size,
is_sparse=False,
neg_num=5):
datas = [] words = []
input_word = paddle.static.data( input_word = paddle.static.data(
name="input_word", shape=[None, 1], dtype='int64') name="input_word", shape=[None, 1], dtype='int64')
true_word = paddle.static.data( true_word = paddle.static.data(
...@@ -31,14 +33,13 @@ def skip_gram_word2vec(dict_size, embedding_size, is_sparse=False, neg_num=5): ...@@ -31,14 +33,13 @@ def skip_gram_word2vec(dict_size, embedding_size, is_sparse=False, neg_num=5):
neg_word = paddle.static.data( neg_word = paddle.static.data(
name="neg_label", shape=[None, neg_num], dtype='int64') name="neg_label", shape=[None, neg_num], dtype='int64')
datas.append(input_word) words.append(input_word)
datas.append(true_word) words.append(true_word)
datas.append(neg_word) words.append(neg_word)
py_reader = fluid.layers.create_py_reader_by_data( py_reader = paddle.io.DataLoader.from_generator(
capacity=64, feed_list=datas, name='py_reader', use_double_buffer=True) capacity=64, feed_list=words, use_double_buffer=True, iterable=False)
words = fluid.layers.read_file(py_reader)
words[0] = paddle.reshape(words[0], [-1]) words[0] = paddle.reshape(words[0], [-1])
words[1] = paddle.reshape(words[1], [-1]) words[1] = paddle.reshape(words[1], [-1])
init_width = 0.5 / embedding_size init_width = 0.5 / embedding_size
...@@ -72,8 +73,7 @@ def skip_gram_word2vec(dict_size, embedding_size, is_sparse=False, neg_num=5): ...@@ -72,8 +73,7 @@ def skip_gram_word2vec(dict_size, embedding_size, is_sparse=False, neg_num=5):
input=neg_word_reshape, input=neg_word_reshape,
is_sparse=is_sparse, is_sparse=is_sparse,
size=[dict_size, embedding_size], size=[dict_size, embedding_size],
param_attr=paddle.ParamAttr( param_attr=paddle.ParamAttr(name='emb_w', learning_rate=1.0))
name='emb_w', learning_rate=1.0))
neg_emb_w_re = paddle.reshape( neg_emb_w_re = paddle.reshape(
neg_emb_w, shape=[-1, neg_num, embedding_size]) neg_emb_w, shape=[-1, neg_num, embedding_size])
...@@ -81,12 +81,11 @@ def skip_gram_word2vec(dict_size, embedding_size, is_sparse=False, neg_num=5): ...@@ -81,12 +81,11 @@ def skip_gram_word2vec(dict_size, embedding_size, is_sparse=False, neg_num=5):
input=neg_word_reshape, input=neg_word_reshape,
is_sparse=is_sparse, is_sparse=is_sparse,
size=[dict_size, 1], size=[dict_size, 1],
param_attr=paddle.ParamAttr( param_attr=paddle.ParamAttr(name='emb_b', learning_rate=1.0))
name='emb_b', learning_rate=1.0))
neg_emb_b_vec = paddle.reshape(neg_emb_b, shape=[-1, neg_num]) neg_emb_b_vec = paddle.reshape(neg_emb_b, shape=[-1, neg_num])
true_logits = paddle.add(paddle.mean( true_logits = paddle.add(
paddle.multiply(input_emb, true_emb_w), keepdim=True), paddle.mean(paddle.multiply(input_emb, true_emb_w), keepdim=True),
true_emb_b) true_emb_b)
input_emb_re = paddle.reshape(input_emb, shape=[-1, 1, embedding_size]) input_emb_re = paddle.reshape(input_emb, shape=[-1, 1, embedding_size])
neg_matmul = paddle.matmul(input_emb_re, neg_emb_w_re, transpose_y=True) neg_matmul = paddle.matmul(input_emb_re, neg_emb_w_re, transpose_y=True)
...@@ -94,18 +93,17 @@ def skip_gram_word2vec(dict_size, embedding_size, is_sparse=False, neg_num=5): ...@@ -94,18 +93,17 @@ def skip_gram_word2vec(dict_size, embedding_size, is_sparse=False, neg_num=5):
neg_logits = paddle.add(neg_matmul_re, neg_emb_b_vec) neg_logits = paddle.add(neg_matmul_re, neg_emb_b_vec)
#nce loss #nce loss
# TODO: replaced by paddle.tensor.creation.fill_constant_batch_size_like label_ones = paddle.full(
label_ones = fluid.layers.fill_constant_batch_size_like( shape=[batch_size, 1], fill_value=1.0, dtype='float32')
true_logits, shape=[-1, 1], value=1.0, dtype='float32') label_zeros = paddle.full(
label_zeros = fluid.layers.fill_constant_batch_size_like( shape=[batch_size, neg_num], fill_value=0.0, dtype='float32')
true_logits, shape=[-1, neg_num], value=0.0, dtype='float32')
true_xent = F.binary_cross_entropy_with_logits(
true_xent = paddle.nn.functional.binary_cross_entropy(true_logits, true_logits, label_ones, reduction='none')
label_ones) neg_xent = F.binary_cross_entropy_with_logits(
neg_xent = paddle.nn.functional.binary_cross_entropy(neg_logits, neg_logits, label_zeros, reduction='none')
label_zeros) cost = paddle.add(
cost = paddle.add(paddle.sum(true_xent, axis=1), paddle.sum(true_xent, axis=1), paddle.sum(neg_xent, axis=1))
paddle.sum(neg_xent, axis=1))
avg_cost = paddle.mean(cost) avg_cost = paddle.mean(cost)
return avg_cost, py_reader return avg_cost, py_reader
......
...@@ -121,7 +121,7 @@ def convert_python_to_tensor(weight, batch_size, sample_reader): ...@@ -121,7 +121,7 @@ def convert_python_to_tensor(weight, batch_size, sample_reader):
def train_loop(args, train_program, reader, py_reader, loss, trainer_id, weight, def train_loop(args, train_program, reader, py_reader, loss, trainer_id, weight,
lr): lr):
py_reader.decorate_tensor_provider( py_reader.set_batch_generator(
convert_python_to_tensor(weight, args.batch_size, reader.train())) convert_python_to_tensor(weight, args.batch_size, reader.train()))
place = paddle.CPUPlace() place = paddle.CPUPlace()
...@@ -213,6 +213,7 @@ def train(args): ...@@ -213,6 +213,7 @@ def train(args):
loss, py_reader = skip_gram_word2vec( loss, py_reader = skip_gram_word2vec(
word2vec_reader.dict_size, word2vec_reader.dict_size,
args.embedding_size, args.embedding_size,
args.batch_size,
is_sparse=args.is_sparse, is_sparse=args.is_sparse,
neg_num=args.nce_num) neg_num=args.nce_num)
......
...@@ -78,10 +78,8 @@ def _create_optimizer(train_config): ...@@ -78,10 +78,8 @@ def _create_optimizer(train_config):
### build optimizer ### build optimizer
optim_params = optimizer_builder['optimizer'] optim_params = optimizer_builder['optimizer']
optim_type = optim_params.pop('type') optim_type = optim_params.pop('type')
opt = getattr(optimizer, optim_type)(learning_rate=lr, opt = getattr(optimizer, optim_type)(
grad_clip=grad_clip, learning_rate=lr, grad_clip=grad_clip, weight_decay=reg, **optim_params)
weight_decay=reg,
**optim_params)
return opt, lr return opt, lr
...@@ -160,8 +158,8 @@ def _parse_distill_loss(distill_node_pair, ...@@ -160,8 +158,8 @@ def _parse_distill_loss(distill_node_pair,
for node, loss_clas, lam in zip(distill_node_pair, distill_loss, for node, loss_clas, lam in zip(distill_node_pair, distill_loss,
distill_lambda): distill_lambda):
tmp_loss = losses.get(loss_clas, 0.0) tmp_loss = losses.get(loss_clas, 0.0)
_logger.info("train config.distill_node_pair: {}".format( _logger.info(
node, loss_clas, lam)) "train config.distill_node_pair: {}".format(node, loss_clas, lam))
assert len(node) % 2 == 0, \ assert len(node) % 2 == 0, \
"distill_node_pair config wrong, the length needs to be an even number" "distill_node_pair config wrong, the length needs to be an even number"
for i in range(len(node) // 2): for i in range(len(node) // 2):
...@@ -529,9 +527,7 @@ def build_prune_program(executor, ...@@ -529,9 +527,7 @@ def build_prune_program(executor,
original_shapes = {} original_shapes = {}
for param in train_program_info.program.global_block( for param in train_program_info.program.global_block(
).all_parameters(): ).all_parameters():
if config[ if config['prune_params_name'] is not None and param.name in config['prune_params_name']:
'prune_params_name'] is not None and param.name in config[
'prune_params_name']:
params.append(param.name) params.append(param.name)
original_shapes[param.name] = param.shape original_shapes[param.name] = param.shape
...@@ -541,9 +537,8 @@ def build_prune_program(executor, ...@@ -541,9 +537,8 @@ def build_prune_program(executor,
train_program_info.program, train_program_info.program,
paddle.static.global_scope(), paddle.static.global_scope(),
params=params, params=params,
ratios=[config['pruned_ratio']] * len(params) ratios=[config['pruned_ratio']] * len(params) if isinstance(
if isinstance(config['pruned_ratio'], float) else config['pruned_ratio'], float) else config['pruned_ratio'],
config['pruned_ratio'],
place=place) place=place)
_logger.info( _logger.info(
"####################channel pruning##########################") "####################channel pruning##########################")
...@@ -577,8 +572,9 @@ def build_prune_program(executor, ...@@ -577,8 +572,9 @@ def build_prune_program(executor,
pruner.add_supported_layer(param.name) pruner.add_supported_layer(param.name)
if "teacher_" in param.name: if "teacher_" in param.name:
excluded_params_name.append(param.name) excluded_params_name.append(param.name)
pruner.set_excluded_layers(train_program_info.program, pruner.set_excluded_layers(
excluded_params_name) main_program=train_program_info.program,
param_names=excluded_params_name)
elif strategy.startswith('transformer_prune'): elif strategy.startswith('transformer_prune'):
from .transformer_pruner import TransformerPruner from .transformer_pruner import TransformerPruner
assert eval_dataloader is not None, "transformer_pruner must set eval_dataloader" assert eval_dataloader is not None, "transformer_pruner must set eval_dataloader"
......
...@@ -83,7 +83,7 @@ class QuantConfig(object): ...@@ -83,7 +83,7 @@ class QuantConfig(object):
"""QuantConfig init""" """QuantConfig init"""
self.executor = executor self.executor = executor
self.place = place self.place = place
self.float_infer_model_path = float_infer_model_path self.float_infer_model_path = float_infer_model_path.rstrip('/')
self.quantize_model_path = quantize_model_path self.quantize_model_path = quantize_model_path
self.algo = algo, self.algo = algo,
self.hist_percent = hist_percent, self.hist_percent = hist_percent,
......
...@@ -25,7 +25,9 @@ from ..dist import merge ...@@ -25,7 +25,9 @@ from ..dist import merge
from ..core.graph_wrapper import GraphWrapper from ..core.graph_wrapper import GraphWrapper
from ..common import get_logger from ..common import get_logger
__all__ = ['ReconstructionQuantization', ] __all__ = [
'ReconstructionQuantization',
]
_logger = get_logger( _logger = get_logger(
__name__, __name__,
...@@ -91,7 +93,8 @@ class ReconstructionQuantization(PostTrainingQuantization): ...@@ -91,7 +93,8 @@ class ReconstructionQuantization(PostTrainingQuantization):
batch_id = 0 batch_id = 0
with utils.tqdm( with utils.tqdm(
total=self._batch_nums, total=self._batch_nums,
bar_format='Preparation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}', bar_format=
'Preparation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80, ) as t: ncols=80, ) as t:
for data in self._data_loader(): for data in self._data_loader():
self._executor.run( self._executor.run(
...@@ -111,7 +114,8 @@ class ReconstructionQuantization(PostTrainingQuantization): ...@@ -111,7 +114,8 @@ class ReconstructionQuantization(PostTrainingQuantization):
batch_id = 0 batch_id = 0
with utils.tqdm( with utils.tqdm(
total=self._batch_nums, total=self._batch_nums,
bar_format='Sampling stage, Run batch:|{bar}| {n_fmt}/{total_fmt}', bar_format=
'Sampling stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80, ) as t: ncols=80, ) as t:
for data in self._data_loader(): for data in self._data_loader():
self._executor.run( self._executor.run(
...@@ -237,7 +241,7 @@ class ReconstructionQuanter(object): ...@@ -237,7 +241,7 @@ class ReconstructionQuanter(object):
return a batch every time. return a batch every time.
executor(paddle.static.Executor): The executor to load, run and save the executor(paddle.static.Executor): The executor to load, run and save the
quantized model. quantized model.
scope(fluid.Scope, optional): The scope of the program, use it to load scope(static.Scope, optional): The scope of the program, use it to load
and save variables. If scope=None, get scope by global_scope(). and save variables. If scope=None, get scope by global_scope().
place(CPUPlace()|CUDAPlace(N)): This parameter represents place(CPUPlace()|CUDAPlace(N)): This parameter represents
paddle run on which device. paddle run on which device.
...@@ -385,8 +389,8 @@ class ReconstructionQuanter(object): ...@@ -385,8 +389,8 @@ class ReconstructionQuanter(object):
with paddle.static.program_guard(tmp_program, startup_program): with paddle.static.program_guard(tmp_program, startup_program):
student_var = tmp_program.global_block().var(quant_op_out_name) student_var = tmp_program.global_block().var(quant_op_out_name)
teacher_var = tmp_program.global_block().var("teacher_" + teacher_var = tmp_program.global_block().var(
quant_op_out_name) "teacher_" + quant_op_out_name)
total_loss, recon_loss, round_loss = loss_function.get_loss( total_loss, recon_loss, round_loss = loss_function.get_loss(
student_var, student_var,
teacher_var, ) teacher_var, )
...@@ -471,7 +475,8 @@ class ReconstructionQuanter(object): ...@@ -471,7 +475,8 @@ class ReconstructionQuanter(object):
shape=weight.shape, shape=weight.shape,
dtype=weight.dtype, dtype=weight.dtype,
name=weight.name + ".alpha", name=weight.name + ".alpha",
default_initializer=paddle.nn.initializer.Assign(self._alpha, ), ) default_initializer=paddle.nn.initializer.Assign(
self._alpha, ), )
h_v = paddle.clip( h_v = paddle.clip(
paddle.nn.functional.sigmoid(v) * (ZETA - GAMMA) + GAMMA, paddle.nn.functional.sigmoid(v) * (ZETA - GAMMA) + GAMMA,
...@@ -483,13 +488,14 @@ class ReconstructionQuanter(object): ...@@ -483,13 +488,14 @@ class ReconstructionQuanter(object):
dtype=weight.dtype, dtype=weight.dtype,
shape=weight.shape, shape=weight.shape,
name=weight.name + '.scale', name=weight.name + '.scale',
default_initializer=paddle.nn.initializer.Assign(scale, )) default_initializer=paddle.nn.initializer.Assign(
scale, ))
else: else:
scale_var = scale scale_var = scale
quantized_weight = _quant(weight_copy, scale_var) quantized_weight = _quant(weight_copy, scale_var)
floor_weight = (paddle.floor(quantized_weight) - quantized_weight floor_weight = (paddle.floor(quantized_weight) -
).detach() + quantized_weight quantized_weight).detach() + quantized_weight
clip_weight = paddle.clip(floor_weight + h_v, -bnt, bnt) clip_weight = paddle.clip(floor_weight + h_v, -bnt, bnt)
w = _dequant(clip_weight, scale_var) w = _dequant(clip_weight, scale_var)
return w return w
...@@ -525,8 +531,9 @@ class ReconstructionQuanter(object): ...@@ -525,8 +531,9 @@ class ReconstructionQuanter(object):
def _insert_drop_quant_dequant(self): def _insert_drop_quant_dequant(self):
for op in self._graph.ops(): for op in self._graph.ops():
if op.type( if op.type() in [
) in ['conv2d', 'depthwise_conv2d', 'mul', 'matmul', 'matmul_v2']: 'conv2d', 'depthwise_conv2d', 'mul', 'matmul', 'matmul_v2'
]:
if op.type() in ['conv2d', 'depthwise_conv2d']: if op.type() in ['conv2d', 'depthwise_conv2d']:
if op.inputs("Filter")[0].name().startswith("teacher"): if op.inputs("Filter")[0].name().startswith("teacher"):
break break
...@@ -670,8 +677,8 @@ class ReconstructionQuanter(object): ...@@ -670,8 +677,8 @@ class ReconstructionQuanter(object):
'X': var._var, 'X': var._var,
'Y': op.input('Y')[0] + '.qdrop', 'Y': op.input('Y')[0] + '.qdrop',
} }
elif _type == 'scale' and op.input('X')[ elif _type == 'scale' and op.input(
0] == inputs.name + '.tmp': 'X')[0] == inputs.name + '.tmp':
_inputs = {'X': var._var} _inputs = {'X': var._var}
else: else:
_inputs = {'X': op.input('X')[0] + '.qdrop'} _inputs = {'X': op.input('X')[0] + '.qdrop'}
...@@ -687,11 +694,13 @@ class ReconstructionQuanter(object): ...@@ -687,11 +694,13 @@ class ReconstructionQuanter(object):
'conv2d', 'depthwise_conv2d', 'mul', 'matmul', 'matmul_v2' 'conv2d', 'depthwise_conv2d', 'mul', 'matmul', 'matmul_v2'
]: ]:
continue continue
if op.type() in ['conv2d', 'depthwise_conv2d'] and op.inputs( if op.type() in [
'Filter')[0].name().startswith('teacher'): 'conv2d', 'depthwise_conv2d'
] and op.inputs('Filter')[0].name().startswith('teacher'):
continue continue
if op.type() in ['mul', 'matmul', 'matmul_v2'] and op.inputs('Y')[ if op.type() in [
0].name().startswith('teacher'): 'mul', 'matmul', 'matmul_v2'
] and op.inputs('Y')[0].name().startswith('teacher'):
continue continue
if func == '_soft_rounding': if func == '_soft_rounding':
op._op._rename_input(inputs.name, out.name + '.rounding') op._op._rename_input(inputs.name, out.name + '.rounding')
...@@ -964,8 +973,8 @@ class RegionBuilder(object): ...@@ -964,8 +973,8 @@ class RegionBuilder(object):
else: else:
future_ep = _find_multi_input_ep(ep) future_ep = _find_multi_input_ep(ep)
if future_ep is None or self._depth[future_ep.idx()] - self._depth[ if future_ep is None or self._depth[future_ep.idx(
sp.idx()] >= limit: )] - self._depth[sp.idx()] >= limit:
return self._create_region(sp, ep) return self._create_region(sp, ep)
ep = future_ep ep = future_ep
......
...@@ -147,10 +147,8 @@ class ModelCase6(paddle.nn.Layer): ...@@ -147,10 +147,8 @@ class ModelCase6(paddle.nn.Layer):
x = paddle.unsqueeze(x=x, axis=[2]) x = paddle.unsqueeze(x=x, axis=[2])
x = self.relu1(x) x = self.relu1(x)
y = paddle.full(shape=x.shape, fill_value=1) y = paddle.full(shape=x.shape, fill_value=1)
# x = paddle.stack([x, y], axis=3)
x = paddle.slice(x, axes=[0], starts=[0], ends=[1]) x = paddle.slice(x, axes=[0], starts=[0], ends=[1])
x = paddle.exp(x) x = paddle.exp(x)
# y += paddle.fluid.layers.uniform_random(y.shape)
y = paddle.expand(y, shape=[1, 768, 768, 2]) y = paddle.expand(y, shape=[1, 768, 768, 2])
x = paddle.expand(x, shape=[1, 768, 768, 2]) x = paddle.expand(x, shape=[1, 768, 768, 2])
out = paddle.concat([x, y]) out = paddle.concat([x, y])
...@@ -161,8 +159,8 @@ class ModelCase6(paddle.nn.Layer): ...@@ -161,8 +159,8 @@ class ModelCase6(paddle.nn.Layer):
max_idx = paddle.argmax( max_idx = paddle.argmax(
out1.reshape((outshape[0], outshape[1], outshape[2] * outshape[3])), out1.reshape((outshape[0], outshape[1], outshape[2] * outshape[3])),
axis=-1) axis=-1)
out2 = out2.reshape( out2 = out2.reshape((outshape[0], outshape[1],
(outshape[0], outshape[1], outshape[2] * outshape[3])) outshape[2] * outshape[3]))
res, _ = self.lstm(out2) res, _ = self.lstm(out2)
return res, max_idx return res, max_idx
...@@ -238,8 +236,8 @@ class TestCase2(unittest.TestCase): ...@@ -238,8 +236,8 @@ class TestCase2(unittest.TestCase):
model_name = '.'.join(model_filename.split('.')[:-1]) model_name = '.'.join(model_filename.split('.')[:-1])
model_path_prefix = os.path.join(model_dir, model_name) model_path_prefix = os.path.join(model_dir, model_name)
[inference_program, feed_target_names, fetch_targets] = ( [inference_program, feed_target_names,
paddle.static.load_inference_model( fetch_targets] = (paddle.static.load_inference_model(
path_prefix=model_path_prefix, executor=exe)) path_prefix=model_path_prefix, executor=exe))
if type(input_shapes) in [list, tuple]: if type(input_shapes) in [list, tuple]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册