diff --git a/tools/infer/predict_cls.py b/tools/infer/predict_cls.py index 7d7e4720143c26e36343e3c8f94a0bf4b2caf892..9ec03396f95bd24704be014633916631ff98e627 100755 --- a/tools/infer/predict_cls.py +++ b/tools/infer/predict_cls.py @@ -37,7 +37,7 @@ logger = get_logger() class TextClassifier(object): def __init__(self, args): self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")] - self.cls_batch_num = args.rec_batch_num + self.cls_batch_num = args.cls_batch_num self.cls_thresh = args.cls_thresh self.use_zero_copy_run = args.use_zero_copy_run postprocess_params = {