提交 fe6f6146 编写于 作者: C cuicheng01

Update multilabel

上级 a90881c9
...@@ -46,8 +46,8 @@ DataLoader: ...@@ -46,8 +46,8 @@ DataLoader:
Train: Train:
dataset: dataset:
name: MultiLabelDataset name: MultiLabelDataset
image_root: ./dataset/NUS-SCENE-dataset/images/ image_root: ./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/images/
cls_label_path: ./dataset/NUS-SCENE-dataset/multilabel_train_list.txt cls_label_path: ./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/multilabel_train_list.txt
transform_ops: transform_ops:
- DecodeImage: - DecodeImage:
to_rgb: True to_rgb: True
...@@ -74,8 +74,8 @@ DataLoader: ...@@ -74,8 +74,8 @@ DataLoader:
Eval: Eval:
dataset: dataset:
name: MultiLabelDataset name: MultiLabelDataset
image_root: ./dataset/NUS-SCENE-dataset/images/ image_root: ./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/images/
cls_label_path: ./dataset/NUS-SCENE-dataset/multilabel_test_list.txt cls_label_path: ./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/multilabel_test_list.txt
transform_ops: transform_ops:
- DecodeImage: - DecodeImage:
to_rgb: True to_rgb: True
......
...@@ -50,7 +50,7 @@ def classification_eval(engine, epoch_id=0): ...@@ -50,7 +50,7 @@ def classification_eval(engine, epoch_id=0):
time_info["reader_cost"].update(time.time() - tic) time_info["reader_cost"].update(time.time() - tic)
batch_size = batch[0].shape[0] batch_size = batch[0].shape[0]
batch[0] = paddle.to_tensor(batch[0]).astype("float32") batch[0] = paddle.to_tensor(batch[0]).astype("float32")
if not evaler.config["Global"].get("use_multilabel", False): if not engine.config["Global"].get("use_multilabel", False):
batch[1] = batch[1].reshape([-1, 1]).astype("int64") batch[1] = batch[1].reshape([-1, 1]).astype("int64")
# image input # image input
out = engine.model(batch[0]) out = engine.model(batch[0])
......
...@@ -76,8 +76,8 @@ def train_epoch(engine, epoch_id, print_batch_step): ...@@ -76,8 +76,8 @@ def train_epoch(engine, epoch_id, print_batch_step):
tic = time.time() tic = time.time()
def forward(trainer, batch): def forward(engine, batch):
if not trainer.is_rec: if not engine.is_rec:
return trainer.model(batch[0]) return engine.model(batch[0])
else: else:
return trainer.model(batch[0], batch[1]) return engine.model(batch[0], batch[1])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册