提交 f3bb87d7 编写于 作者: B baiyfbupt

fix conflicts

上级 3d1bd8ae
...@@ -91,19 +91,6 @@ class AdaBERTClassifier(Layer): ...@@ -91,19 +91,6 @@ class AdaBERTClassifier(Layer):
use_fixed_gumbel=self.use_fixed_gumbel, use_fixed_gumbel=self.use_fixed_gumbel,
gumbel_alphas=gumbel_alphas) gumbel_alphas=gumbel_alphas)
for s_emb, t_emb in zip(self.student.emb_names(),
self.teacher.emb_names()):
t_emb.stop_gradient = True
if fix_emb:
s_emb.stop_gradient = True
print(
"Assigning embedding[{}] from teacher to embedding[{}] in student.".
format(t_emb.name, s_emb.name))
fluid.layers.assign(input=t_emb, output=s_emb)
print(
"Assigned embedding[{}] from teacher to embedding[{}] in student.".
format(t_emb.name, s_emb.name))
fix_emb = False fix_emb = False
for s_emb, t_emb in zip(self.student.emb_names(), for s_emb, t_emb in zip(self.student.emb_names(),
self.teacher.emb_names()): self.teacher.emb_names()):
...@@ -173,4 +160,3 @@ class AdaBERTClassifier(Layer): ...@@ -173,4 +160,3 @@ class AdaBERTClassifier(Layer):
total_loss = (1 - self._gamma) * ce_loss + self._gamma * kd_loss total_loss = (1 - self._gamma) * ce_loss + self._gamma * kd_loss
return total_loss, accuracy, ce_loss, kd_loss, s_logits return total_loss, accuracy, ce_loss, kd_loss, s_logits
...@@ -262,7 +262,6 @@ class EncoderLayer(Layer): ...@@ -262,7 +262,6 @@ class EncoderLayer(Layer):
default_initializer=NormalInitializer( default_initializer=NormalInitializer(
loc=0.0, scale=1e-3)) loc=0.0, scale=1e-3))
self.pool2d_avg = Pool2D(pool_type='avg', global_pooling=True) self.pool2d_avg = Pool2D(pool_type='avg', global_pooling=True)
self.bns = [] self.bns = []
self.outs = [] self.outs = []
...@@ -305,21 +304,6 @@ class EncoderLayer(Layer): ...@@ -305,21 +304,6 @@ class EncoderLayer(Layer):
def forward(self, enc_input_0, enc_input_1, epoch, flops=[], def forward(self, enc_input_0, enc_input_1, epoch, flops=[],
model_size=[]): model_size=[]):
=======
self.outs.append(out)
self.use_fixed_gumbel = use_fixed_gumbel
self.gumbel_alphas = gumbel_softmax(self.alphas)
if gumbel_alphas is not None:
self.gumbel_alphas = np.array(gumbel_alphas).reshape(
self.alphas.shape)
else:
self.gumbel_alphas = gumbel_softmax(self.alphas)
self.gumbel_alphas.stop_gradient = True
print("gumbel_alphas: {}".format(self.gumbel_alphas))
def forward(self, enc_input_0, enc_input_1, flops=[], model_size=[]):
alphas = self.gumbel_alphas if self.use_fixed_gumbel else gumbel_softmax( alphas = self.gumbel_alphas if self.use_fixed_gumbel else gumbel_softmax(
self.alphas, epoch) self.alphas, epoch)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册