From 323d4233f3cb0f72ddac36977941e84880a7eedc Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Tue, 15 Aug 2017 23:50:56 +0000 Subject: [PATCH] gather op added with python unittest --- paddle/operators/gather_op.cu | 20 ++++++++++++++++ .../v2/framework/tests/test_gather_op.py | 23 +++++++++++++++++++ 2 files changed, 43 insertions(+) create mode 100644 paddle/operators/gather_op.cu create mode 100644 python/paddle/v2/framework/tests/test_gather_op.py diff --git a/paddle/operators/gather_op.cu b/paddle/operators/gather_op.cu new file mode 100644 index 0000000000..3f04a7b3f8 --- /dev/null +++ b/paddle/operators/gather_op.cu @@ -0,0 +1,20 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/gather_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(gather, + ops::GatherOpKernel); diff --git a/python/paddle/v2/framework/tests/test_gather_op.py b/python/paddle/v2/framework/tests/test_gather_op.py new file mode 100644 index 0000000000..2ffbf17236 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_gather_op.py @@ -0,0 +1,23 @@ +import unittest + +import numpy +import paddle.v2.framework.core as core +from paddle.v2.framework.op import Operator + +from op_test_util import OpTestMeta + + +class TestGatherOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "gather" + self.inputs = { + 'X': numpy.random.random((10, 20)).astype("float32"), + 'Index': numpy.array([1, 3, 5]).astype("int") + } + self.outputs = {'Y': self.input['X'][self.input['Index']]} + + +if __name__ == "__main__": + unittest.main() -- GitLab