diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index f285a5d60b8c5f1348d029e5f018e0a8c479568b..0f6337d6f3627fcb4a5e7cfd63dc4d581a3e11dd 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -234,6 +234,7 @@ paddle.fluid.layers.npair_loss (ArgSpec(args=['anchor', 'positive', 'labels', 'l paddle.fluid.layers.pixel_shuffle (ArgSpec(args=['x', 'upscale_factor'], varargs=None, keywords=None, defaults=None), ('document', '132b6e74ff642a392bd6b14c10aedc65')) paddle.fluid.layers.fsp_matrix (ArgSpec(args=['x', 'y'], varargs=None, keywords=None, defaults=None), ('document', 'b76ccca3735bea4a58a0dbf0d77c5393')) paddle.fluid.layers.continuous_value_model (ArgSpec(args=['input', 'cvm', 'use_cvm'], varargs=None, keywords=None, defaults=(True,)), ('document', 'a07a44c2bacdcd09c1f5f35a96a0514e')) +paddle.fluid.layers.where (ArgSpec(args=['condition'], varargs=None, keywords=None, defaults=None), ('document', '3126e3039e752ce26077f1efaca355c6')) paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', 'adf285346e23316097f7789b572491e9')) paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'cf12066a3139026119f97f9d4381a1bd')) paddle.fluid.layers.read_file (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', 'b0a1c2fc51c27a106da28f3308c41f5e')) diff --git a/paddle/fluid/operators/where_op.cc b/paddle/fluid/operators/where_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..3b53ebec0b250c7181968e37f996ec9ef5cf2a2c --- /dev/null +++ b/paddle/fluid/operators/where_op.cc @@ -0,0 +1,58 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/where_op.h" + +namespace paddle { +namespace operators { + +class WhereOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Condition"), + "Input(Condition) of WhereOp should not be null."); + PADDLE_ENFORCE( + ctx->GetInputDim("Condition").size() >= 1, + "Input(Condition) should have number of dimension at least 1"); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(OUt) of WhereOp should not be null."); + ctx->SetOutputDim("Out", {-1, ctx->GetInputDim("Condition").size()}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto output_type = framework::proto::VarType::INT64; + return framework::OpKernelType(output_type, ctx.device_context()); + } +}; + +class WhereOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Condition", "A bool tensor whose rank is at least 1"); + AddOutput("Out", "An int64 tensor of rank 2"); + AddComment(R"DOC( + Return a int64 tensor with rank 2, specifying the coordinate of true element in `Condition`. +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(where, ops::WhereOp, ops::WhereOpMaker); +REGISTER_OP_CPU_KERNEL(where, ops::CPUWhereKernel); diff --git a/paddle/fluid/operators/where_op.cu b/paddle/fluid/operators/where_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..27682f869c73c760bf475489a8bdd57e39cfaea5 --- /dev/null +++ b/paddle/fluid/operators/where_op.cu @@ -0,0 +1,81 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/where_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +using CUDADeviceContext = paddle::platform::CUDADeviceContext; + +template +class CUDAWhereKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* condition = context.Input("Condition"); + auto* out = context.Output("Out"); + + // TODO(zhoukunsheng): Should optimize to ensure GPU is faster than CPU. + framework::Tensor cond_cpu; + framework::TensorCopy(*condition, platform::CPUPlace(), &cond_cpu); + + const bool* cond_data = cond_cpu.data(); + int64_t numel = cond_cpu.numel(); + auto dims = cond_cpu.dims(); + int rank = dims.size(); + + thrust::host_vector h_true_index; + for (int64_t i = 0; i < numel; i++) { + if (cond_data[i]) { + h_true_index.push_back(i); + } + } + thrust::device_vector d_true_index = h_true_index; + int* ptr_true_index = thrust::raw_pointer_cast(d_true_index.data()); + + size_t true_num = h_true_index.size(); + + out->Resize(framework::make_ddim({static_cast(true_num), rank})); + auto out_ptr = out->mutable_data(context.GetPlace()); + + if (true_num == 0) { + return; + } + + thrust::host_vector h_stride(rank, 0); + h_stride[rank - 1] = 1; + for (int i = rank - 2; i >= 0; i--) { + h_stride[i] = h_stride[i + 1] * dims[i + 1]; + } + thrust::device_vector d_stride = h_stride; + int* ptr_stride = thrust::raw_pointer_cast(d_stride.data()); + + auto& dev_ctx = context.template device_context(); + WhereFunctor functor(ptr_true_index, true_num, ptr_stride, rank, + out_ptr); + platform::ForRange for_range(dev_ctx, true_num); + for_range(functor); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(where, ops::CUDAWhereKernel); diff --git a/paddle/fluid/operators/where_op.h b/paddle/fluid/operators/where_op.h new file mode 100644 index 0000000000000000000000000000000000000000..6a161a2668fa02f181ef99bfbfb501541988a333 --- /dev/null +++ b/paddle/fluid/operators/where_op.h @@ -0,0 +1,95 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +template +struct WhereFunctor { + WhereFunctor(const T& true_index, int true_num, const T& stride, int rank, + int64_t* out) + : true_index_(true_index), + true_num_(true_num), + stride_(stride), + rank_(rank), + out_ptr_(out) {} + + HOSTDEVICE void operator()(size_t idx) const { + int index = true_index_[idx]; + for (int j = 0; j < rank_; j++) { + out_ptr_[idx * rank_ + j] = index / stride_[j]; + index -= out_ptr_[idx * rank_ + j] * stride_[j]; + } + } + + const T true_index_; + int true_num_; + const T stride_; + int rank_; + int64_t* out_ptr_; +}; + +using CPUDeviceContext = paddle::platform::CPUDeviceContext; + +template +class CPUWhereKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* condition = context.Input("Condition"); + auto* out = context.Output("Out"); + + const bool* cond_data = condition->data(); + auto numel = condition->numel(); + auto dims = condition->dims(); + const int rank = dims.size(); + + std::vector true_index; + for (auto i = 0; i < numel; i++) { + if (cond_data[i]) { + true_index.push_back(i); + } + } + auto true_num = true_index.size(); + + out->Resize(framework::make_ddim({static_cast(true_num), rank})); + auto out_ptr = out->mutable_data(context.GetPlace()); + + if (true_num == 0) { + return; + } + + std::vector stride(rank); + stride[rank - 1] = 1; + for (int i = rank - 2; i >= 0; i--) { + stride[i] = stride[i + 1] * dims[i + 1]; + } + + auto& dev_ctx = context.template device_context(); + WhereFunctor functor(true_index.data(), true_num, stride.data(), rank, + out_ptr); + platform::ForRange for_range(dev_ctx, true_num); + for_range(functor); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 21eb4baa2b0b46f8cdbc813d6062b5d4347f718b..d179f56c6ca3fb482561fcda2b27316670c99696 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -200,6 +200,7 @@ __all__ = [ 'pixel_shuffle', 'fsp_matrix', 'continuous_value_model', + 'where', ] kIgnoreIndex = -100 @@ -11341,3 +11342,38 @@ def continuous_value_model(input, cvm, use_cvm=True): outputs={'Y': [out]}, attrs={"use_cvm": use_cvm}) return out + + +def where(condition): + """ + Return an int64 tensor with rank 2, specifying the coordinate of true element in `condition`. + + Output's first dimension is the number of true element, second dimension is rank(number of dimension) of `condition`. + If there is zero true element, then an empty tensor will be generated. + + Args: + condition(Variable): A bool tensor with rank at least 1. + + Returns: + Variable: The tensor variable storing a 2-D tensor. + + Examples: + .. code-block:: python + + # condition is a tensor [True, False, True] + out = fluid.layers.where(condition) # [[0], [2]] + + # condition is a tensor [[True, False], [False, True]] + out = fluid.layers.where(condition) # [[0, 0], [1, 1]] + + # condition is a tensor [False, False, False] + out = fluid.layers.where(condition) # [[]] + """ + helper = LayerHelper("where", **locals()) + + out = helper.create_variable_for_type_inference( + dtype=core.VarDesc.VarType.INT64) + + helper.append_op( + type='where', inputs={'Condition': condition}, outputs={'Out': [out]}) + return out diff --git a/python/paddle/fluid/tests/unittests/test_where.py b/python/paddle/fluid/tests/unittests/test_where.py new file mode 100644 index 0000000000000000000000000000000000000000..ee0fa1613093c982320337aaa453114cfb187db4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_where.py @@ -0,0 +1,92 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid.core as core +from paddle.fluid.op import Operator + + +class TestWhereOp(OpTest): + def setUp(self): + self.op_type = "where" + self.init_config() + + def test_check_output(self): + self.check_output() + + def init_config(self): + self.inputs = {'Condition': np.array([True, False, True]), } + + self.outputs = {'Out': np.array([[0], [2]], dtype='int64')} + + +class TestAllFalse(unittest.TestCase): + def setUp(self): + self.op_type = "where" + self.init_config() + + def check_with_place(self, place): + scope = core.Scope() + condition = scope.var('Condition').get_tensor() + condition.set(self.cond_data, place) + + out = scope.var("Out").get_tensor() + out.set(np.full(self.shape, 0).astype('int64'), place) + + op = Operator("where", Condition="Condition", Out="Out") + op.run(scope, place) + + out_array = np.array(out) + self.assertTrue((out_array == self.out_data).all()) + + def init_config(self): + self.cond_data = np.array([False, False, False]) + self.shape = (3, 1) + self.out_data = np.array([], dtype='int64') + + def test_all_false(self): + self.check_with_place(core.CPUPlace()) + + if core.is_compiled_with_cuda(): + self.check_with_place(core.CUDAPlace(0)) + + +class TestRank2(TestWhereOp): + def init_config(self): + self.inputs = {'Condition': np.array([[True, False], [False, True]]), } + + self.outputs = {'Out': np.array([[0, 0], [1, 1]], dtype='int64')} + + +class TestRank3(TestWhereOp): + def init_config(self): + self.inputs = { + 'Condition': np.array([[[True, False], [False, True]], + [[False, True], [True, False]], + [[False, False], [False, True]]]), + } + + self.outputs = { + 'Out': np.array( + [[0, 0, 0], [0, 1, 1], [1, 0, 1], [1, 1, 0], [2, 1, 1]], + dtype='int64') + } + + +if __name__ == "__main__": + unittest.main()