提交 5a81d8c2 编写于 作者: Z zhengya01 提交者: hutuxian

add ce (#2011)

上级 0bc2cac1
#!/bin/bash
export MKL_NUM_THREADS=1
export OMP_NUM_THREADS=1
cudaid=${face_detection:=0} # use 0-th card as default
export CUDA_VISIBLE_DEVICES=$cudaid
FLAGS_benchmark=true python -u train.py --config_path 'data/config.txt' --train_dir 'data/paddle_train.txt' --batch_size 32 --epoch_num 1 --use_cuda 1 --enable_ce --batch_num 10000 | python _ce.py
cudaid=${face_detection_4:=0,1,2,3} # use 0,1,2,3 card as default
export CUDA_VISIBLE_DEVICES=$cudaid
FLAGS_benchmark=true python -u train.py --config_path 'data/config.txt' --train_dir 'data/paddle_train.txt' --batch_size 32 --epoch_num 1 --use_cuda 1 --enable_ce --batch_num 10000 | python _ce.py
# this file is only used for continuous evaluation test!
import os
import sys
sys.path.append(os.environ['ceroot'])
from kpi import CostKpi
from kpi import DurationKpi
each_pass_duration_card1_kpi = DurationKpi('each_pass_duration_card1', 0.08, 0, actived=True)
train_loss_card1_kpi = CostKpi('train_loss_card1', 0.08, 0)
each_pass_duration_card4_kpi = DurationKpi('each_pass_duration_card4', 0.08, 0, actived=True)
train_loss_card4_kpi = CostKpi('train_loss_card4', 0.08, 0)
tracking_kpis = [
each_pass_duration_card1_kpi,
train_loss_card1_kpi,
each_pass_duration_card4_kpi,
train_loss_card4_kpi,
]
def parse_log(log):
'''
This method should be implemented by model developers.
The suggestion:
each line in the log should be key, value, for example:
"
train_cost\t1.0
test_cost\t1.0
train_cost\t1.0
train_cost\t1.0
train_acc\t1.2
"
'''
for line in log.split('\n'):
fs = line.strip().split('\t')
print(fs)
if len(fs) == 3 and fs[0] == 'kpis':
kpi_name = fs[1]
kpi_value = float(fs[2])
yield kpi_name, kpi_value
def log_to_ce(log):
kpi_tracker = {}
for kpi in tracking_kpis:
kpi_tracker[kpi.name] = kpi
for (kpi_name, kpi_value) in parse_log(log):
print(kpi_name, kpi_value)
kpi_tracker[kpi_name].add_record(kpi_value)
kpi_tracker[kpi_name].persist()
if __name__ == '__main__':
log = sys.stdin.read()
log_to_ce(log)
......@@ -12,6 +12,7 @@
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import sys
import logging
import time
......@@ -49,6 +50,10 @@ def parse_args():
'--base_lr', type=float, default=0.85, help='based learning rate')
parser.add_argument(
'--num_devices', type=int, default=1, help='Number of GPU devices')
parser.add_argument(
'--enable_ce', action='store_true', help='If set, run the task with continuous evaluation logs.')
parser.add_argument(
'--batch_num', type=int, help="batch num for ce")
args = parser.parse_args()
return args
......@@ -56,6 +61,11 @@ def parse_args():
def train():
args = parse_args()
if args.enable_ce:
SEED = 102
fluid.default_main_program().random_seed = SEED
fluid.default_startup_program().random_seed = SEED
config_path = args.config_path
train_path = args.train_dir
epoch_num = args.epoch_num
......@@ -101,6 +111,8 @@ def train():
global_step = 0
PRINT_STEP = 1000
total_time = []
ce_info = []
start_time = time.time()
loss_sum = 0.0
for id in range(epoch_num):
......@@ -113,6 +125,8 @@ def train():
loss_sum += results[0].mean()
if global_step % PRINT_STEP == 0:
ce_info.append(loss_sum / PRINT_STEP)
total_time.append(time.time() - start_time)
logger.info(
"epoch: %d\tglobal_step: %d\ttrain_loss: %.4f\t\ttime: %.2f"
% (epoch, global_step, loss_sum / PRINT_STEP,
......@@ -133,6 +147,31 @@ def train():
fluid.io.save_inference_model(save_dir, feed_var_name,
fetch_vars, exe)
logger.info("model saved in " + save_dir)
if args.enable_ce and global_step >= args.batch_num:
break
# only for ce
if args.enable_ce:
gpu_num = get_cards(args)
ce_loss = 0
ce_time = 0
try:
ce_loss = ce_info[-1]
ce_time = total_time[-1]
except:
print("ce info error")
print("kpis\teach_pass_duration_card%s\t%s" %
(gpu_num, ce_time))
print("kpis\ttrain_loss_card%s\t%s" %
(gpu_num, ce_loss))
def get_cards(args):
if args.enable_ce:
cards = os.environ.get('CUDA_VISIBLE_DEVICES')
num = len(cards.split(","))
return num
else:
return args.num_devices
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册