提交 4ab6cc59 编写于 作者: W wangxiao1021

update setup.py & fix bugs

上级 c21afb28
......@@ -25,11 +25,13 @@ from paddle.fluid import layers
def _check_and_adapt_shape_dtype(rt_val, attr, message=""):
if not isinstance(rt_val, np.ndarray):
rt_val = np.array(rt_val)
print(message + 'int first if block.')
assert rt_val.dtype != np.dtype('O'), "yielded data is not a valid tensor(number of elements on some dimension may differ)."
if rt_val.dtype == np.dtype('float64'):
rt_val = rt_val.astype('float32')
shape, dtype = attr
# rt_val = np.array(rt_val, dtype=dtype)
assert rt_val.dtype == np.dtype(dtype), message+"yielded data type not consistent with attr settings. Expect: {}, receive: {}.".format(rt_val.dtype, np.dtype(dtype))
assert len(shape) == rt_val.ndim, message+"yielded data rank(ndim) not consistent with attr settings. Expect: {}, receive: {}.".format(len(shape), rt_val.ndim)
for rt, exp in zip(rt_val.shape, shape):
......@@ -161,14 +163,14 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype
if 'token_ids' in outputs:
val1 = len(outputs['token_ids'])
val = _check_and_adapt_shape_dtype([val1], [[1], 'int64'])
val = _check_and_adapt_shape_dtype(np.array([val1], dtype='int64'), [[1], 'int64'], iterator_prefixes[id]+' tokenids: ')
results[outname_to_pos['batch_size']] = val
val2 = len(outputs['token_ids'][0])
val = _check_and_adapt_shape_dtype([val2], [[1], 'int64'])
val = _check_and_adapt_shape_dtype(np.array([val2], dtype='int64'), [[1], 'int64'])
results[outname_to_pos['seqlen']] = val
val = _check_and_adapt_shape_dtype([val1*val2], [[1], 'int64'])
val = _check_and_adapt_shape_dtype(np.array([val1*val2], dtype='int64'), [[1], 'int64'])
results[outname_to_pos['batchsize_x_seqlen']] = val
else:
if not has_show_warn:
......
......@@ -21,7 +21,7 @@ Authors: zhouxiangyang(zhouxiangyang@baidu.com)
Date: 2019/09/29 21:00:01
"""
import setuptools
with open("README.md", "r") as fh:
with open("README.md", "r", encoding='utf-8') as fh:
long_description = fh.read()
setuptools.setup(
name="paddlepalm",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册