diff --git a/BERT/convert_params.py b/BERT/convert_params.py index 5cb95dbddfad200e402d94ccd835b5efb9dd2248..e2930defa01c6ab9be6e3f24e5956ce2dbecf4da 100644 --- a/BERT/convert_params.py +++ b/BERT/convert_params.py @@ -135,7 +135,12 @@ def parse(init_checkpoint): else: print("ignored param: %s" % var_name) else: - print("ignored param: %s" % var_name) + if var_name == 'output_weights': + fluid_param_name = 'cls_out_w' + elif var_name == 'output_bias': + fluid_param_name = 'cls_out_b' + else: + print("ignored param: %s" % var_name) if fluid_param_name != '': tf_fluid_param_name_map[var_name] = fluid_param_name @@ -172,6 +177,8 @@ def convert(args): value = np.transpose(value) if param == 'cls/squad/output_weights': value = np.transpose(value) + if param == 'output_weights': + value = np.transpose(value) fluid.global_scope().find_var(tf_fluid_param_name_map[ param]).get_tensor().set(value, place) print(param, ' --> ', tf_fluid_param_name_map[param], ' ', value.shape)