From af232bf3a2b0bb0d24a94f93125f690f43a2a476 Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Sun, 29 Sep 2019 15:07:23 +0800 Subject: [PATCH] Update ppdet/utils/checkpoint.py to skip the var shape check when the var is not initialized. (#3452) --- ppdet/utils/checkpoint.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/ppdet/utils/checkpoint.py b/ppdet/utils/checkpoint.py index 947a0803f..b1dfa864e 100644 --- a/ppdet/utils/checkpoint.py +++ b/ppdet/utils/checkpoint.py @@ -47,7 +47,6 @@ def is_url(path): return path.startswith('http://') or path.startswith('https://') - def load_params(exe, prog, path, ignore_params=[]): """ Load model from the given path. @@ -177,6 +176,7 @@ def load_and_fusebn(exe, prog, path): inner_prog = fluid.Program() inner_start_prog = fluid.Program() + inner_block = inner_prog.global_block() with fluid.program_guard(inner_prog, inner_start_prog): for block in prog.blocks: ops = list(block.ops) @@ -199,10 +199,20 @@ def load_and_fusebn(exe, prog, path): break bias = block.var(bias_name) - mean_vb = fluid.layers.create_parameter( - bias.shape, bias.dtype, mean_name) - variance_vb = fluid.layers.create_parameter( - bias.shape, bias.dtype, variance_name) + + mean_vb = inner_block.create_var( + name=mean_name, + type=bias.type, + shape=bias.shape, + dtype=bias.dtype, + persistable=True) + variance_vb = inner_block.create_var( + name=variance_name, + type=bias.type, + shape=bias.shape, + dtype=bias.dtype, + persistable=True) + mean_variances.add(mean_vb) mean_variances.add(variance_vb) -- GitLab