ppo.py 20.7 KB
Newer Older
U
u010280923 已提交
1 2 3 4 5 6 7 8 9 10 11
import math
from pathlib import Path
import copy
from tqdm import tqdm
from functools import partial
from collections import deque, namedtuple
from random import randrange

from beartype import beartype
from beartype.typing import List, Optional, Callable, Deque

U
u010280923 已提交
12 13 14
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

U
u010280923 已提交
15 16 17 18 19 20 21 22
import torch
from torch import nn
import torch.nn.functional as F

from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

U
u010280923 已提交
23
import pytorch_lightning as pl
U
u010280923 已提交
24
from pytorch_lightning.utilities import rank_zero_info
U
u010280923 已提交
25 26
from pytorch_lightning.strategies import DeepSpeedStrategy
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
U
u010280923 已提交
27

U
u010280923 已提交
28 29 30 31
from src.model import RWKV
from src.rlhf.reward import RewardModel
from src.rlhf.optimizer import get_optimizer
from src.rlhf.utils import masked_mean, eval_decorator
32 33
from src.dataset import load_prompt_data_4_ppo
from src.dataset import ExperienceDataset
U
u010280923 已提交
34

每日一练社区's avatar
fix bug  
每日一练社区 已提交
35
# actor critic
U
u010280923 已提交
36 37 38 39 40 41 42 43 44 45 46

PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [
    'actions',
    'sequence',
    'mask',
    'prompt_mask',
    'action_logits',
    'values'
])

@beartype
U
u010280923 已提交
47
class ActorCritic(nn.Module):
U
u010280923 已提交
48 49
    def __init__(
        self,
U
u010280923 已提交
50
        args,
U
u010280923 已提交
51 52
        actor: RWKV,
        critic: RWKV,
U
u010280923 已提交
53
        pooled_values = False
U
u010280923 已提交
54 55 56
    ):
        super().__init__()

57
        self.args = args
U
u010280923 已提交
58
        self.actor = actor
U
u010280923 已提交
59
        self.critic = critic
U
u010280923 已提交
60 61 62

        self.pooled_values = pooled_values
        self.value_head = nn.Sequential(
63
            nn.Linear(self.args.n_embd, 1),
U
u010280923 已提交
64 65 66 67 68 69 70 71 72 73 74 75
            Rearrange('... 1 -> ...')
        )

        nn.init.zeros_(self.value_head[0].bias)
        nn.init.orthogonal_(self.value_head[0].weight, gain = math.sqrt(2))

    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        state,
        max_seq_len,
U
u010280923 已提交
76
        return_values = False
U
u010280923 已提交
77
    ):
U
u010280923 已提交
78 79
        # 产生一条 response,相当于采取了一次 action
        actions = self.actor.generate(
U
u010280923 已提交
80
            max_seq_len,
81
            prompt = state
U
u010280923 已提交
82 83
        )

U
u010280923 已提交
84
        # 将 prompt (state) 和 response (action) 进行拼接
U
u010280923 已提交
85 86 87 88
        sequence = torch.cat((state, actions), dim = -1)
        action_len = actions.shape[-1]
        state_len = state.shape[-1]

U
u010280923 已提交
89
        # 构建 prompt_mask (state_mask) 和 response_mask (action_mask)
U
u010280923 已提交
90 91 92 93 94
        prompt_mask = torch.arange(sequence.shape[-1], device = state.device) < state_len
        prompt_mask = repeat(prompt_mask, 'n -> b n', b = sequence.shape[0])

        action_mask = ~prompt_mask

U
u010280923 已提交
95
        # 考虑 eos token
U
u010280923 已提交
96
        mask = None
97 98
        if exists(self.args.eos_token):
            mask = ((sequence == self.args.eos_token).cumsum(dim = -1) == 0)
U
u010280923 已提交
99 100 101
            mask = F.pad(mask, (1, -1), value = True) # include eos token
            action_mask &= mask

U
u010280923 已提交
102 103
        # 将生成的 sequence 输入到 actor 中,得到 action_logits
        # 将生成的 sequence 输入到 critic 中,得到 value
U
u010280923 已提交
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
        action_logits, value = self.forward(
            sequence,
            mask = action_mask,
            return_values = return_values
        )        

        return PPOActionCriticReturn(
            actions,
            sequence,
            mask,
            prompt_mask,
            action_logits,
            value
        )

    def forward(
        self,
        x,
        mask = None,
        return_values = True
    ):
U
u010280923 已提交
125
        action_logits, _ = self.actor(
U
u010280923 已提交
126
            x,
U
u010280923 已提交
127
            ppo_train = True
U
u010280923 已提交
128 129 130 131 132
        )

        if not return_values:
            return action_logits, None

U
u010280923 已提交
133
        _, critic_embeds = self.critic(
U
u010280923 已提交
134 135
            x,
            return_only_embedding = True,
U
u010280923 已提交
136
            ppo_train = True
U
u010280923 已提交
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
        )

        if self.pooled_values:
            critic_embeds = shift(critic_embeds, shift = 1, dim = -2)
            critic_embeds = masked_mean(critic_embeds, mask, dim = 1)

        values = self.value_head(critic_embeds)

        return action_logits, values

# helper functions

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def masked_normalize(t, eps = 1e-5, mask = None, dim = None):
    dim = default(dim, tuple(range(t.ndim)))
    kwargs = dict(dim = dim, keepdim = True)

    mean = masked_mean(t, mask = mask, **kwargs)
    mean_centered = t - mean
    var = masked_mean(mean_centered ** 2, mask = mask, **kwargs)

    return mean_centered * var.clamp(min = eps).rsqrt()

def pad_sequence_fixed(sequences, *args, **kwargs):
    first_el = sequences[0]
    has_no_dimension = first_el.ndim == 0

    # if no dimensions, add a single dimension
    if has_no_dimension:
        sequences = tuple(map(lambda t: t[None], sequences))

    out = pad_sequence(sequences, *args, **kwargs)

    if has_no_dimension:
        out = rearrange(out, '... 1 -> ...')

    return out

def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

def log_prob(prob, indices):
    assert prob.shape[:2] == indices.shape, f'preceding shapes of prob {prob.shape[:2]} and indices {indices.shape} must match'
    return log(prob.gather(-1, indices[..., None])).squeeze(-1)

def shift(t, value = 0, shift = 1, dim = -1):
    zeros = (0, 0) * (-dim - 1)
    return F.pad(t, (*zeros, shift, -shift), value = value)

def masked_entropy(prob, dim = -1, mask = None):
    entropies = (prob * log(prob)).sum(dim = -1)
    return masked_mean(entropies, mask = mask).mean()

def masked_kl_div(prob1, prob2, mask = None):
    """
    need to account for variable sequence lengths, therefore not using the built-in functional version
    """
    kl_divs = (prob1 * (log(prob2) - log(prob1))).sum(dim = -1)

    if not exists(mask):
        return kl_divs.mean()

    return masked_mean(kl_divs, mask).mean()

def clipped_value_loss(values, rewards, old_values, clip):
    value_clipped = old_values + (values - old_values).clamp(-clip, clip)
    value_loss_1 = (value_clipped.flatten() - rewards) ** 2
    value_loss_2 = (values.flatten() - rewards) ** 2
    return torch.mean(torch.max(value_loss_1, value_loss_2))

U
u010280923 已提交
212
# rlhf
U
u010280923 已提交
213
@beartype
U
u010280923 已提交
214
class RLHF(pl.LightningModule):
U
u010280923 已提交
215 216
    def __init__(
        self,
U
u010280923 已提交
217
        args,
U
u010280923 已提交
218 219
        actor: RWKV,
        critic: RWKV,
U
u010280923 已提交
220
        reward_model: RewardModel
U
u010280923 已提交
221 222 223
    ):
        super().__init__()

U
u010280923 已提交
224
        self.args = args
U
u010280923 已提交
225

226 227 228 229 230 231 232 233 234 235 236 237
        # 读入 prompts 数据
        self.prompts = load_prompt_data_4_ppo(args)

        # 用于保存与 environment 的交互数据,用于训练 actor_critic (agent)
        self.sequence_batch = []
        self.prompt_mask_batch = []
        self.mask_batch = []
        self.action_prob_batch = []
        self.action_log_prob_batch = []
        self.reward_batch = []
        self.value_batch = []

U
u010280923 已提交
238 239
        # 使用 RWKV 初始化 actor_critic
        actor_critic = ActorCritic(
U
u010280923 已提交
240 241 242
            args=self.args,
            actor=actor,
            critic=critic,
U
u010280923 已提交
243
            pooled_values = args.critic_pooled_values
U
u010280923 已提交
244
        ).to(actor.device)
U
u010280923 已提交
245 246 247

        self.actor_critic = actor_critic

U
u010280923 已提交
248
        # 将 reward_model 设置为 evaluation 模式 
U
u010280923 已提交
249 250 251 252 253 254 255 256
        self.reward_model = reward_model.eval()

    def save(self, filepath = './checkpoint.pt'):
        torch.save(self.actor_critic.state_dict(), filepath)

    def load(self, filepath = './checkpoint.pt'):
        state_dict = torch.load(filepath)
        self.actor_critic.load_state_dict(state_dict)
U
u010280923 已提交
257 258 259
    
    def configure_optimizers(self):
        args = self.args
260 261 262 263

        optim_groups_actor = []
        optim_groups_critic = []

U
u010280923 已提交
264
        if args.layerwise_lr > 0:
265 266 267 268 269 270 271 272
            lr_1x_actor = set()
            lr_2x_actor = set()
            lr_3x_actor = set()

            lr_1x_critic = set()
            lr_2x_critic = set()
            lr_3x_critic = set()

U
u010280923 已提交
273 274 275
            for n, p in self.named_parameters():
                if "time_mix" in n:
                    if args.my_pile_stage == 2:
276 277 278 279
                        if "actor" in n:
                            lr_2x_actor.add(n)
                        elif "critic" in n:
                            lr_2x_critic.add(n)
U
u010280923 已提交
280
                    else:
281 282 283 284
                        if "actor" in n:
                            lr_1x_actor.add(n)
                        elif "critic" in n:
                            lr_1x_critic.add(n)
U
u010280923 已提交
285 286
                elif "time_decay" in n:
                    if args.my_pile_stage == 2:
287 288 289 290
                        if "actor" in n:
                            lr_3x_actor.add(n)
                        elif "critic" in n:
                            lr_3x_critic.add(n)
U
u010280923 已提交
291
                    else:
292 293 294 295
                        if "actor" in n:
                            lr_2x_actor.add(n)
                        elif "critic" in n:
                            lr_2x_critic.add(n)
U
u010280923 已提交
296
                elif "time_first" in n:
297 298 299 300
                    if "actor" in n:
                        lr_3x_actor.add(n)
                    elif "critic" in n:
                        lr_3x_critic.add(n)
U
u010280923 已提交
301
                else:
302 303 304 305 306 307 308 309 310 311 312 313 314
                    if "actor" in n:
                        lr_1x_actor.add(n)
                    elif "critic" in n:
                        lr_1x_critic.add(n)

            lr_1x_actor = sorted(list(lr_1x_actor))
            lr_2x_actor = sorted(list(lr_2x_actor))
            lr_3x_actor = sorted(list(lr_3x_actor))

            lr_1x_critic = sorted(list(lr_1x_critic))
            lr_2x_critic = sorted(list(lr_2x_critic))
            lr_3x_critic = sorted(list(lr_3x_critic))

U
u010280923 已提交
315 316
            param_dict = {n: p for n, p in self.named_parameters()}
            if args.my_pile_stage == 2:
317 318 319 320 321 322 323 324 325 326
                optim_groups_actor = [
                    {"params": [param_dict[n] for n in lr_1x_actor], "weight_decay": 0.0, "my_lr_scale": 1.0},
                    {"params": [param_dict[n] for n in lr_2x_actor], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init},
                    {"params": [param_dict[n] for n in lr_3x_actor], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init},
                ]

                optim_groups_critic = [
                    {"params": [param_dict[n] for n in lr_1x_critic], "weight_decay": 0.0, "my_lr_scale": 1.0},
                    {"params": [param_dict[n] for n in lr_2x_critic], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init},
                    {"params": [param_dict[n] for n in lr_3x_critic], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init},
U
u010280923 已提交
327 328
                ]
            else:
329 330 331 332 333 334 335 336 337 338
                optim_groups_actor = [
                    {"params": [param_dict[n] for n in lr_1x_actor], "weight_decay": 0.0, "my_lr_scale": 1.0},
                    {"params": [param_dict[n] for n in lr_2x_actor], "weight_decay": 0.0, "my_lr_scale": 2.0},
                    {"params": [param_dict[n] for n in lr_3x_actor], "weight_decay": 0.0, "my_lr_scale": 3.0},
                ]

                optim_groups_critic = [
                    {"params": [param_dict[n] for n in lr_1x_critic], "weight_decay": 0.0, "my_lr_scale": 1.0},
                    {"params": [param_dict[n] for n in lr_2x_critic], "weight_decay": 0.0, "my_lr_scale": 2.0},
                    {"params": [param_dict[n] for n in lr_3x_critic], "weight_decay": 0.0, "my_lr_scale": 3.0},
U
u010280923 已提交
339 340
                ]
        else:
341 342 343 344 345 346
            optim_groups_actor = [
                {"params": [p for n, p in self.named_parameters() if "actor" in n], "weight_decay": 0.0},
            ]

            optim_groups_critic = [
                {"params": [p for n, p in self.named_parameters() if "critic" in n], "weight_decay": 0.0},
U
u010280923 已提交
347 348 349
            ]

        if self.deepspeed_offload:
350 351 352 353 354 355 356 357 358
            actor_optimizer = DeepSpeedCPUAdam(optim_groups_actor, lr=self.args.actor_lr, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
            critic_optimizer = DeepSpeedCPUAdam(optim_groups_critic, lr=self.args.critic_lr, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
            
            return actor_optimizer, critic_optimizer
        
        actor_optimizer =  FusedAdam(optim_groups_actor, lr=self.args.actor_lr, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
        critic_optimizer = FusedAdam(optim_groups_critic, lr=self.args.critic_lr, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)

        return actor_optimizer, critic_optimizer
U
u010280923 已提交
359 360 361 362 363 364 365 366

    @property
    def deepspeed_offload(self) -> bool:
        strategy = self.trainer.strategy
        if isinstance(strategy, DeepSpeedStrategy):
            cfg = strategy.config["zero_optimization"]
            return cfg.get("offload_optimizer") or cfg.get("offload_param")
        return False
U
u010280923 已提交
367 368 369 370 371 372

    @torch.no_grad()
    def generate(
        self,
        max_seq_len,
        prompt,
U
u010280923 已提交
373
        num_samples = 4  # sample 4 per prompt and select the one with highest reward
U
u010280923 已提交
374
    ):
U
u010280923 已提交
375 376 377
        ''' 未参与训练,仅推理时使用
        '''

U
u010280923 已提交
378 379 380
        assert prompt.ndim == 1, 'only one prompt allowed at a time for now'
        prompt = repeat(prompt, 'n -> b n', b = num_samples)

U
u010280923 已提交
381
        self.actor_critic.eval()
U
u010280923 已提交
382 383 384 385 386 387 388
        (
            actions,
            sequences,
            mask,
            prompt_mask,
            action_logits,
            _
U
u010280923 已提交
389
        ) = self.actor_critic.generate(
U
u010280923 已提交
390 391
            prompt,
            max_seq_len = max_seq_len,
U
u010280923 已提交
392
            return_values = False
U
u010280923 已提交
393 394
        )

U
u010280923 已提交
395
        rewards = self.reward_model(
U
u010280923 已提交
396 397 398 399 400 401 402 403 404 405 406 407 408
            sequences,
            prompt_mask = prompt_mask,
            mask = mask,
            sample = True
        )

        best_sequence_index = rewards.topk(1, dim = -1).indices

        best_sequence = sequences[best_sequence_index]
        best_sequence = rearrange(best_sequence, '1 ... -> ...')

        return best_sequence

409
    def training_step(self, batch, batch_idx, optimizer_idx):
U
u010280923 已提交
410 411 412 413 414 415 416
        sequences, \
        prompt_masks, \
        masks, \
        old_action_probs, \
        old_log_probs, \
        rewards, \
        old_values = batch
U
u010280923 已提交
417 418

        # PPO training
U
u010280923 已提交
419
        action_masks = ~prompt_masks & masks
U
u010280923 已提交
420

U
u010280923 已提交
421 422 423 424
        action_logits, values = self.actor_critic(
            sequences,
            mask = action_masks
        )
U
u010280923 已提交
425

U
u010280923 已提交
426
        action_logits = shift(action_logits, shift=1, dim=-2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token
U
u010280923 已提交
427
        action_len = old_log_probs.shape[-1]
U
u010280923 已提交
428

U
u010280923 已提交
429 430 431
        action_probs = action_logits.softmax(dim = -1)
        action_log_probs = log_prob(action_probs, sequences)
        action_log_probs = action_log_probs[:, -action_len:]
U
u010280923 已提交
432

U
u010280923 已提交
433
        # calculate entropies, taking into account which part of the sequence is actually an action
U
u010280923 已提交
434

U
u010280923 已提交
435
        entropies = masked_entropy(action_probs, mask = action_masks)
U
u010280923 已提交
436

U
u010280923 已提交
437
        # calculate kl div between old action probs and new ones, taking into account which part of the sequence is action or not
U
u010280923 已提交
438

U
u010280923 已提交
439
        kl_div_loss = 0.
U
u010280923 已提交
440

U
u010280923 已提交
441 442
        if self.args.kl_div_loss_weight > 0:
            kl_div_loss = masked_kl_div(action_probs, old_action_probs, mask = action_masks) * self.args.kl_div_loss_weight
U
u010280923 已提交
443

U
u010280923 已提交
444
        # handle non-pooled values
U
u010280923 已提交
445

U
u010280923 已提交
446
        normalize_kwargs = dict()
U
u010280923 已提交
447

U
u010280923 已提交
448 449
        if old_values.ndim == 2:
            old_values, values = map(lambda t: shift(t, shift = 1, dim = -2), (old_values, values))
U
u010280923 已提交
450

U
u010280923 已提交
451 452 453 454
            old_values = old_values[:, -action_len:]
            values = values[:, -action_len:]
            rewards = rearrange(rewards, 'b -> b 1')
            normalize_kwargs = dict(dim = -1, mask = action_masks[:, -action_len:])
U
u010280923 已提交
455

U
u010280923 已提交
456 457
        if values.ndim < rewards.ndim:
            values = rearrange(values, '... -> ... 1')
U
u010280923 已提交
458

U
u010280923 已提交
459
        # calculate clipped surrogate objective, classic PPO loss
U
u010280923 已提交
460

U
u010280923 已提交
461 462
        ratios = (action_log_probs - old_log_probs).exp()
        advantages = masked_normalize(rewards - old_values, **normalize_kwargs)
U
u010280923 已提交
463

U
u010280923 已提交
464 465
        if advantages.ndim == 1:
            advantages = rearrange(advantages, 'b -> b 1')
U
u010280923 已提交
466

U
u010280923 已提交
467 468 469
        surr1 = ratios * advantages
        surr2 = ratios.clamp(1 - self.args.eps_clip, 1 + self.args.eps_clip) * advantages
        policy_loss = - torch.min(surr1, surr2) - self.args.beta_s * entropies
U
u010280923 已提交
470

U
u010280923 已提交
471
        # actor loss (也称为 policy loss, 是最终要使用模型的 loss)
472 473 474
        if optimizer_idx == 0:
            actor_loss = policy_loss.mean() + kl_div_loss
            return actor_loss
U
u010280923 已提交
475

U
u010280923 已提交
476 477
        # critic loss (也称为 value loss)
        # update value network separate from policy network
478 479 480 481
        if optimizer_idx == 1:
            critic_loss = clipped_value_loss(values, rewards, old_values, self.args.value_clip)
            critic_loss = critic_loss.mean()
            return critic_loss
U
u010280923 已提交
482

483
    def gen_experience_dataset(self):
U
u010280923 已提交
484 485 486
        ''' 通过与 environment 交互产生训练数据
        '''
        
U
u010280923 已提交
487 488
        device = self.device

489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585
        time_cnt = 0
        for eps in tqdm(range(self.args.num_episodes), desc = 'episodes'):
            for timestep in range(self.args.max_timesteps):
                time_cnt += 1

                # select a bunch of random states (prompts)
                # and get the action (sampled sequence from rwkv as well as the action probs)
                # also calculate the reward using reward model and store
                # 随机挑选一条 prompt
                rand_prompt_index = randrange(0, len(self.prompts))
                state = self.prompts[rand_prompt_index]

                # remove padding from state
                state_mask = state != self.args.pad_value
                state = state[state_mask]

                # get predicted sequence
                # 与 environment 进行交互,其中返回的:
                #   action 是 response,
                #   sequence 是 prompt + response, 
                (
                    actions,
                    sequence,
                    mask,
                    prompt_mask,
                    action_logits,
                    value
                ) = self.actor_critic.generate(
                    rearrange(state, 'n -> 1 n'),
                    max_seq_len = self.args.ctx_len,
                    return_values = True
                )
                action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token

                action_prob = action_logits.softmax(dim = -1)

                action_len = actions.shape[-1]
                action_log_prob = log_prob(action_prob, sequence)
                action_log_prob = action_log_prob[:, -action_len:]

                actions = rearrange(actions, '1 ... -> ...')

                # get reward as given by supervised trained reward model
                sequence = torch.cat((state, actions), dim = 0)

                prompt_length = len(state)
                prompt_mask = torch.arange(sequence.shape[-1], device = device) < prompt_length

                sequence = rearrange(sequence, 'n -> 1 n')
                prompt_mask = rearrange(prompt_mask, 'n -> 1 n')
                mask = rearrange(mask, 'n -> 1 n') if exists(mask) else torch.ones(sequence.shape, dtype = torch.bool, device = device)

                reward = self.reward_model(
                    sequence,
                    prompt_mask = prompt_mask,
                    mask = mask,
                    sample = True
                )

                self.sequence_batch.append(sequence)
                self.prompt_mask_batch.append(prompt_mask)
                self.mask_batch.append(mask)
                self.action_prob_batch.append(action_prob)
                self.action_log_prob_batch.append(action_log_prob)
                self.reward_batch.append(reward)
                self.value_batch.append(value)

                if time_cnt % self.args.update_timesteps == 0:
                    train_data = zip(
                        self.sequence_batch, self.prompt_mask_batch, self.mask_batch, 
                        self.action_prob_batch, self.action_log_prob_batch, self.reward_batch, 
                        self.value_batch
                    )

                    for _sequence, _prompt_mask, _mask, _action_prob, _action_log_prob, _reward, _value in train_data:
                        yield _sequence, _prompt_mask, _mask, _action_prob, _action_log_prob, _reward, _value

                    self.sequence_batch.clear()
                    self.prompt_mask_batch.clear()
                    self.mask_batch.clear()
                    self.action_prob_batch.clear()
                    self.action_log_prob_batch.clear()
                    self.reward_batch.clear()
                    self.value_batch.clear()


    def _dataloader(self) -> DataLoader:
        ''' Initialize the Replay Buffer dataset used for retrieving experiences '''

        dataset = ExperienceDataset(self.gen_experience_dataset)
        dataloader = DataLoader(dataset=dataset, batch_size=self.args.micro_bsz)
        return dataloader

    def train_dataloader(self) -> DataLoader:
        ''' Get train loader '''

        return self._dataloader()
U
u010280923 已提交
586 587