未验证 提交 5043708c 编写于 作者: J JYChen 提交者: GitHub

fix var order in paddle.jit.save (#56510)

* fix var order in paddle.jit.save

* support complex/bfloat16 for scale in static-mode
上级 f3fa2ed3
......@@ -1248,10 +1248,10 @@ def save(layer, path, input_spec=None, **configs):
file_prefix = file_prefix + '.' + attr_func
file_prefix = os.path.join(model_path, file_prefix)
with scope_guard(scope):
input_vars = []
for var in concrete_program.main_program.clone().list_vars():
if var.name in input_var_names:
input_vars.append(var)
input_vars = [
concrete_program.main_program.global_block().var(name)
for name in input_var_names
]
save_inference_model(
path_prefix=file_prefix,
feed_vars=input_vars,
......
......@@ -270,6 +270,7 @@ def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None):
"x",
[
'float16',
'bfloat16',
'uint16',
'float32',
'float64',
......@@ -278,6 +279,8 @@ def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None):
'int32',
'int64',
'uint8',
'complex64',
'complex128',
],
"scale",
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册