提交 4ab6cc59 编写于 作者: W wangxiao1021

update setup.py & fix bugs

上级 c21afb28
...@@ -25,11 +25,13 @@ from paddle.fluid import layers ...@@ -25,11 +25,13 @@ from paddle.fluid import layers
def _check_and_adapt_shape_dtype(rt_val, attr, message=""): def _check_and_adapt_shape_dtype(rt_val, attr, message=""):
if not isinstance(rt_val, np.ndarray): if not isinstance(rt_val, np.ndarray):
rt_val = np.array(rt_val) rt_val = np.array(rt_val)
print(message + 'int first if block.')
assert rt_val.dtype != np.dtype('O'), "yielded data is not a valid tensor(number of elements on some dimension may differ)." assert rt_val.dtype != np.dtype('O'), "yielded data is not a valid tensor(number of elements on some dimension may differ)."
if rt_val.dtype == np.dtype('float64'): if rt_val.dtype == np.dtype('float64'):
rt_val = rt_val.astype('float32') rt_val = rt_val.astype('float32')
shape, dtype = attr shape, dtype = attr
# rt_val = np.array(rt_val, dtype=dtype)
assert rt_val.dtype == np.dtype(dtype), message+"yielded data type not consistent with attr settings. Expect: {}, receive: {}.".format(rt_val.dtype, np.dtype(dtype)) assert rt_val.dtype == np.dtype(dtype), message+"yielded data type not consistent with attr settings. Expect: {}, receive: {}.".format(rt_val.dtype, np.dtype(dtype))
assert len(shape) == rt_val.ndim, message+"yielded data rank(ndim) not consistent with attr settings. Expect: {}, receive: {}.".format(len(shape), rt_val.ndim) assert len(shape) == rt_val.ndim, message+"yielded data rank(ndim) not consistent with attr settings. Expect: {}, receive: {}.".format(len(shape), rt_val.ndim)
for rt, exp in zip(rt_val.shape, shape): for rt, exp in zip(rt_val.shape, shape):
...@@ -161,14 +163,14 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype ...@@ -161,14 +163,14 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype
if 'token_ids' in outputs: if 'token_ids' in outputs:
val1 = len(outputs['token_ids']) val1 = len(outputs['token_ids'])
val = _check_and_adapt_shape_dtype([val1], [[1], 'int64']) val = _check_and_adapt_shape_dtype(np.array([val1], dtype='int64'), [[1], 'int64'], iterator_prefixes[id]+' tokenids: ')
results[outname_to_pos['batch_size']] = val results[outname_to_pos['batch_size']] = val
val2 = len(outputs['token_ids'][0]) val2 = len(outputs['token_ids'][0])
val = _check_and_adapt_shape_dtype([val2], [[1], 'int64']) val = _check_and_adapt_shape_dtype(np.array([val2], dtype='int64'), [[1], 'int64'])
results[outname_to_pos['seqlen']] = val results[outname_to_pos['seqlen']] = val
val = _check_and_adapt_shape_dtype([val1*val2], [[1], 'int64']) val = _check_and_adapt_shape_dtype(np.array([val1*val2], dtype='int64'), [[1], 'int64'])
results[outname_to_pos['batchsize_x_seqlen']] = val results[outname_to_pos['batchsize_x_seqlen']] = val
else: else:
if not has_show_warn: if not has_show_warn:
......
...@@ -21,7 +21,7 @@ Authors: zhouxiangyang(zhouxiangyang@baidu.com) ...@@ -21,7 +21,7 @@ Authors: zhouxiangyang(zhouxiangyang@baidu.com)
Date: 2019/09/29 21:00:01 Date: 2019/09/29 21:00:01
""" """
import setuptools import setuptools
with open("README.md", "r") as fh: with open("README.md", "r", encoding='utf-8') as fh:
long_description = fh.read() long_description = fh.read()
setuptools.setup( setuptools.setup(
name="paddlepalm", name="paddlepalm",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册