From bb2f5d24a25b574e7becb017da163f54b078b840 Mon Sep 17 00:00:00 2001 From: hutuxian Date: Thu, 18 Jul 2019 15:46:29 +0800 Subject: [PATCH] hash_op support int64 hash_size (#18674) * hash_op support int64 hash_size * add corresponding UT --- paddle/fluid/operators/hash_op.cc | 2 +- paddle/fluid/operators/hash_op.h | 4 +- .../fluid/tests/unittests/test_hash_op.py | 44 +++++++++++++++++++ 3 files changed, 47 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/hash_op.cc b/paddle/fluid/operators/hash_op.cc index 6679c109b15..5ef91dcb666 100644 --- a/paddle/fluid/operators/hash_op.cc +++ b/paddle/fluid/operators/hash_op.cc @@ -52,7 +52,7 @@ class HashOpMaker : public framework::OpProtoAndCheckerMaker { Execute `num_hash` times xxHash algorithm on all elements on second dimension of input. )DOC"); AddAttr("num_hash", "").SetDefault(1); - AddAttr("mod_by", "").SetDefault(100000); + AddAttr("mod_by", "").SetDefault(100000); AddAttr(framework::kAllKernelsMustComputeRuntimeShape, "Skip calling InferShape() function in the runtime.") .SetDefault(true); diff --git a/paddle/fluid/operators/hash_op.h b/paddle/fluid/operators/hash_op.h index 14a4660aac7..c2d53000491 100644 --- a/paddle/fluid/operators/hash_op.h +++ b/paddle/fluid/operators/hash_op.h @@ -43,7 +43,7 @@ class HashKernel : public framework::OpKernel { virtual void Compute(const framework::ExecutionContext& context) const { auto* out_t = context.Output("Out"); auto* in_t = context.Input("X"); - int mod_by = context.Attr("mod_by"); + int64_t mod_by = context.Attr("mod_by"); int num_hash = context.Attr("num_hash"); auto in_dims = in_t->dims(); @@ -59,7 +59,7 @@ class HashKernel : public framework::OpKernel { for (int idx = 0; idx < seq_length; ++idx) { for (int ihash = 0; ihash != num_hash; ++ihash) { output[idx * num_hash + ihash] = - XXH64(input, sizeof(int) * last_dim, ihash) % mod_by; + XXH64(input, sizeof(T) * last_dim, ihash) % mod_by; } input += last_dim; } diff --git a/python/paddle/fluid/tests/unittests/test_hash_op.py b/python/paddle/fluid/tests/unittests/test_hash_op.py index 7b4e9bf738b..75af02bd5f4 100644 --- a/python/paddle/fluid/tests/unittests/test_hash_op.py +++ b/python/paddle/fluid/tests/unittests/test_hash_op.py @@ -58,5 +58,49 @@ class TestHashNotLoDOp(TestHashOp): self.check_output() +class TestHashOp2(TestHashOp): + """ + Case: + int64 type input + """ + + def setUp(self): + self.op_type = "hash" + self.init_test_case() + self.inputs = {'X': self.in_seq} + self.attrs = {'num_hash': 2, 'mod_by': 10000} + self.outputs = {'Out': self.out_seq} + + def init_test_case(self): + self.in_seq = np.array([1, 2**32 + 1]).reshape((2, 1)).astype("int64") + self.out_seq = np.array([1269, 9609, 3868, 7268]).reshape((2, 2, 1)) + + def test_check_output(self): + self.check_output() + + +class TestHashOp3(TestHashOp): + """ + Case: + int64 type input + int64 type mod_by attr + """ + + def setUp(self): + self.op_type = "hash" + self.init_test_case() + self.inputs = {'X': self.in_seq} + self.attrs = {'num_hash': 2, 'mod_by': 2**32} + self.outputs = {'Out': self.out_seq} + + def init_test_case(self): + self.in_seq = np.array([10, 5]).reshape((2, 1)).astype("int64") + self.out_seq = np.array( + [1204014882, 393011615, 3586283837, 2814821595]).reshape((2, 2, 1)) + + def test_check_output(self): + self.check_output() + + if __name__ == "__main__": unittest.main() -- GitLab