提交 a7f94ec7 编写于 作者: B barrierye

add similarity_focus op

上级 d0fdcb2f
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/similarity_focus_op.h"
namespace paddle {
namespace operators {
class SimilarityFocusOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>), a 4-D tensor with shape,"
" [BatchSize, X, Y, Z]");
AddOutput("Out",
"(Tensor, default Tensor<float>), the similarity focus mask"
" with the same shape of input X.");
AddAttr<int>("axis",
"(int32), indicating the dimension to be select. It can"
" only be 1, 2, or 3.");
AddAttr<std::vector<int>>("indexes",
"(std::vector<int32>), indicating the indexes"
" of the selected dimension.");
AddComment(R"DOC(
SimilarityFocus Operator.
Generate a similarity focus mask with the same shape of input using the following method:
1. Extract the 3-D matrix(here the first dimension is BatchSize) corresponding
to the axis according to the indexes. For example, if axis=1 and indexes=[a],
it will get the matrix T=X[:, a, :, :]. In this casr, if the shape of input X
is (BatchSize, A, B, C), the shape of matrix T is (BatchSize, B, C).
2. For each index, find the largest numbers in the matrix T, so that the same
row and same column has at most one number(obviously there will be min(B, C)
numbers), and mark the corresponding position of the 3-D similarity focus mask
as 1, otherwise as 0. Do elementwise-or for each index.
3. Broadcast the 3-D similarity focus mask to the same shape of input X.
Refer to `Similarity Focus Layer <http://www.aclweb.org/anthology/N16-1108>`_
)DOC");
}
};
class SimilarityFocusOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should be not null.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 4, "Input(X)'s rank should be 4.");
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
platform::CPUPlace());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(similarity_focus, ops::SimilarityFocusOp,
ops::SimilarityFocusOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(similarity_focus, ops::SimilarityFocusKernel<float>,
ops::SimilarityFocusKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <cstring>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class SimilarityFocusKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
Tensor* out = context.Output<Tensor>("Out");
const Tensor* x = context.Input<Tensor>("X");
T* out_data = out->mutable_data<T>(context.GetPlace());
const T* x_data = x->data<T>();
int axis = context.Attr<int>("axis");
std::vector<int> indexes = context.Attr<std::vector<int>>("indexes");
int64_t batch_size = x->dims()[0];
int64_t dim[4];
for (int i = 1; i <= 3; ++i) {
dim[i] = x->dims()[i];
}
if (indexes.size() < 1) {
PADDLE_THROW("Indexes' size can not be 0.");
}
for (auto index : indexes) {
if (dim[axis] < index) {
PADDLE_THROW("Index exceeds tensor shape limit.");
}
}
int64_t array_size = 1;
for (int i = 1; i <= 3; ++i) {
if (i != axis) {
array_size *= dim[i];
}
}
std::vector<std::pair<T, int64_t>> array(array_size);
bool (*cmp)(std::pair<T, int64_t>, std::pair<T, int64_t>) = [](
std::pair<T, int64_t> x, std::pair<T, int64_t> y) {
return x.first > y.first;
};
int64_t (*compute_index)(int64_t*, int, int, int, int) = [](
int64_t* dim, int d1, int d2, int d3, int d4) {
return d1 * dim[1] * dim[2] * dim[3] + d2 * dim[2] * dim[3] +
d3 * dim[3] + d4;
};
memset(out_data, 0, sizeof(T) * batch_size * dim[1] * dim[2] * dim[3]);
for (int i = 0; i < batch_size; ++i) {
for (auto index : indexes) {
if (axis == 1) {
for (int j = 0; j < dim[2]; ++j) {
for (int k = 0; k < dim[3]; ++k) {
array[j * dim[3] + k] = std::make_pair(
x_data[compute_index(dim, i, index, j, k)], j * dim[3] + k);
}
}
std::sort(array.begin(), array.end(), cmp);
int tag_num = 0;
std::vector<bool> tag2(dim[2]), tag3(dim[3]);
for (auto x : array) {
int idx2 = x.second / dim[3];
int idx3 = x.second % dim[3];
if (tag2[idx2] || tag3[idx3]) {
continue;
}
tag_num++;
tag2[idx2] = true;
tag3[idx3] = true;
for (int j = 0; j < dim[1]; ++j) {
out_data[compute_index(dim, i, j, idx2, idx3)] = 1;
}
if (tag_num == std::min(dim[2], dim[3])) {
break;
}
}
} else if (axis == 2) {
for (int j = 0; j < dim[1]; ++j) {
for (int k = 0; k < dim[3]; ++k) {
array[j * dim[3] + k] = std::make_pair(
x_data[compute_index(dim, i, j, index, k)], j * dim[3] + k);
}
}
std::sort(array.begin(), array.end(), cmp);
int tag_num = 0;
std::vector<bool> tag1(dim[1]), tag3(dim[3]);
for (auto x : array) {
int idx1 = x.second / dim[3];
int idx3 = x.second % dim[3];
if (tag1[idx1] || tag3[idx3]) {
continue;
}
tag_num++;
tag1[idx1] = true;
tag3[idx3] = true;
for (int j = 0; j < dim[2]; ++j) {
out_data[compute_index(dim, i, idx1, j, idx3)] = 1;
}
if (tag_num == std::min(dim[1], dim[3])) {
break;
}
}
} else if (axis == 3) {
for (int j = 0; j < dim[1]; ++j) {
for (int k = 0; k < dim[2]; ++k) {
array[j * dim[2] + k] = std::make_pair(
x_data[compute_index(dim, i, j, k, index)], j * dim[2] + k);
}
}
std::sort(array.begin(), array.end(), cmp);
int tag_num = 0;
std::vector<bool> tag1(dim[1]), tag2(dim[2]);
for (auto x : array) {
int idx1 = x.second / dim[2];
int idx2 = x.second % dim[2];
if (tag1[idx1] || tag2[idx2]) {
continue;
}
tag_num++;
tag1[idx1] = true;
tag2[idx2] = true;
for (int j = 0; j < dim[3]; ++j) {
out_data[compute_index(dim, i, idx1, idx2, j)] = 1;
}
if (tag_num == std::min(dim[1], dim[2])) {
break;
}
}
} else {
PADDLE_THROW("Axis must be 1 or 2 or 3");
}
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -155,6 +155,7 @@ __all__ = [
'sigmoid_cross_entropy_with_logits',
'maxout',
'affine_channel',
'similarity_focus',
]
......@@ -7494,3 +7495,58 @@ def affine_channel(x, scale=None, bias=None, data_layout='NCHW', name=None):
attrs={"data_layout": data_layout},
outputs={"Out": out})
return out
def similarity_focus(input, axis, indexes, name=None):
"""
**SimilarityFocus Operator**
Generate a similarity focus mask with the same shape of input using the following method:
1. Extract the 3-D matrix(here the first dimension is BatchSize) corresponding
to the axis according to the indexes. For example, if axis=1 and indexes=[a],
it will get the matrix T=X[:, a, :, :]. In this casr, if the shape of input X
is (BatchSize, A, B, C), the shape of matrix T is (BatchSize, B, C).
2. For each index, find the largest numbers in the matrix T, so that the same
row and same column has at most one number(obviously there will be min(B, C)
numbers), and mark the corresponding position of the 3-D similarity focus mask
as 1, otherwise as 0. Do elementwise-or for each index.
3. Broadcast the 3-D similarity focus mask to the same shape of input X.
Refer to `Similarity Focus Layer <http://www.aclweb.org/anthology/N16-1108>`_
Args:
input(Variable): The input tensor variable(default float). It should
be a 4-D tensor with shape [BatchSize, A, B, C].
axis(int): Indicating the dimension to be select. It can only be
1, 2, or 3.
indexes(list): indicating the indexes of the selected dimension.
Returns:
Variable: A tensor variable with the same shape and same type
as the input.
Examples:
.. code-block:: python
data = fluid.layers.data(
name='data', shape=[128, 13, 48, 48], dtype='float32')
x = fluid.layers.layer_norm(input=data, axis=1, indexes=[9, 10])
"""
helper = LayerHelper('similarity_focus', **locals())
# check attrs
if isinstance(axis, int) is False:
raise TypeError("axis must be int type.")
if isinstance(indexes, list) is False:
raise TypeError("indexes must be list type.")
if axis != 1 and axis != 2 and axis != 3:
raise ValueError("axis must be 1, 2 or 3.")
if len(indexes) == 0:
raise ValueError("indexes can not be empty.")
out = helper.create_tmp_variable(dtype=helper.input_dtype())
helper.append_op(
type='similarity_focus',
inputs={'X': input},
outputs={'Out': out},
attrs={"axis": axis,
"indexes": indexes})
return out
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
class TestSimilarityFocusOp_axis1(OpTest):
def setUp(self):
self.op_type = "similarity_focus"
batch_size = 3
x_dim, y_dim, z_dim = 4, 5, 6
self.inputs = {
'X': np.random.random(
(batch_size, x_dim, y_dim, z_dim)).astype("float32"),
}
self.attrs = {
'axis': 1,
'indexes': [0, 3],
}
output = None
for batch in range(batch_size):
res = np.zeros((1, y_dim, z_dim)).astype("float32").reshape(-1)
for index in self.attrs['indexes']:
channel = self.inputs['X'][batch, index, :, :].reshape(-1).copy(
)
tag1 = [0 for i in range(y_dim)]
tag2 = [0 for i in range(z_dim)]
cnt = 0
for i in range(channel.size):
index = channel.argmax()
idx1 = index / z_dim
idx2 = index % z_dim
if tag1[idx1] + tag2[idx2] == 0:
tag1[idx1] = 1
tag2[idx2] = 1
res[index] = 1
cnt += 1
if cnt == min(y_dim, z_dim):
break
channel[index] = -1
res = res.reshape(1, y_dim, z_dim)
res = res.repeat([x_dim], axis=0)
res = res.reshape(1, x_dim, y_dim, z_dim)
if output is not None:
output = np.concatenate((output, res), axis=0)
else:
output = res
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
class TestSimilarityFocusOp_axis2(OpTest):
def setUp(self):
self.op_type = "similarity_focus"
batch_size = 6
x_dim, y_dim, z_dim = 7, 8, 9
self.inputs = {
'X': np.random.random(
(batch_size, x_dim, y_dim, z_dim)).astype("float32"),
}
self.attrs = {
'axis': 2,
'indexes': [0, 3, 5],
}
output = None
for batch in range(batch_size):
res = np.zeros((x_dim, 1, z_dim)).astype("float32").reshape(-1)
for index in self.attrs['indexes']:
channel = self.inputs['X'][batch, :, index, :].reshape(-1).copy(
)
tag1 = [0 for i in range(x_dim)]
tag2 = [0 for i in range(z_dim)]
cnt = 0
for i in range(channel.size):
index = channel.argmax()
idx1 = index / z_dim
idx2 = index % z_dim
if tag1[idx1] + tag2[idx2] == 0:
tag1[idx1] = 1
tag2[idx2] = 1
res[index] = 1
cnt += 1
if cnt == min(x_dim, z_dim):
break
channel[index] = -1
res = res.reshape(x_dim, 1, z_dim)
res = res.repeat([y_dim], axis=1)
res = res.reshape(1, x_dim, y_dim, z_dim)
if output is not None:
output = np.concatenate((output, res), axis=0)
else:
output = res
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
class TestSimilarityFocusOp_axis3(OpTest):
def setUp(self):
self.op_type = "similarity_focus"
batch_size = 64
x_dim, y_dim, z_dim = 48, 48, 13
self.inputs = {
'X': np.random.random(
(batch_size, x_dim, y_dim, z_dim)).astype("float32"),
}
self.attrs = {
'axis': 3,
'indexes': [0, 2, 7, 9],
}
output = None
for batch in range(batch_size):
res = np.zeros((x_dim, y_dim, 1)).astype("float32").reshape(-1)
for index in self.attrs['indexes']:
channel = self.inputs['X'][batch, :, :, index].reshape(-1).copy(
)
tag1 = [0 for i in range(x_dim)]
tag2 = [0 for i in range(y_dim)]
cnt = 0
for i in range(channel.size):
index = channel.argmax()
idx1 = index / y_dim
idx2 = index % y_dim
if tag1[idx1] + tag2[idx2] == 0:
tag1[idx1] = 1
tag2[idx2] = 1
res[index] = 1
cnt += 1
if cnt == min(x_dim, y_dim):
break
channel[index] = -1
res = res.reshape(x_dim, y_dim, 1)
res = res.repeat([z_dim], axis=2)
res = res.reshape(1, x_dim, y_dim, z_dim)
if output is not None:
output = np.concatenate((output, res), axis=0)
else:
output = res
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册