diff --git a/paddle/fluid/operators/sampling_id_op.h b/paddle/fluid/operators/sampling_id_op.h index 4d962b4809f5440288feecb46135c2862c3b2523..3d724e3ae726fb5b91843d603d434103ae6201d8 100644 --- a/paddle/fluid/operators/sampling_id_op.h +++ b/paddle/fluid/operators/sampling_id_op.h @@ -13,7 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include +#include +#include #include +#include #include #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" @@ -34,17 +38,17 @@ class SamplingIdKernel : public framework::OpKernel { std::vector ins_vector; framework::TensorToVector(*input, context.device_context(), &ins_vector); - std::vector ids(batch_size); + std::vector ids(batch_size); for (size_t i = 0; i < batch_size; ++i) { double r = this->get_rand(); - int id = width - 1; + int idx = width - 1; for (int j = 0; j < width; ++j) { if ((r -= ins_vector[i * width + j]) < 0) { - id = j; + idx = j; break; } } - ids[i] = id; + ids[i] = ins_vector[i * width + idx]; } std::vector out_dim; diff --git a/python/paddle/fluid/tests/unittests/test_sampling_id_op.py b/python/paddle/fluid/tests/unittests/test_sampling_id_op.py index 86d86acfb521dc49ef7fb54c7bcc41a7e2ef0dc2..e3e71530498172b97b481ba66c7df9078a36ba13 100644 --- a/python/paddle/fluid/tests/unittests/test_sampling_id_op.py +++ b/python/paddle/fluid/tests/unittests/test_sampling_id_op.py @@ -25,17 +25,18 @@ class TestSamplingIdOp(OpTest): self.op_type = "sampling_id" self.use_mkldnn = False self.init_kernel_type() - X = np.random.random((3, 4)).astype('float32') - self.inputs = {"X": X} - Y = np.random.random(3).astype('float32') - self.outputs = {'Out': Y} + self.X = np.random.random((8, 4)).astype('float32') + self.inputs = {"X": self.X} + self.Y = np.random.random(8).astype('float32') + self.outputs = {'Out': self.Y} self.attrs = {'use_mkldnn': self.use_mkldnn} def test_check_output(self): - self.check_output() + self.check_output_customized(self.verify_output) - def test_check_grad(self): - self.check_grad(['X'], 'Out') + def verify_output(self, outs): + out = np.array(outs[0]) + self.assertEqual(len(out), len(self.Y)) def init_kernel_type(self): pass