diff --git a/paddle/fluid/operators/index_sample_op.cc b/paddle/fluid/operators/index_sample_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..139954b0e5e8b71a26a7d87df9ed30ba4fa39ada --- /dev/null +++ b/paddle/fluid/operators/index_sample_op.cc @@ -0,0 +1,154 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/index_sample_op.h" +#include +#include "paddle/fluid/framework/no_need_buffer_vars_inference.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace operators { +class IndexSampleOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "Input(Tensor), dtype support int32/int64/float/double"); + AddInput("Index", "Index(Tensor), dtype support int32/int64"); + AddOutput("Out", "Return the element of input at index"); + + AddComment(R"DOC( + IndexSample OP returns the element of the specified location of X, + and the location is specified by Index. + + X tensor and Index tensor's shape must be 2-D, + dimension at 0 which usually is batch size must be equal. + + The returned tensor has the same shape and dimensions as the Index tensor. + )DOC"); + } +}; + +class IndexSampleOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + platform::errors::InvalidArgument( + "Inputs(Input) of FindByIndex should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true, + platform::errors::InvalidArgument( + "Inputs(Index) of FindByIndex should not be null.")); + + auto input_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_EQ( + input_dims.size(), 2, + platform::errors::InvalidArgument( + "Inputs(X) shape of IndexSample op should be 2-D, but " + "got X's shape = [%s], please check X shape.", + input_dims)); + + auto index_dims = ctx->GetInputDim("Index"); + PADDLE_ENFORCE_EQ( + input_dims.size(), 2, + platform::errors::InvalidArgument( + "Inputs(Index) shape of IndexSample op should be 2-D, but " + "got Index's shape [%s] , please check index shape.", + input_dims)); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(input_dims[0], index_dims[0], + platform::errors::InvalidArgument( + "Inputs(X)'s value of dimension 0 must same with " + "Inputs(Index)'s value of dimension 0, but " + "got %d of Inputs(X), and got %d of Inputs(Index), " + "please check Inputs shape.", + input_dims[0], index_dims[0])); + } + ctx->SetOutputDim("Out", index_dims); + auto type = ctx->GetInputsVarType("Index")[0]; + if (type == framework::proto::VarType::LOD_TENSOR) { + ctx->ShareLoD("Index", /*->*/ "Out"); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +class IndexSampleGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("Index"), true, + platform::errors::InvalidArgument("Input(Index) should be not null.")); + PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true, + platform::errors::InvalidArgument( + "Input(Out@GRAD) should be not null.")); + PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true, + platform::errors::InvalidArgument( + "Output(X@GRAD) should be not null.")); + + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +template +class IndexSampleGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("index_sample_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput("Index", this->Input("Index")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERER(IndexSampleGradNoNeedBufferVarInferer, "X"); +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(index_sample, ops::IndexSampleOp, ops::IndexSampleOpMaker, + ops::IndexSampleGradMaker, + ops::IndexSampleGradMaker); +REGISTER_OPERATOR(index_sample_grad, ops::IndexSampleGradOp, + ops::IndexSampleGradNoNeedBufferVarInferer); +REGISTER_OP_CPU_KERNEL( + index_sample, ops::IndexSampleKernel, + ops::IndexSampleKernel, + ops::IndexSampleKernel, + ops::IndexSampleKernel); +REGISTER_OP_CPU_KERNEL( + index_sample_grad, + ops::IndexSampleGradKernel, + ops::IndexSampleGradKernel, + ops::IndexSampleGradKernel, + ops::IndexSampleGradKernel); diff --git a/paddle/fluid/operators/index_sample_op.h b/paddle/fluid/operators/index_sample_op.h new file mode 100644 index 0000000000000000000000000000000000000000..9618ad5cb67be05ebc709167eb82f29100d737a6 --- /dev/null +++ b/paddle/fluid/operators/index_sample_op.h @@ -0,0 +1,186 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using DDim = framework::DDim; + +template +void IndexSampleInner(const framework::ExecutionContext &context, + const LoDTensor &input, const LoDTensor &index, + LoDTensor *output) { + auto input_dims = input.dims(); + auto index_dims = index.dims(); + + int batch_size = input_dims[0]; + auto value_length = input_dims[1]; + auto index_length = index_dims[1]; + int index_ids_num = index.numel(); + auto *input_data = input.data(); + auto *index_data = index.data(); + + std::vector res{}; + for (int i = 0; i < index_ids_num; i++) { + int b = floor(i / index_length); + PADDLE_ENFORCE_GE( + index_data[i], 0, + platform::errors::InvalidArgument( + "Variable value (index) of OP(index_sample) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + value_length, index_data[i])); + PADDLE_ENFORCE_LT( + index_data[i], value_length, + platform::errors::InvalidArgument( + "Variable value (index) of OP(index_sample) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + value_length, index_data[i])); + + int v_i = b * value_length + static_cast(index_data[i]); + T v = input_data[v_i]; + VLOG(4) << "Index Sample: batch = " << b << " index = " << v_i + << " value = " << v; + res.push_back(v); + } + + auto ddim = framework::make_ddim({batch_size, index_length}); + output->Resize(ddim); + T *out_data = output->mutable_data(context.GetPlace()); + + memcpy(out_data, &res[0], sizeof(T) * index_ids_num); +} + +template +class IndexSampleKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *input_var = ctx.InputVar("X"); + auto *index_var = ctx.InputVar("Index"); + + auto &input_tensor = input_var->Get(); + auto &index_tensor = index_var->Get(); + + auto *out_var = ctx.OutputVar("Out"); + auto *out_tensor = out_var->GetMutable(); + + const auto &index_type = index_tensor.type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + PADDLE_ENFORCE_EQ(index_type_match, true, + platform::errors::InvalidArgument( + "Input(Index) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64))); + if (index_type == framework::proto::VarType::INT32) { + IndexSampleInner(ctx, input_tensor, index_tensor, out_tensor); + } else if (index_type == framework::proto::VarType::INT64) { + IndexSampleInner(ctx, input_tensor, index_tensor, out_tensor); + } + } +}; + +template +void IndexSampleGradInner(const framework::ExecutionContext &context, + const LoDTensor &out_grad, const LoDTensor &index, + LoDTensor *x_grad) { + auto index_dims = index.dims(); + auto x_grad_dims = x_grad->dims(); + + int batch_size = x_grad_dims[0]; + auto value_length = x_grad_dims[1]; + auto index_length = index_dims[1]; + int index_ids_num = index.numel(); + + T *x_grad_data = x_grad->mutable_data(context.GetPlace()); + auto *out_grad_data = out_grad.data(); + auto *index_data = index.data(); + + memset(x_grad_data, 0, batch_size * value_length * sizeof(T)); + + for (int i = 0; i < index_ids_num; i++) { + int b = floor(i / index_length); + PADDLE_ENFORCE_GE( + index_data[i], 0, + platform::errors::InvalidArgument( + "Variable value (index) of OP(index_sample_grad) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + value_length, index_data[i])); + PADDLE_ENFORCE_LT( + index_data[i], value_length, + platform::errors::InvalidArgument( + "Variable value (index) of OP(index_sample_grad) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + value_length, index_data[i])); + int v_i = b * value_length + static_cast(index_data[i]); + x_grad_data[v_i] += out_grad_data[i]; + } +} + +template +class IndexSampleGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *index_var = context.InputVar("Index"); + auto *x_grad_var = context.OutputVar(framework::GradVarName("X")); + auto *out_grad_var = context.InputVar(framework::GradVarName("Out")); + + auto &index_tensor = index_var->Get(); + auto &out_grad_tensor = out_grad_var->Get(); + auto *x_grad_tensor = x_grad_var->GetMutable(); + + const auto &index_type = index_tensor.type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + PADDLE_ENFORCE_EQ(index_type_match, true, + platform::errors::InvalidArgument( + "Input(Index) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64))); + if (index_type == framework::proto::VarType::INT32) { + IndexSampleGradInner(context, out_grad_tensor, index_tensor, + x_grad_tensor); + } else if (index_type == framework::proto::VarType::INT64) { + IndexSampleGradInner(context, out_grad_tensor, index_tensor, + x_grad_tensor); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index c9606f6fb4fd4ca4cacd86881860bb50d942ab21..95da09d1b35d1134311b342b20f258cdb3babac8 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os from paddle.check_import_scipy import check_import_scipy @@ -33,10 +34,10 @@ import paddle.compat import paddle.distributed batch = batch.batch import paddle.sysconfig -import paddle.nn import paddle.tensor +import paddle.nn -#TODO: define alias in tensor and framework directory +# TODO: define alias in tensor and framework directory # from .tensor.creation import create_.tensor #DEFINE_ALIAS # from .tensor.creation import create_lod_.tensor #DEFINE_ALIAS # from .tensor.creation import create_random_int_lod.tensor #DEFINE_ALIAS @@ -191,6 +192,7 @@ from .tensor.search import argmax #DEFINE_ALIAS # from .tensor.search import topk #DEFINE_ALIAS # from .tensor.search import where #DEFINE_ALIAS # from .tensor.search import index_select #DEFINE_ALIAS +from .tensor.search import index_sample #DEFINE_ALIAS # from .tensor.search import nonzero #DEFINE_ALIAS from .tensor.search import sort #DEFINE_ALIAS # from .framework.framework import set_default_dtype #DEFINE_ALIAS diff --git a/python/paddle/fluid/tests/unittests/test_index_sample_op.py b/python/paddle/fluid/tests/unittests/test_index_sample_op.py new file mode 100644 index 0000000000000000000000000000000000000000..e3fc219cdf5b05729ae9c1e5e269a0e037745f68 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_index_sample_op.py @@ -0,0 +1,127 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest + + +class TestIndexSampleOp(OpTest): + def setUp(self): + self.op_type = "index_sample" + self.config() + xnp = np.random.random(self.x_shape).astype(self.x_type) + indexnp = np.random.randint( + low=0, high=self.x_shape[1], + size=self.index_shape).astype(self.index_type) + self.inputs = {'X': xnp, 'Index': indexnp} + index_array = [] + for i in range(self.index_shape[0]): + for j in indexnp[i]: + index_array.append(xnp[i, j]) + out = np.reshape(index_array, self.index_shape) + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + def config(self): + """ + For multi-dimension input + """ + self.x_shape = (10, 20) + self.x_type = "float64" + self.index_shape = (10, 10) + self.index_type = "int32" + + +class TestCase1(TestIndexSampleOp): + def config(self): + """ + For one dimension input + """ + self.x_shape = (100, 1) + self.x_type = "float64" + self.index_shape = (100, 1) + self.index_type = "int32" + + +class TestCase2(TestIndexSampleOp): + def config(self): + """ + For int64_t index type + """ + self.x_shape = (10, 100) + self.x_type = "float64" + self.index_shape = (10, 10) + self.index_type = "int64" + + +class TestCase3(TestIndexSampleOp): + def config(self): + """ + For int index type + """ + self.x_shape = (10, 100) + self.x_type = "float64" + self.index_shape = (10, 10) + self.index_type = "int32" + + +class TestCase4(TestIndexSampleOp): + def config(self): + """ + For int64 index type + """ + self.x_shape = (10, 100) + self.x_type = "float64" + self.index_shape = (10, 10) + self.index_type = "int64" + + +class TestIndexSampleShape(unittest.TestCase): + def test_shape(self): + import paddle.fluid as fluid + import paddle + + # create x value + x_shape = (2, 5) + x_type = "float64" + x_np = np.random.random(x_shape).astype(x_type) + + # create index value + index_shape = (2, 3) + index_type = "int32" + index_np = np.random.randint( + low=0, high=x_shape[1], size=index_shape).astype(index_type) + + x = fluid.data(name='x', shape=[-1, 5], dtype='float64') + index = fluid.data(name='index', shape=[-1, 3], dtype='int32') + output = paddle.index_sample(x=x, index=index) + + place = fluid.CPUPlace() + exe = fluid.Executor(place=place) + exe.run(fluid.default_startup_program()) + + feed = {'x': x_np, 'index': index_np} + res = exe.run(feed=feed, fetch_list=[output]) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 68b487107c325f88c1c01456f93c1a3ad96e1019..d9918dc832a2b4d949ed311f3254ec5be5cee32c 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -#TODO: define alias in tensor and framework directory +# TODO: define alias in tensor and framework directory # from .creation import create_tensor #DEFINE_ALIAS # from .creation import create_lod_tensor #DEFINE_ALIAS # from .creation import create_random_int_lod #DEFINE_ALIAS @@ -29,7 +29,7 @@ from .creation import linspace #DEFINE_ALIAS # from .creation import zeros_like #DEFINE_ALIAS # from .creation import arrange #DEFINE_ALIAS # from .creation import eye #DEFINE_ALIAS -from .creation import full #DEFINE_ALIAS +from .creation import full # DEFINE_ALIAS # from .creation import linspace #DEFINE_ALIAS # from .creation import full_like #DEFINE_ALIAS from .creation import triu #DEFINE_ALIAS @@ -167,5 +167,6 @@ from .search import argmax #DEFINE_ALIAS # from .search import topk #DEFINE_ALIAS # from .search import where #DEFINE_ALIAS # from .search import index_select #DEFINE_ALIAS +from .search import index_sample # DEFINE_ALIAS # from .search import nonzero #DEFINE_ALIAS from .search import sort #DEFINE_ALIAS diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 5d75e56dde0cfc52fd17e9db20b36fd05f4f8bba..3a4ca7ab7bdfbcef82fd089c72396cebba2572b1 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -11,8 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function +from ..fluid.layer_helper import LayerHelper +from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype -# TODO: define searching & indexing functions of a tensor +# TODO: define searching & indexing functions of a tensor __all__ = [ 'argmax', # 'argmin', @@ -24,7 +27,8 @@ __all__ = [ # 'where', # 'index_select', # 'nonzero', - 'sort' + 'sort', + 'index_sample' ] from paddle.common_ops_import import * @@ -125,7 +129,7 @@ def sort(input, axis=-1, descending=False, out=None, name=None): This OP sorts the input along the given axis, and returns sorted output data Varibale and its corresponding index Variable with the same shape as :attr:`input`. - + **NOTICE**: The Variable in the output of this OP has gradient. You could\ set Variable :attr:`stop_gradient`. Args: @@ -207,3 +211,75 @@ def sort(input, axis=-1, descending=False, out=None, name=None): attrs={'axis': axis, 'descending': descending}) return out, ids + + +def index_sample(x, index): + """ + **IndexSample Layer** + + IndexSample OP returns the element of the specified location of X, + and the location is specified by Index. + + .. code-block:: text + + + Given: + + X = [[1, 2, 3, 4, 5], + [6, 7, 8, 9, 10]] + + Index = [[0, 1, 3], + [0, 2, 4]] + + Then: + + Out = [[1, 2, 4], + [6, 8, 10]] + + Args: + x (Variable): The source input tensor with 2-D shape. Supported data type is + int32, int64, float32, float64. + index (Variable): The index input tensor with 2-D shape, first dimension should be same with X. + Data type is int32 or int64. + + Returns: + output (Variable): The output is a tensor with the same shape as index. + + Examples: + + .. code-block:: python + + import paddle + import paddle.fluid as fluid + import numpy as np + + # create x value + x_shape = (2, 5) + x_type = "float64" + x_np = np.random.random(x_shape).astype(x_type) + + # create index value + index_shape = (2, 3) + index_type = "int32" + index_np = np.random.randint(low=0, + high=x_shape[1], + size=index_shape).astype(index_type) + + x = fluid.data(name='x', shape=[-1, 5], dtype='float64') + index = fluid.data(name='index', shape=[-1, 3], dtype='int32') + output = paddle.index_sample(x=x, index=index) + + """ + helper = LayerHelper("index_sample", **locals()) + check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'], + 'paddle.tensor.search.index_sample') + check_variable_and_dtype(index, 'index', ['int32', 'int64'], + 'paddle.tensor.search.index_sample') + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type='index_sample', + inputs={'X': x, + 'Index': index}, + outputs={'Out': out}) + return out