diff --git a/ppdet/utils/checkpoint.py b/ppdet/utils/checkpoint.py index 947a0803fde16082561914d8a61980f8f174ab3d..b1dfa864efdfad056f758fbbde56192ea666c916 100644 --- a/ppdet/utils/checkpoint.py +++ b/ppdet/utils/checkpoint.py @@ -47,7 +47,6 @@ def is_url(path): return path.startswith('http://') or path.startswith('https://') - def load_params(exe, prog, path, ignore_params=[]): """ Load model from the given path. @@ -177,6 +176,7 @@ def load_and_fusebn(exe, prog, path): inner_prog = fluid.Program() inner_start_prog = fluid.Program() + inner_block = inner_prog.global_block() with fluid.program_guard(inner_prog, inner_start_prog): for block in prog.blocks: ops = list(block.ops) @@ -199,10 +199,20 @@ def load_and_fusebn(exe, prog, path): break bias = block.var(bias_name) - mean_vb = fluid.layers.create_parameter( - bias.shape, bias.dtype, mean_name) - variance_vb = fluid.layers.create_parameter( - bias.shape, bias.dtype, variance_name) + + mean_vb = inner_block.create_var( + name=mean_name, + type=bias.type, + shape=bias.shape, + dtype=bias.dtype, + persistable=True) + variance_vb = inner_block.create_var( + name=variance_name, + type=bias.type, + shape=bias.shape, + dtype=bias.dtype, + persistable=True) + mean_variances.add(mean_vb) mean_variances.add(variance_vb)