未验证 提交 5cc9ead7 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

fix reprod log (#5548)

上级 4e8726a3
......@@ -83,14 +83,14 @@
为了减少数据对比中标准不一致、人工对比过程繁杂的问题,我们建立了数据对比日志工具reprod_log。
#### 2.2.1 reprod_log工具简介
`reprod_log`是用于论文复现赛中辅助自查和核验工具。查看它的[源代码](https://github.com/WenmuZhou/reprod_log)能对它有个更全面理解。我们常用的功能如下:
`reprod_log`是用于论文复现赛中辅助自查和核验工具。查看它的[源代码](../reprod_log/)能对它有个更全面理解。我们常用的功能如下:
* 存取指定节点的输入输出tensor;
* 基于文件的tensor读写;
* 2个字典的对比验证;
* 对比结果的输出与记录;
更多API与使用方法可以参考:[reprod_log API使用说明](https://github.com/WenmuZhou/reprod_log/blob/master/README.md)
更多API与使用方法可以参考:[reprod_log API使用说明](../reprod_log/README.md)
#### 2.2.2 reprod_log使用demo
......
......@@ -90,14 +90,14 @@
### 2.2 reprod_log whl包
#### 2.2.1 reprod_log工具简介
`reprod_log`是用于论文复现赛中辅助自查和验收工具。该工具源代码地址在:[https://github.com/WenmuZhou/reprod_log](https://github.com/WenmuZhou/reprod_log)。主要功能如下:
`reprod_log`是用于论文复现赛中辅助自查和验收工具。该工具源代码地址在:[链接](../reprod_log)。主要功能如下:
* 存取指定节点的输入输出tensor
* 基于文件的tensor读写
* 2个字典的对比验证
* 对比结果的输出与记录
更多API与使用方法可以参考:[reprod_log API使用说明](https://github.com/WenmuZhou/reprod_log/blob/master/README.md)
更多API与使用方法可以参考:[reprod_log API使用说明](../reprod_log/README.md)
#### 2.2.2 reprod_log使用demo
......
......@@ -31,7 +31,7 @@
<a name="2.1"></a>
### 2.1 reprod_log 简介
Reprod_log 是一个用于 numpy 数据记录和对比工具,通过传入需要对比的两个 numpy 数组就可以在指定的规则下得到数据之差是否满足期望的结论。其主要接口的说明可以看它的 [github 主页](https://github.com/WenmuZhou/reprod_log)
Reprod_log 是一个用于 numpy 数据记录和对比工具,通过传入需要对比的两个 numpy 数组就可以在指定的规则下得到数据之差是否满足期望的结论。其主要接口的说明可以看它的 [github 主页](../../reprod_log/)
<a name="3"></a>
## 3. 准备数据和环境
......
# Byte-compiled / optimized / DLL files
__pycache__/
.ipynb_checkpoints/
*.py[cod]
*$py.class
# C extensions
*.so
inference/
inference_results/
output/
*.DS_Store
*.vs
*.user
*~
*.vscode
*.idea
*.log
.clang-format
.clang_format.hook
build/
dist/
reprod_log.egg-info/
\ No newline at end of file
# reprod_log
主要用于对比和记录模型复现过程中的各个步骤精度对齐情况
## 安装
1. 本地编译安装
```bash
cd models/tutorials/reprod_log
python3 setup.py bdist_wheel
python3 install dist/reprod_log-x.x.-py3-none-any.whl --force-reinstall
```
2. pip直接安装
```bash
# from pypi
pip3 install reprod_log --force-reinstall
# from bcebos
pip3 install https://paddle-model-ecology.bj.bcebos.com/whl/reprod_log-1.0.1-py3-none-any.whl
```
## 提供的类和方法
### 论文复现赛
在论文复现赛中,主要用到的类如下所示。
* ReprodLogger
* 功能:记录和保存复现过程中的中间变量,用于后续的diff排查
* 初始化参数:无
* 方法
* add(key, val)
* 功能:向logger中添加key-val pair
* 输入
* key (str) : PaddlePaddle中的key与参考代码中保存的key应该完全相同,否则会提示报错
* value (numpy.ndarray) : key对应的值
* 返回: None
* remove(key)
* 功能:移除logger中的关键字段key及其value
* 输入
* key (str) : 关键字段
* value (numpy.ndarray) : key对应的值
* 返回: None
* clear()
* 功能:清空logger中的关键字段key及其value
* 输入: None
* 返回: None
* save(path)
* 功能:将logger中的所有的key-value信息保存到文件中
* 输入:
* path (str): 路径
* 返回: None
* ReprodDiffHelper
* 功能:对`ReprodLogger`保存的日志文件进行解析,打印与记录diff
* 初始化参数:无
* 方法
* load_info(path)
* 功能:加载
* 输入:
* path (str): 日志文件路径
* 返回: dict信息,key为str,value为numpy.ndarray
* compare_info(info1, info2)
* 功能:计算两个字典对于相同key的value的diff,具体计算方法为`diff = np.abs(info1[key] - info2[key])`
* 输入:
* info1/info2 (dict): PaddlePaddle与参考代码保存的文件信息
* 返回: diff的dict信息
* report(diff_method="mean", diff_threshold=1e-6, path="./diff.txt")
* 功能:可视化diff,保存到文件或者到屏幕
* 参数
* diff_method (str): diff计算方法,包括`mean``min``max``all`,默认为`mean`
* diff_threshold (float): 阈值,如果diff大于该阈值,则核验失败,默认为`1e-6`
* path (str): 日志保存的路径,默认为`./diff.txt`
### more
`ReprodLogger` 用于记录和报错复现过程中的中间变量
主要方法为
* add(key, val):添加key-val pair
* remove(key):移除key
* clear():清空字典
* save(path):保存字典
`ReprodDiffHelper` 用于对中间变量进行检查,主要为计算diff
主要方法为
* load_info(path): 加载字典文件
* compare_info(info1:dict, info2:dict): 对比diff
* report(diff_threshold=1e-6,path=None): 可视化diff,保存到文件或者到屏幕
模块 `compare` 提供了基础的网络前向和反向过程对比工具
* compare_forward 用于对比网络的反向过程,其参数为
* torch_model: torch.nn.Module,
* paddle_model: paddle.nn.Layer,
* input_dict: dict, dict值为numpy矩阵
* diff_threshold: float=1e-6
* diff_method: str = 'mean' 检查diff的函数,目前支持 min,max,mean,all四种形式,并且支持min,max,mean的相互组合成的list形式,如['min','max']
* compare_loss_and_backward 用于对比网络的反向过程,其参数为
* torch_model: torch.nn.Module,
* paddle_model: paddle.nn.Layer,
* torch_loss: torch.nn.Module,
* paddle_loss: paddle.nn.Layer,
* input_dict: dict, dict值为numpy矩阵
* lr: float=1e-3,
* steps: int=10,
* diff_threshold: float=1e-6
* diff_method: str = 'mean' 检查diff的函数,目前支持 min,max,mean,all四种形式,并且支持min,max,mean的相互组合成的list形式,如['min','max']
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path
import sys
import numpy as np
from .utils import init_logger, check_print_diff
from .compare import compute_diff, check_data
class ReprodDiffHelper:
def load_info(self, path: str):
"""
加载字典文件
:param path:
:return:
"""
assert os.path.exists(path)
data = np.load(path, allow_pickle=True).tolist()
return data
def compare_info(self, info1: dict, info2: dict):
"""
对比diff
:param info1:
:param info2:
:return:
"""
assert isinstance(info1, dict) and isinstance(info2, dict)
check_data(info1, info2)
self.diff_dict = compute_diff(info1, info2)
def report(self,
diff_method='mean',
diff_threshold: float=1e-6,
path: str="./diff.txt"):
"""
可视化diff,保存到文件或者到屏幕
:param diff_threshold:
:param path:
:return:
"""
logger = init_logger(path)
passed = check_print_diff(
self.diff_dict,
diff_method=diff_method,
diff_threshold=diff_threshold,
print_func=logger.info)
if passed:
logger.info('diff check passed')
else:
logger.info('diff check failed')
return
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
class ReprodLogger(object):
def __init__(self):
self._data = dict()
@property
def data(self):
return self._data
def add(self, key, val):
"""
添加key-val pair
:param key:
:param val:
:return:
"""
msg = '''val must be np.ndarray, you can convert it by follow code:
1. Torch GPU: torch_tensor.cpu().detach().numpy()
2. Torch CPU: torch_tensor.detach().numpy()
3. Paddle: paddle_tensor.numpy()'''
assert isinstance(val, np.ndarray), msg
self._data[key] = val
def remove(self, key):
"""
移除key
:param key:
:return:
"""
if key in self._data:
self._data.pop(key)
else:
print('{} is not in {}'.format(key, self._data.keys()))
def clear(self):
"""
清空字典
:return:
"""
self._data.clear()
def save(self, path):
folder = os.path.dirname(path)
if len(folder) >= 1:
os.makedirs(folder, exist_ok=True)
np.save(path, self._data)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .ReprodLogger import ReprodLogger
from .ReprodDiffHelper import ReprodDiffHelper
from . import utils, compare
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
from .utils import np2torch, np2paddle, paddle2np, torch2np, check_print_diff
def check_data(data1: dict, data2: dict):
for k in data1:
if k not in data2:
assert k in data2, 'k in data1 but not found in data2'.format(
k, data2)
for k in data2:
if k not in data1:
assert k in data1, 'k in data2 but not found in data1'.format(
k, data2.keys())
def compute_diff(data1: dict, data2: dict):
out_dict = {}
for k in data1:
assert k in data2
sub_data1, sub_data2 = data1[k], data2[k]
assert type(sub_data1) == type(sub_data2)
if isinstance(sub_data1, dict):
out = compute_diff(sub_data1, sub_data2)
out_dict[k] = out
elif isinstance(sub_data1, np.ndarray):
if sub_data1.shape != sub_data2.shape and sub_data1.transpose(
).shape == sub_data2.shape:
print('transpose sub_data1')
sub_data1 = sub_data1.transpose()
diff = np.abs(sub_data1 - sub_data2)
out_dict[k] = {
'mean': diff.mean(),
'max': diff.max(),
'min': diff.min()
}
else:
raise NotImplementedError
return out_dict
def compare_forward(torch_model,
paddle_model: paddle.nn.Layer,
input_dict: dict,
diff_threshold: float=1e-6,
diff_method: str='mean'):
torch_input = np2torch(input_dict)
paddle_input = np2paddle(input_dict)
torch_model.eval()
paddle_model.eval()
torch_out = torch_model(**torch_input)
paddle_out = paddle_model(**paddle_input)
diff_dict = compute_diff(torch2np(torch_out), paddle2np(paddle_out))
passed = check_print_diff(
diff_dict,
diff_method=diff_method,
diff_threshold=diff_threshold,
print_func=print)
if passed:
print('diff check passed')
else:
print('diff check failed')
def compare_loss_and_backward(torch_model,
paddle_model: paddle.nn.Layer,
torch_loss,
paddle_loss: paddle.nn.Layer,
input_dict: dict,
lr: float=1e-3,
steps: int=10,
diff_threshold: float=1e-6,
diff_method: str='mean'):
import torch
torch_input = np2torch(input_dict)
paddle_input = np2paddle(input_dict)
torch_model.eval()
paddle_model.eval()
torch_optim = torch.optim.SGD(params=torch_model.parameters(), lr=lr)
paddle_optim = paddle.optimizer.SGD(parameters=paddle_model.parameters(),
learning_rate=lr)
for i in range(steps):
# paddle
paddle_outputs = paddle_model(**paddle_input)
paddle_loss_value = paddle_loss(paddle_input, paddle_outputs)
paddle_loss_value['loss'].backward()
paddle_optim.step()
paddle_grad_dict = {'loss': paddle_loss_value['loss'].numpy()}
for name, parms in paddle_model.named_parameters():
if not parms.stop_gradient and parms.grad is not None:
paddle_grad_dict[name] = parms.grad.numpy()
paddle_optim.clear_grad()
# torch
torch_outputs = torch_model(**torch_input)
torch_loss_value = torch_loss(torch_input, torch_outputs)
torch_loss_value['loss'].backward()
torch_optim.step()
torch_grad_dict = {'loss': torch_loss_value['loss'].detach().numpy()}
for name, parms in torch_model.named_parameters():
if parms.requires_grad and parms.grad is not None:
torch_grad_dict[name] = parms.grad.numpy()
torch_optim.zero_grad()
# compare
diff_dict = compute_diff(paddle_grad_dict, torch_grad_dict)
passed = check_print_diff(
diff_dict,
diff_method=diff_method,
diff_threshold=diff_threshold,
print_func=print)
if passed:
print('diff check passed in iter {}'.format(i))
else:
print('diff check failed in iter {}'.format(i))
return
print('diff check passed')
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import logging
import paddle
import numpy as np
from typing import Union
def init_logger(log_file=None, name='root', log_level=logging.DEBUG):
logger = logging.getLogger(name)
formatter = logging.Formatter(
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
datefmt="%Y/%m/%d %H:%M:%S")
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
if log_file is not None:
dir_name = os.path.dirname(log_file)
if len(dir_name) > 0 and not os.path.exists(dir_name):
os.makedirs(dir_name)
file_handler = logging.FileHandler(log_file, 'w')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.setLevel(log_level)
return logger
def np2torch(data: dict):
import torch
assert isinstance(data, dict)
torch_input = {}
for k, v in data.items():
if isinstance(v, np.ndarray):
torch_input[k] = torch.Tensor(v)
else:
torch_input[k] = v
return torch_input
def np2paddle(data: dict):
assert isinstance(data, dict)
paddle_input = {}
for k, v in data.items():
if isinstance(v, np.ndarray):
paddle_input[k] = paddle.Tensor(v)
else:
paddle_input[k] = v
return paddle_input
def paddle2np(data: Union[paddle.Tensor, dict]=None):
if isinstance(data, dict):
np_data = {}
for k, v in data.items():
np_data[k] = v.numpy()
return np_data
else:
return {'output': data.numpy()}
def torch2np(data):
if isinstance(data, dict):
np_data = {}
for k, v in data.items():
np_data[k] = v.detach().numpy()
return np_data
else:
return {'output': data.detach().numpy()}
def check_print_diff(diff_dict,
diff_method='mean',
diff_threshold: float=1e-6,
print_func=print,
indent: str='\t',
level: int=0):
"""
对 diff 字典打印并进行检查的函数
:param diff_dict:
:param diff_method: 检查diff的函数,目前支持 min,max,mean,all四种形式,并且支持min,max,mean的相互组合成的list形式,如['min','max']
:param diff_threshold:
:param print_func:
:param indent:
:param level:
:return:
"""
if level == 0:
if isinstance(diff_method, str):
if diff_method == 'all':
diff_method = ['min', 'max', 'mean']
else:
diff_method = [diff_method]
for method in diff_method:
assert method in ['all', 'min', 'max', 'mean']
all_passed = True
cur_indent = indent * level
for k, v in diff_dict.items():
if 'mean' in v and 'min' in v and 'max' in v and len(v) == 3:
print_func('{}{}: '.format(cur_indent, k))
sub_passed = True
for method in diff_method:
if v[method] > diff_threshold:
sub_passed = False
print_func("{}{} diff: check passed: {}, value: {}".format(
cur_indent + indent, method, sub_passed, v[method]))
all_passed = all_passed and sub_passed
else:
print_func('{}{}'.format(cur_indent, k))
sub_passed = check_print_diff(v, diff_method, diff_threshold,
print_func, indent, level + 1)
all_passed = all_passed and sub_passed
return all_passed
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from setuptools import setup, find_packages
with open('requirements.txt', encoding="utf-8-sig") as f:
requirements = f.readlines()
setup(
name='reprod_log',
version='1.0.1',
install_requires=requirements,
license='Apache License 2.0',
keywords='reprod_log',
description="TBD",
url='https://github.com/PaddlePaddle/models',
author='PaddlePaddle',
packages=find_packages(), )
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册