From 82db4374caec188a30f5da3181bfba731c5a7fe4 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 30 Apr 2019 14:24:17 +0000 Subject: [PATCH] Support classifier's conversion --- BERT/convert_params.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/BERT/convert_params.py b/BERT/convert_params.py index 5cb95db..e2930de 100644 --- a/BERT/convert_params.py +++ b/BERT/convert_params.py @@ -135,7 +135,12 @@ def parse(init_checkpoint): else: print("ignored param: %s" % var_name) else: - print("ignored param: %s" % var_name) + if var_name == 'output_weights': + fluid_param_name = 'cls_out_w' + elif var_name == 'output_bias': + fluid_param_name = 'cls_out_b' + else: + print("ignored param: %s" % var_name) if fluid_param_name != '': tf_fluid_param_name_map[var_name] = fluid_param_name @@ -172,6 +177,8 @@ def convert(args): value = np.transpose(value) if param == 'cls/squad/output_weights': value = np.transpose(value) + if param == 'output_weights': + value = np.transpose(value) fluid.global_scope().find_var(tf_fluid_param_name_map[ param]).get_tensor().set(value, place) print(param, ' --> ', tf_fluid_param_name_map[param], ' ', value.shape) -- GitLab