提交 8793a586 编写于 作者: W wuzewu

update text classifer demo

上级 b3e2b88f
......@@ -20,10 +20,39 @@ parser.add_argument("--max_seq_len", type=int, default=512,
# yapf: enable.
class TransformerClassifier(fluid.dygraph.Layer):
def __init__(self, num_classes, transformer):
super(TransformerClassifier, self).__init__()
self.num_classes = num_classes
self.transformer = transformer
self.fc = Linear(input_dim=768, output_dim=num_classes)
def forward(self, input_ids, position_ids, segment_ids, input_mask):
pooled_output, sequence_output = self.transformer(
input_ids, position_ids, segment_ids, input_mask)
cls_feats = fluid.layers.dropout(
pooled_output,
dropout_prob=0.1,
dropout_implementation="upscale_in_train")
cls_feats = fluid.layers.reshape(cls_feats, shape=[-1, 768])
pred = self.fc(cls_feats)
return fluid.layers.softmax(pred)
def finetune(args):
with fluid.dygraph.guard():
ernie = hub.Module(name="ernie")
dataset = hub.dataset.ChnSentiCorp()
tc = TransformerClassifier(
num_classes=dataset.num_labels, transformer=ernie)
adam = AdamOptimizer(
learning_rate=0.001, parameter_list=tc.parameters())
print(len(tc.parameters()))
state_dict_path = os.path.join(args.checkpoint_dir,
'dygraph_state_dict')
if os.path.exists(state_dict_path + '.pdparams'):
state_dict, _ = fluid.load_dygraph(state_dict_path)
tc.load_dict(state_dict)
reader = hub.reader.ClassifyReader(
dataset=dataset,
......@@ -34,14 +63,36 @@ def finetune(args):
train_reader = reader.data_generator(
batch_size=args.batch_size, phase='train')
for data_id, data in enumerate(train_reader()):
loss_sum = acc_sum = cnt = 0
# 执行epoch_num次训练
for epoch in range(args.num_epoch):
# 读取训练数据进行训练
for batch_id, data in enumerate(train_reader()):
input_ids = np.array(data[0][0]).astype(np.int64)
position_ids = np.array(data[0][1]).astype(np.int64)
segment_ids = np.array(data[0][2]).astype(np.int64)
input_mask = np.array(data[0][3]).astype(np.float32)
labels = np.array(data[0][4]).astype(np.int64)
pooled_output, sequence_output = ernie(position_ids, input_mask,
input_ids, segment_ids)
pred = tc(input_ids, position_ids, segment_ids, input_mask)
acc = fluid.layers.accuracy(pred, to_variable(labels))
loss = fluid.layers.cross_entropy(pred, to_variable(labels))
avg_loss = fluid.layers.mean(loss)
avg_loss.backward()
# 参数更新
adam.minimize(avg_loss)
loss_sum += avg_loss.numpy() * labels.shape[0]
acc_sum += acc.numpy() * labels.shape[0]
cnt += labels.shape[0]
if batch_id % args.log_interval == 0:
print('epoch {}: loss {}, acc {}'.format(
epoch, loss_sum / cnt, acc_sum / cnt))
loss_sum = acc_sum = cnt = 0
if batch_id % args.save_interval == 0:
state_dict = tc.state_dict()
fluid.save_dygraph(state_dict, state_dict_path)
if __name__ == "__main__":
......
......@@ -131,11 +131,11 @@ class Module(fluid.dygraph.Layer):
module_dir=None,
version=None,
**kwargs):
super(Module, self).__init__()
# Avoid module being initialized multiple times
if "_is_initialize" in self.__dict__ and self._is_initialize:
return
super(Module, self).__init__()
_run_func_name = self._get_func_name(self.__class__,
_module_runnable_func)
self._run_func = getattr(self,
......@@ -146,14 +146,12 @@ class Module(fluid.dygraph.Layer):
self._initialize(**kwargs)
self._is_initialize = True
self._code_version = "v2"
self._model_runner = None
self.model_runner = fluid.dygraph.StaticModelRunner(
self.pretrained_model_path)
@property
def model_runner(self):
if not self._model_runner:
self._model_runner = fluid.dygraph.StaticModelRunner(
self.default_pretrained_model_path)
return self._model_runner
def pretrained_model_path(self):
return self.default_pretrained_model_path
def _get_func_name(self, current_cls, module_func_dict):
mod = current_cls.__module__ + "." + current_cls.__name__
......
......@@ -246,6 +246,10 @@ class TransformerModule(NLPBaseModule):
Tranformer Module base class can be used by BERT, ERNIE, RoBERTa and so on.
"""
@property
def pretrained_model_path(self):
return self.params_path
def init_pretraining_params(self, exe, pretraining_params_path,
main_program):
assert os.path.exists(
......@@ -353,12 +357,13 @@ class TransformerModule(NLPBaseModule):
return inputs, outputs, module_program
@property
def model_runner(self):
if not self._model_runner:
self._model_runner = fluid.dygraph.StaticModelRunner(
self.params_path)
return self._model_runner
# @property
# def model_runner(self):
# if not self._model_runner:
# self._model_runner = fluid.dygraph.StaticModelRunner(
# self.params_path)
# return self._model_runner
def get_embedding(self, texts, use_gpu=False, batch_size=1):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册