提交 44bf33fb 编写于 作者: B baiyfbupt

minor fix

上级 638a980c
...@@ -12,14 +12,25 @@ ...@@ -12,14 +12,25 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import logging
import numpy as np
from six.moves.queue import Queue
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.framework import Variable from paddle.fluid.framework import Variable
from paddle.fluid.reader import DataLoaderBase from paddle.fluid.reader import DataLoaderBase
from paddle.fluid.core import EOFException from paddle.fluid.core import EOFException
from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient
from six.moves.queue import Queue
import numpy as np __all__ = ["HDFSClient"]
import os
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
__all__ = ['Knowledge'] __all__ = ['Knowledge']
...@@ -55,11 +66,12 @@ class Knowledge(object): ...@@ -55,11 +66,12 @@ class Knowledge(object):
self.path = path self.path = path
if isinstance(self.path, list): if isinstance(self.path, list):
self.write_type = 'HDFS/AFS' self.write_type = 'HDFS/AFS'
assert (len(self.path) == 4 and isinstance(self.path[0], str) and assert (
isinstance(self.path[1], str) and len(self.path) == 4 and isinstance(self.path[0], str) and
isinstance(self.path[2], str) and isinstance(self.path[1], str) and
isinstance(self.path[3], str)), "path should contains four \ isinstance(self.path[2], str) and
str, ['local hadoop home', 'fs.default.name', 'hadoop.job.ugi', 'FS path']" isinstance(self.path[3], str)
), "path should contains four str, ['local hadoop home', 'fs.default.name', 'hadoop.job.ugi', 'FS path']"
hadoop_home = self.path[0] hadoop_home = self.path[0]
configs = { configs = {
...@@ -67,9 +79,9 @@ class Knowledge(object): ...@@ -67,9 +79,9 @@ class Knowledge(object):
"hadoop.job.ugi": self.path[2] "hadoop.job.ugi": self.path[2]
} }
self.client = HDFSClient(hadoop_home, configs) self.client = HDFSClient(hadoop_home, configs)
assert (self.client.is_exist(self.path[3]) == True assert (
), "Plese make sure your hadoop \ self.client.is_exist(self.path[3]) == True
confiuration is correct and FS path exists" ), "Plese make sure your hadoop confiuration is correct and FS path exists"
self.hdfs_local_path = "./teacher_knowledge" self.hdfs_local_path = "./teacher_knowledge"
if not os.path.exists(self.hdfs_local_path): if not os.path.exists(self.hdfs_local_path):
...@@ -89,18 +101,19 @@ class Knowledge(object): ...@@ -89,18 +101,19 @@ class Knowledge(object):
def _write(self, data): def _write(self, data):
if self.write_type == 'HDFS/AFS': if self.write_type == 'HDFS/AFS':
file_name = 'knowledge_' + str(self.file_cnt) file_name = 'knowledge_' + str(self.file_cnt)
file_path = self.hdfs_local_path + file_name file_path = os.path.join(self.hdfs_local_path, file_name)
file_path += ".npy"
np.save(file_path, data) np.save(file_path, data)
self.file_cnt += 1 self.file_cnt += 1
self.client.upload(self.path[3], self.local_file_path) self.client.upload(self.path[3], file_path)
print('{}.npy pushed to HDFS/AFS: {}'.format(file_name, self.path[ logger.info('{}.npy pushed to HDFS/AFS: {}'.format(
3])) file_name, self.path[3]))
elif self.write_type == 'LocalFS': elif self.write_type == 'LocalFS':
file_name = 'knowledge_' + str(self.file_cnt) file_name = 'knowledge_' + str(self.file_cnt)
file_path = os.path.join(self.path, file_name) file_path = os.path.join(self.path, file_name)
np.save(file_path, data) np.save(file_path, data)
print('{}.npy saved'.format(file_name)) logger.info('{}.npy saved'.format(file_name))
self.file_cnt += 1 self.file_cnt += 1
else: else:
...@@ -128,11 +141,7 @@ class Knowledge(object): ...@@ -128,11 +141,7 @@ class Knowledge(object):
teacher_program, teacher_program,
fluid.Program)), "teacher_program should be a fluid.Program" fluid.Program)), "teacher_program should be a fluid.Program"
assert (isinstance(inputs, list)), "inputs should be a list" assert (isinstance(inputs, list)), "inputs should be a list"
if len(inputs) > 0:
assert (isinstance(inputs[0], str)), "inputs shoud be list<str>"
assert (isinstance(outputs, list)), "outputs should be a list" assert (isinstance(outputs, list)), "outputs should be a list"
if len(outputs) > 0:
assert (isinstance(outputs[0], str)), "outputs should be list<str>"
assert (len(self.items) == len(outputs) assert (len(self.items) == len(outputs)
), "the length of outputs list should be equal with items list" ), "the length of outputs list should be equal with items list"
assert (callable(call_back) or (call_back is None) assert (callable(call_back) or (call_back is None)
...@@ -145,22 +154,24 @@ class Knowledge(object): ...@@ -145,22 +154,24 @@ class Knowledge(object):
teacher_program) teacher_program)
teacher_knowledge = [] teacher_knowledge = []
self.file_cnt = 0 self.file_cnt = 0
if isinstance(reader, Variable) or ( if isinstance(reader,
isinstance(reader, DataLoaderBase) and (not reader.iterable)): Variable) or (isinstance(reader, DataLoaderBase) and
(not reader.iterable)):
reader.start() reader.start()
try: try:
batch_id = 0 batch_id = 0
while True: while True:
logits = exe.run(compiled_teacher_program, logits = exe.run(
scope=scope, compiled_teacher_program,
fetch_list=outputs, scope=scope,
feed=None) fetch_list=outputs,
feed=None)
knowledge = dict() knowledge = dict()
for index, array in enumerate(logits): for index, array in enumerate(logits):
knowledge[self.items[index]] = array knowledge[self.items[index]] = array
teacher_knowledge.append(knowledge) teacher_knowledge.append(knowledge)
if batch_id % 1 == 0: if batch_id % 1 == 0:
print('infer finish iter {}'.format(batch_id)) logger.info('infer finish iter {}'.format(batch_id))
if len(teacher_knowledge) >= 4: if len(teacher_knowledge) >= 4:
self._write(teacher_knowledge) self._write(teacher_knowledge)
teacher_knowledge = [] teacher_knowledge = []
...@@ -174,16 +185,17 @@ class Knowledge(object): ...@@ -174,16 +185,17 @@ class Knowledge(object):
feed_list=inputs, place=place, program=teacher_program) feed_list=inputs, place=place, program=teacher_program)
for batch_id, data in enumerate(reader()): for batch_id, data in enumerate(reader()):
feed = feeder.feed(data) feed = feeder.feed(data)
logits = exe.run(compiled_teacher_program, logits = exe.run(
scope=scope, compiled_teacher_program,
fetch_list=outputs, scope=scope,
feed=feed) fetch_list=outputs,
feed=feed)
knowledge = dict() knowledge = dict()
for index, array in enumerate(logits): for index, array in enumerate(logits):
knowledge[self.items[index]] = array knowledge[self.items[index]] = array
teacher_knowledge.append(knowledge) teacher_knowledge.append(knowledge)
if batch_id % 1 == 0: if batch_id % 1 == 0:
print('infer finish iter {}'.format(batch_id)) logger.info('infer finish iter {}'.format(batch_id))
if len(teacher_knowledge) >= 4: if len(teacher_knowledge) >= 4:
self._write(teacher_knowledge) self._write(teacher_knowledge)
teacher_knowledge = [] teacher_knowledge = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册