未验证 提交 064d0ce3 编写于 作者: D danleifeng 提交者: GitHub

fix check_type bug and example code in hash api; test=develop (#25253)

上级 9825a9f3
...@@ -12796,16 +12796,14 @@ def hash(input, hash_size, num_hash=1, name=None): ...@@ -12796,16 +12796,14 @@ def hash(input, hash_size, num_hash=1, name=None):
place = fluid.core.CPUPlace() place = fluid.core.CPUPlace()
x = fluid.data(name="x", shape=[1], dtype="int32", lod_level=1) x = fluid.data(name="x", shape=[2,2], dtype="int32", lod_level=1)
res = fluid.layers.hash(name="res",input=x, hash_size=1000, num_hash=4) res = fluid.layers.hash(name="res", input=x, hash_size=1000, num_hash=4)
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
in1 = np.array([[1,2],[3,4]]).astype("int32") in1 = np.array([[1,2],[3,4]]).astype("int32")
print(in1) print(in1)
x_i = fluid.core.LoDTensor() x_i = fluid.create_lod_tensor(in1, [[0, 2]], place)
x_i.set(in1,place)
x_i.set_recursive_sequence_lengths([[0,2]])
res = exe.run(fluid.default_main_program(), feed={'x':x_i}, fetch_list=[res], return_numpy=False) res = exe.run(fluid.default_main_program(), feed={'x':x_i}, fetch_list=[res], return_numpy=False)
print(np.array(res[0])) print(np.array(res[0]))
# [[[722] # [[[722]
...@@ -12818,8 +12816,8 @@ def hash(input, hash_size, num_hash=1, name=None): ...@@ -12818,8 +12816,8 @@ def hash(input, hash_size, num_hash=1, name=None):
# [901]]] # [901]]]
""" """
check_variable_and_dtype(input, 'input', ['int32', 'int64'], 'hash') check_variable_and_dtype(input, 'input', ['int32', 'int64'], 'hash')
check_type(hash_size, 'hash_size', ['int32', 'int64'], 'hash') check_type(hash_size, 'hash_size', int, 'hash')
check_type(num_hash, 'num_hash', ['int32', 'int64'], 'hash') check_type(num_hash, 'num_hash', int, 'hash')
helper = LayerHelper('hash', **locals()) helper = LayerHelper('hash', **locals())
out = helper.create_variable_for_type_inference( out = helper.create_variable_for_type_inference(
helper.input_dtype(), stop_gradient=True) helper.input_dtype(), stop_gradient=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册