提交 188f92cd 编写于 作者: D dengkaipeng

fix unittests

上级 59d07373
...@@ -502,4 +502,4 @@ class ColorJitter(object): ...@@ -502,4 +502,4 @@ class ColorJitter(object):
self.transforms = Compose(transforms) self.transforms = Compose(transforms)
def __call__(self, img, lbl): def __call__(self, img, lbl):
return self.transforms(img), lbl return self.transforms(img, lbl)
...@@ -18,7 +18,7 @@ import unittest ...@@ -18,7 +18,7 @@ import unittest
import time import time
import random import random
from callbacks import config_callbacks from hapi.callbacks import config_callbacks
class TestCallbacks(unittest.TestCase): class TestCallbacks(unittest.TestCase):
......
...@@ -18,12 +18,12 @@ ...@@ -18,12 +18,12 @@
import unittest import unittest
import numpy as np import numpy as np
from vision.datasets import * from hapi.datasets import *
class TestFolderDatasets(unittest.TestCase): class TestFolderDatasets(unittest.TestCase):
def test_dataset(self): def test_dataset(self):
dataset_folder = DatasetFolder('test_data') dataset_folder = DatasetFolder('tests/test_data')
for _ in dataset_folder: for _ in dataset_folder:
pass pass
......
...@@ -28,11 +28,12 @@ import contextlib ...@@ -28,11 +28,12 @@ import contextlib
import paddle import paddle
from paddle import fluid from paddle import fluid
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from model import Model, CrossEntropy, Input, Loss, set_device from paddle.io import BatchSampler, DataLoader
from metrics import Accuracy
from callbacks import ProgBarLogger from hapi.model import Model, CrossEntropy, Input, Loss, set_device
from paddle.fluid.io import BatchSampler, DataLoader from hapi.metrics import Accuracy
from paddle.fluid.io import MNIST as MnistDataset from hapi.callbacks import ProgBarLogger
from hapi.datasets import MNIST as MnistDataset
class SimpleImgConvPool(fluid.dygraph.Layer): class SimpleImgConvPool(fluid.dygraph.Layer):
......
...@@ -18,7 +18,7 @@ import unittest ...@@ -18,7 +18,7 @@ import unittest
import random import random
import time import time
from progressbar import ProgressBar from hapi.progressbar import ProgressBar
class TestProgressBar(unittest.TestCase): class TestProgressBar(unittest.TestCase):
......
...@@ -16,13 +16,13 @@ ...@@ -16,13 +16,13 @@
# export PYTHONPATH=PATH_TO_HAPI:$PYTHONPATH # export PYTHONPATH=PATH_TO_HAPI:$PYTHONPATH
import unittest import unittest
from datasets.folder import DatasetFolder from hapi.datasets import DatasetFolder
from transform import transforms import hapi.vision.transforms as transforms
class TestTransforms(unittest.TestCase): class TestTransforms(unittest.TestCase):
def do_transform(self, trans): def do_transform(self, trans):
dataset_folder = DatasetFolder('test_data', transform=trans) dataset_folder = DatasetFolder('tests/test_data', transform=trans)
for _ in dataset_folder: for _ in dataset_folder:
pass pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册