未验证 提交 66b99dc8 编写于 作者: W WangZhen 提交者: GitHub

Run cinn when enable_cinn=True (#51354)

上级 d8b8c2d8
......@@ -207,7 +207,7 @@ class BertPooler(nn.Layer):
class BertModel(nn.Layer):
def __init__(self, config: BertConfig, to_static):
def __init__(self, config: BertConfig, to_static, enable_cinn):
super(BertModel, self).__init__()
self.config = config
self.pad_token_id = config.pad_token_id
......@@ -248,7 +248,12 @@ class BertModel(nn.Layer):
encoder_layer, config.num_hidden_layers
)
if to_static:
self.encoder = paddle.jit.to_static(self.encoder)
build_strategy = paddle.static.BuildStrategy()
if enable_cinn:
build_strategy.build_cinn_pass = True
self.encoder = paddle.jit.to_static(
self.encoder, None, build_strategy
)
self.pooler = BertPooler(config)
# self.apply(self.init_weights)
......@@ -366,10 +371,10 @@ class BertModel(nn.Layer):
class Bert(nn.Layer):
def __init__(self, to_static):
def __init__(self, to_static, enable_cinn):
super(Bert, self).__init__()
config = BertConfig()
self.bert = BertModel(config, to_static)
self.bert = BertModel(config, to_static, enable_cinn)
self.cls = BertPretrainingHeads(
config,
embedding_weights=self.bert.embeddings.word_embeddings.weight,
......
......@@ -58,20 +58,9 @@ def train(to_static, enable_prim, enable_cinn):
worker_init=None,
)
bert = Bert(to_static)
# Now only apply dy2st for encoder
bert = Bert(to_static, enable_cinn)
criterion = BertPretrainingCriterion()
if to_static:
# input_sepc = [
# InputSpec(shape=(-1, -1), dtype=paddle.int64, name='input_ids'),
# InputSpec(shape=(-1, -1), dtype=paddle.int64, name='segment_ids'),
# None,
# InputSpec(shape=(-1, 1, 1, -1), dtype=paddle.float32, name='input_mask'),
# InputSpec(shape=(-1,), dtype=paddle.int32, name='masked_lm_positions'),
# ]
input_sepc = None
build_strategy = paddle.static.BuildStrategy()
if enable_cinn:
build_strategy.build_cinn_pass = True
optimizer = fluid.optimizer.Adam(parameter_list=bert.parameters())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册