diff --git a/paddle/fluid/operators/sampling_id_op.cc b/paddle/fluid/operators/sampling_id_op.cc index 9729537d1e4d7b4e4614dcd6778f0f0a72310f86..17f6461fcb816b23d100e8d5a61aac484c1a6ba3 100644 --- a/paddle/fluid/operators/sampling_id_op.cc +++ b/paddle/fluid/operators/sampling_id_op.cc @@ -25,9 +25,9 @@ class SamplingIdOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of RowConvOp should not be null."); + "Input(X) of SamplingIdOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of RowConvOp should not be null."); + "Output(Out) of SamplingIdOp should not be null."); auto input_dims = ctx->GetInputDim("X"); @@ -43,8 +43,7 @@ class SamplingIdOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "The input tensor of softmax. " "2-D with shape [batch_size, input_feature_dimensions]."); - AddOutput("Out", "Sliced data tensor."); - + AddOutput("Out", "SamplingId data tensor."); AddComment(R"DOC( SamplingId Operator. @brief A layer for sampling id from multinomial distribution from the diff --git a/paddle/fluid/operators/sampling_id_op.cu b/paddle/fluid/operators/sampling_id_op.cu index e467165b6d24cead7926e61a33e4dfb746be01dd..c0bb9c916cc9e4cae193ce20e72431c173b16288 100644 --- a/paddle/fluid/operators/sampling_id_op.cu +++ b/paddle/fluid/operators/sampling_id_op.cu @@ -16,19 +16,6 @@ limitations under the License. */ #include #include "paddle/fluid/operators/sampling_id_op.h" -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -class SamplingIdOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override {} -} -} // namespace operators -} // namespace paddle - namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( sampling_id, diff --git a/paddle/fluid/operators/sampling_id_op.h b/paddle/fluid/operators/sampling_id_op.h index 5bb1991fc57250bb31a34d8a978204e9ac1298c0..4d962b4809f5440288feecb46135c2862c3b2523 100644 --- a/paddle/fluid/operators/sampling_id_op.h +++ b/paddle/fluid/operators/sampling_id_op.h @@ -50,7 +50,7 @@ class SamplingIdKernel : public framework::OpKernel { std::vector out_dim; out_dim.push_back(static_cast(batch_size)); - Tensor* output = context.Output("Output"); + Tensor* output = context.Output("Out"); output->Resize(framework::make_ddim(out_dim)); output->mutable_data(context.GetPlace()); framework::TensorFromVector(ids, context.device_context(), output); diff --git a/python/paddle/fluid/tests/unittests/test_sampling_id_op.py b/python/paddle/fluid/tests/unittests/test_sampling_id_op.py new file mode 100644 index 0000000000000000000000000000000000000000..86d86acfb521dc49ef7fb54c7bcc41a7e2ef0dc2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sampling_id_op.py @@ -0,0 +1,45 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + +import paddle.fluid.core as core +from paddle.fluid.op import Operator + + +class TestSamplingIdOp(OpTest): + def setUp(self): + self.op_type = "sampling_id" + self.use_mkldnn = False + self.init_kernel_type() + X = np.random.random((3, 4)).astype('float32') + self.inputs = {"X": X} + Y = np.random.random(3).astype('float32') + self.outputs = {'Out': Y} + self.attrs = {'use_mkldnn': self.use_mkldnn} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + def init_kernel_type(self): + pass + + +if __name__ == "__main__": + unittest.main()