import os

import torch
from tqdm import tqdm

from utils.utils import get_lr
def fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, local_rank=0):
    loss        = 0
    val_loss    = 0

    if local_rank == 0:
        print('Start Train')
        pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
    for iteration, batch in enumerate(gen):
        if iteration >= epoch_step:

        images, targets, y_trues = batch[0], batch[1], batch[2]
        with torch.no_grad():
            if cuda:
                images  = images.cuda(local_rank)
                targets = [ann.cuda(local_rank) for ann in targets]
                y_trues = [ann.cuda(local_rank) for ann in y_trues]
        #   清零梯度
        if not fp16:
            #   前向传播
            outputs         = model_train(images)

            loss_value_all  = 0
            #   计算损失
            for l in range(len(outputs)):
                loss_item = yolo_loss(l, outputs[l], targets, y_trues[l])
                loss_value_all  += loss_item
            loss_value = loss_value_all

            #   反向传播
            from torch.cuda.amp import autocast
            with autocast():
                #   前向传播
                outputs         = model_train(images)

                loss_value_all  = 0
                #   计算损失
                for l in range(len(outputs)):
                    loss_item = yolo_loss(l, outputs[l], targets, y_trues[l])
                    loss_value_all  += loss_item
                loss_value = loss_value_all

            #   反向传播
        if ema:

        loss += loss_value.item()
        if local_rank == 0:
            pbar.set_postfix(**{'loss'  : loss / (iteration + 1), 
                                'lr'    : get_lr(optimizer)})

    if local_rank == 0:
        print('Finish Train')
        print('Start Validation')
        pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)

    if ema:
        model_train_eval = ema.ema
        model_train_eval = model_train.eval()
    for iteration, batch in enumerate(gen_val):
        if iteration >= epoch_step_val:
        images, targets, y_trues = batch[0], batch[1], batch[2]
        with torch.no_grad():
            if cuda:
                images  = images.cuda(local_rank)
                targets = [ann.cuda(local_rank) for ann in targets]
                y_trues = [ann.cuda(local_rank) for ann in y_trues]
            #   清零梯度
            #   前向传播
            outputs         = model_train_eval(images)

            loss_value_all  = 0
            #   计算损失
            for l in range(len(outputs)):
                loss_item = yolo_loss(l, outputs[l], targets, y_trues[l])
                loss_value_all  += loss_item
            loss_value  = loss_value_all

        val_loss += loss_value.item()
        if local_rank == 0:
            pbar.set_postfix(**{'val_loss': val_loss / (iteration + 1)})
    if local_rank == 0:
        print('Finish Validation')
        loss_history.append_loss(epoch + 1, loss / epoch_step, val_loss / epoch_step_val)
        eval_callback.on_epoch_end(epoch + 1, model_train_eval)
        print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch))
        print('Total Loss: %.3f || Val Loss: %.3f ' % (loss / epoch_step, val_loss / epoch_step_val))
        #   保存权值
        if ema:
            save_state_dict = ema.ema.state_dict()
            save_state_dict = model.state_dict()

        if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
            torch.save(save_state_dict, os.path.join(save_dir, "ep%03d-loss%.3f-val_loss%.3f.pth" % (epoch + 1, loss / epoch_step, val_loss / epoch_step_val)))
        if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss):
            print('Save best model to best_epoch_weights.pth')
            torch.save(save_state_dict, os.path.join(save_dir, "best_epoch_weights.pth"))
        torch.save(save_state_dict, os.path.join(save_dir, "last_epoch_weights.pth"))