From 0eaa11e473919546b5379943874a7d0609b54c8a Mon Sep 17 00:00:00 2001
From: ceci3 <ceci3@users.noreply.github.com>
Date: Fri, 29 Jan 2021 15:22:46 +0800
Subject: [PATCH] fix bug (#632)

---
 demo/ofa/bert/run_glue_ofa.py       | 25 +++++++++++++++++--------
 paddleslim/nas/ofa/convert_super.py |  9 ++++++---
 tests/test_convert_supernet.py      | 25 +++++++++++++++++++++++++
 3 files changed, 48 insertions(+), 11 deletions(-)

diff --git a/demo/ofa/bert/run_glue_ofa.py b/demo/ofa/bert/run_glue_ofa.py
index 1618ee58..bc581f2f 100644
--- a/demo/ofa/bert/run_glue_ofa.py
+++ b/demo/ofa/bert/run_glue_ofa.py
@@ -27,7 +27,6 @@ from paddle.metric import Accuracy
 
 from paddlenlp.data import Stack, Tuple, Pad
 from paddlenlp.transformers import BertModel, BertForSequenceClassification, BertTokenizer
-from paddlenlp.utils.log import logger
 from paddlenlp.metrics import AccuracyAndF1, Mcc, PearsonAndSpearman
 import paddlenlp.datasets as datasets
 from paddleslim.nas.ofa import OFA, DistillConfig, utils
@@ -166,7 +165,8 @@ def set_seed(args):
     paddle.seed(args.seed + paddle.distributed.get_rank())
 
 
-def evaluate(model, criterion, metric, data_loader, width_mult=1.0):
+def evaluate(model, criterion, metric, data_loader, epoch, step,
+             width_mult=1.0):
     with paddle.no_grad():
         model.eval()
         metric.reset()
@@ -180,8 +180,9 @@ def evaluate(model, criterion, metric, data_loader, width_mult=1.0):
             metric.update(correct)
         results = metric.accumulate()
         print(
-            "width_mult: %f, eval loss: %f, %s: %s\n" %
-            (width_mult, loss.numpy(), metric.name(), results),
+            "epoch: %d, batch: %d, width_mult: %s, eval loss: %f, %s: %s\n" %
+            (epoch, step, 'teacher' if width_mult == 100 else str(width_mult),
+             loss.numpy(), metric.name(), results),
             end='')
         model.train()
 
@@ -485,7 +486,7 @@ def do_train(args):
 
             if global_step % args.logging_steps == 0:
                 if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0:
-                    logger.info(
+                    print(
                         "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s"
                         % (global_step, epoch, step, loss,
                            args.logging_steps / (time.time() - tic_train)))
@@ -498,12 +499,16 @@ def do_train(args):
                         criterion,
                         metric,
                         dev_data_loader_matched,
+                        epoch,
+                        step,
                         width_mult=100)
                     evaluate(
                         teacher_model,
                         criterion,
                         metric,
                         dev_data_loader_mismatched,
+                        epoch,
+                        step,
                         width_mult=100)
                 else:
                     evaluate(
@@ -511,6 +516,8 @@ def do_train(args):
                         criterion,
                         metric,
                         dev_data_loader,
+                        epoch,
+                        step,
                         width_mult=100)
                 for idx, width_mult in enumerate(args.width_mult_list):
                     net_config = utils.dynabert_config(ofa_model, width_mult)
@@ -518,14 +525,16 @@ def do_train(args):
                     tic_eval = time.time()
                     if args.task_name == "mnli":
                         acc = evaluate(ofa_model, criterion, metric,
-                                       dev_data_loader_matched, width_mult)
+                                       dev_data_loader_matched, epoch, step,
+                                       width_mult)
                         evaluate(ofa_model, criterion, metric,
-                                 dev_data_loader_mismatched, width_mult)
+                                 dev_data_loader_mismatched, epoch, step,
+                                 width_mult)
                         print("eval done total : %s s" %
                               (time.time() - tic_eval))
                     else:
                         acc = evaluate(ofa_model, criterion, metric,
-                                       dev_data_loader, width_mult)
+                                       dev_data_loader, epoch, step, width_mult)
                         print("eval done total : %s s" %
                               (time.time() - tic_eval))
 
diff --git a/paddleslim/nas/ofa/convert_super.py b/paddleslim/nas/ofa/convert_super.py
index df460720..07619109 100644
--- a/paddleslim/nas/ofa/convert_super.py
+++ b/paddleslim/nas/ofa/convert_super.py
@@ -64,13 +64,15 @@ class Convert:
             w_attr = layer._param_attr if pd_ver == 185 else layer._weight_attr
 
         if isinstance(w_attr, ParamAttr):
-            if w_attr != None and not isinstance(w_attr, bool):
+            if w_attr != None and not isinstance(w_attr,
+                                                 bool) and w_attr.name != None:
                 w_attr.name = 'super_' + w_attr.name
 
         if has_bias:
             if isinstance(layer._bias_attr, ParamAttr):
-                if layer._bias_attr != None and not isinstance(layer._bias_attr,
-                                                               bool):
+                if layer._bias_attr != None and not isinstance(
+                        layer._bias_attr,
+                        bool) and layer._bias_attr.name != None:
                     layer._bias_attr.name = 'super_' + layer._bias_attr.name
 
     def convert(self, network):
@@ -429,6 +431,7 @@ class Convert:
                     new_attr_name = ['act', 'dtype']
                 else:
                     new_attr_name = ['weight_attr', 'bias_attr']
+                self._change_name(layer, pd_ver)
                 in_nc, out_nc = layer._parameters['weight'].shape
 
                 new_attr_dict = dict.fromkeys(new_attr_name, None)
diff --git a/tests/test_convert_supernet.py b/tests/test_convert_supernet.py
index a75ac18f..64f85e10 100644
--- a/tests/test_convert_supernet.py
+++ b/tests/test_convert_supernet.py
@@ -15,6 +15,8 @@
 import sys
 sys.path.append("../")
 import unittest
+import paddle
+import paddle.nn as nn
 from paddle.vision.models import mobilenet_v1
 from paddleslim.nas.ofa.convert_super import Convert, supernet
 
@@ -29,5 +31,28 @@ class TestConvertSuper(unittest.TestCase):
         assert len(sp_model.sublayers()) == 151
 
 
+class TestConvertSuper(unittest.TestCase):
+    def setUp(self):
+        class Model(nn.Layer):
+            def __init__(self):
+                super(Model, self).__init__()
+                self.fc = nn.Linear(
+                    5,
+                    10,
+                    weight_attr=paddle.ParamAttr(
+                        initializer=nn.initializer.XavierNormal()),
+                    bias_attr=paddle.ParamAttr(
+                        initializer=nn.initializer.Constant(value=0.0)))
+
+            def forward(self, inputs):
+                return self.fc(inputs)
+
+        self.model = Model()
+
+    def test_convert(self):
+        sp_net_config = supernet(expand_ratio=[1, 2, 4])
+        sp_model = Convert(sp_net_config).convert(self.model)
+
+
 if __name__ == '__main__':
     unittest.main()
-- 
GitLab