未验证 提交 5a575859 编写于 作者: N Nyakku Shigure 提交者: GitHub

[CodeStyle] remove compat module (long_type, int_type, get_exception_message,...

[CodeStyle] remove compat module (long_type, int_type, get_exception_message, floor_division) (#46686)
上级 eb6bcc26
...@@ -17,9 +17,6 @@ import math ...@@ -17,9 +17,6 @@ import math
__all__ = [] __all__ = []
int_type = int
long_type = int
# str and bytes related functions # str and bytes related functions
def to_text(obj, encoding='utf-8', inplace=False): def to_text(obj, encoding='utf-8', inplace=False):
...@@ -227,35 +224,3 @@ def round(x, d=0): ...@@ -227,35 +224,3 @@ def round(x, d=0):
else: else:
import __builtin__ import __builtin__
return __builtin__.round(x, d) return __builtin__.round(x, d)
def floor_division(x, y):
"""
Compatible division which act the same behaviour in Python3 and Python2,
whose result will be a int value of floor(x / y) in Python3 and value of
(x / y) in Python2.
Args:
x(int|float) : The number to divide.
y(int|float) : The number to be divided
Returns:
division result of x // y
"""
return x // y
# exception related functions
def get_exception_message(exc):
"""
Get the error message of a specific exception
Args:
exec(Exception) : The exception to get error message.
Returns:
the error message of exec
"""
assert exc is not None
return str(exc)
...@@ -47,18 +47,17 @@ except ImportError as e: ...@@ -47,18 +47,17 @@ except ImportError as e:
from .. import compat as cpt from .. import compat as cpt
if os.name == 'nt': if os.name == 'nt':
executable_path = os.path.abspath(os.path.dirname(sys.executable)) executable_path = os.path.abspath(os.path.dirname(sys.executable))
raise ImportError( raise ImportError("""NOTE: You may need to run \"set PATH=%s;%%PATH%%\"
"""NOTE: You may need to run \"set PATH=%s;%%PATH%%\"
if you encounters \"DLL load failed\" errors. If you have python if you encounters \"DLL load failed\" errors. If you have python
installed in other directory, replace \"%s\" with your own installed in other directory, replace \"%s\" with your own
directory. The original error is: \n %s""" % directory. The original error is: \n %s""" %
(executable_path, executable_path, cpt.get_exception_message(e))) (executable_path, executable_path, str(e)))
else: else:
raise ImportError( raise ImportError(
"""NOTE: You may need to run \"export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH\" """NOTE: You may need to run \"export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH\"
if you encounters \"libmkldnn.so not found\" errors. If you have python if you encounters \"libmkldnn.so not found\" errors. If you have python
installed in other directory, replace \"/usr/local/lib\" with your own installed in other directory, replace \"/usr/local/lib\" with your own
directory. The original error is: \n""" + cpt.get_exception_message(e)) directory. The original error is: \n""" + str(e))
except Exception as e: except Exception as e:
raise e raise e
...@@ -75,8 +74,7 @@ def avx_supported(): ...@@ -75,8 +74,7 @@ def avx_supported():
has_avx = os.popen('cat /proc/cpuinfo | grep -i avx').read() != '' has_avx = os.popen('cat /proc/cpuinfo | grep -i avx').read() != ''
except Exception as e: except Exception as e:
sys.stderr.write('Can not get the AVX flag from /proc/cpuinfo.\n' sys.stderr.write('Can not get the AVX flag from /proc/cpuinfo.\n'
'The original error is: %s\n' % 'The original error is: %s\n' % str(e))
cpt.get_exception_message(e))
return has_avx return has_avx
elif sysstr == 'darwin': elif sysstr == 'darwin':
try: try:
...@@ -85,7 +83,7 @@ def avx_supported(): ...@@ -85,7 +83,7 @@ def avx_supported():
except Exception as e: except Exception as e:
sys.stderr.write( sys.stderr.write(
'Can not get the AVX flag from machdep.cpu.features.\n' 'Can not get the AVX flag from machdep.cpu.features.\n'
'The original error is: %s\n' % cpt.get_exception_message(e)) 'The original error is: %s\n' % str(e))
if not has_avx: if not has_avx:
import subprocess import subprocess
pipe = subprocess.Popen( pipe = subprocess.Popen(
...@@ -155,8 +153,7 @@ def avx_supported(): ...@@ -155,8 +153,7 @@ def avx_supported():
ctypes.c_size_t(0), ONE_PAGE) ctypes.c_size_t(0), ONE_PAGE)
except Exception as e: except Exception as e:
sys.stderr.write('Failed getting the AVX flag on Windows.\n' sys.stderr.write('Failed getting the AVX flag on Windows.\n'
'The original error is: %s\n' % 'The original error is: %s\n' % str(e))
cpt.get_exception_message(e))
return (retval & (1 << avx_bit)) > 0 return (retval & (1 << avx_bit)) > 0
else: else:
sys.stderr.write('Do not get AVX flag on %s\n' % sysstr) sys.stderr.write('Do not get AVX flag on %s\n' % sysstr)
......
...@@ -27,7 +27,6 @@ import paddle.fluid as fluid ...@@ -27,7 +27,6 @@ import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
from test_dist_base import TestDistRunnerBase, runtime_main, RUN_STEP from test_dist_base import TestDistRunnerBase, runtime_main, RUN_STEP
import paddle.compat as cpt import paddle.compat as cpt
from paddle.compat import long_type
const_para_attr = fluid.ParamAttr(initializer=fluid.initializer.Constant(0.001)) const_para_attr = fluid.ParamAttr(initializer=fluid.initializer.Constant(0.001))
const_bias_attr = const_para_attr const_bias_attr = const_para_attr
...@@ -173,10 +172,10 @@ seq_len = ModelHyperParams.max_length ...@@ -173,10 +172,10 @@ seq_len = ModelHyperParams.max_length
input_descs = { input_descs = {
# The actual data shape of src_word is: # The actual data shape of src_word is:
# [batch_size * max_src_len_in_batch, 1] # [batch_size * max_src_len_in_batch, 1]
"src_word": [(batch_size, seq_len, long_type(1)), "int64", 2], "src_word": [(batch_size, seq_len, 1), "int64", 2],
# The actual data shape of src_pos is: # The actual data shape of src_pos is:
# [batch_size * max_src_len_in_batch, 1] # [batch_size * max_src_len_in_batch, 1]
"src_pos": [(batch_size, seq_len, long_type(1)), "int64"], "src_pos": [(batch_size, seq_len, 1), "int64"],
# This input is used to remove attention weights on paddings in the # This input is used to remove attention weights on paddings in the
# encoder. # encoder.
# The actual data shape of src_slf_attn_bias is: # The actual data shape of src_slf_attn_bias is:
...@@ -185,11 +184,11 @@ input_descs = { ...@@ -185,11 +184,11 @@ input_descs = {
[(batch_size, ModelHyperParams.n_head, seq_len, seq_len), "float32"], [(batch_size, ModelHyperParams.n_head, seq_len, seq_len), "float32"],
# The actual data shape of trg_word is: # The actual data shape of trg_word is:
# [batch_size * max_trg_len_in_batch, 1] # [batch_size * max_trg_len_in_batch, 1]
"trg_word": [(batch_size, seq_len, long_type(1)), "int64", "trg_word": [(batch_size, seq_len, 1), "int64",
2], # lod_level is only used in fast decoder. 2], # lod_level is only used in fast decoder.
# The actual data shape of trg_pos is: # The actual data shape of trg_pos is:
# [batch_size * max_trg_len_in_batch, 1] # [batch_size * max_trg_len_in_batch, 1]
"trg_pos": [(batch_size, seq_len, long_type(1)), "int64"], "trg_pos": [(batch_size, seq_len, 1), "int64"],
# This input is used to remove attention weights on paddings and # This input is used to remove attention weights on paddings and
# subsequent words in the decoder. # subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is: # The actual data shape of trg_slf_attn_bias is:
...@@ -208,15 +207,15 @@ input_descs = { ...@@ -208,15 +207,15 @@ input_descs = {
"enc_output": [(batch_size, seq_len, ModelHyperParams.d_model), "float32"], "enc_output": [(batch_size, seq_len, ModelHyperParams.d_model), "float32"],
# The actual data shape of label_word is: # The actual data shape of label_word is:
# [batch_size * max_trg_len_in_batch, 1] # [batch_size * max_trg_len_in_batch, 1]
"lbl_word": [(batch_size * seq_len, long_type(1)), "int64"], "lbl_word": [(batch_size * seq_len, 1), "int64"],
# This input is used to mask out the loss of padding tokens. # This input is used to mask out the loss of padding tokens.
# The actual data shape of label_weight is: # The actual data shape of label_weight is:
# [batch_size * max_trg_len_in_batch, 1] # [batch_size * max_trg_len_in_batch, 1]
"lbl_weight": [(batch_size * seq_len, long_type(1)), "float32"], "lbl_weight": [(batch_size * seq_len, 1), "float32"],
# These inputs are used to change the shape tensor in beam-search decoder. # These inputs are used to change the shape tensor in beam-search decoder.
"trg_slf_attn_pre_softmax_shape_delta": [(long_type(2), ), "int32"], "trg_slf_attn_pre_softmax_shape_delta": [(2, ), "int32"],
"trg_slf_attn_post_softmax_shape_delta": [(long_type(4), ), "int32"], "trg_slf_attn_post_softmax_shape_delta": [(4, ), "int32"],
"init_score": [(batch_size, long_type(1)), "float32"], "init_score": [(batch_size, 1), "float32"],
} }
# Names of word embedding table which might be reused for weight sharing. # Names of word embedding table which might be reused for weight sharing.
......
...@@ -18,10 +18,6 @@ import paddle.compat as cpt ...@@ -18,10 +18,6 @@ import paddle.compat as cpt
class TestCompatible(unittest.TestCase): class TestCompatible(unittest.TestCase):
def test_type(self):
self.assertEqual(cpt.int_type, int)
self.assertEqual(cpt.long_type, int)
def test_to_text(self): def test_to_text(self):
self.assertIsNone(cpt.to_text(None)) self.assertIsNone(cpt.to_text(None))
...@@ -252,30 +248,6 @@ class TestCompatible(unittest.TestCase): ...@@ -252,30 +248,6 @@ class TestCompatible(unittest.TestCase):
self.assertEqual(5.0, cpt.round(5)) self.assertEqual(5.0, cpt.round(5))
self.assertRaises(TypeError, cpt.round, None) self.assertRaises(TypeError, cpt.round, None)
def test_floor_division(self):
self.assertEqual(0.0, cpt.floor_division(3, 4))
self.assertEqual(1.0, cpt.floor_division(4, 3))
self.assertEqual(2.0, cpt.floor_division(6, 3))
self.assertEqual(-2.0, cpt.floor_division(-4, 3))
self.assertEqual(-2.0, cpt.floor_division(-6, 3))
self.assertRaises(ZeroDivisionError, cpt.floor_division, 3, 0)
self.assertRaises(TypeError, cpt.floor_division, None, None)
def test_get_exception_message(self):
exception_message = "test_message"
self.assertRaises(AssertionError, cpt.get_exception_message, None)
try:
raise RuntimeError(exception_message)
except Exception as e:
self.assertEqual(exception_message, cpt.get_exception_message(e))
self.assertIsNotNone(e)
try:
raise Exception(exception_message)
except Exception as e:
self.assertEqual(exception_message, cpt.get_exception_message(e))
self.assertIsNotNone(e)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -904,7 +904,7 @@ class TestDatasetWithFetchHandler(unittest.TestCase): ...@@ -904,7 +904,7 @@ class TestDatasetWithFetchHandler(unittest.TestCase):
print("warning: we skip trainer_desc_pb2 import problem in windows") print("warning: we skip trainer_desc_pb2 import problem in windows")
except RuntimeError as e: except RuntimeError as e:
error_msg = "dataset is need and should be initialized" error_msg = "dataset is need and should be initialized"
self.assertEqual(error_msg, cpt.get_exception_message(e)) self.assertEqual(error_msg, str(e))
except Exception as e: except Exception as e:
self.assertTrue(False) self.assertTrue(False)
...@@ -948,7 +948,7 @@ class TestDatasetWithFetchHandler(unittest.TestCase): ...@@ -948,7 +948,7 @@ class TestDatasetWithFetchHandler(unittest.TestCase):
print("warning: we skip trainer_desc_pb2 import problem in windows") print("warning: we skip trainer_desc_pb2 import problem in windows")
except RuntimeError as e: except RuntimeError as e:
error_msg = "dataset is need and should be initialized" error_msg = "dataset is need and should be initialized"
self.assertEqual(error_msg, cpt.get_exception_message(e)) self.assertEqual(error_msg, str(e))
except Exception as e: except Exception as e:
self.assertTrue(False) self.assertTrue(False)
......
...@@ -28,8 +28,7 @@ class TestException(unittest.TestCase): ...@@ -28,8 +28,7 @@ class TestException(unittest.TestCase):
try: try:
core.__unittest_throw_exception__() core.__unittest_throw_exception__()
except RuntimeError as ex: except RuntimeError as ex:
self.assertIn("This is a test of exception", self.assertIn("This is a test of exception", str(ex))
cpt.get_exception_message(ex))
exception = ex exception = ex
self.assertIsNotNone(exception) self.assertIsNotNone(exception)
......
...@@ -67,8 +67,7 @@ class TestDygraphDataLoaderWithException(unittest.TestCase): ...@@ -67,8 +67,7 @@ class TestDygraphDataLoaderWithException(unittest.TestCase):
for _ in loader(): for _ in loader():
print("test_single_process_with_thread_expection") print("test_single_process_with_thread_expection")
except core.EnforceNotMet as ex: except core.EnforceNotMet as ex:
self.assertIn("Blocking queue is killed", self.assertIn("Blocking queue is killed", str(ex))
cpt.get_exception_message(ex))
exception = ex exception = ex
self.assertIsNotNone(exception) self.assertIsNotNone(exception)
...@@ -130,8 +129,7 @@ class TestDygraphDataLoaderWithException(unittest.TestCase): ...@@ -130,8 +129,7 @@ class TestDygraphDataLoaderWithException(unittest.TestCase):
for image, _ in loader(): for image, _ in loader():
fluid.layers.relu(image) fluid.layers.relu(image)
except core.EnforceNotMet as ex: except core.EnforceNotMet as ex:
self.assertIn("Blocking queue is killed", self.assertIn("Blocking queue is killed", str(ex))
cpt.get_exception_message(ex))
exception = ex exception = ex
self.assertIsNotNone(exception) self.assertIsNotNone(exception)
......
...@@ -57,7 +57,7 @@ class TestRegisterExitFunc(unittest.TestCase): ...@@ -57,7 +57,7 @@ class TestRegisterExitFunc(unittest.TestCase):
try: try:
CleanupFuncRegistrar.register(5) CleanupFuncRegistrar.register(5)
except TypeError as ex: except TypeError as ex:
self.assertIn("is not callable", cpt.get_exception_message(ex)) self.assertIn("is not callable", str(ex))
exception = ex exception = ex
self.assertIsNotNone(exception) self.assertIsNotNone(exception)
......
...@@ -109,7 +109,7 @@ class TestEagerGrad(TestCase): ...@@ -109,7 +109,7 @@ class TestEagerGrad(TestCase):
# allow_unused is false in default # allow_unused is false in default
dx = fluid.dygraph.grad(out, [x, z]) dx = fluid.dygraph.grad(out, [x, z])
except ValueError as e: except ValueError as e:
error_msg = cpt.get_exception_message(e) error_msg = str(e)
assert error_msg.find("allow_unused") > 0 assert error_msg.find("allow_unused") > 0
def test_simple_example_eager_grad_not_allow_unused(self): def test_simple_example_eager_grad_not_allow_unused(self):
...@@ -133,7 +133,7 @@ class TestEagerGrad(TestCase): ...@@ -133,7 +133,7 @@ class TestEagerGrad(TestCase):
# duplicate input will arise RuntimeError errors # duplicate input will arise RuntimeError errors
dx = fluid.dygraph.grad(out, [x, x]) dx = fluid.dygraph.grad(out, [x, x])
except RuntimeError as e: except RuntimeError as e:
error_msg = cpt.get_exception_message(e) error_msg = str(e)
assert error_msg.find("duplicate") > 0 assert error_msg.find("duplicate") > 0
def test_simple_example_eager_grad_duplicate_input(self): def test_simple_example_eager_grad_duplicate_input(self):
...@@ -157,7 +157,7 @@ class TestEagerGrad(TestCase): ...@@ -157,7 +157,7 @@ class TestEagerGrad(TestCase):
# duplicate output will arise RuntimeError errors # duplicate output will arise RuntimeError errors
dx = fluid.dygraph.grad([out, out], [x]) dx = fluid.dygraph.grad([out, out], [x])
except RuntimeError as e: except RuntimeError as e:
error_msg = cpt.get_exception_message(e) error_msg = str(e)
assert error_msg.find("duplicate") > 0 assert error_msg.find("duplicate") > 0
def test_simple_example_eager_grad_duplicate_output(self): def test_simple_example_eager_grad_duplicate_output(self):
......
...@@ -55,7 +55,7 @@ class DygraphDataLoaderSingalHandler(unittest.TestCase): ...@@ -55,7 +55,7 @@ class DygraphDataLoaderSingalHandler(unittest.TestCase):
set_child_signal_handler(id(self), test_process.pid) set_child_signal_handler(id(self), test_process.pid)
time.sleep(5) time.sleep(5)
except SystemError as ex: except SystemError as ex:
self.assertIn("Fatal", cpt.get_exception_message(ex)) self.assertIn("Fatal", str(ex))
exception = ex exception = ex
return exception return exception
...@@ -88,8 +88,7 @@ class DygraphDataLoaderSingalHandler(unittest.TestCase): ...@@ -88,8 +88,7 @@ class DygraphDataLoaderSingalHandler(unittest.TestCase):
set_child_signal_handler(id(self), test_process.pid) set_child_signal_handler(id(self), test_process.pid)
time.sleep(5) time.sleep(5)
except SystemError as ex: except SystemError as ex:
self.assertIn("Segmentation fault", self.assertIn("Segmentation fault", str(ex))
cpt.get_exception_message(ex))
exception = ex exception = ex
return exception return exception
...@@ -122,7 +121,7 @@ class DygraphDataLoaderSingalHandler(unittest.TestCase): ...@@ -122,7 +121,7 @@ class DygraphDataLoaderSingalHandler(unittest.TestCase):
set_child_signal_handler(id(self), test_process.pid) set_child_signal_handler(id(self), test_process.pid)
time.sleep(5) time.sleep(5)
except SystemError as ex: except SystemError as ex:
self.assertIn("Bus error", cpt.get_exception_message(ex)) self.assertIn("Bus error", str(ex))
exception = ex exception = ex
return exception return exception
......
...@@ -82,7 +82,7 @@ class TestLookupTableOpWithTensorIdsAndPadding(TestLookupTableOpWithTensorIds): ...@@ -82,7 +82,7 @@ class TestLookupTableOpWithTensorIdsAndPadding(TestLookupTableOpWithTensorIds):
flatten_idx = ids.flatten() flatten_idx = ids.flatten()
padding_idx = np.random.choice(flatten_idx, 1)[0] padding_idx = np.random.choice(flatten_idx, 1)[0]
self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31) self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31)
self.attrs = {'padding_idx': cpt.long_type(padding_idx)} self.attrs = {'padding_idx': padding_idx}
self.check_output() self.check_output()
...@@ -250,7 +250,7 @@ class TestLookupTableOpWithTensorIdsAndPaddingInt8( ...@@ -250,7 +250,7 @@ class TestLookupTableOpWithTensorIdsAndPaddingInt8(
flatten_idx = ids.flatten() flatten_idx = ids.flatten()
padding_idx = np.random.choice(flatten_idx, 1)[0] padding_idx = np.random.choice(flatten_idx, 1)[0]
self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31) self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31)
self.attrs = {'padding_idx': cpt.long_type(padding_idx)} self.attrs = {'padding_idx': padding_idx}
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
...@@ -380,7 +380,7 @@ class TestLookupTableOpWithTensorIdsAndPaddingInt16( ...@@ -380,7 +380,7 @@ class TestLookupTableOpWithTensorIdsAndPaddingInt16(
flatten_idx = ids.flatten() flatten_idx = ids.flatten()
padding_idx = np.random.choice(flatten_idx, 1)[0] padding_idx = np.random.choice(flatten_idx, 1)[0]
self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31) self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31)
self.attrs = {'padding_idx': cpt.long_type(padding_idx)} self.attrs = {'padding_idx': padding_idx}
self.check_output() self.check_output()
......
...@@ -123,7 +123,7 @@ class TestLookupTableOpWithTensorIdsAndPadding(TestLookupTableOpWithTensorIds): ...@@ -123,7 +123,7 @@ class TestLookupTableOpWithTensorIdsAndPadding(TestLookupTableOpWithTensorIds):
flatten_idx = ids.flatten() flatten_idx = ids.flatten()
padding_idx = np.random.choice(flatten_idx, 1)[0] padding_idx = np.random.choice(flatten_idx, 1)[0]
self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31) self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31)
self.attrs = {'padding_idx': cpt.long_type(padding_idx)} self.attrs = {'padding_idx': padding_idx}
self.check_output() self.check_output()
......
...@@ -31,15 +31,14 @@ class TestOperator(unittest.TestCase): ...@@ -31,15 +31,14 @@ class TestOperator(unittest.TestCase):
self.assertFail() self.assertFail()
except ValueError as v_err: except ValueError as v_err:
self.assertEqual( self.assertEqual(
cpt.get_exception_message(v_err), str(v_err),
"`type` to initialized an Operator can not be None.") "`type` to initialized an Operator can not be None.")
try: try:
block.append_op(type="no_such_op") block.append_op(type="no_such_op")
self.assertFail() self.assertFail()
except ValueError as a_err: except ValueError as a_err:
self.assertEqual( self.assertEqual(
cpt.get_exception_message(a_err), str(a_err), "Operator \"no_such_op\" has not been registered.")
"Operator \"no_such_op\" has not been registered.")
def test_op_desc_creation(self): def test_op_desc_creation(self):
program = Program() program = Program()
......
...@@ -990,7 +990,7 @@ class TestRecomputeOptimizer(unittest.TestCase): ...@@ -990,7 +990,7 @@ class TestRecomputeOptimizer(unittest.TestCase):
except NotImplementedError as e: except NotImplementedError as e:
self.assertEqual( self.assertEqual(
"load function is not supported by Recompute Optimizer for now", "load function is not supported by Recompute Optimizer for now",
cpt.get_exception_message(e)) str(e))
def test_dropout(self): def test_dropout(self):
""" """
......
...@@ -100,7 +100,7 @@ class TestPrune(unittest.TestCase): ...@@ -100,7 +100,7 @@ class TestPrune(unittest.TestCase):
except ValueError as e: except ValueError as e:
self.assertIn( self.assertIn(
"All targets of Program._prune_with_input() can only be Variable or Operator", "All targets of Program._prune_with_input() can only be Variable or Operator",
cpt.get_exception_message(e)) str(e))
def mock(self, program, feed, fetch, optimize_ops): def mock(self, program, feed, fetch, optimize_ops):
......
...@@ -103,7 +103,7 @@ class ApiZerosError(unittest.TestCase): ...@@ -103,7 +103,7 @@ class ApiZerosError(unittest.TestCase):
shape = [-1, 5] shape = [-1, 5]
out = paddle.zeros(shape) out = paddle.zeros(shape)
except Exception as e: except Exception as e:
error_msg = cpt.get_exception_message(e) error_msg = str(e)
assert error_msg.find("expected to be no less than 0") > 0 assert error_msg.find("expected to be no less than 0") > 0
def test_eager(self): def test_eager(self):
......
...@@ -95,7 +95,7 @@ class TestLookupTableOpWithTensorIdsAndPadding(TestLookupTableOpWithTensorIds): ...@@ -95,7 +95,7 @@ class TestLookupTableOpWithTensorIdsAndPadding(TestLookupTableOpWithTensorIds):
flatten_idx = ids.flatten() flatten_idx = ids.flatten()
padding_idx = np.random.choice(flatten_idx, 1)[0] padding_idx = np.random.choice(flatten_idx, 1)[0]
self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31) self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31)
self.attrs = {'padding_idx': cpt.long_type(padding_idx)} self.attrs = {'padding_idx': padding_idx}
self.check_output_with_place(place=paddle.XPUPlace(0)) self.check_output_with_place(place=paddle.XPUPlace(0))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册