提交 4d0c99f7 编写于 作者: C chenxuyi 提交者: Meiyim

fix python2 compat

上级 7677bced
...@@ -24,6 +24,7 @@ import six ...@@ -24,6 +24,7 @@ import six
import logging import logging
import paddle.fluid as fluid import paddle.fluid as fluid
from io import open from io import open
from paddle.fluid.layers import core
from model.transformer_encoder import encoder, pre_process_layer from model.transformer_encoder import encoder, pre_process_layer
...@@ -85,8 +86,8 @@ class ErnieModel(object): ...@@ -85,8 +86,8 @@ class ErnieModel(object):
self._pos_emb_name = "pos_embedding" self._pos_emb_name = "pos_embedding"
self._sent_emb_name = "sent_embedding" self._sent_emb_name = "sent_embedding"
self._task_emb_name = "task_embedding" self._task_emb_name = "task_embedding"
self._dtype = "float16" if use_fp16 else "float32" self._dtype = core.VarDesc.VarType.FP16 if use_fp16 else core.VarDesc.VarType.FP32
self._emb_dtype = 'float32' self._emb_dtype = core.VarDesc.VarType.FP32
# Initialize all weigths by truncated normal initializer, and all biases # Initialize all weigths by truncated normal initializer, and all biases
# will be initialized by constant zero by default. # will be initialized by constant zero by default.
...@@ -138,7 +139,7 @@ class ErnieModel(object): ...@@ -138,7 +139,7 @@ class ErnieModel(object):
emb_out = pre_process_layer( emb_out = pre_process_layer(
emb_out, 'nd', self._prepostprocess_dropout, name='pre_encoder') emb_out, 'nd', self._prepostprocess_dropout, name='pre_encoder')
if self._dtype == 'float16': if self._dtype == core.VarDesc.VarType.FP16:
emb_out = fluid.layers.cast(x=emb_out, dtype=self._dtype) emb_out = fluid.layers.cast(x=emb_out, dtype=self._dtype)
input_mask = fluid.layers.cast(x=input_mask, dtype=self._dtype) input_mask = fluid.layers.cast(x=input_mask, dtype=self._dtype)
self_attn_mask = fluid.layers.matmul( self_attn_mask = fluid.layers.matmul(
...@@ -167,7 +168,7 @@ class ErnieModel(object): ...@@ -167,7 +168,7 @@ class ErnieModel(object):
postprocess_cmd="dan", postprocess_cmd="dan",
param_initializer=self._param_initializer, param_initializer=self._param_initializer,
name='encoder') name='encoder')
if self._dtype == 'float16': if self._dtype == core.VarDesc.VarType.FP16:
self._enc_out = fluid.layers.cast( self._enc_out = fluid.layers.cast(
x=self._enc_out, dtype=self._emb_dtype) x=self._enc_out, dtype=self._emb_dtype)
......
...@@ -24,6 +24,7 @@ import logging ...@@ -24,6 +24,7 @@ import logging
import six import six
import paddle.fluid as fluid import paddle.fluid as fluid
from io import open from io import open
from paddle.fluid.layers import core
from model.transformer_encoder import encoder, pre_process_layer from model.transformer_encoder import encoder, pre_process_layer
...@@ -76,7 +77,7 @@ class ErnieModel(object): ...@@ -76,7 +77,7 @@ class ErnieModel(object):
self._word_emb_name = "word_embedding" self._word_emb_name = "word_embedding"
self._pos_emb_name = "pos_embedding" self._pos_emb_name = "pos_embedding"
self._sent_emb_name = "sent_embedding" self._sent_emb_name = "sent_embedding"
self._dtype = 'float16' if use_fp16 else 'float32' self._dtype = core.VarDesc.VarType.FP16 if use_fp16 else core.VarDesc.VarType.FP32
# Initialize all weigths by truncated normal initializer, and all biases # Initialize all weigths by truncated normal initializer, and all biases
# will be initialized by constant zero by default. # will be initialized by constant zero by default.
...@@ -114,7 +115,7 @@ class ErnieModel(object): ...@@ -114,7 +115,7 @@ class ErnieModel(object):
emb_out = pre_process_layer( emb_out = pre_process_layer(
emb_out, 'nd', self._prepostprocess_dropout, name='pre_encoder') emb_out, 'nd', self._prepostprocess_dropout, name='pre_encoder')
if self._dtype == 'float16': if self._dtype == core.VarDesc.VarType.FP16:
input_mask = fluid.layers.cast(x=input_mask, dtype=self._dtype) input_mask = fluid.layers.cast(x=input_mask, dtype=self._dtype)
self_attn_mask = fluid.layers.matmul( self_attn_mask = fluid.layers.matmul(
x=input_mask, y=input_mask, transpose_y=True) x=input_mask, y=input_mask, transpose_y=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册