提交 40141f74 编写于 作者: M minqiyang

Implement the unittest for hash op

test=develop
上级 accb7b5d
...@@ -16,7 +16,7 @@ ExternalProject_Add( ...@@ -16,7 +16,7 @@ ExternalProject_Add(
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
BUILD_IN_SOURCE 1 BUILD_IN_SOURCE 1
PATCH_COMMAND PATCH_COMMAND
BUILD_COMMAND make lib BUILD_COMMAND sed -i "s/-Wstrict-prototypes -Wundef/-Wstrict-prototypes -Wundef -fPIC/g" ${XXHASH_SOURCE_DIR}/src/extern_xxhash/Makefile && make lib
INSTALL_COMMAND export PREFIX=${XXHASH_INSTALL_DIR}/ && make install INSTALL_COMMAND export PREFIX=${XXHASH_INSTALL_DIR}/ && make install
TEST_COMMAND "" TEST_COMMAND ""
) )
......
...@@ -46,7 +46,7 @@ class HashOp : public framework::OperatorWithKernel { ...@@ -46,7 +46,7 @@ class HashOp : public framework::OperatorWithKernel {
// keep the last dim to 1 // keep the last dim to 1
out_dims.emplace_back(1); out_dims.emplace_back(1);
ctx->SetOutputDim("Out", dims); ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
......
...@@ -22,13 +22,32 @@ class TestScaleOp(OpTest): ...@@ -22,13 +22,32 @@ class TestScaleOp(OpTest):
self.op_type = "hash" self.op_type = "hash"
self.init_test_case() self.init_test_case()
self.inputs = {'X': (self.in_seq, self.lod)} self.inputs = {'X': (self.in_seq, self.lod)}
self.attrs = {'num_hash': 8, 'mod_by': 10000} self.attrs = {'num_hash': 4, 'mod_by': 10000}
self.outputs = {'Out': (self.out_seq, self.lod)} self.outputs = {'Out': (self.out_seq, self.lod)}
def init_test_case(self): def init_test_case(self):
np.random.seed = 1
self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
self.lod = [[9, 4, 11, 6]] self.lod = [[9, 4, 11, 6]]
self.out_seq = np.ones([30, 8], dtype=np.int32) # self.out_seq = np.ones([30, 4, 1], dtype=np.int32)
self.out_seq = [
[[9662], [9217], [1129], [8487]], [[9662], [9217], [1129], [8487]],
[[8310], [1327], [1654], [4567]], [[6897], [3218], [2013], [1241]],
[[9407], [6715], [6949], [8094]], [[8473], [694], [5142], [2479]],
[[8310], [1327], [1654], [4567]], [[6897], [3218], [2013], [1241]],
[[4372], [9456], [8204], [6695]], [[6897], [3218], [2013], [1241]],
[[8473], [694], [5142], [2479]], [[4372], [9456], [8204], [6695]],
[[4372], [9456], [8204], [6695]], [[8473], [694], [5142], [2479]],
[[9407], [6715], [6949], [8094]], [[9369], [4525], [8935], [9210]],
[[4372], [9456], [8204], [6695]], [[4372], [9456], [8204], [6695]],
[[9369], [4525], [8935], [9210]], [[6897], [3218], [2013], [1241]],
[[9038], [7951], [5953], [8657]], [[9407], [6715], [6949], [8094]],
[[9662], [9217], [1129], [8487]], [[9369], [4525], [8935], [9210]],
[[9038], [7951], [5953], [8657]], [[9662], [9217], [1129], [8487]],
[[9369], [4525], [8935], [9210]], [[1719], [5986], [9919], [3421]],
[[4372], [9456], [8204], [6695]], [[9038], [7951], [5953], [8657]]
]
self.out_seq = np.array(self.out_seq)
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册