未验证 提交 31628f60 编写于 作者: W wangguanzhong 提交者: GitHub

fix train (#1743)

上级 2a9e2559
...@@ -12,7 +12,7 @@ TrainReader: ...@@ -12,7 +12,7 @@ TrainReader:
- PadBatch: {pad_to_stride: 32, use_padded_im_info: false, pad_gt: true} - PadBatch: {pad_to_stride: 32, use_padded_im_info: false, pad_gt: true}
batch_size: 1 batch_size: 1
shuffle: true shuffle: true
drop_last: false drop_last: true
EvalReader: EvalReader:
......
...@@ -135,7 +135,7 @@ class TrainReader(BaseDataLoader): ...@@ -135,7 +135,7 @@ class TrainReader(BaseDataLoader):
batch_transforms=None, batch_transforms=None,
batch_size=1, batch_size=1,
shuffle=True, shuffle=True,
drop_last=False, drop_last=True,
drop_empty=True, drop_empty=True,
num_classes=81, num_classes=81,
with_background=True): with_background=True):
......
...@@ -129,10 +129,9 @@ class COCODataSet(DetDataset): ...@@ -129,10 +129,9 @@ class COCODataSet(DetDataset):
gt_bbox[i, :] = box['clean_bbox'] gt_bbox[i, :] = box['clean_bbox']
is_crowd[i][0] = box['iscrowd'] is_crowd[i][0] = box['iscrowd']
# check RLE format # check RLE format
if box['iscrowd'] == 1: if 'segmentation' in box and box['iscrowd'] == 1:
gt_poly[i] = [[0.0, 0.0], ] gt_poly[i] = [[0.0, 0.0], ]
continue elif 'segmentation' in box:
if 'segmentation' in box:
gt_poly[i] = box['segmentation'] gt_poly[i] = box['segmentation']
if not any(gt_poly): if not any(gt_poly):
......
...@@ -22,6 +22,7 @@ except Exception: ...@@ -22,6 +22,7 @@ except Exception:
from paddle.io import Dataset from paddle.io import Dataset
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
from ppdet.utils.download import get_dataset_path from ppdet.utils.download import get_dataset_path
import copy
@serializable @serializable
...@@ -45,7 +46,7 @@ class DetDataset(Dataset): ...@@ -45,7 +46,7 @@ class DetDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
# data batch # data batch
roidb = self.roidbs[idx] roidb = copy.deepcopy(self.roidbs[idx])
# data augment # data augment
roidb = self.transform(roidb) roidb = self.transform(roidb)
# data item # data item
......
...@@ -32,8 +32,7 @@ class BaseArch(nn.Layer): ...@@ -32,8 +32,7 @@ class BaseArch(nn.Layer):
def build_inputs(self, data, input_def): def build_inputs(self, data, input_def):
inputs = {} inputs = {}
for i, k in enumerate(input_def): for i, k in enumerate(input_def):
v = paddle.to_tensor(data[i]) inputs[k] = data[i]
inputs[k] = v
return inputs return inputs
def model_arch(self): def model_arch(self):
......
...@@ -126,7 +126,7 @@ def run(FLAGS, cfg, place): ...@@ -126,7 +126,7 @@ def run(FLAGS, cfg, place):
model = create(cfg.architecture) model = create(cfg.architecture)
# Optimizer # Optimizer
lr = create('LearningRate')(step_per_epoch / int(ParallelEnv().nranks)) lr = create('LearningRate')(step_per_epoch)
optimizer = create('OptimizerBuilder')(lr, model.parameters()) optimizer = create('OptimizerBuilder')(lr, model.parameters())
# Init Model & Optimzer # Init Model & Optimzer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册