未验证 提交 c1631d04 编写于 作者: Y Yibing Liu 提交者: GitHub

Merge pull request #116 from PaddlePaddle/fix_param_convert

Support classifier's conversion
......@@ -135,7 +135,12 @@ def parse(init_checkpoint):
else:
print("ignored param: %s" % var_name)
else:
print("ignored param: %s" % var_name)
if var_name == 'output_weights':
fluid_param_name = 'cls_out_w'
elif var_name == 'output_bias':
fluid_param_name = 'cls_out_b'
else:
print("ignored param: %s" % var_name)
if fluid_param_name != '':
tf_fluid_param_name_map[var_name] = fluid_param_name
......@@ -172,6 +177,8 @@ def convert(args):
value = np.transpose(value)
if param == 'cls/squad/output_weights':
value = np.transpose(value)
if param == 'output_weights':
value = np.transpose(value)
fluid.global_scope().find_var(tf_fluid_param_name_map[
param]).get_tensor().set(value, place)
print(param, ' --> ', tf_fluid_param_name_map[param], ' ', value.shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册