提交 82db4374 编写于 作者: Y Yibing Liu

Support classifier's conversion

上级 b9dae026
...@@ -134,6 +134,11 @@ def parse(init_checkpoint): ...@@ -134,6 +134,11 @@ def parse(init_checkpoint):
fluid_param_name = 'cls_squad_out_b' fluid_param_name = 'cls_squad_out_b'
else: else:
print("ignored param: %s" % var_name) print("ignored param: %s" % var_name)
else:
if var_name == 'output_weights':
fluid_param_name = 'cls_out_w'
elif var_name == 'output_bias':
fluid_param_name = 'cls_out_b'
else: else:
print("ignored param: %s" % var_name) print("ignored param: %s" % var_name)
...@@ -172,6 +177,8 @@ def convert(args): ...@@ -172,6 +177,8 @@ def convert(args):
value = np.transpose(value) value = np.transpose(value)
if param == 'cls/squad/output_weights': if param == 'cls/squad/output_weights':
value = np.transpose(value) value = np.transpose(value)
if param == 'output_weights':
value = np.transpose(value)
fluid.global_scope().find_var(tf_fluid_param_name_map[ fluid.global_scope().find_var(tf_fluid_param_name_map[
param]).get_tensor().set(value, place) param]).get_tensor().set(value, place)
print(param, ' --> ', tf_fluid_param_name_map[param], ' ', value.shape) print(param, ' --> ', tf_fluid_param_name_map[param], ' ', value.shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册