未验证 提交 c4f3ebc3 编写于 作者: Z zhengya01 提交者: GitHub

add ce for dygraph_sentiment (#4287)

上级 6c14f65e
......@@ -22,6 +22,7 @@ from paddle.fluid.dygraph.base import to_variable
import nets
import reader
from utils import ArgumentGroup
from utils import get_cards
parser = argparse.ArgumentParser(__doc__)
model_g = ArgumentGroup(parser, "model", "model configuration and paths.")
......@@ -148,6 +149,7 @@ def train():
steps = 0
total_cost, total_acc, total_num_seqs = [], [], []
gru_hidden_data = np.zeros((args.batch_size, 128), dtype='float32')
ce_time, ce_infor = [], []
for eop in range(args.epoch):
time_begin = time.time()
for batch_id, data in enumerate(train_data_generator()):
......@@ -186,6 +188,8 @@ def train():
np.sum(total_cost) / np.sum(total_num_seqs),
np.sum(total_acc) / np.sum(total_num_seqs),
args.skip_steps / used_time))
ce_time.append(used_time)
ce_infor.append(np.sum(total_acc) / np.sum(total_num_seqs))
total_cost, total_acc, total_num_seqs = [], [], []
time_begin = time.time()
......@@ -247,6 +251,17 @@ def train():
if enable_profile:
print('save profile result into /tmp/profile_file')
return
if args.ce:
card_num = get_cards()
_acc = 0
_time = 0
try:
_time = ce_time[-1]
_acc = ce_infor[-1]
except:
print("ce info error")
print("kpis\ttrain_duration_card%s\t%s" % (card_num, _time))
print("kpis\ttrain_acc_card%s\t%f" % (card_num, _acc))
def infer():
......
......@@ -16,6 +16,7 @@ from __future__ import division
from __future__ import print_function
import io
import os
import sys
import random
......@@ -80,3 +81,10 @@ def load_vocab(file_path):
wid += 1
vocab["<unk>"] = len(vocab)
return vocab
def get_cards():
num = 0
cards = os.environ.get('CUDA_VISIBLE_DEVICES', '')
if cards != '':
num = len(cards.split(","))
return num
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册