提交 394a34e0 编写于 作者: C chenguowei01

update

上级 6a569885
......@@ -12,13 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import random
from paddle.fluid.io import Dataset
import cv2
from utils.download import download_file_and_uncompress
......
......@@ -12,10 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid.dygraph import Conv2D, BatchNorm, Pool2D
......@@ -23,8 +19,8 @@ from paddle.fluid.dygraph import Conv2D, BatchNorm, Pool2D
class UNet(fluid.dygraph.Layer):
def __init__(self, num_classes, ignore_index=255):
super().__init__()
self.encode = Encoder()
self.decode = Decode()
self.encode = UnetEncoder()
self.decode = UnetDecode()
self.get_logit = GetLogit(64, num_classes)
self.ignore_index = ignore_index
self.EPS = 1e-5
......@@ -61,7 +57,7 @@ class UNet(fluid.dygraph.Layer):
return avg_loss
class Encoder(fluid.dygraph.Layer):
class UnetEncoder(fluid.dygraph.Layer):
def __init__(self):
super().__init__()
self.double_conv = DoubleConv(3, 64)
......@@ -84,7 +80,7 @@ class Encoder(fluid.dygraph.Layer):
return x, short_cuts
class Decode(fluid.dygraph.Layer):
class UnetDecode(fluid.dygraph.Layer):
def __init__(self):
super().__init__()
self.up1 = Up(512, 256)
......
......@@ -41,36 +41,8 @@ class DistributedBatchSampler(BatchSampler):
batch indices. Default False.
drop_last(bool): whether drop the last incomplete batch dataset size
is not divisible by the batch size. Default False
Examples:
.. code-block:: python
import numpy as np
from hapi.datasets import MNIST
from hapi.distributed import DistributedBatchSampler
class MnistDataset(MNIST):
def __init__(self, mode, return_label=True):
super(MnistDataset, self).__init__(mode=mode)
self.return_label = return_label
def __getitem__(self, idx):
img = np.reshape(self.images[idx], [1, 28, 28])
if self.return_label:
return img, np.array(self.labels[idx]).astype('int64')
return img,
def __len__(self):
return len(self.images)
train_dataset = MnistDataset(mode='train')
dist_train_dataloader = DistributedBatchSampler(train_dataset, batch_size=64)
for data in dist_train_dataloader:
# do something
break
"""
def __init__(self, dataset, batch_size, shuffle=False, drop_last=False):
self.dataset = dataset
......@@ -111,9 +83,8 @@ class DistributedBatchSampler(BatchSampler):
indices = indices[len(indices) - last_batch_size:]
subsampled_indices.extend(
indices[self.local_rank *
last_local_batch_size:(self.local_rank + 1) *
last_local_batch_size])
indices[self.local_rank * last_local_batch_size:
(self.local_rank + 1) * last_local_batch_size])
return subsampled_indices
if self.nranks > 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册