提交 acce0c42 编写于 作者: C chenguowei01

update train/eval judge

上级 0a9249c9
...@@ -99,7 +99,7 @@ def infer(model, test_dataset=None, model_dir=None, save_dir='output'): ...@@ -99,7 +99,7 @@ def infer(model, test_dataset=None, model_dir=None, save_dir='output'):
logging.info("Start to predict...") logging.info("Start to predict...")
for im, im_info, im_path in tqdm.tqdm(test_dataset): for im, im_info, im_path in tqdm.tqdm(test_dataset):
im = to_variable(im) im = to_variable(im)
pred, _ = model(im, mode='test') pred, _ = model(im)
pred = pred.numpy() pred = pred.numpy()
pred = np.squeeze(pred).astype('uint8') pred = np.squeeze(pred).astype('uint8')
for info in im_info[::-1]: for info in im_info[::-1]:
......
...@@ -171,7 +171,7 @@ class HRNet(fluid.dygraph.Layer): ...@@ -171,7 +171,7 @@ class HRNet(fluid.dygraph.Layer):
logit = self.conv_last_1(x) logit = self.conv_last_1(x)
logit = fluid.layers.resize_bilinear(logit, input_shape) logit = fluid.layers.resize_bilinear(logit, input_shape)
if mode == 'train': if self.training:
if label is None: if label is None:
raise Exception('Label is need during training') raise Exception('Label is need during training')
return self._get_loss(logit, label) return self._get_loss(logit, label)
......
...@@ -29,11 +29,11 @@ class UNet(fluid.dygraph.Layer): ...@@ -29,11 +29,11 @@ class UNet(fluid.dygraph.Layer):
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.EPS = 1e-5 self.EPS = 1e-5
def forward(self, x, label=None, mode='train'): def forward(self, x, label=None):
encode_data, short_cuts = self.encode(x) encode_data, short_cuts = self.encode(x)
decode_data = self.decode(encode_data, short_cuts) decode_data = self.decode(encode_data, short_cuts)
logit = self.get_logit(decode_data) logit = self.get_logit(decode_data)
if mode == 'train': if self.training:
return self._get_loss(logit, label) return self._get_loss(logit, label)
else: else:
score_map = fluid.layers.softmax(logit, axis=1) score_map = fluid.layers.softmax(logit, axis=1)
......
...@@ -188,12 +188,12 @@ def train(model, ...@@ -188,12 +188,12 @@ def train(model,
images = data[0] images = data[0]
labels = data[1].astype('int64') labels = data[1].astype('int64')
if nranks > 1: if nranks > 1:
loss = model_parallel(images, labels, mode='train') loss = model_parallel(images, labels)
loss = model_parallel.scale_loss(loss) loss = model_parallel.scale_loss(loss)
loss.backward() loss.backward()
model_parallel.apply_collective_grads() model_parallel.apply_collective_grads()
else: else:
loss = model(images, labels, mode='train') loss = model(images, labels)
loss.backward() loss.backward()
optimizer.minimize(loss) optimizer.minimize(loss)
model.clear_gradients() model.clear_gradients()
......
...@@ -94,7 +94,7 @@ def evaluate(model, ...@@ -94,7 +94,7 @@ def evaluate(model,
timer.start() timer.start()
for step, (im, im_info, label) in enumerate(eval_dataset): for step, (im, im_info, label) in enumerate(eval_dataset):
im = to_variable(im) im = to_variable(im)
pred, _ = model(im, mode='eval') pred, _ = model(im)
pred = pred.numpy().astype('float32') pred = pred.numpy().astype('float32')
pred = np.squeeze(pred) pred = np.squeeze(pred)
for info in im_info[::-1]: for info in im_info[::-1]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册