diff --git a/demo/distillation/train.py b/demo/distillation/train.py index 7f389168440a59f0872d44ab6e62f262e373f6f0..e17678a70c73cfd1af746118ae7d0685317e96aa 100644 --- a/demo/distillation/train.py +++ b/demo/distillation/train.py @@ -140,41 +140,38 @@ def compress(args): # define teacher program teacher_program = fluid.Program() t_startup = fluid.Program() - teacher_scope = fluid.Scope() - with fluid.scope_guard(teacher_scope): - with fluid.program_guard(teacher_program, t_startup): - with fluid.unique_name.guard(): - image = fluid.layers.data( - name='image', shape=image_shape, dtype='float32') - predict = teacher_model.net(image, class_dim=class_dim) - - #print("="*50+"teacher_model_params"+"="*50) - #for v in teacher_program.list_vars(): - # print(v.name, v.shape) - - exe.run(t_startup) - assert args.teacher_pretrained_model and os.path.exists( - args.teacher_pretrained_model - ), "teacher_pretrained_model should be set when teacher_model is not None." - - def if_exist(var): - return os.path.exists( - os.path.join(args.teacher_pretrained_model, var.name) - ) and var.name != 'conv1_weights' and var.name != 'fc_0.w_0' and var.name != 'fc_0.b_0' - - fluid.io.load_vars( - exe, - args.teacher_pretrained_model, - main_program=teacher_program, - predicate=if_exist) + with fluid.program_guard(teacher_program, t_startup): + with fluid.unique_name.guard(): + image = fluid.layers.data( + name='image', shape=image_shape, dtype='float32') + predict = teacher_model.net(image, class_dim=class_dim) + + #print("="*50+"teacher_model_params"+"="*50) + #for v in teacher_program.list_vars(): + # print(v.name, v.shape) + + exe.run(t_startup) + assert args.teacher_pretrained_model and os.path.exists( + args.teacher_pretrained_model + ), "teacher_pretrained_model should be set when teacher_model is not None." + + def if_exist(var): + return os.path.exists( + os.path.join(args.teacher_pretrained_model, var.name) + ) and var.name != 'conv1_weights' and var.name != 'fc_0.w_0' and var.name != 'fc_0.b_0' + + fluid.io.load_vars( + exe, + args.teacher_pretrained_model, + main_program=teacher_program, + predicate=if_exist) data_name_map = {'image': 'image'} main = merge( teacher_program, student_program, data_name_map, - place, - teacher_scope=teacher_scope) + place) #print("="*50+"teacher_vars"+"="*50) #for v in teacher_program.list_vars():