diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 250ea89b1239a27ae98cc0c1f31b6695ce7c405e..a155c7052be4f07b568af7954f270ea13bd70164 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -179,6 +179,7 @@ paddle.fluid.layers.space_to_depth ArgSpec(args=['x', 'blocksize', 'name'], vara paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None)) +paddle.fluid.layers.similarity_focus ArgSpec(args=['input', 'axis', 'indexes', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None)) paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None)) diff --git a/paddle/fluid/operators/similarity_focus_op.cc b/paddle/fluid/operators/similarity_focus_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..9612f82b6d45dc4e08bfe288ddd1c7790875ee4d --- /dev/null +++ b/paddle/fluid/operators/similarity_focus_op.cc @@ -0,0 +1,87 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/similarity_focus_op.h" + +namespace paddle { +namespace operators { +class SimilarityFocusOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor, default Tensor), a 4-D tensor with shape," + " [BatchSize, X, Y, Z]"); + AddOutput("Out", + "(Tensor, default Tensor), the similarity focus mask" + " with the same shape of input X."); + AddAttr("axis", + "(int32), indicating the dimension to be select. It can" + " only be 1, 2, or 3."); + AddAttr>("indexes", + "(std::vector), indicating the indexes" + " of the selected dimension."); + AddComment(R"DOC( +SimilarityFocus Operator. + +Generate a similarity focus mask with the same shape of input using the following method: +1. Extract the 3-D tensor(here the first dimension is BatchSize) corresponding + to the axis according to the indexes. For example, if axis=1 and indexes=[a], + it will get the matrix T=X[:, a, :, :]. In this case, if the shape of input X + is (BatchSize, A, B, C), the shape of tensor T is (BatchSize, B, C). +2. For each index, find the largest numbers in the tensor T, so that the same + row and same column has at most one number(what it means is that if the + largest number has been found in the i-th row and the j-th column, then + the numbers in the i-th row or j-th column will be skipped. And then the + next largest number will be selected from the remaining numbers. Obviously + there will be min(B, C) numbers), and mark the corresponding position of the + 3-D similarity focus mask as 1, otherwise as 0. Do elementwise-or for + each index. +3. Broadcast the 3-D similarity focus mask to the same shape of input X. + +Refer to `Similarity Focus Layer `_ +)DOC"); + } +}; + +class SimilarityFocusOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should be not null."); + auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_EQ(x_dims.size(), 4, "Input(X)'s rank should be 4."); + ctx->SetOutputDim("Out", x_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + platform::CPUPlace()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(similarity_focus, ops::SimilarityFocusOp, + ops::SimilarityFocusOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL(similarity_focus, ops::SimilarityFocusKernel, + ops::SimilarityFocusKernel); diff --git a/paddle/fluid/operators/similarity_focus_op.h b/paddle/fluid/operators/similarity_focus_op.h new file mode 100644 index 0000000000000000000000000000000000000000..bf3fed2aaf2cf92d5619ae5bce6dd70d9dfe9621 --- /dev/null +++ b/paddle/fluid/operators/similarity_focus_op.h @@ -0,0 +1,168 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +template +class SimilarityFocusKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + Tensor* out = context.Output("Out"); + const Tensor* x = context.Input("X"); + T* out_data = out->mutable_data(context.GetPlace()); + const T* x_data = x->data(); + + int axis = context.Attr("axis"); + std::vector indexes = context.Attr>("indexes"); + + int64_t batch_size = x->dims()[0]; + int64_t dim[4]; + for (int i = 1; i <= 3; ++i) { + dim[i] = x->dims()[i]; + } + + if (indexes.size() < 1) { + PADDLE_THROW("Indexes' size can not be 0."); + } + for (auto index : indexes) { + if (dim[axis] < index) { + PADDLE_THROW("Index exceeds tensor shape limit."); + } + } + + int64_t array_size = 1; + for (int i = 1; i <= 3; ++i) { + if (i != axis) { + array_size *= dim[i]; + } + } + + std::vector> array(array_size); + + bool (*cmp)(std::pair, std::pair) = []( + std::pair x, std::pair y) { + return x.first > y.first; + }; + + int64_t (*compute_index)(int64_t*, int, int, int, int) = []( + int64_t* dim, int d1, int d2, int d3, int d4) { + return d1 * dim[1] * dim[2] * dim[3] + d2 * dim[2] * dim[3] + + d3 * dim[3] + d4; + }; + + memset(out_data, 0, sizeof(T) * batch_size * dim[1] * dim[2] * dim[3]); + for (int i = 0; i < batch_size; ++i) { + for (auto index : indexes) { + if (axis == 1) { + for (int j = 0; j < dim[2]; ++j) { + for (int k = 0; k < dim[3]; ++k) { + array[j * dim[3] + k] = std::make_pair( + x_data[compute_index(dim, i, index, j, k)], j * dim[3] + k); + } + } + + std::sort(array.begin(), array.end(), cmp); + int tag_num = 0; + std::vector tag2(dim[2]), tag3(dim[3]); + for (auto x : array) { + int idx2 = x.second / dim[3]; + int idx3 = x.second % dim[3]; + if (tag2[idx2] || tag3[idx3]) { + continue; + } + tag_num++; + tag2[idx2] = true; + tag3[idx3] = true; + for (int j = 0; j < dim[1]; ++j) { + out_data[compute_index(dim, i, j, idx2, idx3)] = 1; + } + if (tag_num == std::min(dim[2], dim[3])) { + break; + } + } + } else if (axis == 2) { + for (int j = 0; j < dim[1]; ++j) { + for (int k = 0; k < dim[3]; ++k) { + array[j * dim[3] + k] = std::make_pair( + x_data[compute_index(dim, i, j, index, k)], j * dim[3] + k); + } + } + + std::sort(array.begin(), array.end(), cmp); + int tag_num = 0; + std::vector tag1(dim[1]), tag3(dim[3]); + for (auto x : array) { + int idx1 = x.second / dim[3]; + int idx3 = x.second % dim[3]; + if (tag1[idx1] || tag3[idx3]) { + continue; + } + tag_num++; + tag1[idx1] = true; + tag3[idx3] = true; + for (int j = 0; j < dim[2]; ++j) { + out_data[compute_index(dim, i, idx1, j, idx3)] = 1; + } + if (tag_num == std::min(dim[1], dim[3])) { + break; + } + } + } else if (axis == 3) { + for (int j = 0; j < dim[1]; ++j) { + for (int k = 0; k < dim[2]; ++k) { + array[j * dim[2] + k] = std::make_pair( + x_data[compute_index(dim, i, j, k, index)], j * dim[2] + k); + } + } + + std::sort(array.begin(), array.end(), cmp); + int tag_num = 0; + std::vector tag1(dim[1]), tag2(dim[2]); + for (auto x : array) { + int idx1 = x.second / dim[2]; + int idx2 = x.second % dim[2]; + if (tag1[idx1] || tag2[idx2]) { + continue; + } + tag_num++; + tag1[idx1] = true; + tag2[idx2] = true; + for (int j = 0; j < dim[3]; ++j) { + out_data[compute_index(dim, i, idx1, idx2, j)] = 1; + } + if (tag_num == std::min(dim[1], dim[2])) { + break; + } + } + } else { + PADDLE_THROW("Axis must be 1 or 2 or 3"); + } + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 0a39587574c06a093b9f9e407bb82fc28612ff0d..c1a93e831b6df02c2c4cfa34cd498b7b147a8b1b 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -160,6 +160,7 @@ __all__ = [ 'affine_grid', 'sequence_reverse', 'affine_channel', + 'similarity_focus', 'hash', 'grid_sampler', 'log_loss', @@ -7933,6 +7934,118 @@ def affine_channel(x, scale=None, bias=None, data_layout='NCHW', name=None): return out +def similarity_focus(input, axis, indexes, name=None): + """ + SimilarityFocus Operator + + Generate a similarity focus mask with the same shape of input using the following method: + 1. Extract the 3-D tensor(here the first dimension is BatchSize) corresponding + to the axis according to the indexes. For example, if axis=1 and indexes=[a], + it will get the matrix T=X[:, a, :, :]. In this case, if the shape of input X + is (BatchSize, A, B, C), the shape of tensor T is (BatchSize, B, C). + 2. For each index, find the largest numbers in the tensor T, so that the same + row and same column has at most one number(what it means is that if the + largest number has been found in the i-th row and the j-th column, then + the numbers in the i-th row or j-th column will be skipped. And then the + next largest number will be selected from the remaining numbers. Obviously + there will be min(B, C) numbers), and mark the corresponding position of the + 3-D similarity focus mask as 1, otherwise as 0. Do elementwise-or for + each index. + 3. Broadcast the 3-D similarity focus mask to the same shape of input X. + + Refer to `Similarity Focus Layer `_ + + .. code-block:: text + + * Example : + + Given a 4-D tensor x with the shape (BatchSize, C, A, B), where C is + the number of channels and the shape of feature map is (A, B): + x.shape = (2, 3, 2, 2) + x.data = [[[[0.8, 0.1], + [0.4, 0.5]], + + [[0.9, 0.7], + [0.9, 0.9]], + + [[0.8, 0.9], + [0.1, 0.2]]], + + + [[[0.2, 0.5], + [0.3, 0.4]], + + [[0.9, 0.7], + [0.8, 0.4]], + + [[0.0, 0.2], + [0.4, 0.7]]]] + + Given axis: 1 (the axis of the channel) + Given indexes: [0] + + then we get a 4-D tensor out with the same shape of input x: + out.shape = (2, 3, 2, 2) + out.data = [[[[1.0, 0.0], + [0.0, 1.0]], + + [[1.0, 0.0], + [0.0, 1.0]], + + [[1.0, 0.0], + [0.0, 1.0]]], + + [[[0.0, 1.0], + [1.0, 0.0]], + + [[0.0, 1.0], + [1.0, 0.0]], + + [[0.0, 1.0], + [1.0, 0.0]]]] + + Args: + input(Variable): The input tensor variable(default float). It should + be a 4-D tensor with shape [BatchSize, A, B, C]. + axis(int): Indicating the dimension to be selected. It can only be + 1, 2 or 3. + indexes(list): Indicating the indexes of the selected dimension. + + Returns: + Variable: A tensor variable with the same shape and same type + as the input. + + Examples: + .. code-block:: python + data = fluid.layers.data( + name='data', shape=[2, 3, 2, 2], dtype='float32') + x = fluid.layers.layer_norm(input=data, axis=1, indexes=[0]) + """ + helper = LayerHelper('similarity_focus', **locals()) + # check attrs + if isinstance(axis, int) is False: + raise TypeError("axis must be int type.") + if isinstance(indexes, list) is False: + raise TypeError("indexes must be list type.") + if axis != 1 and axis != 2 and axis != 3: + raise ValueError("axis must be 1, 2 or 3.") + if len(indexes) == 0: + raise ValueError("indexes can not be empty.") + + if name is None: + out = helper.create_variable_for_type_inference(dtype=input.dtype) + else: + out = helper.create_variable( + name=name, dtype=input.dtype, persistable=False) + helper.append_op( + type='similarity_focus', + inputs={'X': input}, + outputs={'Out': out}, + attrs={"axis": axis, + "indexes": indexes}) + return out + + def hash(input, hash_size, num_hash=1, name=None): """ Hash the input to an integer whose value is less than the given hash size. diff --git a/python/paddle/fluid/tests/unittests/test_similarity_focus_op.py b/python/paddle/fluid/tests/unittests/test_similarity_focus_op.py new file mode 100755 index 0000000000000000000000000000000000000000..b3833f05f1aa3aac7b5bcc5b6fdc138870cc8844 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_similarity_focus_op.py @@ -0,0 +1,217 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle.fluid.core as core +from op_test import OpTest + + +class TestSimilarityFocusOp(OpTest): + def setUp(self): + self.op_type = "similarity_focus" + batch_size = 2 + x_dim, y_dim, z_dim = 3, 2, 2 + self.inputs = { + 'X': np.array([[[[0.8, 0.1], [0.4, 0.5]], [[0.9, 0.7], [0.9, 0.9]], + [[0.8, 0.9], [0.1, 0.2]]], + [[[0.2, 0.5], [0.3, 0.4]], [[0.9, 0.7], [0.8, 0.4]], + [[0.0, 0.2], [0.4, 0.7]]]]), + } + self.attrs = { + 'axis': 1, + 'indexes': [0], + } + + output = None + for batch in range(batch_size): + res = np.zeros((1, y_dim, z_dim)).astype("float32").reshape(-1) + for index in self.attrs['indexes']: + channel = self.inputs['X'][batch, index, :, :].reshape(-1).copy( + ) + tag1 = [0 for i in range(y_dim)] + tag2 = [0 for i in range(z_dim)] + cnt = 0 + for i in range(channel.size): + index = channel.argmax() + idx1 = index // z_dim + idx2 = index % z_dim + if tag1[idx1] + tag2[idx2] == 0: + tag1[idx1] = 1 + tag2[idx2] = 1 + res[index] = 1 + cnt += 1 + if cnt == min(y_dim, z_dim): + break + channel[index] = -1 + res = res.reshape(1, y_dim, z_dim).repeat([x_dim], axis=0) + res = res.reshape(1, x_dim, y_dim, z_dim) + if output is not None: + output = np.concatenate((output, res), axis=0) + else: + output = res + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + +class TestSimilarityFocusOp_axis1(OpTest): + def setUp(self): + self.op_type = "similarity_focus" + batch_size = 3 + x_dim, y_dim, z_dim = 4, 5, 6 + self.inputs = { + 'X': np.random.random( + (batch_size, x_dim, y_dim, z_dim)).astype("float32"), + } + self.attrs = { + 'axis': 1, + 'indexes': [0, 3], + } + + output = None + for batch in range(batch_size): + res = np.zeros((1, y_dim, z_dim)).astype("float32").reshape(-1) + for index in self.attrs['indexes']: + channel = self.inputs['X'][batch, index, :, :].reshape(-1).copy( + ) + tag1 = [0 for i in range(y_dim)] + tag2 = [0 for i in range(z_dim)] + cnt = 0 + for i in range(channel.size): + index = channel.argmax() + idx1 = index // z_dim + idx2 = index % z_dim + if tag1[idx1] + tag2[idx2] == 0: + tag1[idx1] = 1 + tag2[idx2] = 1 + res[index] = 1 + cnt += 1 + if cnt == min(y_dim, z_dim): + break + channel[index] = -1 + res = res.reshape(1, y_dim, z_dim) + res = res.repeat([x_dim], axis=0) + res = res.reshape(1, x_dim, y_dim, z_dim) + if output is not None: + output = np.concatenate((output, res), axis=0) + else: + output = res + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + +class TestSimilarityFocusOp_axis2(OpTest): + def setUp(self): + self.op_type = "similarity_focus" + batch_size = 6 + x_dim, y_dim, z_dim = 7, 8, 9 + self.inputs = { + 'X': np.random.random( + (batch_size, x_dim, y_dim, z_dim)).astype("float32"), + } + self.attrs = { + 'axis': 2, + 'indexes': [0, 3, 5], + } + + output = None + for batch in range(batch_size): + res = np.zeros((x_dim, 1, z_dim)).astype("float32").reshape(-1) + for index in self.attrs['indexes']: + channel = self.inputs['X'][batch, :, index, :].reshape(-1).copy( + ) + tag1 = [0 for i in range(x_dim)] + tag2 = [0 for i in range(z_dim)] + cnt = 0 + for i in range(channel.size): + index = channel.argmax() + idx1 = index // z_dim + idx2 = index % z_dim + if tag1[idx1] + tag2[idx2] == 0: + tag1[idx1] = 1 + tag2[idx2] = 1 + res[index] = 1 + cnt += 1 + if cnt == min(x_dim, z_dim): + break + channel[index] = -1 + res = res.reshape(x_dim, 1, z_dim) + res = res.repeat([y_dim], axis=1) + res = res.reshape(1, x_dim, y_dim, z_dim) + if output is not None: + output = np.concatenate((output, res), axis=0) + else: + output = res + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + +class TestSimilarityFocusOp_axis3(OpTest): + def setUp(self): + self.op_type = "similarity_focus" + batch_size = 64 + x_dim, y_dim, z_dim = 48, 48, 13 + self.inputs = { + 'X': np.random.random( + (batch_size, x_dim, y_dim, z_dim)).astype("float32"), + } + self.attrs = { + 'axis': 3, + 'indexes': [0, 2, 7, 9], + } + + output = None + for batch in range(batch_size): + res = np.zeros((x_dim, y_dim, 1)).astype("float32").reshape(-1) + for index in self.attrs['indexes']: + channel = self.inputs['X'][batch, :, :, index].reshape(-1).copy( + ) + tag1 = [0 for i in range(x_dim)] + tag2 = [0 for i in range(y_dim)] + cnt = 0 + for i in range(channel.size): + index = channel.argmax() + idx1 = index // y_dim + idx2 = index % y_dim + if tag1[idx1] + tag2[idx2] == 0: + tag1[idx1] = 1 + tag2[idx2] = 1 + res[index] = 1 + cnt += 1 + if cnt == min(x_dim, y_dim): + break + channel[index] = -1 + res = res.reshape(x_dim, y_dim, 1) + res = res.repeat([z_dim], axis=2) + res = res.reshape(1, x_dim, y_dim, z_dim) + if output is not None: + output = np.concatenate((output, res), axis=0) + else: + output = res + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main()