diff --git a/paddle/fluid/operators/hash_op.cc b/paddle/fluid/operators/hash_op.cc index 6679c109b15a7db7ba9c54e3af21875dd06fc68b..5ef91dcb66638d5786e9769802bfc3790ffc6079 100644 --- a/paddle/fluid/operators/hash_op.cc +++ b/paddle/fluid/operators/hash_op.cc @@ -52,7 +52,7 @@ class HashOpMaker : public framework::OpProtoAndCheckerMaker { Execute `num_hash` times xxHash algorithm on all elements on second dimension of input. )DOC"); AddAttr("num_hash", "").SetDefault(1); - AddAttr("mod_by", "").SetDefault(100000); + AddAttr("mod_by", "").SetDefault(100000); AddAttr(framework::kAllKernelsMustComputeRuntimeShape, "Skip calling InferShape() function in the runtime.") .SetDefault(true); diff --git a/paddle/fluid/operators/hash_op.h b/paddle/fluid/operators/hash_op.h index 14a4660aac738563232d8bdede4007f8825c5b9a..c2d530004912287b0720ab5d00da90c4e1b5cbc7 100644 --- a/paddle/fluid/operators/hash_op.h +++ b/paddle/fluid/operators/hash_op.h @@ -43,7 +43,7 @@ class HashKernel : public framework::OpKernel { virtual void Compute(const framework::ExecutionContext& context) const { auto* out_t = context.Output("Out"); auto* in_t = context.Input("X"); - int mod_by = context.Attr("mod_by"); + int64_t mod_by = context.Attr("mod_by"); int num_hash = context.Attr("num_hash"); auto in_dims = in_t->dims(); @@ -59,7 +59,7 @@ class HashKernel : public framework::OpKernel { for (int idx = 0; idx < seq_length; ++idx) { for (int ihash = 0; ihash != num_hash; ++ihash) { output[idx * num_hash + ihash] = - XXH64(input, sizeof(int) * last_dim, ihash) % mod_by; + XXH64(input, sizeof(T) * last_dim, ihash) % mod_by; } input += last_dim; } diff --git a/python/paddle/fluid/tests/unittests/test_hash_op.py b/python/paddle/fluid/tests/unittests/test_hash_op.py index 7b4e9bf738b212059faf2e3046510497f820a723..75af02bd5f46ea61f0bf4bc2494cb941fb1f64b4 100644 --- a/python/paddle/fluid/tests/unittests/test_hash_op.py +++ b/python/paddle/fluid/tests/unittests/test_hash_op.py @@ -58,5 +58,49 @@ class TestHashNotLoDOp(TestHashOp): self.check_output() +class TestHashOp2(TestHashOp): + """ + Case: + int64 type input + """ + + def setUp(self): + self.op_type = "hash" + self.init_test_case() + self.inputs = {'X': self.in_seq} + self.attrs = {'num_hash': 2, 'mod_by': 10000} + self.outputs = {'Out': self.out_seq} + + def init_test_case(self): + self.in_seq = np.array([1, 2**32 + 1]).reshape((2, 1)).astype("int64") + self.out_seq = np.array([1269, 9609, 3868, 7268]).reshape((2, 2, 1)) + + def test_check_output(self): + self.check_output() + + +class TestHashOp3(TestHashOp): + """ + Case: + int64 type input + int64 type mod_by attr + """ + + def setUp(self): + self.op_type = "hash" + self.init_test_case() + self.inputs = {'X': self.in_seq} + self.attrs = {'num_hash': 2, 'mod_by': 2**32} + self.outputs = {'Out': self.out_seq} + + def init_test_case(self): + self.in_seq = np.array([10, 5]).reshape((2, 1)).astype("int64") + self.out_seq = np.array( + [1204014882, 393011615, 3586283837, 2814821595]).reshape((2, 2, 1)) + + def test_check_output(self): + self.check_output() + + if __name__ == "__main__": unittest.main()