提交 188f92cd 编写于 作者: D dengkaipeng

fix unittests

上级 59d07373
......@@ -502,4 +502,4 @@ class ColorJitter(object):
self.transforms = Compose(transforms)
def __call__(self, img, lbl):
return self.transforms(img), lbl
return self.transforms(img, lbl)
......@@ -18,7 +18,7 @@ import unittest
import time
import random
from callbacks import config_callbacks
from hapi.callbacks import config_callbacks
class TestCallbacks(unittest.TestCase):
......
......@@ -18,12 +18,12 @@
import unittest
import numpy as np
from vision.datasets import *
from hapi.datasets import *
class TestFolderDatasets(unittest.TestCase):
def test_dataset(self):
dataset_folder = DatasetFolder('test_data')
dataset_folder = DatasetFolder('tests/test_data')
for _ in dataset_folder:
pass
......
......@@ -28,11 +28,12 @@ import contextlib
import paddle
from paddle import fluid
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from model import Model, CrossEntropy, Input, Loss, set_device
from metrics import Accuracy
from callbacks import ProgBarLogger
from paddle.fluid.io import BatchSampler, DataLoader
from paddle.fluid.io import MNIST as MnistDataset
from paddle.io import BatchSampler, DataLoader
from hapi.model import Model, CrossEntropy, Input, Loss, set_device
from hapi.metrics import Accuracy
from hapi.callbacks import ProgBarLogger
from hapi.datasets import MNIST as MnistDataset
class SimpleImgConvPool(fluid.dygraph.Layer):
......
......@@ -18,7 +18,7 @@ import unittest
import random
import time
from progressbar import ProgressBar
from hapi.progressbar import ProgressBar
class TestProgressBar(unittest.TestCase):
......
......@@ -16,13 +16,13 @@
# export PYTHONPATH=PATH_TO_HAPI:$PYTHONPATH
import unittest
from datasets.folder import DatasetFolder
from transform import transforms
from hapi.datasets import DatasetFolder
import hapi.vision.transforms as transforms
class TestTransforms(unittest.TestCase):
def do_transform(self, trans):
dataset_folder = DatasetFolder('test_data', transform=trans)
dataset_folder = DatasetFolder('tests/test_data', transform=trans)
for _ in dataset_folder:
pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册