未验证 提交 d72e1e5a 编写于 作者: Y yukavio 提交者: GitHub

fix prune demo and migrate static apis in prune related module (#573)

......@@ -124,8 +124,7 @@ def compress(args):
# model definition
model = models.__dict__[args.model]()
out = model.net(input=image, class_dim=class_dim)
cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
avg_cost = paddle.mean(x=cost)
avg_cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
val_program = paddle.static.default_main_program().clone(for_test=True)
......@@ -144,8 +143,8 @@ def compress(args):
_logger.info("Load pretrained model from {}".format(
args.pretrained_model))
paddle.fluid.io.load_vars(
exe, args.pretrained_model, predicate=if_exist)
paddle.static.load(paddle.static.default_main_program(),
args.pretrained_model, exe)
train_loader = paddle.io.DataLoader(
train_dataset,
......@@ -155,7 +154,7 @@ def compress(args):
batch_size=args.batch_size,
shuffle=True,
return_list=False,
use_shared_memory=False,
use_shared_memory=True,
num_workers=16)
valid_loader = paddle.io.DataLoader(
val_dataset,
......@@ -163,7 +162,7 @@ def compress(args):
feed_list=[image, label],
drop_last=False,
return_list=False,
use_shared_memory=False,
use_shared_memory=True,
batch_size=args.batch_size,
shuffle=False)
......@@ -245,8 +244,10 @@ def compress(args):
if args.save_inference:
infer_model_path = os.path.join(args.model_path, "infer_models",
str(i))
paddle.fluid.io.save_inference_model(infer_model_path, ["image"],
[out], exe, pruned_val_program)
paddle.static.save_inference_model(
infer_model_path, [image], [out],
exe,
program=pruned_val_program)
_logger.info("Saved inference model into [{}]".format(
infer_model_path))
......
......@@ -75,6 +75,7 @@ def compress(args):
feed_list=[image, label],
drop_last=False,
batch_size=args.batch_size,
use_shared_memory=True,
shuffle=False)
def test(program):
......
......@@ -36,7 +36,7 @@ class AutoPruner(object):
Args:
program(Program): The program to be pruned.
scope(Scope): The scope to be pruned.
place(fluid.Place): The device place of parameters.
place(paddle.CUDAPlace||paddle.CPUPlace): The device place of parameters.
params(list<str>): The names of parameters to be pruned.
init_ratios(list<float>|float): Init ratios used to pruned parameters in `params`.
List means ratios used for pruning each parameter in `params`.
......
......@@ -25,17 +25,16 @@ def save_model(exe, graph, dirname):
graph = GraphWrapper(graph) if isinstance(graph,
paddle.static.Program) else graph
paddle.fluid.io.save_persistables(
executor=exe,
dirname=dirname,
main_program=graph.program,
filename=None)
paddle.static.save(program=graph.program, model_path=dirname)
weights_file = dirname
_logger.info("Save model weights into {}".format(weights_file))
shapes = {}
for var in paddle.fluid.io.get_program_persistable_vars(graph.program):
shapes[var.name] = var.shape
for var in graph.program.list_vars():
if var.persistable:
shapes[var.name] = var.shape
SHAPES_FILE = os.path.join(dirname, _SHAPES_FILE)
if not os.path.exists(dirname):
os.makedirs(dirname)
with open(SHAPES_FILE, "w") as f:
json.dump(shapes, f)
_logger.info("Save shapes of weights into {}".format(SHAPES_FILE))
......@@ -65,11 +64,7 @@ def load_model(exe, graph, dirname):
_logger.info('{} is not loaded'.format(param_name))
_logger.info("Load shapes of weights from {}".format(SHAPES_FILE))
paddle.fluid.io.load_persistables(
executor=exe,
dirname=dirname,
main_program=graph.program,
filename=None)
paddle.static.load(program=graph.program, model_path=dirname, executor=exe)
graph.update_groups_of_conv()
graph.infer_shape()
_logger.info("Load weights from {}".format(dirname))
......@@ -65,10 +65,10 @@ class Pruner():
Args:
program(paddle.static.Program): The program to be pruned.
scope(fluid.Scope): The scope storing paramaters to be pruned.
scope(paddle.static.Scope): The scope storing paramaters to be pruned.
params(list<str>): A list of parameter names to be pruned.
ratios(list<float>): A list of ratios to be used to pruning parameters.
place(fluid.Place): The device place of filter parameters. Defalut: None.
place(paddle.CUDAPlace||paddle.CPUPlace): The device place of filter parameters. Defalut: None.
lazy(bool): True means setting the pruned elements to zero.
False means cutting down the pruned elements. Default: False.
only_graph(bool): True means only modifying the graph.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
反馈
建议
客服 返回
顶部