diff --git a/demo/distillation/distillation_demo.py b/demo/distillation/distillation_demo.py index d0dd181409adb6e58c8884621345477669ee0dc0..79142026359b525e3cf3754d9dd52808081b3a57 100644 --- a/demo/distillation/distillation_demo.py +++ b/demo/distillation/distillation_demo.py @@ -140,43 +140,40 @@ def compress(args): # define teacher program teacher_program = fluid.Program() t_startup = fluid.Program() - teacher_scope = fluid.Scope() - with fluid.scope_guard(teacher_scope): - with fluid.program_guard(teacher_program, t_startup): - with fluid.unique_name.guard(): - image = fluid.layers.data( - name='image', shape=image_shape, dtype='float32') - predict = teacher_model.net(image, class_dim=class_dim) - - #print("="*50+"teacher_model_params"+"="*50) - #for v in teacher_program.list_vars(): - # print(v.name, v.shape) - - exe.run(t_startup) - _download('http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar', '.') - _decompress('./ResNet50_pretrained.tar') - assert args.teacher_pretrained_model and os.path.exists( - args.teacher_pretrained_model - ), "teacher_pretrained_model should be set when teacher_model is not None." - - def if_exist(var): - return os.path.exists( - os.path.join(args.teacher_pretrained_model, var.name) - ) and var.name != 'fc_0.w_0' and var.name != 'fc_0.b_0' - - fluid.io.load_vars( - exe, - args.teacher_pretrained_model, - main_program=teacher_program, - predicate=if_exist) + with fluid.program_guard(teacher_program, t_startup): + with fluid.unique_name.guard(): + image = fluid.layers.data( + name='image', shape=image_shape, dtype='float32') + predict = teacher_model.net(image, class_dim=class_dim) + + #print("="*50+"teacher_model_params"+"="*50) + #for v in teacher_program.list_vars(): + # print(v.name, v.shape) + + exe.run(t_startup) + _download('http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar', '.') + _decompress('./ResNet50_pretrained.tar') + assert args.teacher_pretrained_model and os.path.exists( + args.teacher_pretrained_model + ), "teacher_pretrained_model should be set when teacher_model is not None." + + def if_exist(var): + return os.path.exists( + os.path.join(args.teacher_pretrained_model, var.name) + ) and var.name != 'fc_0.w_0' and var.name != 'fc_0.b_0' + + fluid.io.load_vars( + exe, + args.teacher_pretrained_model, + main_program=teacher_program, + predicate=if_exist) data_name_map = {'image': 'image'} main = merge( teacher_program, student_program, data_name_map, - place, - teacher_scope=teacher_scope) + place) with fluid.program_guard(main, s_startup): l2_loss_v = l2_loss("teacher_fc_0.tmp_0", "fc_0.tmp_0", main)