提交 ee496dc6 编写于 作者: L LielinJiang

move commen functions to utils.py

上级 c3930dc5
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -20,15 +20,9 @@ from paddle import fluid ...@@ -20,15 +20,9 @@ from paddle import fluid
from paddle.fluid.framework import in_dygraph_mode, Variable from paddle.fluid.framework import in_dygraph_mode, Variable
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
__all__ = ['Loss', 'CrossEntropy', 'SoftmaxWithCrossEntropy'] from hapi.utils import to_list
def to_list(value): __all__ = ['Loss', 'CrossEntropy', 'SoftmaxWithCrossEntropy']
if value is None:
return value
if isinstance(value, (list, tuple)):
return list(value)
return [value]
class Loss(object): class Loss(object):
......
...@@ -38,6 +38,7 @@ from hapi.loss import Loss ...@@ -38,6 +38,7 @@ from hapi.loss import Loss
from hapi.distributed import DistributedBatchSampler, _all_gather, prepare_distributed_context, _parallel_context_initialized from hapi.distributed import DistributedBatchSampler, _all_gather, prepare_distributed_context, _parallel_context_initialized
from hapi.metrics import Metric from hapi.metrics import Metric
from hapi.callbacks import config_callbacks from hapi.callbacks import config_callbacks
from hapi.utils import to_list, to_numpy, flatten_list, restore_flatten_list
__all__ = [ __all__ = [
'Model', 'Model',
...@@ -65,49 +66,6 @@ def set_device(device): ...@@ -65,49 +66,6 @@ def set_device(device):
return place return place
def to_list(value):
if value is None:
return value
if isinstance(value, (list, tuple)):
return list(value)
return [value]
def to_numpy(var):
assert isinstance(var, (Variable, fluid.core.VarBase)), "not a variable"
if isinstance(var, fluid.core.VarBase):
return var.numpy()
t = global_scope().find_var(var.name).get_tensor()
return np.array(t)
def flatten_list(l):
assert isinstance(l, list), "not a list"
outl = []
splits = []
for sl in l:
assert isinstance(sl, list), "sub content not a list"
splits.append(len(sl))
outl += sl
return outl, splits
def restore_flatten_list(l, splits):
outl = []
for split in splits:
assert len(l) >= split, "list length invalid"
sl, l = l[:split], l[split:]
outl.append(sl)
return outl
def extract_args(func):
if hasattr(inspect, 'getfullargspec'):
return inspect.getfullargspec(func)[0]
else:
return inspect.getargspec(func)[0]
class Input(fluid.dygraph.Layer): class Input(fluid.dygraph.Layer):
def __init__(self, shape=None, dtype=None, name=None): def __init__(self, shape=None, dtype=None, name=None):
super(Input, self).__init__() super(Input, self).__init__()
...@@ -1180,7 +1138,6 @@ class Model(fluid.dygraph.Layer): ...@@ -1180,7 +1138,6 @@ class Model(fluid.dygraph.Layer):
save_dir, save_dir,
model_filename=None, model_filename=None,
params_filename=None, params_filename=None,
export_for_deployment=True,
program_only=False): program_only=False):
""" """
Save inference model must in static mode. Save inference model must in static mode.
...@@ -1193,12 +1150,6 @@ class Model(fluid.dygraph.Layer): ...@@ -1193,12 +1150,6 @@ class Model(fluid.dygraph.Layer):
params_filename(str|None): The name of file to save all related parameters. params_filename(str|None): The name of file to save all related parameters.
If it is set None, parameters will be saved If it is set None, parameters will be saved
in separate files . in separate files .
export_for_deployment(bool): If True, programs are modified to only support
direct inference deployment. Otherwise,
more information will be stored for flexible
optimization and re-training. Currently, only
True is supported.
Default: True.
program_only(bool): If True, It will save inference program only, and do not program_only(bool): If True, It will save inference program only, and do not
save params of Program. save params of Program.
Default: False. Default: False.
...@@ -1226,7 +1177,6 @@ class Model(fluid.dygraph.Layer): ...@@ -1226,7 +1177,6 @@ class Model(fluid.dygraph.Layer):
main_program=infer_prog, main_program=infer_prog,
model_filename=model_filename, model_filename=model_filename,
params_filename=params_filename, params_filename=params_filename,
export_for_deployment=export_for_deployment,
program_only=program_only) program_only=program_only)
def _run_one_epoch(self, def _run_one_epoch(self,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import numpy as np
from paddle import fluid
from paddle.fluid.framework import Variable
from paddle.fluid.executor import global_scope
def to_list(value):
if value is None:
return value
if isinstance(value, (list, tuple)):
return list(value)
return [value]
def to_numpy(var):
assert isinstance(var, (Variable, fluid.core.VarBase)), "not a variable"
if isinstance(var, fluid.core.VarBase):
return var.numpy()
t = global_scope().find_var(var.name).get_tensor()
return np.array(t)
def flatten_list(l):
assert isinstance(l, list), "not a list"
outl = []
splits = []
for sl in l:
assert isinstance(sl, list), "sub content not a list"
splits.append(len(sl))
outl += sl
return outl, splits
def restore_flatten_list(l, splits):
outl = []
for split in splits:
assert len(l) >= split, "list length invalid"
sl, l = l[:split], l[split:]
outl.append(sl)
return outl
def extract_args(func):
if hasattr(inspect, 'getfullargspec'):
return inspect.getfullargspec(func)[0]
else:
return inspect.getargspec(func)[0]
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册