提交 7ae7505c 编写于 作者: Z zhouyaqiang

move argmax from host to device

上级 34864fbc
......@@ -381,6 +381,7 @@ class DeepLabV3(nn.Cell):
self.concat = P.Concat(axis=2)
self.expand_dims = P.ExpandDims()
self.reduce_mean = P.ReduceMean()
self.argmax = P.Argmax(axis=1)
self.sample_common = P.ResizeBilinear((int(feature_shape[2]),
int(feature_shape[3])),
align_corners=True)
......@@ -419,6 +420,8 @@ class DeepLabV3(nn.Cell):
logits_i = self.expand_dims(logits_i, 2)
logits = self.concat((logits, logits_i))
logits = self.reduce_mean(logits, 2)
if not self.training:
logits = self.argmax(logits)
return logits
......
......@@ -42,6 +42,8 @@ class OhemLoss(nn.Cell):
self.loss_weight = 1.0
def construct(self, logits, labels):
if not self.training:
return 0
logits = self.transpose(logits, (0, 2, 3, 1))
logits = self.reshape(logits, (-1, self.num))
labels = F.cast(labels, mstype.int32)
......
......@@ -50,10 +50,7 @@ class MiouPrecision(Metric):
raise ValueError('Need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
predict_in = self._convert_data(inputs[0])
label_in = self._convert_data(inputs[1])
if predict_in.shape[1] != self._num_class:
raise ValueError('Class number not match, last input data contain {} classes, but current data contain {} '
'classes'.format(self._num_class, predict_in.shape[1]))
pred = np.argmax(predict_in, axis=1)
pred = predict_in
label = label_in
if len(label.flatten()) != len(pred.flatten()):
print('Skipping: len(gt) = {:d}, len(pred) = {:d}'.format(len(label.flatten()), len(pred.flatten())))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册