未验证 提交 bb13f3c4 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

fix single card dist (#1889)

* fix single card logit

* fix distillation yaml files
上级 c7a6fdda
......@@ -49,9 +49,8 @@ Loss:
model_name_pairs:
- ["Student", "Teacher"]
Eval:
- DistillationGTCELoss:
- CELoss:
weight: 1.0
model_names: ["Student"]
Optimizer:
......
......@@ -88,10 +88,8 @@ Loss:
s_shapes: *s_shapes
t_shapes: *t_shapes
Eval:
- DistillationGTCELoss:
- CELoss:
weight: 1.0
model_names: ["Student"]
Optimizer:
name: Momentum
......
......@@ -80,22 +80,17 @@ def classification_eval(engine, epoch_id=0):
current_samples = batch_size * paddle.distributed.get_world_size()
accum_samples += current_samples
if isinstance(out, dict) and "Student" in out:
out = out["Student"]
if isinstance(out, dict) and "logits" in out:
out = out["logits"]
# gather Tensor when distributed
if paddle.distributed.get_world_size() > 1:
label_list = []
paddle.distributed.all_gather(label_list, batch[1])
labels = paddle.concat(label_list, 0)
if isinstance(out, dict):
if "Student" in out:
out = out["Student"]
if isinstance(out, dict):
out = out["logits"]
elif "logits" in out:
out = out["logits"]
else:
msg = "Error: Wrong key in out!"
raise Exception(msg)
if isinstance(out, list):
preds = []
for x in out:
......
......@@ -118,8 +118,6 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
if hasattr(model_list[i], optim_scope):
optim_model.append(getattr(model_list[i], optim_scope))
assert len(optim_model) == 1, \
"Invalid optim model for optim scope({}), number of optim_model={}".format(optim_scope, len(optim_model))
optim = getattr(optimizer, optim_name)(
learning_rate=lr, grad_clip=grad_clip,
**optim_cfg)(model_list=optim_model)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册