未验证 提交 bf6b9802 编写于 作者: O Olatunji Ruwase 提交者: GitHub

ZeRO3 handling frozen weights] (#2653)

上级 35575bce
......@@ -252,11 +252,15 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
self.sub_group_size = sub_group_size
self.sub_group_to_group_id = {}
see_memory_usage("Before creating fp16 partitions", force=False)
self._create_fp16_partitions_with_defragmentation()
# Trainable parameters
self.trainable_param_groups = self._get_trainable_parameter_groups()
see_memory_usage("Before creating fp16 partitions", force=True)
self._create_fp16_partitions_with_defragmentation(self.trainable_param_groups)
num_fp16_subgroups = len(self.fp16_partitioned_groups_flat)
see_memory_usage(f"After creating fp16 partitions: {num_fp16_subgroups}",
force=False)
force=True)
# Optimizer tensor swapping
if self.swap_optimizer:
......@@ -350,19 +354,28 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
def destroy(self):
self.parameter_offload.destroy()
def _get_trainable_parameter_groups(self):
param_groups = []
for param_group in self.optimizer.param_groups:
trainable_params = {
"params": [p for p in param_group["params"] if p.requires_grad]
}
param_groups.append(trainable_params)
return param_groups
def _setup_for_real_optimizer(self):
see_memory_usage("Before creating fp32 partitions", force=False)
see_memory_usage("Before creating fp32 partitions", force=True)
self._create_fp32_partitions()
see_memory_usage("After creating fp32 partitions", force=False)
see_memory_usage("After creating fp32 partitions", force=True)
dist.barrier()
# To support pipelined optimizer swapping
self._create_next_swappable_fp32_groups()
see_memory_usage("Before initializing optimizer states", force=False)
see_memory_usage("Before initializing optimizer states", force=True)
self.initialize_optimizer_states()
see_memory_usage("After initializing optimizer states", force=False)
see_memory_usage("After initializing optimizer states", force=True)
dist.barrier()
if dist.get_rank() == 0:
......@@ -523,7 +536,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
aggregate_params_count = 0
for j, param_group in enumerate(self.optimizer.param_groups):
for j, param_group in enumerate(self.trainable_param_groups):
params_in_group = sum([p.partition_numel() for p in param_group['params']])
flat_buffer_size = params_in_group
......@@ -552,11 +565,12 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
torch.empty(1,
dtype=self.dtype))
def _create_fp16_partitions_with_defragmentation(self):
def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups):
dist.barrier()
param_groups: List[List[Parameter]] = tuple(
self._create_fp16_sub_groups(param_group["params"])
for param_group in self.optimizer.param_groups)
for param_group in fp16_param_groups)
# bookkeeping related to param groups
for param_group_idx, param_group in enumerate(param_groups):
......@@ -884,7 +898,6 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
dtype=gradient_dtype,
device=self.device)
timers = self.timers
timer_names = set()
if self.swap_optimizer:
......@@ -2122,6 +2135,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
def _set_param_groups(self, value):
self.optimizer.param_groups = value
self.trainable_param_groups = self._get_trainable_parameter_groups()
param_groups = property(_get_param_groups, _set_param_groups)
......
......@@ -1313,3 +1313,63 @@ class TestZeroAdamOptimizerStepCount(DistributedTest):
state = optimizer.optimizer.state[param]
step_counts.append(state['step'])
assert all(step == step_counts[0] for step in step_counts)
class TestZeroFrozenWeights(DistributedTest):
world_size = 1
def test(self):
config_dict = {
"train_batch_size": 4,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-4
}
},
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": 3
}
}
hidden_dim = 10
class MyModel(torch.nn.Module):
def __init__(self, hidden_dim):
super(MyModel, self).__init__()
self.l1 = torch.nn.Linear(hidden_dim, hidden_dim)
self.l2 = torch.nn.Linear(hidden_dim, hidden_dim)
self.act = torch.nn.ReLU()
self.cel = torch.nn.CrossEntropyLoss()
# freeze one fc
self.l2.weight.requires_grad = False
self.l2.bias.requires_grad = False
def forward(self, x, y):
x = self.l1(x)
x = self.act(x)
x = self.l2(x)
loss = self.cel(x, y)
val = (x, loss)
return val
with deepspeed.zero.Init(config_dict_or_path=config_dict):
model = MyModel(hidden_dim)
model, _, _, _ = deepspeed.initialize(model=model,
model_parameters=model.parameters(),
config=config_dict)
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device)
dist.barrier()
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
loss = loss[1]
model.backward(loss)
model.step()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册