From acce0c4297d877fecdab3e6c6c5bc293950485bd Mon Sep 17 00:00:00 2001 From: chenguowei01 Date: Wed, 29 Jul 2020 11:29:52 +0800 Subject: [PATCH] update train/eval judge --- dygraph/infer.py | 2 +- dygraph/models/hrnet.py | 2 +- dygraph/models/unet.py | 4 ++-- dygraph/train.py | 4 ++-- dygraph/val.py | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/dygraph/infer.py b/dygraph/infer.py index 0b25a48f..3d4c5d5d 100644 --- a/dygraph/infer.py +++ b/dygraph/infer.py @@ -99,7 +99,7 @@ def infer(model, test_dataset=None, model_dir=None, save_dir='output'): logging.info("Start to predict...") for im, im_info, im_path in tqdm.tqdm(test_dataset): im = to_variable(im) - pred, _ = model(im, mode='test') + pred, _ = model(im) pred = pred.numpy() pred = np.squeeze(pred).astype('uint8') for info in im_info[::-1]: diff --git a/dygraph/models/hrnet.py b/dygraph/models/hrnet.py index 3c8139d1..069e6c74 100644 --- a/dygraph/models/hrnet.py +++ b/dygraph/models/hrnet.py @@ -171,7 +171,7 @@ class HRNet(fluid.dygraph.Layer): logit = self.conv_last_1(x) logit = fluid.layers.resize_bilinear(logit, input_shape) - if mode == 'train': + if self.training: if label is None: raise Exception('Label is need during training') return self._get_loss(logit, label) diff --git a/dygraph/models/unet.py b/dygraph/models/unet.py index 78d1a394..8474c2cf 100644 --- a/dygraph/models/unet.py +++ b/dygraph/models/unet.py @@ -29,11 +29,11 @@ class UNet(fluid.dygraph.Layer): self.ignore_index = ignore_index self.EPS = 1e-5 - def forward(self, x, label=None, mode='train'): + def forward(self, x, label=None): encode_data, short_cuts = self.encode(x) decode_data = self.decode(encode_data, short_cuts) logit = self.get_logit(decode_data) - if mode == 'train': + if self.training: return self._get_loss(logit, label) else: score_map = fluid.layers.softmax(logit, axis=1) diff --git a/dygraph/train.py b/dygraph/train.py index 70b61aaf..55048165 100644 --- a/dygraph/train.py +++ b/dygraph/train.py @@ -188,12 +188,12 @@ def train(model, images = data[0] labels = data[1].astype('int64') if nranks > 1: - loss = model_parallel(images, labels, mode='train') + loss = model_parallel(images, labels) loss = model_parallel.scale_loss(loss) loss.backward() model_parallel.apply_collective_grads() else: - loss = model(images, labels, mode='train') + loss = model(images, labels) loss.backward() optimizer.minimize(loss) model.clear_gradients() diff --git a/dygraph/val.py b/dygraph/val.py index ca36a6fe..60fbd17e 100644 --- a/dygraph/val.py +++ b/dygraph/val.py @@ -94,7 +94,7 @@ def evaluate(model, timer.start() for step, (im, im_info, label) in enumerate(eval_dataset): im = to_variable(im) - pred, _ = model(im, mode='eval') + pred, _ = model(im) pred = pred.numpy().astype('float32') pred = np.squeeze(pred) for info in im_info[::-1]: -- GitLab