提交 df80c2f0 编写于 作者: S shippingwang

fix bug

上级 50717d68
...@@ -151,7 +151,6 @@ def create_loss(out, ...@@ -151,7 +151,6 @@ def create_loss(out,
loss = JSDivLoss(class_dim=classes_num, epsilon=epsilon) loss = JSDivLoss(class_dim=classes_num, epsilon=epsilon)
return loss(out[1], out[0]) return loss(out[1], out[0])
print("++++++", use_mix)
if use_mix: if use_mix:
loss = MixCELoss(class_dim=classes_num, epsilon=epsilon) loss = MixCELoss(class_dim=classes_num, epsilon=epsilon)
feed_y_a = feeds['feed_y_a'] feed_y_a = feeds['feed_y_a']
...@@ -341,7 +340,7 @@ def build(config, main_prog, startup_prog, is_train=True): ...@@ -341,7 +340,7 @@ def build(config, main_prog, startup_prog, is_train=True):
use_mix = config.get('use_mix') and is_train use_mix = config.get('use_mix') and is_train
use_dali = config.get('use_dali') use_dali = config.get('use_dali')
use_distillation = config.get('use_distillation') use_distillation = config.get('use_distillation')
feeds = create_feeds(config.image_shape, use_mix=use_mix, use_dali) feeds = create_feeds(config.image_shape, use_mix, use_dali)
if use_dali and use_mix: if use_dali and use_mix:
import dali import dali
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册