import functools def clas_forward_decorator(forward_func): @functools.wraps(forward_func) def parse_batch_wrapper(model, batch): x, label = batch[0], batch[1] return forward_func(model, x) return parse_batch_wrapper