提交 613375e5 编写于 作者: W wuzewu

Fix paddle2.0 adaptation problem

上级 73e68b4b
...@@ -17,7 +17,7 @@ import os ...@@ -17,7 +17,7 @@ import os
import pickle import pickle
import time import time
from collections import defaultdict from collections import defaultdict
from typing import Any, Callable, List from typing import Any, Callable, Generic, List
import paddle import paddle
from paddle.distributed import ParallelEnv from paddle.distributed import ParallelEnv
...@@ -185,7 +185,7 @@ class Trainer(object): ...@@ -185,7 +185,7 @@ class Trainer(object):
timer.count() timer.count()
if (batch_idx + 1) % log_interval == 0 and self.local_rank == 0: if (batch_idx + 1) % log_interval == 0 and self.local_rank == 0:
lr = self.optimizer.current_step_lr() lr = self.optimizer.get_lr()
avg_loss /= log_interval avg_loss /= log_interval
if self.use_vdl: if self.use_vdl:
self.log_writer.add_scalar(tag='TRAIN/loss', step=timer.current_step, value=avg_loss) self.log_writer.add_scalar(tag='TRAIN/loss', step=timer.current_step, value=avg_loss)
...@@ -346,7 +346,12 @@ class Trainer(object): ...@@ -346,7 +346,12 @@ class Trainer(object):
optimizer(paddle.optimizer.Optimizer) : Optimizer used. optimizer(paddle.optimizer.Optimizer) : Optimizer used.
loss(paddle.Tensor) : Loss tensor. loss(paddle.Tensor) : Loss tensor.
''' '''
self.optimizer.minimize(loss) self.optimizer.step()
self.learning_rate_step(epoch_idx, batch_idx, self.optimizer.get_lr(), loss)
def learning_rate_step(self, epoch_idx: int, batch_idx: int, learning_rate: Generic, loss: paddle.Tensor):
if isinstance(learning_rate, paddle.optimizer._LRScheduler):
learning_rate.step()
def optimizer_zero_grad(self, epoch_idx: int, batch_idx: int, optimizer: paddle.optimizer.Optimizer): def optimizer_zero_grad(self, epoch_idx: int, batch_idx: int, optimizer: paddle.optimizer.Optimizer):
''' '''
......
...@@ -135,7 +135,7 @@ class Module(object): ...@@ -135,7 +135,7 @@ class Module(object):
manager = LocalModuleManager() manager = LocalModuleManager()
user_module_cls = manager.search(name) user_module_cls = manager.search(name)
if not user_module_cls or not user_module_cls.version.match(version): if not user_module_cls or not user_module_cls.version.match(version):
user_module_cls = manager.install(name, version) user_module_cls = manager.install(name=name, version=version)
directory = manager._get_normalized_path(user_module_cls.name) directory = manager._get_normalized_path(user_module_cls.name)
...@@ -148,7 +148,8 @@ class Module(object): ...@@ -148,7 +148,8 @@ class Module(object):
user_module = user_module_cls(directory=directory) user_module = user_module_cls(directory=directory)
user_module._initialize(**kwargs) user_module._initialize(**kwargs)
return user_module return user_module
return user_module_cls(directory=directory, **kwargs) user_module_cls.directory = directory
return user_module_cls(**kwargs)
@classmethod @classmethod
def init_with_directory(cls, directory: str, **kwargs): def init_with_directory(cls, directory: str, **kwargs):
...@@ -165,7 +166,8 @@ class Module(object): ...@@ -165,7 +166,8 @@ class Module(object):
user_module = user_module_cls(directory=directory) user_module = user_module_cls(directory=directory)
user_module._initialize(**kwargs) user_module._initialize(**kwargs)
return user_module return user_module
return user_module_cls(directory=directory, **kwargs) user_module_cls.directory = directory
return user_module_cls(**kwargs)
@classmethod @classmethod
def get_py_requirements(cls): def get_py_requirements(cls):
......
...@@ -73,7 +73,7 @@ class Logger(object): ...@@ -73,7 +73,7 @@ class Logger(object):
self.__dict__[key.lower()] = functools.partial(self.__call__, conf['level']) self.__dict__[key.lower()] = functools.partial(self.__call__, conf['level'])
self.format = colorlog.ColoredFormatter( self.format = colorlog.ColoredFormatter(
'%(log_color)s[%(asctime)-15s] [%(levelname)8s] - %(message)s', '%(log_color)s[%(asctime)-15s] [%(levelname)8s]%(reset)s - %(message)s',
log_colors={key: conf['color'] log_colors={key: conf['color']
for key, conf in log_config.items()}) for key, conf in log_config.items()})
...@@ -178,13 +178,13 @@ class FormattedText(object): ...@@ -178,13 +178,13 @@ class FormattedText(object):
======== ==================================== ======== ====================================
color(str) : Text color, default is None(depends on terminal configuration) color(str) : Text color, default is None(depends on terminal configuration)
''' '''
_MAP = {'red': Fore.RED, 'yellow': Fore.YELLOW, 'green': Fore.GREEN, 'blue': Fore.BLUE} _MAP = {'red': Fore.RED, 'yellow': Fore.YELLOW, 'green': Fore.GREEN, 'blue': Fore.BLUE, 'cyan': Fore.CYAN}
def __init__(self, text: str, width: int, align: str = '<', color: str = None): def __init__(self, text: str, width: int = None, align: str = '<', color: str = None):
self.text = text self.text = text
self.align = align self.align = align
self.color = FormattedText._MAP[color] if color else color self.color = FormattedText._MAP[color] if color else color
self.width = width self.width = width if width else len(self.text)
def __repr__(self) -> str: def __repr__(self) -> str:
form = '{{:{}{}}}'.format(self.align, self.width) form = '{{:{}{}}}'.format(self.align, self.width)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
import platform import platform
...@@ -22,3 +23,13 @@ def get_platform() -> str: ...@@ -22,3 +23,13 @@ def get_platform() -> str:
def is_windows() -> str: def is_windows() -> str:
return get_platform().lower().startswith("windows") return get_platform().lower().startswith("windows")
def get_platform_info() -> dict:
return {
'python_version': '.'.join(map(str, sys.version_info[0:3])),
'platform_version': platform.version(),
'platform_system': platform.system(),
'platform_architecture': platform.architecture(),
'platform_type': platform.platform()
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册