提交 98049ba1 编写于 作者: W wangwenjin

update

上级 e7968881
...@@ -264,8 +264,8 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o ...@@ -264,8 +264,8 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
"""Implementation of GaAN""" """Implementation of GaAN"""
def send_func(src_feat, dst_feat, edge_feat): def send_func(src_feat, dst_feat, edge_feat):
# 计算每条边上的注意力分数 # attention score of each edge
# E * (M * D1), 每个 dst 点都查询它的全部邻边的 src 点 # E * (M * D1)
feat_query, feat_key = dst_feat['feat_query'], src_feat['feat_key'] feat_query, feat_key = dst_feat['feat_query'], src_feat['feat_key']
# E * M * D1 # E * M * D1
old = feat_query old = feat_query
...@@ -281,16 +281,11 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o ...@@ -281,16 +281,11 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
'feat_gate': src_feat['feat_gate']} 'feat_gate': src_feat['feat_gate']}
def recv_func(message): def recv_func(message):
# 每条边的终点的特征
dst_feat = message['dst_node_feat'] dst_feat = message['dst_node_feat']
# 每条边的出发点的特征
src_feat = message['src_node_feat'] src_feat = message['src_node_feat']
# 每个中心点自己的特征
x = fluid.layers.sequence_pool(dst_feat, 'average') x = fluid.layers.sequence_pool(dst_feat, 'average')
# 每个中心点的邻居的特征的平均值
z = fluid.layers.sequence_pool(src_feat, 'average') z = fluid.layers.sequence_pool(src_feat, 'average')
# 计算 gate
feat_gate = message['feat_gate'] feat_gate = message['feat_gate']
g_max = fluid.layers.sequence_pool(feat_gate, 'max') g_max = fluid.layers.sequence_pool(feat_gate, 'max')
g = fluid.layers.concat([x, g_max, z], axis=1) g = fluid.layers.concat([x, g_max, z], axis=1)
...@@ -318,10 +313,6 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o ...@@ -318,10 +313,6 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
return output return output
# feature N * D
# 计算每个点自己需要发送出去的内容
# 投影后的特征向量
# N * (D1 * M) # N * (D1 * M)
feat_key = fluid.layers.fc(feature, hidden_size_a * heads, bias_attr=False, feat_key = fluid.layers.fc(feature, hidden_size_a * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_key')) param_attr=fluid.ParamAttr(name=name + '_project_key'))
...@@ -335,8 +326,7 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o ...@@ -335,8 +326,7 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
feat_gate = fluid.layers.fc(feature, hidden_size_m, bias_attr=False, feat_gate = fluid.layers.fc(feature, hidden_size_m, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_gate')) param_attr=fluid.ParamAttr(name=name + '_project_gate'))
# send 阶段 # send stage
message = gw.send( message = gw.send(
send_func, send_func,
nfeat_list=[('node_feat', feature), ('feat_key', feat_key), ('feat_value', feat_value), nfeat_list=[('node_feat', feature), ('feat_key', feat_key), ('feat_value', feat_value),
...@@ -344,7 +334,7 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o ...@@ -344,7 +334,7 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
efeat_list=None, efeat_list=None,
) )
# 聚合邻居特征 # recv stage
output = gw.recv(message, recv_func) output = gw.recv(message, recv_func)
output = fluid.layers.fc(output, hidden_size_o, bias_attr=False, output = fluid.layers.fc(output, hidden_size_o, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_output')) param_attr=fluid.ParamAttr(name=name + '_project_output'))
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import fluid from paddle import fluid
from pgl.utils import paddle_helper from pgl.utils import paddle_helper
......
""" # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
将 ogb_proteins 的数据处理为 PGL 的 graph 数据,并返回 graph, label, train/valid/test 等信息 #
""" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ssl import ssl
ssl._create_default_https_context = ssl._create_unverified_context ssl._create_default_https_context = ssl._create_unverified_context
from ogb.nodeproppred import NodePropPredDataset, Evaluator from ogb.nodeproppred import NodePropPredDataset, Evaluator
...@@ -17,7 +27,7 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False): ...@@ -17,7 +27,7 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
d_name: name of dataset d_name: name of dataset
mini_data: if mini_data==True, only use a small dataset (for test) mini_data: if mini_data==True, only use a small dataset (for test)
""" """
# 导入 ogb 数据 # import ogb data
dataset = NodePropPredDataset(name = d_name) dataset = NodePropPredDataset(name = d_name)
num_tasks = dataset.num_tasks # obtaining the number of prediction tasks in a dataset num_tasks = dataset.num_tasks # obtaining the number of prediction tasks in a dataset
...@@ -25,10 +35,10 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False): ...@@ -25,10 +35,10 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"] train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]
graph, label = dataset[0] graph, label = dataset[0]
# 调整维度,符合 PGL 的 Graph 要求 # reshape
graph["edge_index"] = graph["edge_index"].T graph["edge_index"] = graph["edge_index"].T
# 使用小规模数据,500个节点 # mini dataset
if mini_data: if mini_data:
graph['num_nodes'] = 500 graph['num_nodes'] = 500
mask = (graph['edge_index'][:, 0] < 500)*(graph['edge_index'][:, 1] < 500) mask = (graph['edge_index'][:, 0] < 500)*(graph['edge_index'][:, 1] < 500)
...@@ -39,19 +49,9 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False): ...@@ -39,19 +49,9 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
valid_idx = np.arange(400,450) valid_idx = np.arange(400,450)
test_idx = np.arange(450,500) test_idx = np.arange(450,500)
# 输出 dataset 的信息
print(graph.keys())
print("节点个数 ", graph["num_nodes"]) # read/compute node feature
print("节点最小编号", graph['edge_index'][0].min())
print("边个数 ", graph["edge_index"].shape[1])
print("边索引 shape ", graph["edge_index"].shape)
print("边特征 shape ", graph["edge_feat"].shape)
print("节点特征是 ", graph["node_feat"])
print("species shape", graph['species'].shape)
print("label shape ", label.shape)
# 读取/计算 node feature
# 确定读取文件的路径
if mini_data: if mini_data:
node_feat_path = './dataset/ogbn_proteins_node_feat_small.npy' node_feat_path = './dataset/ogbn_proteins_node_feat_small.npy'
else: else:
...@@ -59,14 +59,11 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False): ...@@ -59,14 +59,11 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
new_node_feat = None new_node_feat = None
if os.path.exists(node_feat_path): if os.path.exists(node_feat_path):
# 如果文件存在,直接读取 print("Begin: read node feature".center(50, '='))
print("读取 node feature 开始".center(50, '='))
new_node_feat = np.load(node_feat_path) new_node_feat = np.load(node_feat_path)
print("读取 node feature 成功".center(50, '=')) print("End: read node feature".center(50, '='))
else: else:
# 如果文件不存在,则计算 print("Begin: compute node feature".center(50, '='))
# 每个节点 i 的特征为其邻边特征的均值
print("计算 node feature 开始".center(50, '='))
start = time.perf_counter() start = time.perf_counter()
for i in range(graph['num_nodes']): for i in range(graph['num_nodes']):
if i % 100 == 0: if i % 100 == 0:
...@@ -74,8 +71,8 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False): ...@@ -74,8 +71,8 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
print("{}/{}({}%), times: {:.2f}s".format( print("{}/{}({}%), times: {:.2f}s".format(
i, graph['num_nodes'], i/graph['num_nodes']*100, dur i, graph['num_nodes'], i/graph['num_nodes']*100, dur
)) ))
mask = (graph['edge_index'][:, 0] == i) # 选择 i 的所有邻边 mask = (graph['edge_index'][:, 0] == i)
# 计算均值
current_node_feat = np.mean(np.compress(mask, graph['edge_feat'], axis=0), current_node_feat = np.mean(np.compress(mask, graph['edge_feat'], axis=0),
axis=0, keepdims=True) axis=0, keepdims=True)
if i == 0: if i == 0:
...@@ -84,23 +81,23 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False): ...@@ -84,23 +81,23 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
new_node_feat.append(current_node_feat) new_node_feat.append(current_node_feat)
new_node_feat = np.concatenate(new_node_feat, axis=0) new_node_feat = np.concatenate(new_node_feat, axis=0)
print("计算 node feature 结束".center(50,'=')) print("End: compute node feature".center(50,'='))
print("存储 node feature 中,在"+node_feat_path.center(50, '=')) print("Saving node feature in "+node_feat_path.center(50, '='))
np.save(node_feat_path, new_node_feat) np.save(node_feat_path, new_node_feat)
print("存储 node feature 结束".center(50,'=')) print("Saving finish".center(50,'='))
print(new_node_feat) print(new_node_feat)
# 构造 Graph 对象 # create graph
g = pgl.graph.Graph( g = pgl.graph.Graph(
num_nodes=graph["num_nodes"], num_nodes=graph["num_nodes"],
edges = graph["edge_index"], edges = graph["edge_index"],
node_feat = {'node_feat': new_node_feat}, node_feat = {'node_feat': new_node_feat},
edge_feat = None edge_feat = None
) )
print("创建 Graph 对象成功") print("Create graph")
print(g) print(g)
return g, label, train_idx, valid_idx, test_idx, Evaluator(d_name) return g, label, train_idx, valid_idx, test_idx, Evaluator(d_name)
\ No newline at end of file
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from preprocess import get_graph_data from preprocess import get_graph_data
import pgl import pgl
import argparse import argparse
import numpy as np import numpy as np
import time import time
from paddle import fluid from paddle import fluid
from visualdl import LogWriter
import reader import reader
from train_tool import train_epoch, valid_epoch from train_tool import train_epoch, valid_epoch
...@@ -50,9 +62,7 @@ if __name__ == "__main__": ...@@ -50,9 +62,7 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
# d_name = "ogbn-proteins" print("Parameters Setting".center(50, "="))
print("超参数配置".center(50, "="))
print("lr = {}, rc = {}, epochs = {}, batch_size = {}".format(args.lr, args.rc, args.epochs, print("lr = {}, rc = {}, epochs = {}, batch_size = {}".format(args.lr, args.rc, args.epochs,
args.batch_size)) args.batch_size))
print("Experiment ID: {}".format(args.exp_id).center(50, "=")) print("Experiment ID: {}".format(args.exp_id).center(50, "="))
...@@ -63,20 +73,6 @@ if __name__ == "__main__": ...@@ -63,20 +73,6 @@ if __name__ == "__main__":
g, label, train_idx, valid_idx, test_idx, evaluator = get_graph_data(d_name=d_name, g, label, train_idx, valid_idx, test_idx, evaluator = get_graph_data(d_name=d_name,
mini_data=eval(args.mini_data)) mini_data=eval(args.mini_data))
# create log writer
log_writer = LogWriter(args.log_path+'/'+str(args.exp_id), sync_cycle=10)
with log_writer.mode("train") as logger:
log_train_loss_epoch = logger.scalar("loss")
log_train_rocauc_epoch = logger.scalar("rocauc")
with log_writer.mode("valid") as logger:
log_valid_loss_epoch = logger.scalar("loss")
log_valid_rocauc_epoch = logger.scalar("rocauc")
log_text = log_writer.text("text")
log_time = log_writer.scalar("time")
log_test_loss = log_writer.scalar("test_loss")
log_test_rocauc = log_writer.scalar("test_rocauc")
if args.model == "GaAN": if args.model == "GaAN":
graph_model = GaANModel(112, 3, args.hidden_size_a, args.hidden_size_v, args.hidden_size_m, graph_model = GaANModel(112, 3, args.hidden_size_a, args.hidden_size_v, args.hidden_size_m,
args.hidden_size_o, args.heads) args.hidden_size_o, args.heads)
...@@ -162,17 +158,13 @@ if __name__ == "__main__": ...@@ -162,17 +158,13 @@ if __name__ == "__main__":
start = time.time() start = time.time()
print("Training Begin".center(50, "=")) print("Training Begin".center(50, "="))
log_text.add_record(0, "Training Begin".center(50, "="))
best_valid = -1.0 best_valid = -1.0
for epoch in range(args.epochs): for epoch in range(args.epochs):
start_e = time.time() start_e = time.time()
# print("Train Epoch {}".format(epoch).center(50, "="))
train_loss, train_rocauc = train_epoch( train_loss, train_rocauc = train_epoch(
train_iter, program=train_program, exe=exe, loss=loss, score=score, train_iter, program=train_program, exe=exe, loss=loss, score=score,
evaluator=evaluator, epoch=epoch evaluator=evaluator, epoch=epoch
) )
print("Valid Epoch {}".format(epoch).center(50, "="))
valid_loss, valid_rocauc = valid_epoch( valid_loss, valid_rocauc = valid_epoch(
val_iter, program=val_program, exe=exe, loss=loss, score=score, val_iter, program=val_program, exe=exe, loss=loss, score=score,
evaluator=evaluator, epoch=epoch) evaluator=evaluator, epoch=epoch)
...@@ -180,15 +172,6 @@ if __name__ == "__main__": ...@@ -180,15 +172,6 @@ if __name__ == "__main__":
print("Epoch {}: train_loss={:.4},val_loss={:.4}, train_rocauc={:.4}, val_rocauc={:.4}, s/epoch={:.3}".format( print("Epoch {}: train_loss={:.4},val_loss={:.4}, train_rocauc={:.4}, val_rocauc={:.4}, s/epoch={:.3}".format(
epoch, train_loss, valid_loss, train_rocauc, valid_rocauc, end_e-start_e epoch, train_loss, valid_loss, train_rocauc, valid_rocauc, end_e-start_e
)) ))
log_text.add_record(epoch+1,
"Epoch {}: train_loss={:.4},val_loss={:.4}, train_rocauc={:.4}, val_rocauc={:.4}, s/epoch={:.3}".format(
epoch, train_loss, valid_loss, train_rocauc, valid_rocauc, end_e-start_e
))
log_train_loss_epoch.add_record(epoch, train_loss)
log_valid_loss_epoch.add_record(epoch, valid_loss)
log_train_rocauc_epoch.add_record(epoch, train_rocauc)
log_valid_rocauc_epoch.add_record(epoch, valid_rocauc)
log_time.add_record(epoch, end_e-start_e)
if valid_rocauc > best_valid: if valid_rocauc > best_valid:
print("Update: new {}, old {}".format(valid_rocauc, best_valid)) print("Update: new {}, old {}".format(valid_rocauc, best_valid))
...@@ -198,23 +181,14 @@ if __name__ == "__main__": ...@@ -198,23 +181,14 @@ if __name__ == "__main__":
print("Test Stage".center(50, "=")) print("Test Stage".center(50, "="))
log_text.add_record(args.epochs+1, "Test Stage".center(50, "="))
fluid.io.load_params(executor=exe, dirname='./params/'+str(args.exp_id), main_program=val_program) fluid.io.load_params(executor=exe, dirname='./params/'+str(args.exp_id), main_program=val_program)
test_loss, test_rocauc = valid_epoch( test_loss, test_rocauc = valid_epoch(
test_iter, program=val_program, exe=exe, loss=loss, score=score, test_iter, program=val_program, exe=exe, loss=loss, score=score,
evaluator=evaluator, epoch=epoch) evaluator=evaluator, epoch=epoch)
log_test_loss.add_record(0, test_loss)
log_test_rocauc.add_record(0, test_rocauc)
end = time.time() end = time.time()
print("test_loss={:.4},test_rocauc={:.4}, Total Time={:.3}".format( print("test_loss={:.4},test_rocauc={:.4}, Total Time={:.3}".format(
test_loss, test_rocauc, end-start test_loss, test_rocauc, end-start
)) ))
print("End".center(50, "=")) print("End".center(50, "="))
log_text.add_record(args.epochs+2, "test_loss={:.4},test_rocauc={:.4}, Total Time={:.3}".format(
test_loss, test_rocauc, end-start
))
log_text.add_record(args.epochs+3, "End".center(50, "="))
\ No newline at end of file
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time import time
from pgl.utils.logger import log from pgl.utils.logger import log
...@@ -15,50 +28,12 @@ def train_epoch(batch_iter, exe, program, loss, score, evaluator, epoch, log_per ...@@ -15,50 +28,12 @@ def train_epoch(batch_iter, exe, program, loss, score, evaluator, epoch, log_per
total_sample += num_samples total_sample += num_samples
input_dict = { input_dict = {
"y_true": batch_feed_dict["node_label"], "y_true": batch_feed_dict["node_label"],
# "y_pred": y_pred[batch_feed_dict["node_index"]]
"y_pred": y_pred "y_pred": y_pred
} }
result += evaluator.eval(input_dict)["rocauc"] result += evaluator.eval(input_dict)["rocauc"]
# if batch % log_per_step == 0:
# print("Batch {}: Loss={}".format(batch, batch_loss))
# log.info("Batch %s %s-Loss %s %s-Acc %s" %
# (batch, prefix, batch_loss, prefix, batch_acc))
# print("Epoch {} Train: Loss={}, rocauc={}, Speed(per batch)={}".format(
# epoch, total_loss/total_sample, result/batch, (end-start)/batch))
return total_loss.item()/total_sample, result/batch return total_loss.item()/total_sample, result/batch
def inference(batch_iter, exe, program, loss, score, evaluator, epoch, log_per_step=1):
batch = 0
total_sample = 0
total_loss = 0
result = 0
start = time.time()
for batch_feed_dict in batch_iter():
batch += 1
y_pred = exe.run(program, fetch_list=[score], feed=batch_feed_dict)[0]
input_dict = {
"y_true": batch_feed_dict["node_label"],
"y_pred": y_pred[batch_feed_dict["node_index"]]
}
result += evaluator.eval(input_dict)["rocauc"]
if batch % log_per_step == 0:
print(batch, result/batch)
num_samples = len(batch_feed_dict["node_index"])
# total_loss += batch_loss * num_samples
# total_acc += batch_acc * num_samples
total_sample += num_samples
end = time.time()
print("Epoch {} Valid: Loss={}, Speed(per batch)={}".format(epoch, total_loss/total_sample,
(end-start)/batch))
return total_loss/total_sample, result/batch
def valid_epoch(batch_iter, exe, program, loss, score, evaluator, epoch, log_per_step=1): def valid_epoch(batch_iter, exe, program, loss, score, evaluator, epoch, log_per_step=1):
batch = 0 batch = 0
total_sample = 0 total_sample = 0
...@@ -69,53 +44,13 @@ def valid_epoch(batch_iter, exe, program, loss, score, evaluator, epoch, log_per ...@@ -69,53 +44,13 @@ def valid_epoch(batch_iter, exe, program, loss, score, evaluator, epoch, log_per
batch_loss, y_pred = exe.run(program, fetch_list=[loss, score], feed=batch_feed_dict) batch_loss, y_pred = exe.run(program, fetch_list=[loss, score], feed=batch_feed_dict)
input_dict = { input_dict = {
"y_true": batch_feed_dict["node_label"], "y_true": batch_feed_dict["node_label"],
# "y_pred": y_pred[batch_feed_dict["node_index"]]
"y_pred": y_pred "y_pred": y_pred
} }
# print(evaluator.eval(input_dict))
result += evaluator.eval(input_dict)["rocauc"] result += evaluator.eval(input_dict)["rocauc"]
# if batch % log_per_step == 0:
# print(batch, result/batch)
num_samples = len(batch_feed_dict["node_index"]) num_samples = len(batch_feed_dict["node_index"])
total_loss += batch_loss * num_samples total_loss += batch_loss * num_samples
# total_acc += batch_acc * num_samples
total_sample += num_samples total_sample += num_samples
# print("Epoch {} Valid: Loss={}, Speed(per batch)={}".format(epoch, total_loss/total_sample, (end-start)/batch))
return total_loss.item()/total_sample, result/batch return total_loss.item()/total_sample, result/batch
def run_epoch(batch_iter, exe, program, prefix, model_loss, model_acc, epoch, log_per_step=100):
"""
已废弃
"""
batch = 0
total_loss = 0.
total_acc = 0.
total_sample = 0
start = time.time()
for batch_feed_dict in batch_iter():
batch += 1
batch_loss, batch_acc = exe.run(program,
fetch_list=[model_loss, model_acc],
feed=batch_feed_dict)
if batch % log_per_step == 0:
log.info("Batch %s %s-Loss %s %s-Acc %s" %
(batch, prefix, batch_loss, prefix, batch_acc))
num_samples = len(batch_feed_dict["node_index"])
total_loss += batch_loss * num_samples
total_acc += batch_acc * num_samples
total_sample += num_samples
end = time.time()
log.info("%s Epoch %s Loss %.5lf Acc %.5lf Speed(per batch) %.5lf sec" %
(prefix, epoch, total_loss / total_sample,
total_acc / total_sample, (end - start) / batch))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册