diff --git a/03.image_classification/train.py b/03.image_classification/train.py index 10ce9df1d0c91543e62172fb1800817ab87c3fe7..9290b6acf3edb8a3951066bb053289289660bdba 100644 --- a/03.image_classification/train.py +++ b/03.image_classification/train.py @@ -118,7 +118,11 @@ def infer(use_cuda, inference_program, params_dirname=None): # inference results = inferencer.infer({'pixel': img}) - print("infer results: ", results) + label_list = [ + "airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", + "ship", "truck" + ] + print("infer results: %s" % label_list[np.argmax(results[0])]) def main(use_cuda):