未验证 提交 c54887dd 编写于 作者: K kinghuin 提交者: GitHub

fix the windows bug in ernie_gen (#905)

上级 a6ceff1c
......@@ -184,3 +184,7 @@ paddlehub >= 1.7.0
* 1.0.1
修复模型导出bug
* 1.0.2
修复windows运行中的bug
......@@ -39,7 +39,7 @@ import ernie_gen.propeller.paddle as propeller
@moduleinfo(
name="ernie_gen",
version="1.0.1",
version="1.0.2",
summary=
"ERNIE-GEN is a multi-flow language generation framework for both pre-training and fine-tuning.",
author="baidu",
......@@ -371,10 +371,11 @@ class ErnieGen(hub.Module):
src_ids = src_ids[:self.max_encode_len]
tgt_ids = tgt_ids[:self.max_decode_len]
src_ids, src_sids = self.tokenizer.build_for_ernie(src_ids)
src_pids = np.arange(len(src_ids))
src_pids = np.arange(len(src_ids), dtype=np.int64)
tgt_ids, tgt_sids = self.tokenizer.build_for_ernie(tgt_ids)
tgt_pids = np.arange(len(tgt_ids)) + len(src_ids) # continues position
tgt_pids = np.arange(
len(tgt_ids), dtype=np.int64) + len(src_ids) # continues position
tgt_sids = np.ones_like(tgt_sids)
attn_ids = np.ones_like(tgt_ids) * self.tokenizer.vocab['[MASK]']
......
......@@ -16,8 +16,8 @@
// model training or inference.
syntax = "proto3";
import "propeller/paddle/data/feature.proto";
package propeller;
import "ernie_gen.propeller/paddle/data/feature.proto";
package ernie_gen.propeller;
message Example {
Features features = 1;
......
......@@ -13,7 +13,7 @@
// limitations under the License.
syntax = "proto3";
package propeller;
package ernie_gen.propeller;
message BytesList {
repeated bytes value = 1;
......
......@@ -125,7 +125,7 @@ class LabelColumn(Column):
ids = int(raw)
else:
ids = self.vocab[raw]
return ids
return np.array(ids, dtype=np.int64)
class TextColumn(Column):
......
......@@ -73,7 +73,7 @@ class TqdmProgressBarHook(RunHook):
"""doc"""
self.tqdm = None
import tqdm
from propeller import log as main_log
from ernie_gen.propeller import log as main_log
hdl = main_log.handlers[0]
class _TqdmLogginHandler(logging.Handler):
......@@ -110,7 +110,7 @@ class TqdmNotebookProgressBarHook(RunHook):
"""doc"""
self.tqdm = None
import tqdm
from propeller import log as main_log
from ernie_gen.propeller import log as main_log
hdl = main_log.handlers[0]
class _TqdmLogginHandler(logging.Handler):
......@@ -144,7 +144,7 @@ class TqdmNotebookProgressBarHook(RunHook):
class LoggingHook(RunHook):
"""log tensor in to screan and tensorboard"""
"""log tensor in to screan and VisualDL"""
def __init__(self,
loss,
......@@ -205,7 +205,7 @@ class LoggingHook(RunHook):
speed = -1.
self.last_state = state
# log to tensorboard
# log to VisualDL
if self.writer is not None:
self.writer.add_scalar('loss', loss, state.gstep)
for name, t in zip(self.s_name, s_np):
......
......@@ -48,11 +48,11 @@ __all__ = ['train_and_eval', 'Learner']
def _get_summary_writer(path):
summary_writer = None
try:
from tensorboardX import SummaryWriter
from visualdl import LogWriter
if distribution.status.is_master:
summary_writer = SummaryWriter(os.path.join(path))
summary_writer = LogWriter(os.path.join(path))
except ImportError:
log.warning('tensorboardX not installed, will not log to tensorboard')
log.warning('VisualDL not installed, will not log to VisualDL')
return summary_writer
......@@ -69,7 +69,7 @@ def _log_eval_result(name, eval_result, swriter, state):
printable.append('{}\t{}'.format(n, val))
if swriter is not None:
swriter.add_scalar(n, val, state.gstep)
log.debug('write to tensorboard %s' % swriter.logdir)
log.debug('write to VisualDL %s' % swriter.logdir)
if len(printable):
log.info('*** eval res: %10s ***' % name)
......@@ -134,10 +134,10 @@ class Learner(object):
if run_config.model_dir is None:
raise ValueError('model_dir should specified in run_config')
if issubclass(model_class_or_model_fn, Model):
_model_fn = _build_model_fn(model_class_or_model_fn)
elif inspect.isfunction(model_class_or_model_fn):
if inspect.isfunction(model_class_or_model_fn):
_model_fn = model_class_or_model_fn
elif issubclass(model_class_or_model_fn, Model):
_model_fn = _build_model_fn(model_class_or_model_fn)
else:
raise ValueError('unknown model %s' % model_class_or_model_fn)
......
......@@ -71,8 +71,8 @@ def run_worker(model_dir, device_idx, endpoint="ipc://worker.ipc"):
"CUDA_VISIBLE_DEVICES").split(",")[device_idx]
log.debug('cuda_env %s' % os.environ["CUDA_VISIBLE_DEVICES"])
import paddle.fluid as F
from propeller.service import interface_pb2
import propeller.service.utils as serv_utils
from ernie_gen.propeller.service import interface_pb2
import ernie_gen.propeller.service.utils as serv_utils
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.connect(endpoint)
......
......@@ -26,7 +26,6 @@ import collections
from distutils import dir_util
import pickle
#from utils import print_arguments
import paddle.fluid as F
from paddle.fluid.proto import framework_pb2
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册