From 73d785572697f0cc0ebb03791048001dd52174d1 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 30 Oct 2017 10:11:30 -0700 Subject: [PATCH] Fix a type error top_k_op (#5201) * Fix Type error * Fix error --- paddle/operators/top_k_op.h | 4 ++-- python/paddle/v2/framework/tests/test_top_k_op.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/operators/top_k_op.h b/paddle/operators/top_k_op.h index 4b248faa120..bc8563717a2 100644 --- a/paddle/operators/top_k_op.h +++ b/paddle/operators/top_k_op.h @@ -40,7 +40,7 @@ class TopkKernel : public framework::OpKernel { const size_t k = static_cast(ctx.Attr("k")); T* output_data = output->mutable_data(ctx.GetPlace()); - T* indices_data = indices->mutable_data(ctx.GetPlace()); + int64_t* indices_data = indices->mutable_data(ctx.GetPlace()); auto eg_input = EigenMatrix::From(*input); @@ -66,7 +66,7 @@ class TopkKernel : public framework::OpKernel { }); for (size_t j = 0; j < k; j++) { output_data[i * k + j] = vec[j].first; - indices_data[i * k + j] = vec[j].second; + indices_data[i * k + j] = int64_t(vec[j].second); } } } diff --git a/python/paddle/v2/framework/tests/test_top_k_op.py b/python/paddle/v2/framework/tests/test_top_k_op.py index 694f37d612d..6e8fbefa6ea 100644 --- a/python/paddle/v2/framework/tests/test_top_k_op.py +++ b/python/paddle/v2/framework/tests/test_top_k_op.py @@ -9,7 +9,7 @@ class TestTopkOp(OpTest): k = 1 input = np.random.random((32, 84)).astype("float32") output = np.ndarray((32, k)) - indices = np.ndarray((32, k)) + indices = np.ndarray((32, k)).astype("int64") self.inputs = {'X': input} self.attrs = {'k': k} @@ -32,7 +32,7 @@ class TestTopkOp3d(OpTest): input = np.random.random((32, 2, 84)).astype("float32") input_flat_2d = input.reshape(64, 84) output = np.ndarray((64, k)) - indices = np.ndarray((64, k)).astype("int") + indices = np.ndarray((64, k)).astype("int64") # FIXME: should use 'X': input for a 3d input self.inputs = {'X': input_flat_2d} -- GitLab