提交 82db4374 编写于 作者: Y Yibing Liu

Support classifier's conversion

上级 b9dae026
...@@ -135,7 +135,12 @@ def parse(init_checkpoint): ...@@ -135,7 +135,12 @@ def parse(init_checkpoint):
else: else:
print("ignored param: %s" % var_name) print("ignored param: %s" % var_name)
else: else:
print("ignored param: %s" % var_name) if var_name == 'output_weights':
fluid_param_name = 'cls_out_w'
elif var_name == 'output_bias':
fluid_param_name = 'cls_out_b'
else:
print("ignored param: %s" % var_name)
if fluid_param_name != '': if fluid_param_name != '':
tf_fluid_param_name_map[var_name] = fluid_param_name tf_fluid_param_name_map[var_name] = fluid_param_name
...@@ -172,6 +177,8 @@ def convert(args): ...@@ -172,6 +177,8 @@ def convert(args):
value = np.transpose(value) value = np.transpose(value)
if param == 'cls/squad/output_weights': if param == 'cls/squad/output_weights':
value = np.transpose(value) value = np.transpose(value)
if param == 'output_weights':
value = np.transpose(value)
fluid.global_scope().find_var(tf_fluid_param_name_map[ fluid.global_scope().find_var(tf_fluid_param_name_map[
param]).get_tensor().set(value, place) param]).get_tensor().set(value, place)
print(param, ' --> ', tf_fluid_param_name_map[param], ' ', value.shape) print(param, ' --> ', tf_fluid_param_name_map[param], ' ', value.shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册