未验证 提交 4717bdbc 编写于 作者: A Aurelius84 提交者: GitHub

Fix hang in seq_topk_avg_pooling op (#25522)

* fix topk_avg_pool hang test=develop

* refactor get_topk_pos test=develop

* add check of channel_num and num_k test=develop

* add TopKPosPaddingId test=develop
上级 b796d8fa
......@@ -24,24 +24,30 @@ class SequenceTopkAvgPoolingOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of SequencePoolOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("ROW"), true,
"Input(ROW) of SequencePoolOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("COLUMN"), true,
"Input(COLUMN) of SequencePoolOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of SequencePoolOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("pos"), true,
"pos(out) should not be null");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SequenceTopkAvgPooling");
OP_INOUT_CHECK(ctx->HasInput("ROW"), "Input", "ROW",
"SequenceTopkAvgPooling");
OP_INOUT_CHECK(ctx->HasInput("COLUMN"), "Input", "COLUMN",
"SequenceTopkAvgPooling");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
"SequenceTopkAvgPooling");
OP_INOUT_CHECK(ctx->HasOutput("pos"), "Output", "pos",
"SequenceTopkAvgPooling");
auto attr = ctx->Attrs();
auto channel_num = attr.Get<int>("channel_num");
PADDLE_ENFORCE_GT(
channel_num, 0,
platform::errors::InvalidArgument(
"Expected channel_num > 0, but received %d.", channel_num));
auto topks = attr.Get<std::vector<int>>("topks");
auto num_k = topks.size();
PADDLE_ENFORCE_GT(
num_k, 0, platform::errors::InvalidArgument(
"Expected topks.size() > 0, but received %zu.", num_k));
auto row_dim = ctx->GetInputDim("ROW");
auto num_k = topks.size();
auto row_shape_0 = row_dim[0];
std::vector<int> vec_out_shape;
......@@ -49,7 +55,7 @@ class SequenceTopkAvgPoolingOp : public framework::OperatorWithKernel {
vec_out_shape.push_back(channel_num * num_k);
ctx->SetOutputDim("Out", framework::make_ddim(vec_out_shape));
ctx->ShareLoD("X", "Out");
ctx->ShareLoD("ROW", "Out");
}
};
......@@ -78,10 +84,10 @@ class SequenceTopkAvgPoolingGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
"Gradient of Out should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"The input X should not be null.");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "SequenceTopkAvgPoolingGrad");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
"SequenceTopkAvgPoolingGrad");
ctx->ShareDim("X", /*->*/ framework::GradVarName("X"));
ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
......
......@@ -13,52 +13,57 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <functional>
#include <limits>
#include <queue>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
static constexpr int TopKPosPaddingId = -1;
namespace details {
template <typename T>
void get_topk_pos(const T* data, int length, int k, int* pos) {
size_t real_k = k < length ? k : length;
std::vector<T> v(data, data + length);
std::vector<int> topk_pos;
T min_val = std::numeric_limits<T>::lowest();
while (topk_pos.size() < real_k) {
T max_val = min_val;
int max_pos = -1;
for (int i = 0; i < length; ++i) {
if (v[i] > max_val) {
max_pos = i;
max_val = v[i];
static void get_topk_pos(const T* data, int length, int k, int* pos) {
VLOG(3) << "length: " << length << " , k : " << k;
std::priority_queue<std::pair<T, int>, std::vector<std::pair<T, int>>,
std::greater<std::pair<T, int>>>
topk_queue;
for (int i = 0; i < length; ++i) {
T elem = data[i];
if (topk_queue.size() < static_cast<size_t>(k)) {
topk_queue.emplace(elem, i);
} else {
if (elem >= topk_queue.top().first) {
// replace top node if found a bigger value
topk_queue.pop();
topk_queue.emplace(elem, i);
}
}
assert(max_pos >= 0);
topk_pos.push_back(max_pos);
v[max_pos] = min_val;
}
assert(topk_pos.size() > 0);
while (topk_pos.size() < (size_t)k) {
topk_pos.push_back(-1);
// reversely assign value
int real_k = topk_queue.size();
for (int i = real_k - 1; i >= 0; --i) {
pos[i] = topk_queue.top().second;
topk_queue.pop();
}
for (size_t i = 0; i < topk_pos.size(); ++i) {
pos[i] = topk_pos[i];
// if length of data is less than k, fill TopKPosPaddingId at the end of pos.
for (int i = real_k; i < k; ++i) {
pos[i] = TopKPosPaddingId;
}
}
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
} // namespace details
template <typename DeviceContext, typename T>
class SequenceTopkAvgPoolingKernel : public framework::OpKernel<T> {
......@@ -70,20 +75,29 @@ class SequenceTopkAvgPoolingKernel : public framework::OpKernel<T> {
auto* out = context.Output<LoDTensor>("Out");
auto* pos = context.Output<Tensor>("pos");
PADDLE_ENFORCE_EQ(in->lod().empty(), false,
"Input(X) Tensor of SequenceTopkAvgPoolingOp does not "
"contain LoD information.");
PADDLE_ENFORCE_EQ(row->lod().empty(), false,
"Input(ROW) Tensor of SequenceTopkAvgPoolingOp does not "
"contain LoD information.");
PADDLE_ENFORCE_EQ(col->lod().empty(), false,
"Input(COLUMN) Tensor of SequenceTopkAvgPoolingOp does "
"not contain LoD information.");
PADDLE_ENFORCE_EQ(
in->lod().empty(), false,
platform::errors::InvalidArgument(
"Input(X) Tensor of SequenceTopkAvgPoolingOp does not "
"contain LoD information."));
PADDLE_ENFORCE_EQ(
row->lod().empty(), false,
platform::errors::InvalidArgument(
"Input(ROW) Tensor of SequenceTopkAvgPoolingOp does not "
"contain LoD information."));
PADDLE_ENFORCE_EQ(
col->lod().empty(), false,
platform::errors::InvalidArgument(
"Input(COLUMN) Tensor of SequenceTopkAvgPoolingOp does "
"not contain LoD information."));
auto channel_num = context.Attr<int>("channel_num");
auto topks = context.Attr<std::vector<int>>("topks");
auto k_num = topks.size();
auto max_k = topks[topks.size() - 1];
PADDLE_ENFORCE_GE(max_k, 0,
platform::errors::InvalidArgument(
"Expected max_k >= 0, but received %d.", max_k));
std::vector<int> vec_pos_shape;
auto in_lod = in->lod()[0];
......@@ -116,7 +130,10 @@ class SequenceTopkAvgPoolingKernel : public framework::OpKernel<T> {
int row_size = row_lod[i + 1] - row_lod[i];
int col_size = col_lod[i + 1] - col_lod[i];
PADDLE_ENFORCE_EQ(total_size, channel_num * row_size * col_size,
"size wrong in sequence_topk_avg_pooling_op!");
platform::errors::PreconditionNotMet(
"Expected total_size == channel_num * row_size * "
"col_size, but got %d != %d.",
total_size, channel_num * row_size * col_size));
int feature_num = row_size * col_size;
for (int j = 0; j < channel_num; ++j) {
......@@ -130,14 +147,14 @@ class SequenceTopkAvgPoolingKernel : public framework::OpKernel<T> {
auto out_slice_data = dout_data + row_lod[i] * channel_num * k_num +
r * channel_num * k_num + j * k_num;
get_topk_pos<T>(row_data, col_size, max_k, pos_slice_data);
if (pos_slice_data[0] == -1) {
details::get_topk_pos<T>(row_data, col_size, max_k, pos_slice_data);
if (pos_slice_data[0] == TopKPosPaddingId) {
sum_data[0] = 0.0;
} else {
sum_data[0] = row_data[pos_slice_data[0]];
}
for (int k = 1; k < max_k; ++k) {
if (pos_slice_data[k] == -1) {
if (pos_slice_data[k] == TopKPosPaddingId) {
sum_data[k] = sum_data[k - 1];
} else {
sum_data[k] = sum_data[k - 1] + row_data[pos_slice_data[k]];
......@@ -206,7 +223,7 @@ class SequenceTopkAvgPoolingGradKernel : public framework::OpKernel<T> {
for (size_t m = 0; m < k_num; ++m) {
for (int k = 0; k < topks[m]; ++k) {
if (pos_slice_data[k] == -1) {
if (pos_slice_data[k] == TopKPosPaddingId) {
break;
} else {
in_slice_data[pos_slice_data[k]] += row_data[m] / topks[m];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册