提交 b556097a 编写于 作者: X xiexionghang

commit kagle for paddle

上级 62812dd3
"""
util for file_system io
"""
import os import os
import time import time
from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient
def is_afs_path(path): def is_afs_path(path):
"""R
"""
if path.startswith("afs") or path.startswith("hdfs"): if path.startswith("afs") or path.startswith("hdfs"):
return True return True
return False return False
class LocalFSClient: class LocalFSClient:
"""
Util for local disk file_system io
"""
def __init__(self): def __init__(self):
"""R
"""
pass pass
def write(self, content, path, mode): def write(self, content, path, mode):
"""
write to file
Args:
content(string)
path(string)
mode(string): w/a w:clear_write a:append_write
"""
temp_dir = os.path.dirname(path) temp_dir = os.path.dirname(path)
if not os.path.exists(temp_dir): if not os.path.exists(temp_dir):
os.makedirs(temp_dir) os.makedirs(temp_dir)
...@@ -20,35 +39,52 @@ class LocalFSClient: ...@@ -20,35 +39,52 @@ class LocalFSClient:
f.close() f.close()
def cp(self, org_path, dest_path): def cp(self, org_path, dest_path):
"""R
"""
temp_dir = os.path.dirname(dest_path) temp_dir = os.path.dirname(dest_path)
if not os.path.exists(temp_dir): if not os.path.exists(temp_dir):
os.makedirs(temp_dir) os.makedirs(temp_dir)
return os.system("cp -r " + org_path + " " + dest_path) return os.system("cp -r " + org_path + " " + dest_path)
def cat(self, file_path): def cat(self, file_path):
"""R
"""
f = open(file_path) f = open(file_path)
content = f.read() content = f.read()
f.close() f.close()
return content return content
def mkdir(self, dir_name): def mkdir(self, dir_name):
os.system("mkdir -p " + path) """R
"""
os.makedirs(dir_name)
def remove(self, path): def remove(self, path):
"""R
"""
os.system("rm -rf " + path) os.system("rm -rf " + path)
def is_exist(self, path): def is_exist(self, path):
"""R
"""
if os.system("ls " + path) == 0: if os.system("ls " + path) == 0:
return True return True
return False return False
def ls(self, path): def ls(self, path):
"""R
"""
files = os.listdir(path) files = os.listdir(path)
files = [ path + '/' + fi for fi in files ] files = [ path + '/' + fi for fi in files ]
return files return files
class FileHandler: class FileHandler:
"""
A Smart file handler. auto judge local/afs by path
"""
def __init__(self, config): def __init__(self, config):
"""R
"""
if 'fs_name' in config: if 'fs_name' in config:
hadoop_home="$HADOOP_HOME" hadoop_home="$HADOOP_HOME"
hdfs_configs = { hdfs_configs = {
...@@ -59,16 +95,22 @@ class FileHandler: ...@@ -59,16 +95,22 @@ class FileHandler:
self._local_fs_client = LocalFSClient() self._local_fs_client = LocalFSClient()
def is_exist(self, path): def is_exist(self, path):
"""R
"""
if is_afs_path(path): if is_afs_path(path):
return self._hdfs_client.is_exist(path) return self._hdfs_client.is_exist(path)
else: else:
return self._local_fs_client.is_exist(path) return self._local_fs_client.is_exist(path)
def get_file_name(self, path): def get_file_name(self, path):
"""R
"""
sub_paths = path.split('/') sub_paths = path.split('/')
return sub_paths[-1] return sub_paths[-1]
def write(self, content, dest_path, mode='w'): def write(self, content, dest_path, mode='w'):
"""R
"""
if is_afs_path(dest_path): if is_afs_path(dest_path):
file_name = self.get_file_name(dest_path) file_name = self.get_file_name(dest_path)
temp_local_file = "./tmp/" + file_name temp_local_file = "./tmp/" + file_name
...@@ -88,6 +130,8 @@ class FileHandler: ...@@ -88,6 +130,8 @@ class FileHandler:
def cat(self, path): def cat(self, path):
"""R
"""
if is_afs_path(path): if is_afs_path(path):
print("xxh go cat " + path) print("xxh go cat " + path)
hdfs_cat = self._hdfs_client.cat(path) hdfs_cat = self._hdfs_client.cat(path)
...@@ -97,6 +141,8 @@ class FileHandler: ...@@ -97,6 +141,8 @@ class FileHandler:
return self._local_fs_client.cat(path) return self._local_fs_client.cat(path)
def ls(self, path): def ls(self, path):
"""R
"""
if is_afs_path(path): if is_afs_path(path):
return self._hdfs_client.ls(path) return self._hdfs_client.ls(path)
else: else:
...@@ -104,6 +150,8 @@ class FileHandler: ...@@ -104,6 +150,8 @@ class FileHandler:
def cp(self, org_path, dest_path): def cp(self, org_path, dest_path):
"""R
"""
org_is_afs = is_afs_path(org_path) org_is_afs = is_afs_path(org_path)
dest_is_afs = is_afs_path(dest_path) dest_is_afs = is_afs_path(dest_path)
if not org_is_afs and not dest_is_afs: if not org_is_afs and not dest_is_afs:
......
...@@ -4,12 +4,13 @@ Do metric jobs. calculate AUC, MSE, COCP ... ...@@ -4,12 +4,13 @@ Do metric jobs. calculate AUC, MSE, COCP ...
import math import math
import time import time
import numpy as np import numpy as np
import kagle_util import kagle.kagle_util
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
class Metric(object): class Metric(object):
""" """ """R
"""
__metaclass__=abc.ABCMeta __metaclass__=abc.ABCMeta
def __init__(self, config): def __init__(self, config):
...@@ -52,6 +53,7 @@ class Metric(object): ...@@ -52,6 +53,7 @@ class Metric(object):
""" """
pass pass
class PaddleAUCMetric(Metric): class PaddleAUCMetric(Metric):
""" """
Metric For Paddle Model Metric For Paddle Model
...@@ -117,7 +119,8 @@ class PaddleAUCMetric(Metric): ...@@ -117,7 +119,8 @@ class PaddleAUCMetric(Metric):
return result return result
def calculate_auc(self, global_pos, global_neg): def calculate_auc(self, global_pos, global_neg):
""" """ """R
"""
num_bucket = len(global_pos) num_bucket = len(global_pos)
area = 0.0 area = 0.0
pos = 0.0 pos = 0.0
...@@ -142,7 +145,8 @@ class PaddleAUCMetric(Metric): ...@@ -142,7 +145,8 @@ class PaddleAUCMetric(Metric):
return auc_value return auc_value
def calculate_bucket_error(self, global_pos, global_neg): def calculate_bucket_error(self, global_pos, global_neg):
""" """ """R
"""
num_bucket = len(global_pos) num_bucket = len(global_pos)
last_ctr = -1.0 last_ctr = -1.0
impression_sum = 0.0 impression_sum = 0.0
......
"""
Define A Trainer Base
"""
import abc
import sys import sys
import time import time
from abc import ABCMeta, abstractmethod
class Trainer(object): class Trainer(object):
__metaclass__=ABCMeta """R
"""
__metaclass__ = self.ABCMeta
def __init__(self, config): def __init__(self, config):
"""R
"""
self._status_processor = {} self._status_processor = {}
self._context = {'status': 'uninit', 'is_exit': False} self._context = {'status': 'uninit', 'is_exit': False}
def regist_context_processor(self, status_name, processor): def regist_context_processor(self, status_name, processor):
"""
regist a processor for specify status
"""
self._status_processor[status_name] = processor self._status_processor[status_name] = processor
def context_process(self, context): def context_process(self, context):
"""
select a processor to deal specify context
Args:
context : context with status
Return:
None : run a processor for this status
"""
if context['status'] in self._status_processor: if context['status'] in self._status_processor:
self._status_processor[context['status']](context) self._status_processor[context['status']](context)
else: else:
self.other_status_processor(context) self.other_status_processor(context)
def other_status_processor(self, context): def other_status_processor(self, context):
"""
if no processor match context.status, use defalut processor
Return:
None, just sleep in base
"""
print('unknow context_status:%s, do nothing' % context['status']) print('unknow context_status:%s, do nothing' % context['status'])
time.sleep(60) time.sleep(60)
def reload_train_context(self): def reload_train_context(self):
"""
context maybe update timely, reload for update
"""
pass pass
def run(self): def run(self):
"""
keep running by statu context.
"""
while True: while True:
self.reload_train_context() self.reload_train_context()
self.context_process(self._context) self.context_process(self._context)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册