未验证 提交 c22bdc7e 编写于 作者: G gaotingquan

remove fluid

上级 bd2869de
...@@ -221,7 +221,7 @@ class Engine(object): ...@@ -221,7 +221,7 @@ class Engine(object):
AMP_RELATED_FLAGS_SETTING.update({ AMP_RELATED_FLAGS_SETTING.update({
'FLAGS_cudnn_batchnorm_spatial_persistent': 1 'FLAGS_cudnn_batchnorm_spatial_persistent': 1
}) })
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING) paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
self.scale_loss = self.config["AMP"].get("scale_loss", 1.0) self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
self.use_dynamic_loss_scaling = self.config["AMP"].get( self.use_dynamic_loss_scaling = self.config["AMP"].get(
......
...@@ -62,8 +62,8 @@ def load_params(exe, prog, path, ignore_params=None): ...@@ -62,8 +62,8 @@ def load_params(exe, prog, path, ignore_params=None):
""" """
Load model from the given path. Load model from the given path.
Args: Args:
exe (fluid.Executor): The fluid.Executor object. exe (paddle.static.Executor): The paddle.static.Executor object.
prog (fluid.Program): load weight to which Program object. prog (paddle.static.Program): load weight to which Program object.
path (string): URL string or loca model path. path (string): URL string or loca model path.
ignore_params (list): ignore variable to load when finetuning. ignore_params (list): ignore variable to load when finetuning.
It can be specified by finetune_exclude_pretrained_params It can be specified by finetune_exclude_pretrained_params
......
...@@ -87,7 +87,7 @@ def main(args): ...@@ -87,7 +87,7 @@ def main(args):
'FLAGS_max_inplace_grad_add': 8, 'FLAGS_max_inplace_grad_add': 8,
} }
os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = '1' os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = '1'
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING) paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
use_xpu = global_config.get("use_xpu", False) use_xpu = global_config.get("use_xpu", False)
use_npu = global_config.get("use_npu", False) use_npu = global_config.get("use_npu", False)
......
...@@ -112,7 +112,7 @@ def get_path_from_url(url, ...@@ -112,7 +112,7 @@ def get_path_from_url(url,
str: a local path to save downloaded models & weights & datasets. str: a local path to save downloaded models & weights & datasets.
""" """
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.distributed import ParallelEnv
assert is_url(url), "downloading from {} not a url".format(url) assert is_url(url), "downloading from {} not a url".format(url)
# parse path after download to decompress under root_dir # parse path after download to decompress under root_dir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册