未验证 提交 5b5379b3 编写于 作者: A Aurelius84 提交者: GitHub

Add sequence_topk_avg_pooling Op (#19442)

* add topk_avg_pooling

* refine api doc and modify api.spec test=develop
上级 1cdd3b69
......@@ -256,6 +256,7 @@ paddle.fluid.layers.maxout (ArgSpec(args=['x', 'groups', 'name'], varargs=None,
paddle.fluid.layers.space_to_depth (ArgSpec(args=['x', 'blocksize', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '26decdea9376b6b9a0d3432d82ca207b'))
paddle.fluid.layers.affine_grid (ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'f85b263b7b6698d000977529a28f202b'))
paddle.fluid.layers.sequence_reverse (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '65c8362e48810b8226e311c5d046db51'))
paddle.fluid.layers.sequence_topk_avg_pooling (ArgSpec(args=['input', 'row', 'col', 'topks', 'channel_num'], varargs=None, keywords=None, defaults=None), ('document', '1cee1bbbba8b567ae50509a38d9ec42a'))
paddle.fluid.layers.affine_channel (ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name', 'act'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None, None)), ('document', '9f303c67538e468a36c5904a0a3aa110'))
paddle.fluid.layers.similarity_focus (ArgSpec(args=['input', 'axis', 'indexes', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '18ec2e3afeb90e70c8b73d2b71c40fdb'))
paddle.fluid.layers.hash (ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None)), ('document', 'a0b73c21be618cec0281e7903039e5e3'))
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sequence_ops/sequence_topk_avg_pooling_op.h"
#include <memory>
#include <string>
namespace paddle {
namespace operators {
class SequenceTopkAvgPoolingOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of SequencePoolOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("ROW"), true,
"Input(ROW) of SequencePoolOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("COLUMN"), true,
"Input(COLUMN) of SequencePoolOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of SequencePoolOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("pos"), true,
"pos(out) should not be null");
auto attr = ctx->Attrs();
auto channel_num = attr.Get<int>("channel_num");
auto topks = attr.Get<std::vector<int>>("topks");
auto row_dim = ctx->GetInputDim("ROW");
auto num_k = topks.size();
auto row_shape_0 = row_dim[0];
std::vector<int> vec_out_shape;
vec_out_shape.push_back(row_shape_0);
vec_out_shape.push_back(channel_num * num_k);
ctx->SetOutputDim("Out", framework::make_ddim(vec_out_shape));
ctx->ShareLoD("X", "Out");
}
};
class SequenceTopkAvgPoolingOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(LoDTensor) The variable-length input of SequenceTopkPoolingOp");
AddInput("ROW", "(LoDTensor) the row info");
AddInput("COLUMN", "(LoDTensor) the column info");
AddOutput(
"Out",
"(Tensor) The output of SequenceTopkPoolingOp does not contain LoD "
"infomation.");
AddOutput("pos", "(Tensor<int>) store the topk index ").AsIntermediate();
AddAttr<std::vector<int>>("topks", "topks");
AddAttr<int>("channel_num", "channel number");
AddComment(R"DOC(
sequecen topk average pooling op
)DOC");
}
};
class SequenceTopkAvgPoolingGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
"Gradient of Out should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"The input X should not be null.");
ctx->ShareDim("X", /*->*/ framework::GradVarName("X"));
ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X"));
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class SequenceTopkAvgPoolGradOpMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* op_desc_ptr = new framework::OpDesc();
op_desc_ptr->SetType("sequence_topk_avg_pooling_grad");
op_desc_ptr->SetInput("X", Input("X"));
op_desc_ptr->SetInput("ROW", Input("ROW"));
op_desc_ptr->SetInput("COLUMN", Input("COLUMN"));
op_desc_ptr->SetInput("pos", Output("pos"));
op_desc_ptr->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op_desc_ptr->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op_desc_ptr->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(op_desc_ptr);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(sequence_topk_avg_pooling, ops::SequenceTopkAvgPoolingOp,
ops::SequenceTopkAvgPoolingOpMaker,
ops::SequenceTopkAvgPoolGradOpMaker);
REGISTER_OPERATOR(sequence_topk_avg_pooling_grad,
ops::SequenceTopkAvgPoolingGradOp);
REGISTER_OP_CPU_KERNEL(sequence_topk_avg_pooling,
ops::SequenceTopkAvgPoolingKernel<
paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(sequence_topk_avg_pooling_grad,
ops::SequenceTopkAvgPoolingGradKernel<
paddle::platform::CPUDeviceContext, float>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <limits>
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
template <typename T>
void get_topk_pos(const T* data, int length, int k, int* pos) {
size_t real_k = k < length ? k : length;
std::vector<T> v(data, data + length);
std::vector<int> topk_pos;
T min_val = std::numeric_limits<T>::lowest();
while (topk_pos.size() < real_k) {
T max_val = min_val;
int max_pos = -1;
for (int i = 0; i < length; ++i) {
if (v[i] > max_val) {
max_pos = i;
max_val = v[i];
}
}
assert(max_pos >= 0);
topk_pos.push_back(max_pos);
v[max_pos] = min_val;
}
assert(topk_pos.size() > 0);
while (topk_pos.size() < (size_t)k) {
topk_pos.push_back(-1);
}
for (size_t i = 0; i < topk_pos.size(); ++i) {
pos[i] = topk_pos[i];
}
}
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename DeviceContext, typename T>
class SequenceTopkAvgPoolingKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* row = context.Input<LoDTensor>("ROW");
auto* col = context.Input<LoDTensor>("COLUMN");
auto* out = context.Output<LoDTensor>("Out");
auto* pos = context.Output<Tensor>("pos");
auto channel_num = context.Attr<int>("channel_num");
auto topks = context.Attr<std::vector<int>>("topks");
auto k_num = topks.size();
auto max_k = topks[topks.size() - 1];
std::vector<int> vec_pos_shape;
auto in_lod = in->lod()[0];
auto row_lod = row->lod()[0];
auto col_lod = col->lod()[0];
int batch_size = row_lod.size() - 1;
int pos_total_size = row_lod[batch_size] * channel_num * max_k;
vec_pos_shape.push_back(pos_total_size);
pos->Resize({framework::make_ddim(vec_pos_shape)});
auto pos_data = pos->mutable_data<int>(context.GetPlace());
int offset = 0;
framework::Vector<size_t> vec_out_lod;
vec_out_lod.reserve(batch_size + 1);
for (int i = 0; i <= batch_size; ++i) {
offset = row_lod[i];
vec_out_lod.push_back(offset);
}
framework::LoD lod_temp;
lod_temp.push_back(vec_out_lod);
out->set_lod(lod_temp);
auto din_data = in->data<T>();
auto dout_data = out->mutable_data<T>(context.GetPlace());
T* sum_data = new T[max_k];
for (int i = 0; i < batch_size; ++i) {
int total_size = in_lod[i + 1] - in_lod[i];
int row_size = row_lod[i + 1] - row_lod[i];
int col_size = col_lod[i + 1] - col_lod[i];
PADDLE_ENFORCE_EQ(total_size, channel_num * row_size * col_size,
"size wrong in sequence_topk_avg_pooling_op!");
int feature_num = row_size * col_size;
for (int j = 0; j < channel_num; ++j) {
auto input_offset_feature_data = din_data + in_lod[i] + j * feature_num;
for (int r = 0; r < row_size; ++r) {
auto row_data = input_offset_feature_data + r * col_size;
auto pos_slice_data = pos_data + row_lod[i] * channel_num * max_k +
r * channel_num * max_k + j * max_k;
auto out_slice_data = dout_data + row_lod[i] * channel_num * k_num +
r * channel_num * k_num + j * k_num;
get_topk_pos<T>(row_data, col_size, max_k, pos_slice_data);
if (pos_slice_data[0] == -1) {
sum_data[0] = 0.0;
} else {
sum_data[0] = row_data[pos_slice_data[0]];
}
for (int k = 1; k < max_k; ++k) {
if (pos_slice_data[k] == -1) {
sum_data[k] = sum_data[k - 1];
} else {
sum_data[k] = sum_data[k - 1] + row_data[pos_slice_data[k]];
}
}
for (size_t k = 0; k < k_num; ++k) {
out_slice_data[k] = sum_data[topks[k] - 1] / topks[k];
}
}
}
}
delete[] sum_data;
}
};
template <typename DeviceContext, typename T>
class SequenceTopkAvgPoolingGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* d_out = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_in = context.Output<LoDTensor>(framework::GradVarName("X"));
auto* pos_input = context.Input<Tensor>("pos");
auto* row_input = context.Input<LoDTensor>("ROW");
auto* col_input = context.Input<LoDTensor>("COLUMN");
auto* forward_input = context.Input<LoDTensor>("X");
int batch_size = row_input->lod()[0].size() - 1;
auto channel_num = context.Attr<int>("channel_num");
auto topks = context.Attr<std::vector<int>>("topks");
auto k_num = topks.size();
auto max_k = topks[k_num - 1];
auto out_lod = forward_input->lod();
d_in->set_lod(out_lod);
d_in->mutable_data<T>(context.GetPlace());
auto pos_data = pos_input->data<int>();
auto dout_data = d_out->data<T>();
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
math::SetConstant<paddle::platform::CPUDeviceContext, T> zero;
zero(dev_ctx, d_in, static_cast<T>(0.0));
auto din_data = d_in->data<T>();
auto out_offset = out_lod[0];
auto row_lod = row_input->lod()[0];
auto col_lod = col_input->lod()[0];
for (int i = 0; i < batch_size; ++i) {
int row_size = row_lod[i + 1] - row_lod[i];
int col_size = col_lod[i + 1] - col_lod[i];
int feature_num = row_size * col_size;
for (int j = 0; j < channel_num; ++j) {
auto in_offset_feature_data =
din_data + out_offset[i] + j * feature_num;
for (int r = 0; r < row_size; r++) {
auto row_data = dout_data + row_lod[i] * channel_num * k_num +
r * channel_num * k_num + j * k_num;
auto pos_slice_data = pos_data + row_lod[i] * channel_num * max_k +
r * channel_num * max_k + j * max_k;
auto in_slice_data = in_offset_feature_data + r * col_size;
for (size_t m = 0; m < k_num; ++m) {
for (int k = 0; k < topks[m]; ++k) {
if (pos_slice_data[k] == -1) {
break;
} else {
in_slice_data[pos_slice_data[k]] += row_data[m] / topks[m];
}
}
}
}
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -184,6 +184,7 @@ __all__ = [
'space_to_depth',
'affine_grid',
'sequence_reverse',
'sequence_topk_avg_pooling',
'affine_channel',
'similarity_focus',
'hash',
......@@ -11133,6 +11134,73 @@ def sequence_reverse(x, name=None):
return out
def sequence_topk_avg_pooling(input, row, col, topks, channel_num):
"""
The :attr:`topks` is a list with incremental values in this function. For each topk,
it will average the topk features as an output feature for each channel of every
input sequence. Both :attr:`row` and :attr:`col` are LodTensor, which provide height
and width information for :attr:`input` tensor. If feature size of input sequence is less
than topk, it will padding 0 at the back.
.. code-block:: text
If channel_num is 2 and given row LoDTensor and col LoDTensor as follows:
row.lod = [[5, 4]]
col.lod = [[6, 7]]
input is a LoDTensor with input.lod[0][i] = channel_num * row.lod[0][i] * col.lod[0][i]
input.lod = [[60, 56]] # where 60 = channel_num * 5 * 6
input.dims = [116, 1] # where 116 = 60 + 56
If topks is [1, 3, 5], then we get a 1-level LoDTensor:
out.lod = [[5, 4]] # share Lod info with row LodTensor
out.dims = [9, 6] # where 6 = len(topks) * channel_num
Args:
input (Variable): The input should be 2D LodTensor with dims[1] equals 1.
row (Variable): The row shoud be 1-level LodTensor to provide the height information
of the input tensor data.
col (Variable): The col shoud be 1-level LodTensor to provide the width information
of the input tensor data.
topks (list): A list of incremental value to average the topk feature.
channel_num (int): The number of input channel.
Returns:
Variable: output LodTensor specified by this layer.
Examples:
.. code-block:: python
import numpy as np
from paddle.fluid import layers
x_lod_tensor = layers.data(name='x', shape=[1], lod_level=1)
row_lod_tensor = layers.data(name='row', shape=[6], lod_level=1)
col_lod_tensor = layers.data(name='col', shape=[6], lod_level=1)
out = layers.sequence_topk_avg_pooling(input=x_lod_tensor,
row=row_lod_tensor,
col=col_lod_tensor,
topks=[1, 3, 5],
channel_num=5)
"""
helper = LayerHelper('sequence_topk_avg_pooling', **locals())
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
pos = helper.create_variable_for_type_inference(
dtype=helper.input_dtype(), stop_gradient=True)
helper.append_op(
type='sequence_topk_avg_pooling',
inputs={'X': input,
'ROW': row,
'COLUMN': col},
outputs={'Out': out,
'pos': pos},
attrs={'topks': topks,
'channel_num': channel_num})
return out
def affine_channel(x,
scale=None,
bias=None,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
from copy import deepcopy
class TestSequenceTopkAvgPoolingOp(OpTest):
def setUp(self):
self.init_op_type()
self.set_data()
self.compute()
def init_op_type(self):
self.op_type = "sequence_topk_avg_pooling"
def set_data(self):
topks = [2]
channel_num = 3
dim = 10
row = [2, 4]
col = [3, 2]
self.init_data(topks, channel_num, row, col, dim)
def init_data(self, topks, channel_num, row, col, dim=10):
self.attrs = {"topks": topks, "channel_num": channel_num}
feature = [row[i] * col[i] for i in range(len(row))]
numel = sum(feature) * channel_num
x_data = np.random.random((numel, )).astype('float32')
x_lod = [[x * channel_num for x in feature]]
row_data = np.random.random((sum(row), dim)).astype('float32')
col_data = np.random.random((sum(col), dim)).astype('float32')
self.inputs = {
'X': (x_data, x_lod),
'ROW': (row_data, [row]),
'COLUMN': (col_data, [col])
}
def compute(self):
topks = self.attrs['topks']
max_k = topks[-1]
x_data, x_lod = self.inputs['X']
row_data, row_lod = self.inputs['ROW']
col_data, col_lod = self.inputs['COLUMN']
channel_num = self.attrs['channel_num']
out = np.zeros((0, len(topks) * channel_num), dtype=x_data.dtype)
pos = np.zeros((0, ), dtype='int32')
out_lod = deepcopy(row_lod)
offset = 0
for idx in range(len(x_lod[0])):
x_len = x_lod[0][idx]
self.assertTrue(
x_len == channel_num * row_lod[0][idx] * col_lod[0][idx],
"x_len: %s can't mod channel_num: %s" % (x_len, channel_num))
# feature = x_len / channel_num
out_tmp = np.zeros((0, ), dtype=x_data.dtype)
pos_tmp = np.zeros((0, ), dtype='int32')
for ch in range(channel_num):
for r_id in range(row_lod[0][idx]):
x_sub = x_data[offset:(offset + col_lod[0][idx])]
topk_val, topk_pos = self.get_topk(x_sub, max_k)
sum_data = self.topk_sum(topk_val, topk_pos, max_k)
new_feature = np.array(
[sum_data[topk] / topk for topk in topks])
out_tmp = np.hstack((out_tmp, new_feature))
pos_tmp = np.hstack((pos_tmp, topk_pos))
offset += col_lod[0][idx]
out_tmp = out_tmp.reshape([channel_num, -1, len(topks)]).transpose(
1, 0, 2)
pos_tmp = pos_tmp.reshape([channel_num, -1, max_k]).transpose(1, 0,
2)
out = np.vstack(
(out, out_tmp.reshape([-1, len(topks) * channel_num])))
pos = np.hstack((pos, pos_tmp.flatten()))
self.outputs = {'Out': (out.astype('float32'), out_lod), 'pos': pos}
def get_topk(self, x, topk):
real_topk = topk if topk < len(x) else len(x)
topk_pos = np.array(x).argsort()[-topk:][::-1]
topk_val = np.array(x)[topk_pos]
if real_topk < topk:
topk_pos = np.hstack((topk_pos, np.full((topk - real_topk, ), -1)))
topk_val = np.hstack((topk_val, np.full((topk - real_topk, ), 0.0)))
return topk_val, topk_pos
def topk_sum(self, x, pos, max_k):
sum_data = [0.] * (max_k + 1)
for i in range(1, max_k + 1):
if pos[i - 1] == -1:
sum_data[i] = sum_data[i - 1]
else:
sum_data[i] = sum_data[i - 1] + x[i - 1]
return sum_data
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', max_relative_error=0.005)
class TestSequenceTopkAvgPoolingOpCase1(TestSequenceTopkAvgPoolingOp):
def set_data(self):
topks = [2, 3]
channel_num = 3
dim = 10
row = [3]
col = [4]
self.init_data(topks, channel_num, row, col, dim)
def test_api(self):
import paddle.fluid as fluid
x = fluid.layers.data(name='x', shape=[1], lod_level=1)
row = fluid.layers.data(name='row', shape=[10], lod_level=1)
col = fluid.layers.data(name='col', shape=[10], lod_level=1)
topk_avg = fluid.layers.sequence_topk_avg_pooling(
input=x, row=row, col=col, topks=[1, 3, 5], channel_num=5)
place = fluid.CPUPlace()
x_tensor = fluid.create_lod_tensor(
np.random.rand(45, 1).astype('float32'), [[30, 15]], place)
row_tensor = fluid.create_lod_tensor(
np.random.rand(5, 10).astype('float32'), [[2, 3]], place)
col_tensor = fluid.create_lod_tensor(
np.random.rand(4, 10).astype('float32'), [[3, 1]], place)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
ret = exe.run(
feed={'x': x_tensor,
'row': row_tensor,
'col': col_tensor},
fetch_list=[topk_avg],
return_numpy=False)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册