diff --git a/python/paddle/fluid/tests/unittests/prim/model/bert.py b/python/paddle/fluid/tests/unittests/prim/model/bert.py index 240179a0697aee509d65216b6f1f9605629f9db6..ce7723d78b0a44fbffaf8069c3c6f3217b877fcd 100644 --- a/python/paddle/fluid/tests/unittests/prim/model/bert.py +++ b/python/paddle/fluid/tests/unittests/prim/model/bert.py @@ -207,7 +207,7 @@ class BertPooler(nn.Layer): class BertModel(nn.Layer): - def __init__(self, config: BertConfig, to_static): + def __init__(self, config: BertConfig, to_static, enable_cinn): super(BertModel, self).__init__() self.config = config self.pad_token_id = config.pad_token_id @@ -248,7 +248,12 @@ class BertModel(nn.Layer): encoder_layer, config.num_hidden_layers ) if to_static: - self.encoder = paddle.jit.to_static(self.encoder) + build_strategy = paddle.static.BuildStrategy() + if enable_cinn: + build_strategy.build_cinn_pass = True + self.encoder = paddle.jit.to_static( + self.encoder, None, build_strategy + ) self.pooler = BertPooler(config) # self.apply(self.init_weights) @@ -366,10 +371,10 @@ class BertModel(nn.Layer): class Bert(nn.Layer): - def __init__(self, to_static): + def __init__(self, to_static, enable_cinn): super(Bert, self).__init__() config = BertConfig() - self.bert = BertModel(config, to_static) + self.bert = BertModel(config, to_static, enable_cinn) self.cls = BertPretrainingHeads( config, embedding_weights=self.bert.embeddings.word_embeddings.weight, diff --git a/python/paddle/fluid/tests/unittests/prim/model/test_bert_prim_cinn.py b/python/paddle/fluid/tests/unittests/prim/model/test_bert_prim_cinn.py index 04f72d486ec47c1e70a19cd2010448857b5e0c00..8bd89f48337a935206bc4c0950c1e7f73cc1c013 100644 --- a/python/paddle/fluid/tests/unittests/prim/model/test_bert_prim_cinn.py +++ b/python/paddle/fluid/tests/unittests/prim/model/test_bert_prim_cinn.py @@ -58,20 +58,9 @@ def train(to_static, enable_prim, enable_cinn): worker_init=None, ) - bert = Bert(to_static) + # Now only apply dy2st for encoder + bert = Bert(to_static, enable_cinn) criterion = BertPretrainingCriterion() - if to_static: - # input_sepc = [ - # InputSpec(shape=(-1, -1), dtype=paddle.int64, name='input_ids'), - # InputSpec(shape=(-1, -1), dtype=paddle.int64, name='segment_ids'), - # None, - # InputSpec(shape=(-1, 1, 1, -1), dtype=paddle.float32, name='input_mask'), - # InputSpec(shape=(-1,), dtype=paddle.int32, name='masked_lm_positions'), - # ] - input_sepc = None - build_strategy = paddle.static.BuildStrategy() - if enable_cinn: - build_strategy.build_cinn_pass = True optimizer = fluid.optimizer.Adam(parameter_list=bert.parameters())