提交 bc6e7f5c 编写于 作者: H HydrogenSulfate

update code

......@@ -7,7 +7,7 @@
飞桨图像识别套件PaddleClas是飞桨为工业界和学术界所准备的一个图像识别任务的工具集,助力使用者训练出更好的视觉模型和应用落地。
**近期更新**
- 2022.4.21 新增 CVPR2022 oral论文 [MixFormmer](https://arxiv.org/pdf/2204.02557.pdf) 相关[代码](https://github.com/PaddlePaddle/PaddleClas/pull/1820/files)
- 2022.1.27 全面升级文档;新增[PaddleServing C++ pipeline部署方式](./deploy/paddleserving)[18M图像识别安卓部署Demo](./deploy/lite_shitu)
- 2021.11.1 发布[PP-ShiTu技术报告](https://arxiv.org/pdf/2111.00775.pdf),新增饮料识别demo
- 2021.10.23 发布轻量级图像识别系统PP-ShiTu,CPU上0.2s即可完成在10w+库的图像识别。
......@@ -35,10 +35,11 @@ Res2Net200_vd预训练模型Top-1精度高达85.1%。
## 欢迎加入技术交流群
* 您可以扫描下面的微信群二维码, 加入PaddleClas 微信交流群。获得更高效的问题答疑,与各行各业开发者充分交流,期待您的加入。
* 您可以扫描下面的QQ/微信二维码(添加小助手微信并回复“C”),加入PaddleClas微信交流群,获得更高效的问题答疑,与各行各业开发者充分交流,期待您的加入。
<div align="center">
<img src="https://user-images.githubusercontent.com/12560511/162710270-8a249aca-4fa9-46f9-95e5-66d906fe6d66.jpg" width="200"/>
<img src="https://user-images.githubusercontent.com/80816848/164383225-e375eb86-716e-41b4-a9e0-4b8a3976c1aa.jpg" width="200"/>
<img src="https://user-images.githubusercontent.com/48054808/160531099-9811bbe6-cfbb-47d5-8bdb-c2b40684d7dd.png" width="200"/>
</div>
## 快速体验
......
......@@ -8,6 +8,8 @@ PaddleClas is an image recognition toolset for industry and academia, helping us
**Recent updates**
- 2022.4.21 Added the related [code](https://github.com/PaddlePaddle/PaddleClas/pull/1820/files) of the CVPR2022 oral paper [MixFormmer](https://arxiv.org/pdf/2204.02557.pdf).
- 2021.09.17 Add PP-LCNet series model developed by PaddleClas, these models show strong competitiveness on Intel CPUs.
For the introduction of PP-LCNet, please refer to [paper](https://arxiv.org/pdf/2109.15099.pdf) or [PP-LCNet model introduction](docs/en/models/PP-LCNet_en.md). The metrics and pretrained model are available [here](docs/en/ImageNet_models_en.md).
......@@ -38,10 +40,11 @@ Four sample solutions are provided, including product recognition, vehicle recog
## Welcome to Join the Technical Exchange Group
* You can also scan the QR code below to join the PaddleClas WeChat group to get more efficient answers to your questions and to communicate with developers from all walks of life. We look forward to hearing from you.
* You can also scan the QR code below to join the PaddleClas QQ group and WeChat group (add and replay "C") to get more efficient answers to your questions and to communicate with developers from all walks of life. We look forward to hearing from you.
<div align="center">
<img src="https://user-images.githubusercontent.com/12560511/162710270-8a249aca-4fa9-46f9-95e5-66d906fe6d66.jpg" width="200"/>
<img src="https://user-images.githubusercontent.com/80816848/164383225-e375eb86-716e-41b4-a9e0-4b8a3976c1aa.jpg" width="200"/>
<img src="https://user-images.githubusercontent.com/48054808/160531099-9811bbe6-cfbb-47d5-8bdb-c2b40684d7dd.png" width="200"/>
</div>
## Quick Start
......
......@@ -75,9 +75,23 @@ python3 -m paddle.distributed.launch \
The highest accuracy of the validation set is around 0.415.
* ** Note**
Here, multiple GPUs are used for training. If only one GPU is used, please specify the GPU with the `CUDA_VISIBLE_DEVICES` setting, and specify the GPU with the `--gpus` setting, the same below. For example, to train with only GPU 0:
```shell
export CUDA_VISIBLE_DEVICES=0
python3 -m paddle.distributed.launch \
--gpus="0" \
tools/train.py \
-c ./ppcls/configs/quick_start/professional/ResNet50_vd_CIFAR100.yaml \
-o Global.output_dir="output_CIFAR" \
-o Optimizer.lr.learning_rate=0.01
```
* **Notice**:
* The GPUs specified in `--gpus` can be a subset of the GPUs specified in `CUDA_VISIBLE_DEVICES`.
* Since the initial learning rate and batch-size need to maintain a linear relationship, when training is switched from 4 GPUs to 1 GPU, the total batch-size is reduced to 1/4 of the original, and the learning rate also needs to be reduced to 1/4 of the original, so changed the default learning rate from 0.04 to 0.01.
* If the number of GPU cards is not 4, the accuracy of the validation set may be different from 0.415. To maintain a comparable accuracy, you need to change the learning rate in the configuration file to `the current learning rate / 4 \* current card number`. The same below.
<a name="2.1.2"></a>
......
......@@ -75,9 +75,22 @@ python3 -m paddle.distributed.launch \
验证集的最高准确率为 0.415 左右。
* **注意**
此处使用了多个 GPU 训练,如果只使用一个 GPU,请将 `CUDA_VISIBLE_DEVICES` 设置指定 GPU,`--gpus`设置指定 GPU,下同。例如,只使用 0 号 GPU 训练:
```shell
export CUDA_VISIBLE_DEVICES=0
python3 -m paddle.distributed.launch \
--gpus="0" \
tools/train.py \
-c ./ppcls/configs/quick_start/professional/ResNet50_vd_CIFAR100.yaml \
-o Global.output_dir="output_CIFAR" \
-o Optimizer.lr.learning_rate=0.01
```
* **注意**:
* 如果 GPU 卡数不是 4,验证集的准确率可能与 0.415 有差异,若需保持相当的准确率,需要将配置文件中的学习率改为`当前学习率 / 4 \* 当前卡数`。下同。
* `--gpus`中指定的 GPU 可以是 `CUDA_VISIBLE_DEVICES` 指定的 GPU 的子集。
* 由于初始学习率和 batch-size 需要保持线性关系,所以训练从 4 个 GPU 切换到 1 个 GPU 训练时,总 batch-size 缩减为原来的 1/4,学习率也需要缩减为原来的 1/4,所以改变了默认的学习率从 0.04 到 0.01。
<a name="2.1.2"></a>
......@@ -157,7 +170,7 @@ python3 -m paddle.distributed.launch \
* **注意**
* 其他数据增广的配置文件可以参考 `ppcls/configs/ImageNet/DataAugment/` 中的配置文件。
* 训练 CIFAR100 的迭代轮数较少,因此进行训练时,验证集的精度指标可能会有 1% 左右的波动。
* 训练 CIFAR100 的迭代轮数较少,因此进行训练时,验证集的精度指标可能会有 1% 左右的波动。
<a name="4"></a>
......
......@@ -21,6 +21,15 @@ from .common_dataset import CommonDataset
class ImageNetDataset(CommonDataset):
def __init__(
self,
image_root,
cls_label_path,
transform_ops=None,
delimiter=None):
self.delimiter = delimiter if delimiter is not None else " "
super(ImageNetDataset, self).__init__(image_root, cls_label_path, transform_ops)
def _load_anno(self, seed=None):
assert os.path.exists(self._cls_path)
assert os.path.exists(self._img_root)
......@@ -32,7 +41,7 @@ class ImageNetDataset(CommonDataset):
if seed is not None:
np.random.RandomState(seed).shuffle(lines)
for l in lines:
l = l.strip().split(" ")
l = l.strip().split(self.delimiter)
self.images.append(os.path.join(self._img_root, l[0]))
self.labels.append(np.int64(l[1]))
assert os.path.exists(self.images[-1])
......@@ -19,10 +19,11 @@ import paddle.nn.functional as F
class Topk(object):
def __init__(self, topk=1, class_id_map_file=None):
def __init__(self, topk=1, class_id_map_file=None, delimiter=None):
assert isinstance(topk, (int, ))
self.class_id_map = self.parse_class_id_map(class_id_map_file)
self.topk = topk
self.delimiter = delimiter if delimiter is not None else " "
def parse_class_id_map(self, class_id_map_file):
if class_id_map_file is None:
......@@ -38,7 +39,7 @@ class Topk(object):
with open(class_id_map_file, "r") as fin:
lines = fin.readlines()
for line in lines:
partition = line.split("\n")[0].partition(" ")
partition = line.split("\n")[0].partition(self.delimiter)
class_id_map[int(partition[0])] = str(partition[-1])
except Exception as ex:
print(ex)
......
......@@ -224,7 +224,7 @@ class Engine(object):
# build optimizer
if self.mode == 'train':
self.optimizer, self.lr_sch = build_optimizer(
self.config, self.config["Global"]["epochs"],
self.config["Optimizer"], self.config["Global"]["epochs"],
len(self.train_dataloader),
[self.model, self.train_loss_func])
......
......@@ -39,7 +39,7 @@ def update_loss(trainer, loss_dict, batch_size):
def log_info(trainer, batch_size, epoch_id, iter_id):
lr_msg = ", ".join([
"lr_{}: {:.8f}".format(i + 1, lr.get_lr())
"lr({}): {:.8f}".format(lr.__class__.__name__, lr.get_lr())
for i, lr in enumerate(trainer.lr_sch)
])
metric_msg = ", ".join([
......@@ -64,7 +64,7 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
for i, lr in enumerate(trainer.lr_sch):
logger.scaler(
name="lr_{}".format(i + 1),
name="lr({})".format(lr.__class__.__name__),
value=lr.get_lr(),
step=trainer.global_step,
writer=trainer.vdl_writer)
......
......@@ -44,8 +44,7 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
# model_list is None in static graph
def build_optimizer(config, epochs, step_each_epoch, model_list=None):
config = copy.deepcopy(config)
optim_config = config["Optimizer"]
optim_config = copy.deepcopy(config)
if isinstance(optim_config, dict):
# convert {'name': xxx, **optim_cfg} to [{'name': {'scope': xxx, **optim_cfg}}]
optim_name = optim_config.pop("name")
......@@ -61,19 +60,20 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
"""NOTE:
Currently only support optim objets below.
1. single optimizer config.
2. next level uner Arch, such as Arch.backbone, Arch.neck, Arch.head.
3. loss which has parameters, such as CenterLoss.
2. model(entire Arch), backbone, neck, head.
3. loss(entire Loss), specific loss listed in ppcls/loss/__init__.py.
"""
for optim_item in optim_config:
# optim_cfg = {optim_name: {'scope': xxx, **optim_cfg}}
# step1 build lr
optim_name = list(optim_item.keys())[0] # get optim_name
optim_scope = optim_item[optim_name].pop('scope') # get optim_scope
optim_scope_list = optim_item[optim_name].pop('scope').split(
' ') # get optim_scope list
optim_cfg = optim_item[optim_name] # get optim_cfg
lr = build_lr_scheduler(optim_cfg.pop('lr'), epochs, step_each_epoch)
logger.info("build lr ({}) for scope ({}) success..".format(
lr.__class__.__name__, optim_scope))
lr.__class__.__name__, optim_scope_list))
# step2 build regularization
if 'regularizer' in optim_cfg and optim_cfg['regularizer'] is not None:
if 'weight_decay' in optim_cfg:
......@@ -85,46 +85,53 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
reg = getattr(paddle.regularizer, reg_name)(**reg_config)
optim_cfg["weight_decay"] = reg
logger.info("build regularizer ({}) for scope ({}) success..".
format(reg.__class__.__name__, optim_scope))
format(reg.__class__.__name__, optim_scope_list))
# step3 build optimizer
if 'clip_norm' in optim_cfg:
clip_norm = optim_cfg.pop('clip_norm')
grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm)
logger.info("build gradclip ({}) for scope ({}) success..".format(
grad_clip.__class__.__name__, optim_scope_list))
else:
grad_clip = None
optim_model = []
for i in range(len(model_list)):
if len(model_list[i].parameters()) == 0:
continue
if optim_scope == "all":
# optimizer for all
optim_model.append(model_list[i])
else:
if "Loss" in optim_scope:
# optimizer for loss
if hasattr(model_list[i], 'loss_func'):
for j in range(len(model_list[i].loss_func)):
if model_list[i].loss_func[
j].__class__.__name__ == optim_scope:
optim_model.append(model_list[i].loss_func[j])
elif optim_scope == "model":
# opmizer for entire model
if not model_list[i].__class__.__name__.lower().endswith(
"loss"):
optim_model.append(model_list[i])
else:
# opmizer for module in model, such as backbone, neck, head...
if hasattr(model_list[i], optim_scope):
optim_model.append(getattr(model_list[i], optim_scope))
assert len(optim_model) == 1, \
"Invalid optim model for optim scope({}), number of optim_model={}".\
format(optim_scope, [m.__class__.__name__ for m in optim_model])
# for static graph
if model_list is None:
optim = getattr(optimizer, optim_name)(
learning_rate=lr, grad_clip=grad_clip,
**optim_cfg)(model_list=optim_model)
return optim, lr
# for dynamic graph
for scope in optim_scope_list:
if scope == "all":
optim_model += model_list
elif scope == "model":
optim_model += [model_list[0], ]
elif scope in ["backbone", "neck", "head"]:
optim_model += [getattr(model_list[0], scope, None), ]
elif scope == "loss":
optim_model += [model_list[1], ]
else:
optim_model += [
model_list[1].loss_func[i]
for i in range(len(model_list[1].loss_func))
if model_list[1].loss_func[i].__class__.__name__ == scope
]
# remove invalid items
optim_model = [
optim_model[i] for i in range(len(optim_model))
if (optim_model[i] is not None
) and (len(optim_model[i].parameters()) > 0)
]
assert len(optim_model) > 0, \
f"optim_model is empty for optim_scope({optim_scope_list})"
optim = getattr(optimizer, optim_name)(
learning_rate=lr, grad_clip=grad_clip,
**optim_cfg)(model_list=optim_model)
logger.info("build optimizer ({}) for scope ({}) success..".format(
optim.__class__.__name__, optim_scope))
optim.__class__.__name__, optim_scope_list))
optim_list.append(optim)
lr_list.append(lr)
return optim_list, lr_list
......@@ -50,9 +50,5 @@ inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.tran
-o Global.benchmark:True
null:null
null:null
===========================train_benchmark_params==========================
batch_size:32
fp_items:fp32
epoch:1
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
......@@ -50,9 +50,5 @@ inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.tran
-o Global.benchmark:True
null:null
null:null
===========================train_benchmark_params==========================
batch_size:32
fp_items:fp32
epoch:1
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,384,384]}]
......@@ -50,9 +50,5 @@ inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.tran
-o Global.benchmark:True
null:null
null:null
===========================train_benchmark_params==========================
batch_size:32
fp_items:fp32
epoch:1
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
......@@ -50,9 +50,5 @@ inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.tran
-o Global.benchmark:True
null:null
null:null
===========================train_benchmark_params==========================
batch_size:8
fp_items:fp32
epoch:1
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,384,384]}]
......@@ -50,9 +50,5 @@ inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.tran
-o Global.benchmark:True
null:null
null:null
===========================train_benchmark_params==========================
batch_size:64
fp_items:fp32
epoch:1
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
......@@ -56,3 +56,5 @@ fp_items:fp32
epoch:1
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
......@@ -56,3 +56,5 @@ fp_items:fp32
epoch:1
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,256,256]}]
......@@ -50,9 +50,5 @@ inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.tran
-o Global.benchmark:True
null:null
null:null
===========================train_benchmark_params==========================
batch_size:128
fp_items:fp32
epoch:1
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,256,256]}]
......@@ -50,9 +50,5 @@ inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.tran
-o Global.benchmark:True
null:null
null:null
===========================train_benchmark_params==========================
batch_size:128
fp_items:fp32
epoch:1
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,256,256]}]
......@@ -50,9 +50,5 @@ inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.tran
-o Global.benchmark:True
null:null
null:null
===========================train_benchmark_params==========================
batch_size:128
fp_items:fp32
epoch:1
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
\ No newline at end of file
......@@ -50,9 +50,5 @@ inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.tran
-o Global.benchmark:True
null:null
null:null
===========================train_benchmark_params==========================
batch_size:128
fp_items:fp32
epoch:1
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
\ No newline at end of file
......@@ -56,3 +56,5 @@ fp_items:fp32
epoch:1
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
......@@ -50,9 +50,5 @@ inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.tran
-o Global.benchmark:True
null:null
null:null
===========================train_benchmark_params==========================
batch_size:128
fp_items:fp32
epoch:1
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
......@@ -50,9 +50,5 @@ inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.tran
-o Global.benchmark:True
null:null
null:null
===========================train_benchmark_params==========================
batch_size:128
fp_items:fp32
epoch:1
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
......@@ -50,9 +50,5 @@ inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.tran
-o Global.benchmark:True
null:null
null:null
===========================train_benchmark_params==========================
batch_size:128
fp_items:fp32
epoch:1
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
......@@ -50,9 +50,5 @@ inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.tran
-o Global.benchmark:True
null:null
null:null
===========================train_benchmark_params==========================
batch_size:128
fp_items:fp32
epoch:1
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
......@@ -50,9 +50,5 @@ inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.tran
-o Global.benchmark:True
null:null
null:null
===========================train_benchmark_params==========================
batch_size:128
fp_items:fp32
epoch:1
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册