提交 d6636846 编写于 作者: C Channingss

optimize code

上级 ae5a1811
...@@ -521,7 +521,7 @@ class PaddleGraph(object): ...@@ -521,7 +521,7 @@ class PaddleGraph(object):
gen_codes( gen_codes(
comment_list, comment_list,
indent=1)) indent=1))
use_structured_name = False if self.source_type in ["tf", "onnx"] else True use_structured_name = False if self.source_type in ["tf"] else True
self.run_func.extend( self.run_func.extend(
gen_codes(["paddle.disable_static()", gen_codes(["paddle.disable_static()",
"params = paddle.load('{}/model.pdparams')".format(osp.abspath(code_dir)), "params = paddle.load('{}/model.pdparams')".format(osp.abspath(code_dir)),
...@@ -673,7 +673,7 @@ class PaddleGraph(object): ...@@ -673,7 +673,7 @@ class PaddleGraph(object):
paddle.disable_static() paddle.disable_static()
restore = paddle.load(osp.join(save_dir, "model.pdparams")) restore = paddle.load(osp.join(save_dir, "model.pdparams"))
model = getattr(x2paddle_code, self.name)() model = getattr(x2paddle_code, self.name)()
if self.source_type in ["tf", "onnx"]: if self.source_type in ["tf"]:
model.set_dict(restore, use_structured_name=False) model.set_dict(restore, use_structured_name=False)
else: else:
model.set_dict(restore) model.set_dict(restore)
......
...@@ -1898,7 +1898,7 @@ class OpSet9(): ...@@ -1898,7 +1898,7 @@ class OpSet9():
reform_permutation = [(0, 1), (2, 4), (1, 2)] reform_permutation = [(0, 1), (2, 4), (1, 2)]
input_weight_np, hidden_weight_np, input_bias_np, hidden_bias_np = transform_weight_with_bias( weights = transform_weight_with_bias(
[input_weight_np, hidden_weight_np, input_bias_np, hidden_bias_np], [input_weight_np, hidden_weight_np, input_bias_np, hidden_bias_np],
hidden_size, reform_permutation) hidden_size, reform_permutation)
...@@ -1907,30 +1907,27 @@ class OpSet9(): ...@@ -1907,30 +1907,27 @@ class OpSet9():
yh_out = node.output(1) yh_out = node.output(1)
yc_out = node.output(2) yc_out = node.output(2)
direction = node.get_attr('direction', 'forward') direction = node.get_attr('direction', 'forward')
if direction == 'backward':
raise Exception("LSTM support 'forward' or 'bidirectional', except '{}'.".format(direction)) def generate_paddle_param_names(op_name, suffix=''):
elif direction == 'forward':
self.weights[input_weight.name] = input_weight_np.squeeze(0)
self.weights[hidden_weight.name] = hidden_weight_np.squeeze(0)
self.weights[input_bias_name] = input_bias_np.squeeze(0)
self.weights[hidden_bias_name] = hidden_bias_np.squeeze(0)
else:
param_names = [] param_names = []
for direct in range(2):
suffix = '_reverse' if direct == 1 else ''
param_names.extend(['{}.weight_ih_l0{}', '{}.weight_hh_l0{}']) param_names.extend(['{}.weight_ih_l0{}', '{}.weight_hh_l0{}'])
if have_bias != False: param_names.append('{}.bias_ih_l0{}') if have_bias != False: param_names.append('{}.bias_ih_l0{}')
if have_bias != False: param_names.append('{}.bias_hh_l0{}') if have_bias != False: param_names.append('{}.bias_hh_l0{}')
param_names = [x.format(op_name, suffix) for x in param_names] param_names = [x.format(op_name, suffix) for x in param_names]
return param_names
self.weights[param_names[0]] = input_weight_np[0] def assign_params(op_name, weights, weight_idx=0, suffix=''):
self.weights[param_names[4]] = input_weight_np[1] param_names = generate_paddle_param_names(op_name, suffix)
self.weights[param_names[1]] = hidden_weight_np[0] print(param_names)
self.weights[param_names[5]] = hidden_weight_np[1] for param_name, weight in zip(param_names, weights):
self.weights[param_names[2]] = input_bias_np[0] self.weights[param_name] = weight[weight_idx]
self.weights[param_names[6]] = input_bias_np[1]
self.weights[param_names[3]] = hidden_bias_np[0] if direction == 'backward':
self.weights[param_names[7]] = hidden_bias_np[1] raise Exception("LSTM support 'forward' or 'bidirectional', except '{}'.".format(direction))
else:
assign_params(op_name, weights)
if direction == 'bidirectional':
assign_params(op_name, weights, 1, '_reverse')
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.nn.LSTM', 'paddle.nn.LSTM',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册