提交 44bf33fb 编写于 作者: B baiyfbupt

minor fix

上级 638a980c
......@@ -12,14 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import logging
import numpy as np
from six.moves.queue import Queue
import paddle.fluid as fluid
from paddle.fluid.framework import Variable
from paddle.fluid.reader import DataLoaderBase
from paddle.fluid.core import EOFException
from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient
from six.moves.queue import Queue
import numpy as np
import os
__all__ = ["HDFSClient"]
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
__all__ = ['Knowledge']
......@@ -55,11 +66,12 @@ class Knowledge(object):
self.path = path
if isinstance(self.path, list):
self.write_type = 'HDFS/AFS'
assert (len(self.path) == 4 and isinstance(self.path[0], str) and
isinstance(self.path[1], str) and
isinstance(self.path[2], str) and
isinstance(self.path[3], str)), "path should contains four \
str, ['local hadoop home', 'fs.default.name', 'hadoop.job.ugi', 'FS path']"
assert (
len(self.path) == 4 and isinstance(self.path[0], str) and
isinstance(self.path[1], str) and
isinstance(self.path[2], str) and
isinstance(self.path[3], str)
), "path should contains four str, ['local hadoop home', 'fs.default.name', 'hadoop.job.ugi', 'FS path']"
hadoop_home = self.path[0]
configs = {
......@@ -67,9 +79,9 @@ class Knowledge(object):
"hadoop.job.ugi": self.path[2]
}
self.client = HDFSClient(hadoop_home, configs)
assert (self.client.is_exist(self.path[3]) == True
), "Plese make sure your hadoop \
confiuration is correct and FS path exists"
assert (
self.client.is_exist(self.path[3]) == True
), "Plese make sure your hadoop confiuration is correct and FS path exists"
self.hdfs_local_path = "./teacher_knowledge"
if not os.path.exists(self.hdfs_local_path):
......@@ -89,18 +101,19 @@ class Knowledge(object):
def _write(self, data):
if self.write_type == 'HDFS/AFS':
file_name = 'knowledge_' + str(self.file_cnt)
file_path = self.hdfs_local_path + file_name
file_path = os.path.join(self.hdfs_local_path, file_name)
file_path += ".npy"
np.save(file_path, data)
self.file_cnt += 1
self.client.upload(self.path[3], self.local_file_path)
print('{}.npy pushed to HDFS/AFS: {}'.format(file_name, self.path[
3]))
self.client.upload(self.path[3], file_path)
logger.info('{}.npy pushed to HDFS/AFS: {}'.format(
file_name, self.path[3]))
elif self.write_type == 'LocalFS':
file_name = 'knowledge_' + str(self.file_cnt)
file_path = os.path.join(self.path, file_name)
np.save(file_path, data)
print('{}.npy saved'.format(file_name))
logger.info('{}.npy saved'.format(file_name))
self.file_cnt += 1
else:
......@@ -128,11 +141,7 @@ class Knowledge(object):
teacher_program,
fluid.Program)), "teacher_program should be a fluid.Program"
assert (isinstance(inputs, list)), "inputs should be a list"
if len(inputs) > 0:
assert (isinstance(inputs[0], str)), "inputs shoud be list<str>"
assert (isinstance(outputs, list)), "outputs should be a list"
if len(outputs) > 0:
assert (isinstance(outputs[0], str)), "outputs should be list<str>"
assert (len(self.items) == len(outputs)
), "the length of outputs list should be equal with items list"
assert (callable(call_back) or (call_back is None)
......@@ -145,22 +154,24 @@ class Knowledge(object):
teacher_program)
teacher_knowledge = []
self.file_cnt = 0
if isinstance(reader, Variable) or (
isinstance(reader, DataLoaderBase) and (not reader.iterable)):
if isinstance(reader,
Variable) or (isinstance(reader, DataLoaderBase) and
(not reader.iterable)):
reader.start()
try:
batch_id = 0
while True:
logits = exe.run(compiled_teacher_program,
scope=scope,
fetch_list=outputs,
feed=None)
logits = exe.run(
compiled_teacher_program,
scope=scope,
fetch_list=outputs,
feed=None)
knowledge = dict()
for index, array in enumerate(logits):
knowledge[self.items[index]] = array
teacher_knowledge.append(knowledge)
if batch_id % 1 == 0:
print('infer finish iter {}'.format(batch_id))
logger.info('infer finish iter {}'.format(batch_id))
if len(teacher_knowledge) >= 4:
self._write(teacher_knowledge)
teacher_knowledge = []
......@@ -174,16 +185,17 @@ class Knowledge(object):
feed_list=inputs, place=place, program=teacher_program)
for batch_id, data in enumerate(reader()):
feed = feeder.feed(data)
logits = exe.run(compiled_teacher_program,
scope=scope,
fetch_list=outputs,
feed=feed)
logits = exe.run(
compiled_teacher_program,
scope=scope,
fetch_list=outputs,
feed=feed)
knowledge = dict()
for index, array in enumerate(logits):
knowledge[self.items[index]] = array
teacher_knowledge.append(knowledge)
if batch_id % 1 == 0:
print('infer finish iter {}'.format(batch_id))
logger.info('infer finish iter {}'.format(batch_id))
if len(teacher_knowledge) >= 4:
self._write(teacher_knowledge)
teacher_knowledge = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册