提交 d6636846 编写于 作者: C Channingss

optimize code

上级 ae5a1811
......@@ -521,7 +521,7 @@ class PaddleGraph(object):
gen_codes(
comment_list,
indent=1))
use_structured_name = False if self.source_type in ["tf", "onnx"] else True
use_structured_name = False if self.source_type in ["tf"] else True
self.run_func.extend(
gen_codes(["paddle.disable_static()",
"params = paddle.load('{}/model.pdparams')".format(osp.abspath(code_dir)),
......@@ -673,7 +673,7 @@ class PaddleGraph(object):
paddle.disable_static()
restore = paddle.load(osp.join(save_dir, "model.pdparams"))
model = getattr(x2paddle_code, self.name)()
if self.source_type in ["tf", "onnx"]:
if self.source_type in ["tf"]:
model.set_dict(restore, use_structured_name=False)
else:
model.set_dict(restore)
......
......@@ -1898,7 +1898,7 @@ class OpSet9():
reform_permutation = [(0, 1), (2, 4), (1, 2)]
input_weight_np, hidden_weight_np, input_bias_np, hidden_bias_np = transform_weight_with_bias(
weights = transform_weight_with_bias(
[input_weight_np, hidden_weight_np, input_bias_np, hidden_bias_np],
hidden_size, reform_permutation)
......@@ -1907,30 +1907,27 @@ class OpSet9():
yh_out = node.output(1)
yc_out = node.output(2)
direction = node.get_attr('direction', 'forward')
def generate_paddle_param_names(op_name, suffix=''):
param_names = []
param_names.extend(['{}.weight_ih_l0{}', '{}.weight_hh_l0{}'])
if have_bias != False: param_names.append('{}.bias_ih_l0{}')
if have_bias != False: param_names.append('{}.bias_hh_l0{}')
param_names = [x.format(op_name, suffix) for x in param_names]
return param_names
def assign_params(op_name, weights, weight_idx=0, suffix=''):
param_names = generate_paddle_param_names(op_name, suffix)
print(param_names)
for param_name, weight in zip(param_names, weights):
self.weights[param_name] = weight[weight_idx]
if direction == 'backward':
raise Exception("LSTM support 'forward' or 'bidirectional', except '{}'.".format(direction))
elif direction == 'forward':
self.weights[input_weight.name] = input_weight_np.squeeze(0)
self.weights[hidden_weight.name] = hidden_weight_np.squeeze(0)
self.weights[input_bias_name] = input_bias_np.squeeze(0)
self.weights[hidden_bias_name] = hidden_bias_np.squeeze(0)
else:
param_names = []
for direct in range(2):
suffix = '_reverse' if direct == 1 else ''
param_names.extend(['{}.weight_ih_l0{}', '{}.weight_hh_l0{}'])
if have_bias != False: param_names.append('{}.bias_ih_l0{}')
if have_bias != False: param_names.append('{}.bias_hh_l0{}')
param_names = [x.format(op_name, suffix) for x in param_names]
self.weights[param_names[0]] = input_weight_np[0]
self.weights[param_names[4]] = input_weight_np[1]
self.weights[param_names[1]] = hidden_weight_np[0]
self.weights[param_names[5]] = hidden_weight_np[1]
self.weights[param_names[2]] = input_bias_np[0]
self.weights[param_names[6]] = input_bias_np[1]
self.weights[param_names[3]] = hidden_bias_np[0]
self.weights[param_names[7]] = hidden_bias_np[1]
assign_params(op_name, weights)
if direction == 'bidirectional':
assign_params(op_name, weights, 1, '_reverse')
self.paddle_graph.add_layer(
'paddle.nn.LSTM',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册