提交 613375e5 编写于 作者: W wuzewu

Fix paddle2.0 adaptation problem

上级 73e68b4b
......@@ -17,7 +17,7 @@ import os
import pickle
import time
from collections import defaultdict
from typing import Any, Callable, List
from typing import Any, Callable, Generic, List
import paddle
from paddle.distributed import ParallelEnv
......@@ -185,7 +185,7 @@ class Trainer(object):
timer.count()
if (batch_idx + 1) % log_interval == 0 and self.local_rank == 0:
lr = self.optimizer.current_step_lr()
lr = self.optimizer.get_lr()
avg_loss /= log_interval
if self.use_vdl:
self.log_writer.add_scalar(tag='TRAIN/loss', step=timer.current_step, value=avg_loss)
......@@ -346,7 +346,12 @@ class Trainer(object):
optimizer(paddle.optimizer.Optimizer) : Optimizer used.
loss(paddle.Tensor) : Loss tensor.
'''
self.optimizer.minimize(loss)
self.optimizer.step()
self.learning_rate_step(epoch_idx, batch_idx, self.optimizer.get_lr(), loss)
def learning_rate_step(self, epoch_idx: int, batch_idx: int, learning_rate: Generic, loss: paddle.Tensor):
if isinstance(learning_rate, paddle.optimizer._LRScheduler):
learning_rate.step()
def optimizer_zero_grad(self, epoch_idx: int, batch_idx: int, optimizer: paddle.optimizer.Optimizer):
'''
......
......@@ -135,7 +135,7 @@ class Module(object):
manager = LocalModuleManager()
user_module_cls = manager.search(name)
if not user_module_cls or not user_module_cls.version.match(version):
user_module_cls = manager.install(name, version)
user_module_cls = manager.install(name=name, version=version)
directory = manager._get_normalized_path(user_module_cls.name)
......@@ -148,7 +148,8 @@ class Module(object):
user_module = user_module_cls(directory=directory)
user_module._initialize(**kwargs)
return user_module
return user_module_cls(directory=directory, **kwargs)
user_module_cls.directory = directory
return user_module_cls(**kwargs)
@classmethod
def init_with_directory(cls, directory: str, **kwargs):
......@@ -165,7 +166,8 @@ class Module(object):
user_module = user_module_cls(directory=directory)
user_module._initialize(**kwargs)
return user_module
return user_module_cls(directory=directory, **kwargs)
user_module_cls.directory = directory
return user_module_cls(**kwargs)
@classmethod
def get_py_requirements(cls):
......
......@@ -73,7 +73,7 @@ class Logger(object):
self.__dict__[key.lower()] = functools.partial(self.__call__, conf['level'])
self.format = colorlog.ColoredFormatter(
'%(log_color)s[%(asctime)-15s] [%(levelname)8s] - %(message)s',
'%(log_color)s[%(asctime)-15s] [%(levelname)8s]%(reset)s - %(message)s',
log_colors={key: conf['color']
for key, conf in log_config.items()})
......@@ -178,13 +178,13 @@ class FormattedText(object):
======== ====================================
color(str) : Text color, default is None(depends on terminal configuration)
'''
_MAP = {'red': Fore.RED, 'yellow': Fore.YELLOW, 'green': Fore.GREEN, 'blue': Fore.BLUE}
_MAP = {'red': Fore.RED, 'yellow': Fore.YELLOW, 'green': Fore.GREEN, 'blue': Fore.BLUE, 'cyan': Fore.CYAN}
def __init__(self, text: str, width: int, align: str = '<', color: str = None):
def __init__(self, text: str, width: int = None, align: str = '<', color: str = None):
self.text = text
self.align = align
self.color = FormattedText._MAP[color] if color else color
self.width = width
self.width = width if width else len(self.text)
def __repr__(self) -> str:
form = '{{:{}{}}}'.format(self.align, self.width)
......
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import platform
......@@ -22,3 +23,13 @@ def get_platform() -> str:
def is_windows() -> str:
return get_platform().lower().startswith("windows")
def get_platform_info() -> dict:
return {
'python_version': '.'.join(map(str, sys.version_info[0:3])),
'platform_version': platform.version(),
'platform_system': platform.system(),
'platform_architecture': platform.architecture(),
'platform_type': platform.platform()
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册