未验证 提交 ebbc5431 编写于 作者: G Guanghua Yu 提交者: GitHub

update YOLO ptq dataloader (#1340)

上级 ca12971c
...@@ -17,13 +17,10 @@ import sys ...@@ -17,13 +17,10 @@ import sys
import numpy as np import numpy as np
import argparse import argparse
import paddle import paddle
from ppdet.core.workspace import load_config, merge_config from paddleslim.auto_compression.config_helpers import load_config
from ppdet.core.workspace import create
from ppdet.metrics import COCOMetric, VOCMetric
from paddleslim.auto_compression.config_helpers import load_config as load_slim_config
from paddleslim.quant import quant_post_static
from paddleslim.common import load_onnx_model from paddleslim.common import load_onnx_model
from paddleslim.quant import quant_post_static
from dataset import COCOTrainDataset
def argsparser(): def argsparser():
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
...@@ -49,32 +46,17 @@ def argsparser(): ...@@ -49,32 +46,17 @@ def argsparser():
return parser return parser
def reader_wrapper(reader, input_list):
def gen():
for data in reader:
in_dict = {}
if isinstance(input_list, list):
for input_name in input_list:
in_dict[input_name] = data[input_name]
elif isinstance(input_list, dict):
for input_name in input_list.keys():
in_dict[input_list[input_name]] = data[input_name]
yield in_dict
return gen
def main(): def main():
global global_config global global_config
all_config = load_slim_config(FLAGS.config_path) all_config = load_config(FLAGS.config_path)
assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}"
global_config = all_config["Global"] global_config = all_config["Global"]
reader_cfg = load_config(global_config['reader_config'])
train_loader = create('EvalReader')(reader_cfg['TrainDataset'], dataset = COCOTrainDataset(
reader_cfg['worker_num'], dataset_dir=global_config['dataset_dir'],
return_list=True) image_dir=global_config['val_image_dir'],
train_loader = reader_wrapper(train_loader, global_config['input_list']) anno_path=global_config['val_anno_path'])
train_loader = paddle.io.DataLoader(
dataset, batch_size=1, shuffle=True, drop_last=True, num_workers=0)
place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace()
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
......
...@@ -17,13 +17,10 @@ import sys ...@@ -17,13 +17,10 @@ import sys
import numpy as np import numpy as np
import argparse import argparse
import paddle import paddle
from ppdet.core.workspace import load_config, merge_config from paddleslim.auto_compression.config_helpers import load_config
from ppdet.core.workspace import create
from ppdet.metrics import COCOMetric, VOCMetric
from paddleslim.auto_compression.config_helpers import load_config as load_slim_config
from paddleslim.quant import quant_post_static
from paddleslim.common import load_onnx_model from paddleslim.common import load_onnx_model
from paddleslim.quant import quant_post_static
from dataset import COCOTrainDataset
def argsparser(): def argsparser():
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
...@@ -49,32 +46,17 @@ def argsparser(): ...@@ -49,32 +46,17 @@ def argsparser():
return parser return parser
def reader_wrapper(reader, input_list):
def gen():
for data in reader:
in_dict = {}
if isinstance(input_list, list):
for input_name in input_list:
in_dict[input_name] = data[input_name]
elif isinstance(input_list, dict):
for input_name in input_list.keys():
in_dict[input_list[input_name]] = data[input_name]
yield in_dict
return gen
def main(): def main():
global global_config global global_config
all_config = load_slim_config(FLAGS.config_path) all_config = load_config(FLAGS.config_path)
assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}"
global_config = all_config["Global"] global_config = all_config["Global"]
reader_cfg = load_config(global_config['reader_config'])
train_loader = create('EvalReader')(reader_cfg['TrainDataset'], dataset = COCOTrainDataset(
reader_cfg['worker_num'], dataset_dir=global_config['dataset_dir'],
return_list=True) image_dir=global_config['val_image_dir'],
train_loader = reader_wrapper(train_loader, global_config['input_list']) anno_path=global_config['val_anno_path'])
train_loader = paddle.io.DataLoader(
dataset, batch_size=1, shuffle=True, drop_last=True, num_workers=0)
place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace()
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
......
...@@ -17,13 +17,10 @@ import sys ...@@ -17,13 +17,10 @@ import sys
import numpy as np import numpy as np
import argparse import argparse
import paddle import paddle
from ppdet.core.workspace import load_config, merge_config from paddleslim.auto_compression.config_helpers import load_config
from ppdet.core.workspace import create
from ppdet.metrics import COCOMetric, VOCMetric
from paddleslim.auto_compression.config_helpers import load_config as load_slim_config
from paddleslim.quant import quant_post_static
from paddleslim.common import load_onnx_model from paddleslim.common import load_onnx_model
from paddleslim.quant import quant_post_static
from dataset import COCOTrainDataset
def argsparser(): def argsparser():
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
...@@ -49,32 +46,17 @@ def argsparser(): ...@@ -49,32 +46,17 @@ def argsparser():
return parser return parser
def reader_wrapper(reader, input_list):
def gen():
for data in reader:
in_dict = {}
if isinstance(input_list, list):
for input_name in input_list:
in_dict[input_name] = data[input_name]
elif isinstance(input_list, dict):
for input_name in input_list.keys():
in_dict[input_list[input_name]] = data[input_name]
yield in_dict
return gen
def main(): def main():
global global_config global global_config
all_config = load_slim_config(FLAGS.config_path) all_config = load_config(FLAGS.config_path)
assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}"
global_config = all_config["Global"] global_config = all_config["Global"]
reader_cfg = load_config(global_config['reader_config'])
train_loader = create('EvalReader')(reader_cfg['TrainDataset'], dataset = COCOTrainDataset(
reader_cfg['worker_num'], dataset_dir=global_config['dataset_dir'],
return_list=True) image_dir=global_config['val_image_dir'],
train_loader = reader_wrapper(train_loader, global_config['input_list']) anno_path=global_config['val_anno_path'])
train_loader = paddle.io.DataLoader(
dataset, batch_size=1, shuffle=True, drop_last=True, num_workers=0)
place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace()
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册