未验证 提交 efe06caa 编写于 作者: Q Qiao Longfei 提交者: GitHub

change data type of beam_search op (#7374)

上级 91f80f79
...@@ -39,7 +39,7 @@ void BeamSearch::operator()(const framework::LoDTensor &pre_ids, ...@@ -39,7 +39,7 @@ void BeamSearch::operator()(const framework::LoDTensor &pre_ids,
std::map<size_t /*offset*/, std::vector<Item>> hash; std::map<size_t /*offset*/, std::vector<Item>> hash;
framework::LoD new_lod; framework::LoD new_lod;
auto *ids_data = selected_ids->mutable_data<int>(platform::CPUPlace()); auto *ids_data = selected_ids->mutable_data<int64_t>(platform::CPUPlace());
auto *scores_data = auto *scores_data =
selected_scores->mutable_data<float>(platform::CPUPlace()); selected_scores->mutable_data<float>(platform::CPUPlace());
...@@ -66,7 +66,7 @@ void BeamSearch::operator()(const framework::LoDTensor &pre_ids, ...@@ -66,7 +66,7 @@ void BeamSearch::operator()(const framework::LoDTensor &pre_ids,
void BeamSearch::PruneEndidCandidates(const framework::LoDTensor &pre_ids, void BeamSearch::PruneEndidCandidates(const framework::LoDTensor &pre_ids,
std::vector<std::vector<Item>> *items) { std::vector<std::vector<Item>> *items) {
auto *pre_ids_data = pre_ids.data<int>(); auto *pre_ids_data = pre_ids.data<int64_t>();
for (size_t offset = 0; offset < items->size(); offset++) { for (size_t offset = 0; offset < items->size(); offset++) {
auto prefix_id = pre_ids_data[offset]; auto prefix_id = pre_ids_data[offset];
...@@ -127,7 +127,7 @@ bool BeamSearch::NextItemSet(std::vector<BeamSearch::Item> *items) { ...@@ -127,7 +127,7 @@ bool BeamSearch::NextItemSet(std::vector<BeamSearch::Item> *items) {
auto abs_lod = framework::ToAbsOffset(ids.lod()); auto abs_lod = framework::ToAbsOffset(ids.lod());
PADDLE_ENFORCE_GE(source_abs_two_level_lod.size(), 2UL); PADDLE_ENFORCE_GE(source_abs_two_level_lod.size(), 2UL);
auto *ids_data = ids.data<int>(); auto *ids_data = ids.data<int64_t>();
auto *scores_data = scores.data<float>(); auto *scores_data = scores.data<float>();
size_t instance_dim = 1; size_t instance_dim = 1;
......
...@@ -37,13 +37,13 @@ class BeamSearchOpTester(unittest.TestCase): ...@@ -37,13 +37,13 @@ class BeamSearchOpTester(unittest.TestCase):
print 'lod', selected_ids.lod() print 'lod', selected_ids.lod()
def _create_pre_ids(self): def _create_pre_ids(self):
np_data = np.array([[1, 2, 3, 4]], dtype='int32') np_data = np.array([[1, 2, 3, 4]], dtype='int64')
tensor = create_tensor(self.scope, "pre_ids", np_data) tensor = create_tensor(self.scope, "pre_ids", np_data)
def _create_ids(self): def _create_ids(self):
self.lod = [[0, 1, 4], [0, 1, 2, 3, 4]] self.lod = [[0, 1, 4], [0, 1, 2, 3, 4]]
np_data = np.array( np_data = np.array(
[[4, 2, 5], [2, 1, 3], [3, 5, 2], [8, 2, 1]], dtype='int32') [[4, 2, 5], [2, 1, 3], [3, 5, 2], [8, 2, 1]], dtype='int64')
tensor = create_tensor(self.scope, "ids", np_data) tensor = create_tensor(self.scope, "ids", np_data)
tensor.set_lod(self.lod) tensor.set_lod(self.lod)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册