未验证 提交 c54887dd 编写于 作者: K kinghuin 提交者: GitHub

fix the windows bug in ernie_gen (#905)

上级 a6ceff1c
...@@ -184,3 +184,7 @@ paddlehub >= 1.7.0 ...@@ -184,3 +184,7 @@ paddlehub >= 1.7.0
* 1.0.1 * 1.0.1
修复模型导出bug 修复模型导出bug
* 1.0.2
修复windows运行中的bug
...@@ -39,7 +39,7 @@ import ernie_gen.propeller.paddle as propeller ...@@ -39,7 +39,7 @@ import ernie_gen.propeller.paddle as propeller
@moduleinfo( @moduleinfo(
name="ernie_gen", name="ernie_gen",
version="1.0.1", version="1.0.2",
summary= summary=
"ERNIE-GEN is a multi-flow language generation framework for both pre-training and fine-tuning.", "ERNIE-GEN is a multi-flow language generation framework for both pre-training and fine-tuning.",
author="baidu", author="baidu",
...@@ -371,10 +371,11 @@ class ErnieGen(hub.Module): ...@@ -371,10 +371,11 @@ class ErnieGen(hub.Module):
src_ids = src_ids[:self.max_encode_len] src_ids = src_ids[:self.max_encode_len]
tgt_ids = tgt_ids[:self.max_decode_len] tgt_ids = tgt_ids[:self.max_decode_len]
src_ids, src_sids = self.tokenizer.build_for_ernie(src_ids) src_ids, src_sids = self.tokenizer.build_for_ernie(src_ids)
src_pids = np.arange(len(src_ids)) src_pids = np.arange(len(src_ids), dtype=np.int64)
tgt_ids, tgt_sids = self.tokenizer.build_for_ernie(tgt_ids) tgt_ids, tgt_sids = self.tokenizer.build_for_ernie(tgt_ids)
tgt_pids = np.arange(len(tgt_ids)) + len(src_ids) # continues position tgt_pids = np.arange(
len(tgt_ids), dtype=np.int64) + len(src_ids) # continues position
tgt_sids = np.ones_like(tgt_sids) tgt_sids = np.ones_like(tgt_sids)
attn_ids = np.ones_like(tgt_ids) * self.tokenizer.vocab['[MASK]'] attn_ids = np.ones_like(tgt_ids) * self.tokenizer.vocab['[MASK]']
......
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
// model training or inference. // model training or inference.
syntax = "proto3"; syntax = "proto3";
import "propeller/paddle/data/feature.proto"; import "ernie_gen.propeller/paddle/data/feature.proto";
package propeller; package ernie_gen.propeller;
message Example { message Example {
Features features = 1; Features features = 1;
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
syntax = "proto3"; syntax = "proto3";
package propeller; package ernie_gen.propeller;
message BytesList { message BytesList {
repeated bytes value = 1; repeated bytes value = 1;
......
...@@ -125,7 +125,7 @@ class LabelColumn(Column): ...@@ -125,7 +125,7 @@ class LabelColumn(Column):
ids = int(raw) ids = int(raw)
else: else:
ids = self.vocab[raw] ids = self.vocab[raw]
return ids return np.array(ids, dtype=np.int64)
class TextColumn(Column): class TextColumn(Column):
......
...@@ -73,7 +73,7 @@ class TqdmProgressBarHook(RunHook): ...@@ -73,7 +73,7 @@ class TqdmProgressBarHook(RunHook):
"""doc""" """doc"""
self.tqdm = None self.tqdm = None
import tqdm import tqdm
from propeller import log as main_log from ernie_gen.propeller import log as main_log
hdl = main_log.handlers[0] hdl = main_log.handlers[0]
class _TqdmLogginHandler(logging.Handler): class _TqdmLogginHandler(logging.Handler):
...@@ -110,7 +110,7 @@ class TqdmNotebookProgressBarHook(RunHook): ...@@ -110,7 +110,7 @@ class TqdmNotebookProgressBarHook(RunHook):
"""doc""" """doc"""
self.tqdm = None self.tqdm = None
import tqdm import tqdm
from propeller import log as main_log from ernie_gen.propeller import log as main_log
hdl = main_log.handlers[0] hdl = main_log.handlers[0]
class _TqdmLogginHandler(logging.Handler): class _TqdmLogginHandler(logging.Handler):
...@@ -144,7 +144,7 @@ class TqdmNotebookProgressBarHook(RunHook): ...@@ -144,7 +144,7 @@ class TqdmNotebookProgressBarHook(RunHook):
class LoggingHook(RunHook): class LoggingHook(RunHook):
"""log tensor in to screan and tensorboard""" """log tensor in to screan and VisualDL"""
def __init__(self, def __init__(self,
loss, loss,
...@@ -205,7 +205,7 @@ class LoggingHook(RunHook): ...@@ -205,7 +205,7 @@ class LoggingHook(RunHook):
speed = -1. speed = -1.
self.last_state = state self.last_state = state
# log to tensorboard # log to VisualDL
if self.writer is not None: if self.writer is not None:
self.writer.add_scalar('loss', loss, state.gstep) self.writer.add_scalar('loss', loss, state.gstep)
for name, t in zip(self.s_name, s_np): for name, t in zip(self.s_name, s_np):
......
...@@ -48,11 +48,11 @@ __all__ = ['train_and_eval', 'Learner'] ...@@ -48,11 +48,11 @@ __all__ = ['train_and_eval', 'Learner']
def _get_summary_writer(path): def _get_summary_writer(path):
summary_writer = None summary_writer = None
try: try:
from tensorboardX import SummaryWriter from visualdl import LogWriter
if distribution.status.is_master: if distribution.status.is_master:
summary_writer = SummaryWriter(os.path.join(path)) summary_writer = LogWriter(os.path.join(path))
except ImportError: except ImportError:
log.warning('tensorboardX not installed, will not log to tensorboard') log.warning('VisualDL not installed, will not log to VisualDL')
return summary_writer return summary_writer
...@@ -69,7 +69,7 @@ def _log_eval_result(name, eval_result, swriter, state): ...@@ -69,7 +69,7 @@ def _log_eval_result(name, eval_result, swriter, state):
printable.append('{}\t{}'.format(n, val)) printable.append('{}\t{}'.format(n, val))
if swriter is not None: if swriter is not None:
swriter.add_scalar(n, val, state.gstep) swriter.add_scalar(n, val, state.gstep)
log.debug('write to tensorboard %s' % swriter.logdir) log.debug('write to VisualDL %s' % swriter.logdir)
if len(printable): if len(printable):
log.info('*** eval res: %10s ***' % name) log.info('*** eval res: %10s ***' % name)
...@@ -134,10 +134,10 @@ class Learner(object): ...@@ -134,10 +134,10 @@ class Learner(object):
if run_config.model_dir is None: if run_config.model_dir is None:
raise ValueError('model_dir should specified in run_config') raise ValueError('model_dir should specified in run_config')
if issubclass(model_class_or_model_fn, Model): if inspect.isfunction(model_class_or_model_fn):
_model_fn = _build_model_fn(model_class_or_model_fn)
elif inspect.isfunction(model_class_or_model_fn):
_model_fn = model_class_or_model_fn _model_fn = model_class_or_model_fn
elif issubclass(model_class_or_model_fn, Model):
_model_fn = _build_model_fn(model_class_or_model_fn)
else: else:
raise ValueError('unknown model %s' % model_class_or_model_fn) raise ValueError('unknown model %s' % model_class_or_model_fn)
......
...@@ -71,8 +71,8 @@ def run_worker(model_dir, device_idx, endpoint="ipc://worker.ipc"): ...@@ -71,8 +71,8 @@ def run_worker(model_dir, device_idx, endpoint="ipc://worker.ipc"):
"CUDA_VISIBLE_DEVICES").split(",")[device_idx] "CUDA_VISIBLE_DEVICES").split(",")[device_idx]
log.debug('cuda_env %s' % os.environ["CUDA_VISIBLE_DEVICES"]) log.debug('cuda_env %s' % os.environ["CUDA_VISIBLE_DEVICES"])
import paddle.fluid as F import paddle.fluid as F
from propeller.service import interface_pb2 from ernie_gen.propeller.service import interface_pb2
import propeller.service.utils as serv_utils import ernie_gen.propeller.service.utils as serv_utils
context = zmq.Context() context = zmq.Context()
socket = context.socket(zmq.REP) socket = context.socket(zmq.REP)
socket.connect(endpoint) socket.connect(endpoint)
......
...@@ -26,7 +26,6 @@ import collections ...@@ -26,7 +26,6 @@ import collections
from distutils import dir_util from distutils import dir_util
import pickle import pickle
#from utils import print_arguments
import paddle.fluid as F import paddle.fluid as F
from paddle.fluid.proto import framework_pb2 from paddle.fluid.proto import framework_pb2
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册