未验证 提交 66b99dc8 编写于 作者: W WangZhen 提交者: GitHub

Run cinn when enable_cinn=True (#51354)

上级 d8b8c2d8
...@@ -207,7 +207,7 @@ class BertPooler(nn.Layer): ...@@ -207,7 +207,7 @@ class BertPooler(nn.Layer):
class BertModel(nn.Layer): class BertModel(nn.Layer):
def __init__(self, config: BertConfig, to_static): def __init__(self, config: BertConfig, to_static, enable_cinn):
super(BertModel, self).__init__() super(BertModel, self).__init__()
self.config = config self.config = config
self.pad_token_id = config.pad_token_id self.pad_token_id = config.pad_token_id
...@@ -248,7 +248,12 @@ class BertModel(nn.Layer): ...@@ -248,7 +248,12 @@ class BertModel(nn.Layer):
encoder_layer, config.num_hidden_layers encoder_layer, config.num_hidden_layers
) )
if to_static: if to_static:
self.encoder = paddle.jit.to_static(self.encoder) build_strategy = paddle.static.BuildStrategy()
if enable_cinn:
build_strategy.build_cinn_pass = True
self.encoder = paddle.jit.to_static(
self.encoder, None, build_strategy
)
self.pooler = BertPooler(config) self.pooler = BertPooler(config)
# self.apply(self.init_weights) # self.apply(self.init_weights)
...@@ -366,10 +371,10 @@ class BertModel(nn.Layer): ...@@ -366,10 +371,10 @@ class BertModel(nn.Layer):
class Bert(nn.Layer): class Bert(nn.Layer):
def __init__(self, to_static): def __init__(self, to_static, enable_cinn):
super(Bert, self).__init__() super(Bert, self).__init__()
config = BertConfig() config = BertConfig()
self.bert = BertModel(config, to_static) self.bert = BertModel(config, to_static, enable_cinn)
self.cls = BertPretrainingHeads( self.cls = BertPretrainingHeads(
config, config,
embedding_weights=self.bert.embeddings.word_embeddings.weight, embedding_weights=self.bert.embeddings.word_embeddings.weight,
......
...@@ -58,20 +58,9 @@ def train(to_static, enable_prim, enable_cinn): ...@@ -58,20 +58,9 @@ def train(to_static, enable_prim, enable_cinn):
worker_init=None, worker_init=None,
) )
bert = Bert(to_static) # Now only apply dy2st for encoder
bert = Bert(to_static, enable_cinn)
criterion = BertPretrainingCriterion() criterion = BertPretrainingCriterion()
if to_static:
# input_sepc = [
# InputSpec(shape=(-1, -1), dtype=paddle.int64, name='input_ids'),
# InputSpec(shape=(-1, -1), dtype=paddle.int64, name='segment_ids'),
# None,
# InputSpec(shape=(-1, 1, 1, -1), dtype=paddle.float32, name='input_mask'),
# InputSpec(shape=(-1,), dtype=paddle.int32, name='masked_lm_positions'),
# ]
input_sepc = None
build_strategy = paddle.static.BuildStrategy()
if enable_cinn:
build_strategy.build_cinn_pass = True
optimizer = fluid.optimizer.Adam(parameter_list=bert.parameters()) optimizer = fluid.optimizer.Adam(parameter_list=bert.parameters())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册