未验证 提交 2e36a01d 编写于 作者: W whs 提交者: GitHub

Fix ce of dygraph quant (#873)

上级 fd5084c6
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import cv2
import math
import random
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from paddle.io import Dataset
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms from paddle.vision.transforms import transforms
class ImageNetDataset(DatasetFolder): class ImageNetDataset(Dataset):
def __init__(self, def __init__(self,
path, data_dir,
mode='train', mode='train',
image_size=224, image_size=224,
resize_short_size=256): resize_short_size=256):
super(ImageNetDataset, self).__init__(path) super(ImageNetDataset, self).__init__()
train_file_list = os.path.join(data_dir, 'train_list.txt')
val_file_list = os.path.join(data_dir, 'val_list.txt')
test_file_list = os.path.join(data_dir, 'test_list.txt')
self.data_dir = data_dir
self.mode = mode self.mode = mode
normalize = transforms.Normalize( normalize = transforms.Normalize(
...@@ -47,11 +33,35 @@ class ImageNetDataset(DatasetFolder): ...@@ -47,11 +33,35 @@ class ImageNetDataset(DatasetFolder):
normalize normalize
]) ])
def __getitem__(self, idx): if mode == 'train':
img_path, label = self.samples[idx] with open(train_file_list) as flist:
full_lines = [line.strip() for line in flist]
np.random.shuffle(full_lines)
if os.getenv('PADDLE_TRAINING_ROLE'):
# distributed mode if the env var `PADDLE_TRAINING_ROLE` exits
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
per_node_lines = len(full_lines) // trainer_count
lines = full_lines[trainer_id * per_node_lines:(
trainer_id + 1) * per_node_lines]
print(
"read images from %d, length: %d, lines length: %d, total: %d"
% (trainer_id * per_node_lines, per_node_lines,
len(lines), len(full_lines)))
else:
lines = full_lines
self.data = [line.split() for line in lines]
else:
with open(val_file_list) as flist:
lines = [line.strip() for line in flist]
self.data = [line.split() for line in lines]
def __getitem__(self, index):
img_path, label = self.data[index]
img_path = os.path.join(self.data_dir, img_path)
img = Image.open(img_path).convert('RGB') img = Image.open(img_path).convert('RGB')
label = np.array([label]).astype(np.int64) label = np.array([label]).astype(np.int64)
return self.transform(img), label return self.transform(img), label
def __len__(self): def __len__(self):
return len(self.samples) return len(self.data)
...@@ -60,8 +60,7 @@ def main(): ...@@ -60,8 +60,7 @@ def main():
fp32_model = models.__dict__[FLAGS.arch](pretrained=True) fp32_model = models.__dict__[FLAGS.arch](pretrained=True)
fp32_model.eval() fp32_model.eval()
val_dataset = ImageNetDataset( val_dataset = ImageNetDataset(FLAGS.data, mode='val')
os.path.join(FLAGS.data, FLAGS.val_dir), mode='val')
# 2 quantizations # 2 quantizations
ptq = PTQ() ptq = PTQ()
......
...@@ -86,10 +86,8 @@ def main(): ...@@ -86,10 +86,8 @@ def main():
print("Resume from " + FLAGS.resume) print("Resume from " + FLAGS.resume)
model.load(FLAGS.resume) model.load(FLAGS.resume)
train_dataset = ImageNetDataset( train_dataset = ImageNetDataset(FLAGS.data, mode='train')
os.path.join(FLAGS.data, 'train'), mode='train') val_dataset = ImageNetDataset(FLAGS.data, mode='val')
val_dataset = ImageNetDataset(
os.path.join(FLAGS.data, FLAGS.val_dir), mode='val')
optim = make_optimizer( optim = make_optimizer(
np.ceil( np.ceil(
...@@ -152,10 +150,6 @@ if __name__ == '__main__': ...@@ -152,10 +150,6 @@ if __name__ == '__main__':
default="", default="",
help='path to dataset ' help='path to dataset '
'(should have subdirectories named "train" and "val"') '(should have subdirectories named "train" and "val"')
parser.add_argument(
'--val_dir',
default="val",
help='the dir that saves val images for paddle.Model')
# train # train
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册