未验证 提交 5043708c 编写于 作者: J JYChen 提交者: GitHub

fix var order in paddle.jit.save (#56510)

* fix var order in paddle.jit.save

* support complex/bfloat16 for scale in static-mode
上级 f3fa2ed3
...@@ -1248,10 +1248,10 @@ def save(layer, path, input_spec=None, **configs): ...@@ -1248,10 +1248,10 @@ def save(layer, path, input_spec=None, **configs):
file_prefix = file_prefix + '.' + attr_func file_prefix = file_prefix + '.' + attr_func
file_prefix = os.path.join(model_path, file_prefix) file_prefix = os.path.join(model_path, file_prefix)
with scope_guard(scope): with scope_guard(scope):
input_vars = [] input_vars = [
for var in concrete_program.main_program.clone().list_vars(): concrete_program.main_program.global_block().var(name)
if var.name in input_var_names: for name in input_var_names
input_vars.append(var) ]
save_inference_model( save_inference_model(
path_prefix=file_prefix, path_prefix=file_prefix,
feed_vars=input_vars, feed_vars=input_vars,
......
...@@ -270,6 +270,7 @@ def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None): ...@@ -270,6 +270,7 @@ def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None):
"x", "x",
[ [
'float16', 'float16',
'bfloat16',
'uint16', 'uint16',
'float32', 'float32',
'float64', 'float64',
...@@ -278,6 +279,8 @@ def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None): ...@@ -278,6 +279,8 @@ def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None):
'int32', 'int32',
'int64', 'int64',
'uint8', 'uint8',
'complex64',
'complex128',
], ],
"scale", "scale",
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册