diff --git a/paddle/fluid/operators/similarity_focus_op.cc b/paddle/fluid/operators/similarity_focus_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..0750fc737af093a1cbf3640d640057173721cb4d --- /dev/null +++ b/paddle/fluid/operators/similarity_focus_op.cc @@ -0,0 +1,83 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/similarity_focus_op.h" + +namespace paddle { +namespace operators { +class SimilarityFocusOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor, default Tensor), a 4-D tensor with shape," + " [BatchSize, X, Y, Z]"); + AddOutput("Out", + "(Tensor, default Tensor), the similarity focus mask" + " with the same shape of input X."); + AddAttr("axis", + "(int32), indicating the dimension to be select. It can" + " only be 1, 2, or 3."); + AddAttr>("indexes", + "(std::vector), indicating the indexes" + " of the selected dimension."); + AddComment(R"DOC( +SimilarityFocus Operator. + +Generate a similarity focus mask with the same shape of input using the following method: +1. Extract the 3-D matrix(here the first dimension is BatchSize) corresponding + to the axis according to the indexes. For example, if axis=1 and indexes=[a], + it will get the matrix T=X[:, a, :, :]. In this casr, if the shape of input X + is (BatchSize, A, B, C), the shape of matrix T is (BatchSize, B, C). +2. For each index, find the largest numbers in the matrix T, so that the same + row and same column has at most one number(obviously there will be min(B, C) + numbers), and mark the corresponding position of the 3-D similarity focus mask + as 1, otherwise as 0. Do elementwise-or for each index. +3. Broadcast the 3-D similarity focus mask to the same shape of input X. + +Refer to `Similarity Focus Layer `_ +)DOC"); + } +}; + +class SimilarityFocusOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should be not null."); + auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_EQ(x_dims.size(), 4, "Input(X)'s rank should be 4."); + ctx->SetOutputDim("Out", x_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + platform::CPUPlace()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(similarity_focus, ops::SimilarityFocusOp, + ops::SimilarityFocusOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL(similarity_focus, ops::SimilarityFocusKernel, + ops::SimilarityFocusKernel); diff --git a/paddle/fluid/operators/similarity_focus_op.h b/paddle/fluid/operators/similarity_focus_op.h new file mode 100644 index 0000000000000000000000000000000000000000..bf3fed2aaf2cf92d5619ae5bce6dd70d9dfe9621 --- /dev/null +++ b/paddle/fluid/operators/similarity_focus_op.h @@ -0,0 +1,168 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +template +class SimilarityFocusKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + Tensor* out = context.Output("Out"); + const Tensor* x = context.Input("X"); + T* out_data = out->mutable_data(context.GetPlace()); + const T* x_data = x->data(); + + int axis = context.Attr("axis"); + std::vector indexes = context.Attr>("indexes"); + + int64_t batch_size = x->dims()[0]; + int64_t dim[4]; + for (int i = 1; i <= 3; ++i) { + dim[i] = x->dims()[i]; + } + + if (indexes.size() < 1) { + PADDLE_THROW("Indexes' size can not be 0."); + } + for (auto index : indexes) { + if (dim[axis] < index) { + PADDLE_THROW("Index exceeds tensor shape limit."); + } + } + + int64_t array_size = 1; + for (int i = 1; i <= 3; ++i) { + if (i != axis) { + array_size *= dim[i]; + } + } + + std::vector> array(array_size); + + bool (*cmp)(std::pair, std::pair) = []( + std::pair x, std::pair y) { + return x.first > y.first; + }; + + int64_t (*compute_index)(int64_t*, int, int, int, int) = []( + int64_t* dim, int d1, int d2, int d3, int d4) { + return d1 * dim[1] * dim[2] * dim[3] + d2 * dim[2] * dim[3] + + d3 * dim[3] + d4; + }; + + memset(out_data, 0, sizeof(T) * batch_size * dim[1] * dim[2] * dim[3]); + for (int i = 0; i < batch_size; ++i) { + for (auto index : indexes) { + if (axis == 1) { + for (int j = 0; j < dim[2]; ++j) { + for (int k = 0; k < dim[3]; ++k) { + array[j * dim[3] + k] = std::make_pair( + x_data[compute_index(dim, i, index, j, k)], j * dim[3] + k); + } + } + + std::sort(array.begin(), array.end(), cmp); + int tag_num = 0; + std::vector tag2(dim[2]), tag3(dim[3]); + for (auto x : array) { + int idx2 = x.second / dim[3]; + int idx3 = x.second % dim[3]; + if (tag2[idx2] || tag3[idx3]) { + continue; + } + tag_num++; + tag2[idx2] = true; + tag3[idx3] = true; + for (int j = 0; j < dim[1]; ++j) { + out_data[compute_index(dim, i, j, idx2, idx3)] = 1; + } + if (tag_num == std::min(dim[2], dim[3])) { + break; + } + } + } else if (axis == 2) { + for (int j = 0; j < dim[1]; ++j) { + for (int k = 0; k < dim[3]; ++k) { + array[j * dim[3] + k] = std::make_pair( + x_data[compute_index(dim, i, j, index, k)], j * dim[3] + k); + } + } + + std::sort(array.begin(), array.end(), cmp); + int tag_num = 0; + std::vector tag1(dim[1]), tag3(dim[3]); + for (auto x : array) { + int idx1 = x.second / dim[3]; + int idx3 = x.second % dim[3]; + if (tag1[idx1] || tag3[idx3]) { + continue; + } + tag_num++; + tag1[idx1] = true; + tag3[idx3] = true; + for (int j = 0; j < dim[2]; ++j) { + out_data[compute_index(dim, i, idx1, j, idx3)] = 1; + } + if (tag_num == std::min(dim[1], dim[3])) { + break; + } + } + } else if (axis == 3) { + for (int j = 0; j < dim[1]; ++j) { + for (int k = 0; k < dim[2]; ++k) { + array[j * dim[2] + k] = std::make_pair( + x_data[compute_index(dim, i, j, k, index)], j * dim[2] + k); + } + } + + std::sort(array.begin(), array.end(), cmp); + int tag_num = 0; + std::vector tag1(dim[1]), tag2(dim[2]); + for (auto x : array) { + int idx1 = x.second / dim[2]; + int idx2 = x.second % dim[2]; + if (tag1[idx1] || tag2[idx2]) { + continue; + } + tag_num++; + tag1[idx1] = true; + tag2[idx2] = true; + for (int j = 0; j < dim[3]; ++j) { + out_data[compute_index(dim, i, idx1, idx2, j)] = 1; + } + if (tag_num == std::min(dim[1], dim[2])) { + break; + } + } + } else { + PADDLE_THROW("Axis must be 1 or 2 or 3"); + } + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index cca618b9ad2fef9bf4870f0f94d17fbc529fb83c..463200fb7215c1a5e8a1a893da5fc70690caf700 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -155,6 +155,7 @@ __all__ = [ 'sigmoid_cross_entropy_with_logits', 'maxout', 'affine_channel', + 'similarity_focus', ] @@ -7494,3 +7495,58 @@ def affine_channel(x, scale=None, bias=None, data_layout='NCHW', name=None): attrs={"data_layout": data_layout}, outputs={"Out": out}) return out + + +def similarity_focus(input, axis, indexes, name=None): + """ + **SimilarityFocus Operator** + + Generate a similarity focus mask with the same shape of input using the following method: + 1. Extract the 3-D matrix(here the first dimension is BatchSize) corresponding + to the axis according to the indexes. For example, if axis=1 and indexes=[a], + it will get the matrix T=X[:, a, :, :]. In this casr, if the shape of input X + is (BatchSize, A, B, C), the shape of matrix T is (BatchSize, B, C). + 2. For each index, find the largest numbers in the matrix T, so that the same + row and same column has at most one number(obviously there will be min(B, C) + numbers), and mark the corresponding position of the 3-D similarity focus mask + as 1, otherwise as 0. Do elementwise-or for each index. + 3. Broadcast the 3-D similarity focus mask to the same shape of input X. + + Refer to `Similarity Focus Layer `_ + + Args: + input(Variable): The input tensor variable(default float). It should + be a 4-D tensor with shape [BatchSize, A, B, C]. + axis(int): Indicating the dimension to be select. It can only be + 1, 2, or 3. + indexes(list): indicating the indexes of the selected dimension. + + Returns: + Variable: A tensor variable with the same shape and same type + as the input. + + Examples: + .. code-block:: python + data = fluid.layers.data( + name='data', shape=[128, 13, 48, 48], dtype='float32') + x = fluid.layers.layer_norm(input=data, axis=1, indexes=[9, 10]) + """ + helper = LayerHelper('similarity_focus', **locals()) + # check attrs + if isinstance(axis, int) is False: + raise TypeError("axis must be int type.") + if isinstance(indexes, list) is False: + raise TypeError("indexes must be list type.") + if axis != 1 and axis != 2 and axis != 3: + raise ValueError("axis must be 1, 2 or 3.") + if len(indexes) == 0: + raise ValueError("indexes can not be empty.") + + out = helper.create_tmp_variable(dtype=helper.input_dtype()) + helper.append_op( + type='similarity_focus', + inputs={'X': input}, + outputs={'Out': out}, + attrs={"axis": axis, + "indexes": indexes}) + return out diff --git a/python/paddle/fluid/tests/unittests/test_similarity_focus_op.py b/python/paddle/fluid/tests/unittests/test_similarity_focus_op.py new file mode 100755 index 0000000000000000000000000000000000000000..21308a7e0ccd9e0bab1faf1f593643ea4229d3f4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_similarity_focus_op.py @@ -0,0 +1,168 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle.fluid.core as core +from op_test import OpTest + + +class TestSimilarityFocusOp_axis1(OpTest): + def setUp(self): + self.op_type = "similarity_focus" + batch_size = 3 + x_dim, y_dim, z_dim = 4, 5, 6 + self.inputs = { + 'X': np.random.random( + (batch_size, x_dim, y_dim, z_dim)).astype("float32"), + } + self.attrs = { + 'axis': 1, + 'indexes': [0, 3], + } + + output = None + for batch in range(batch_size): + res = np.zeros((1, y_dim, z_dim)).astype("float32").reshape(-1) + for index in self.attrs['indexes']: + channel = self.inputs['X'][batch, index, :, :].reshape(-1).copy( + ) + tag1 = [0 for i in range(y_dim)] + tag2 = [0 for i in range(z_dim)] + cnt = 0 + for i in range(channel.size): + index = channel.argmax() + idx1 = index / z_dim + idx2 = index % z_dim + if tag1[idx1] + tag2[idx2] == 0: + tag1[idx1] = 1 + tag2[idx2] = 1 + res[index] = 1 + cnt += 1 + if cnt == min(y_dim, z_dim): + break + channel[index] = -1 + res = res.reshape(1, y_dim, z_dim) + res = res.repeat([x_dim], axis=0) + res = res.reshape(1, x_dim, y_dim, z_dim) + if output is not None: + output = np.concatenate((output, res), axis=0) + else: + output = res + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + +class TestSimilarityFocusOp_axis2(OpTest): + def setUp(self): + self.op_type = "similarity_focus" + batch_size = 6 + x_dim, y_dim, z_dim = 7, 8, 9 + self.inputs = { + 'X': np.random.random( + (batch_size, x_dim, y_dim, z_dim)).astype("float32"), + } + self.attrs = { + 'axis': 2, + 'indexes': [0, 3, 5], + } + + output = None + for batch in range(batch_size): + res = np.zeros((x_dim, 1, z_dim)).astype("float32").reshape(-1) + for index in self.attrs['indexes']: + channel = self.inputs['X'][batch, :, index, :].reshape(-1).copy( + ) + tag1 = [0 for i in range(x_dim)] + tag2 = [0 for i in range(z_dim)] + cnt = 0 + for i in range(channel.size): + index = channel.argmax() + idx1 = index / z_dim + idx2 = index % z_dim + if tag1[idx1] + tag2[idx2] == 0: + tag1[idx1] = 1 + tag2[idx2] = 1 + res[index] = 1 + cnt += 1 + if cnt == min(x_dim, z_dim): + break + channel[index] = -1 + res = res.reshape(x_dim, 1, z_dim) + res = res.repeat([y_dim], axis=1) + res = res.reshape(1, x_dim, y_dim, z_dim) + if output is not None: + output = np.concatenate((output, res), axis=0) + else: + output = res + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + +class TestSimilarityFocusOp_axis3(OpTest): + def setUp(self): + self.op_type = "similarity_focus" + batch_size = 64 + x_dim, y_dim, z_dim = 48, 48, 13 + self.inputs = { + 'X': np.random.random( + (batch_size, x_dim, y_dim, z_dim)).astype("float32"), + } + self.attrs = { + 'axis': 3, + 'indexes': [0, 2, 7, 9], + } + + output = None + for batch in range(batch_size): + res = np.zeros((x_dim, y_dim, 1)).astype("float32").reshape(-1) + for index in self.attrs['indexes']: + channel = self.inputs['X'][batch, :, :, index].reshape(-1).copy( + ) + tag1 = [0 for i in range(x_dim)] + tag2 = [0 for i in range(y_dim)] + cnt = 0 + for i in range(channel.size): + index = channel.argmax() + idx1 = index / y_dim + idx2 = index % y_dim + if tag1[idx1] + tag2[idx2] == 0: + tag1[idx1] = 1 + tag2[idx2] = 1 + res[index] = 1 + cnt += 1 + if cnt == min(x_dim, y_dim): + break + channel[index] = -1 + res = res.reshape(x_dim, y_dim, 1) + res = res.repeat([z_dim], axis=2) + res = res.reshape(1, x_dim, y_dim, z_dim) + if output is not None: + output = np.concatenate((output, res), axis=0) + else: + output = res + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main()