提交 81506640 编写于 作者: B baiyfbupt

fix demo details

上级 fff7c57b
...@@ -140,43 +140,40 @@ def compress(args): ...@@ -140,43 +140,40 @@ def compress(args):
# define teacher program # define teacher program
teacher_program = fluid.Program() teacher_program = fluid.Program()
t_startup = fluid.Program() t_startup = fluid.Program()
teacher_scope = fluid.Scope() with fluid.program_guard(teacher_program, t_startup):
with fluid.scope_guard(teacher_scope): with fluid.unique_name.guard():
with fluid.program_guard(teacher_program, t_startup): image = fluid.layers.data(
with fluid.unique_name.guard(): name='image', shape=image_shape, dtype='float32')
image = fluid.layers.data( predict = teacher_model.net(image, class_dim=class_dim)
name='image', shape=image_shape, dtype='float32')
predict = teacher_model.net(image, class_dim=class_dim) #print("="*50+"teacher_model_params"+"="*50)
#for v in teacher_program.list_vars():
#print("="*50+"teacher_model_params"+"="*50) # print(v.name, v.shape)
#for v in teacher_program.list_vars():
# print(v.name, v.shape) exe.run(t_startup)
_download('http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar', '.')
exe.run(t_startup) _decompress('./ResNet50_pretrained.tar')
_download('http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar', '.') assert args.teacher_pretrained_model and os.path.exists(
_decompress('./ResNet50_pretrained.tar') args.teacher_pretrained_model
assert args.teacher_pretrained_model and os.path.exists( ), "teacher_pretrained_model should be set when teacher_model is not None."
args.teacher_pretrained_model
), "teacher_pretrained_model should be set when teacher_model is not None." def if_exist(var):
return os.path.exists(
def if_exist(var): os.path.join(args.teacher_pretrained_model, var.name)
return os.path.exists( ) and var.name != 'fc_0.w_0' and var.name != 'fc_0.b_0'
os.path.join(args.teacher_pretrained_model, var.name)
) and var.name != 'fc_0.w_0' and var.name != 'fc_0.b_0' fluid.io.load_vars(
exe,
fluid.io.load_vars( args.teacher_pretrained_model,
exe, main_program=teacher_program,
args.teacher_pretrained_model, predicate=if_exist)
main_program=teacher_program,
predicate=if_exist)
data_name_map = {'image': 'image'} data_name_map = {'image': 'image'}
main = merge( main = merge(
teacher_program, teacher_program,
student_program, student_program,
data_name_map, data_name_map,
place, place)
teacher_scope=teacher_scope)
with fluid.program_guard(main, s_startup): with fluid.program_guard(main, s_startup):
l2_loss_v = l2_loss("teacher_fc_0.tmp_0", "fc_0.tmp_0", main) l2_loss_v = l2_loss("teacher_fc_0.tmp_0", "fc_0.tmp_0", main)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册