未验证 提交 4717bdbc 编写于 作者: A Aurelius84 提交者: GitHub

Fix hang in seq_topk_avg_pooling op (#25522)

* fix topk_avg_pool hang test=develop

* refactor get_topk_pos test=develop

* add check of channel_num and num_k test=develop

* add TopKPosPaddingId test=develop
上级 b796d8fa
...@@ -24,24 +24,30 @@ class SequenceTopkAvgPoolingOp : public framework::OperatorWithKernel { ...@@ -24,24 +24,30 @@ class SequenceTopkAvgPoolingOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SequenceTopkAvgPooling");
"Input(X) of SequencePoolOp should not be null."); OP_INOUT_CHECK(ctx->HasInput("ROW"), "Input", "ROW",
PADDLE_ENFORCE_EQ(ctx->HasInput("ROW"), true, "SequenceTopkAvgPooling");
"Input(ROW) of SequencePoolOp should not be null."); OP_INOUT_CHECK(ctx->HasInput("COLUMN"), "Input", "COLUMN",
PADDLE_ENFORCE_EQ(ctx->HasInput("COLUMN"), true, "SequenceTopkAvgPooling");
"Input(COLUMN) of SequencePoolOp should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, "SequenceTopkAvgPooling");
"Output(Out) of SequencePoolOp should not be null."); OP_INOUT_CHECK(ctx->HasOutput("pos"), "Output", "pos",
PADDLE_ENFORCE_EQ(ctx->HasOutput("pos"), true, "SequenceTopkAvgPooling");
"pos(out) should not be null");
auto attr = ctx->Attrs(); auto attr = ctx->Attrs();
auto channel_num = attr.Get<int>("channel_num"); auto channel_num = attr.Get<int>("channel_num");
PADDLE_ENFORCE_GT(
channel_num, 0,
platform::errors::InvalidArgument(
"Expected channel_num > 0, but received %d.", channel_num));
auto topks = attr.Get<std::vector<int>>("topks"); auto topks = attr.Get<std::vector<int>>("topks");
auto num_k = topks.size();
PADDLE_ENFORCE_GT(
num_k, 0, platform::errors::InvalidArgument(
"Expected topks.size() > 0, but received %zu.", num_k));
auto row_dim = ctx->GetInputDim("ROW"); auto row_dim = ctx->GetInputDim("ROW");
auto num_k = topks.size();
auto row_shape_0 = row_dim[0]; auto row_shape_0 = row_dim[0];
std::vector<int> vec_out_shape; std::vector<int> vec_out_shape;
...@@ -49,7 +55,7 @@ class SequenceTopkAvgPoolingOp : public framework::OperatorWithKernel { ...@@ -49,7 +55,7 @@ class SequenceTopkAvgPoolingOp : public framework::OperatorWithKernel {
vec_out_shape.push_back(channel_num * num_k); vec_out_shape.push_back(channel_num * num_k);
ctx->SetOutputDim("Out", framework::make_ddim(vec_out_shape)); ctx->SetOutputDim("Out", framework::make_ddim(vec_out_shape));
ctx->ShareLoD("X", "Out"); ctx->ShareLoD("ROW", "Out");
} }
}; };
...@@ -78,10 +84,10 @@ class SequenceTopkAvgPoolingGradOp : public framework::OperatorWithKernel { ...@@ -78,10 +84,10 @@ class SequenceTopkAvgPoolingGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true, OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Gradient of Out should not be null."); framework::GradVarName("Out"), "SequenceTopkAvgPoolingGrad");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
"The input X should not be null."); "SequenceTopkAvgPoolingGrad");
ctx->ShareDim("X", /*->*/ framework::GradVarName("X")); ctx->ShareDim("X", /*->*/ framework::GradVarName("X"));
ctx->ShareLoD("X", /*->*/ framework::GradVarName("X")); ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
......
...@@ -13,52 +13,57 @@ See the License for the specific language governing permissions and ...@@ -13,52 +13,57 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <functional>
#include <limits> #include <limits>
#include <queue>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
static constexpr int TopKPosPaddingId = -1;
namespace details {
template <typename T> template <typename T>
void get_topk_pos(const T* data, int length, int k, int* pos) { static void get_topk_pos(const T* data, int length, int k, int* pos) {
size_t real_k = k < length ? k : length; VLOG(3) << "length: " << length << " , k : " << k;
std::vector<T> v(data, data + length); std::priority_queue<std::pair<T, int>, std::vector<std::pair<T, int>>,
std::greater<std::pair<T, int>>>
std::vector<int> topk_pos; topk_queue;
T min_val = std::numeric_limits<T>::lowest();
while (topk_pos.size() < real_k) { for (int i = 0; i < length; ++i) {
T max_val = min_val; T elem = data[i];
int max_pos = -1; if (topk_queue.size() < static_cast<size_t>(k)) {
for (int i = 0; i < length; ++i) { topk_queue.emplace(elem, i);
if (v[i] > max_val) { } else {
max_pos = i; if (elem >= topk_queue.top().first) {
max_val = v[i]; // replace top node if found a bigger value
topk_queue.pop();
topk_queue.emplace(elem, i);
} }
} }
assert(max_pos >= 0);
topk_pos.push_back(max_pos);
v[max_pos] = min_val;
} }
// reversely assign value
assert(topk_pos.size() > 0); int real_k = topk_queue.size();
while (topk_pos.size() < (size_t)k) { for (int i = real_k - 1; i >= 0; --i) {
topk_pos.push_back(-1); pos[i] = topk_queue.top().second;
topk_queue.pop();
} }
// if length of data is less than k, fill TopKPosPaddingId at the end of pos.
for (size_t i = 0; i < topk_pos.size(); ++i) { for (int i = real_k; i < k; ++i) {
pos[i] = topk_pos[i]; pos[i] = TopKPosPaddingId;
} }
} }
} // namespace details
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SequenceTopkAvgPoolingKernel : public framework::OpKernel<T> { class SequenceTopkAvgPoolingKernel : public framework::OpKernel<T> {
...@@ -70,20 +75,29 @@ class SequenceTopkAvgPoolingKernel : public framework::OpKernel<T> { ...@@ -70,20 +75,29 @@ class SequenceTopkAvgPoolingKernel : public framework::OpKernel<T> {
auto* out = context.Output<LoDTensor>("Out"); auto* out = context.Output<LoDTensor>("Out");
auto* pos = context.Output<Tensor>("pos"); auto* pos = context.Output<Tensor>("pos");
PADDLE_ENFORCE_EQ(in->lod().empty(), false, PADDLE_ENFORCE_EQ(
"Input(X) Tensor of SequenceTopkAvgPoolingOp does not " in->lod().empty(), false,
"contain LoD information."); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(row->lod().empty(), false, "Input(X) Tensor of SequenceTopkAvgPoolingOp does not "
"Input(ROW) Tensor of SequenceTopkAvgPoolingOp does not " "contain LoD information."));
"contain LoD information."); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(col->lod().empty(), false, row->lod().empty(), false,
"Input(COLUMN) Tensor of SequenceTopkAvgPoolingOp does " platform::errors::InvalidArgument(
"not contain LoD information."); "Input(ROW) Tensor of SequenceTopkAvgPoolingOp does not "
"contain LoD information."));
PADDLE_ENFORCE_EQ(
col->lod().empty(), false,
platform::errors::InvalidArgument(
"Input(COLUMN) Tensor of SequenceTopkAvgPoolingOp does "
"not contain LoD information."));
auto channel_num = context.Attr<int>("channel_num"); auto channel_num = context.Attr<int>("channel_num");
auto topks = context.Attr<std::vector<int>>("topks"); auto topks = context.Attr<std::vector<int>>("topks");
auto k_num = topks.size(); auto k_num = topks.size();
auto max_k = topks[topks.size() - 1]; auto max_k = topks[topks.size() - 1];
PADDLE_ENFORCE_GE(max_k, 0,
platform::errors::InvalidArgument(
"Expected max_k >= 0, but received %d.", max_k));
std::vector<int> vec_pos_shape; std::vector<int> vec_pos_shape;
auto in_lod = in->lod()[0]; auto in_lod = in->lod()[0];
...@@ -116,7 +130,10 @@ class SequenceTopkAvgPoolingKernel : public framework::OpKernel<T> { ...@@ -116,7 +130,10 @@ class SequenceTopkAvgPoolingKernel : public framework::OpKernel<T> {
int row_size = row_lod[i + 1] - row_lod[i]; int row_size = row_lod[i + 1] - row_lod[i];
int col_size = col_lod[i + 1] - col_lod[i]; int col_size = col_lod[i + 1] - col_lod[i];
PADDLE_ENFORCE_EQ(total_size, channel_num * row_size * col_size, PADDLE_ENFORCE_EQ(total_size, channel_num * row_size * col_size,
"size wrong in sequence_topk_avg_pooling_op!"); platform::errors::PreconditionNotMet(
"Expected total_size == channel_num * row_size * "
"col_size, but got %d != %d.",
total_size, channel_num * row_size * col_size));
int feature_num = row_size * col_size; int feature_num = row_size * col_size;
for (int j = 0; j < channel_num; ++j) { for (int j = 0; j < channel_num; ++j) {
...@@ -130,14 +147,14 @@ class SequenceTopkAvgPoolingKernel : public framework::OpKernel<T> { ...@@ -130,14 +147,14 @@ class SequenceTopkAvgPoolingKernel : public framework::OpKernel<T> {
auto out_slice_data = dout_data + row_lod[i] * channel_num * k_num + auto out_slice_data = dout_data + row_lod[i] * channel_num * k_num +
r * channel_num * k_num + j * k_num; r * channel_num * k_num + j * k_num;
get_topk_pos<T>(row_data, col_size, max_k, pos_slice_data); details::get_topk_pos<T>(row_data, col_size, max_k, pos_slice_data);
if (pos_slice_data[0] == -1) { if (pos_slice_data[0] == TopKPosPaddingId) {
sum_data[0] = 0.0; sum_data[0] = 0.0;
} else { } else {
sum_data[0] = row_data[pos_slice_data[0]]; sum_data[0] = row_data[pos_slice_data[0]];
} }
for (int k = 1; k < max_k; ++k) { for (int k = 1; k < max_k; ++k) {
if (pos_slice_data[k] == -1) { if (pos_slice_data[k] == TopKPosPaddingId) {
sum_data[k] = sum_data[k - 1]; sum_data[k] = sum_data[k - 1];
} else { } else {
sum_data[k] = sum_data[k - 1] + row_data[pos_slice_data[k]]; sum_data[k] = sum_data[k - 1] + row_data[pos_slice_data[k]];
...@@ -206,7 +223,7 @@ class SequenceTopkAvgPoolingGradKernel : public framework::OpKernel<T> { ...@@ -206,7 +223,7 @@ class SequenceTopkAvgPoolingGradKernel : public framework::OpKernel<T> {
for (size_t m = 0; m < k_num; ++m) { for (size_t m = 0; m < k_num; ++m) {
for (int k = 0; k < topks[m]; ++k) { for (int k = 0; k < topks[m]; ++k) {
if (pos_slice_data[k] == -1) { if (pos_slice_data[k] == TopKPosPaddingId) {
break; break;
} else { } else {
in_slice_data[pos_slice_data[k]] += row_data[m] / topks[m]; in_slice_data[pos_slice_data[k]] += row_data[m] / topks[m];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册