未验证 提交 f05e444a 编写于 作者: C crystal 提交者: GitHub

optimization of index_select forward op (#32863)

上级 81e702ac
...@@ -15,10 +15,8 @@ ...@@ -15,10 +15,8 @@
#pragma once #pragma once
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/jit/macro.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -27,69 +25,69 @@ using Tensor = framework::Tensor; ...@@ -27,69 +25,69 @@ using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
using DDim = framework::DDim; using DDim = framework::DDim;
template <typename T, typename IndexT = int> template <typename DeviceContext, typename T, typename IndexT = int>
void IndexSelectInner(const framework::ExecutionContext& context, void IndexSelectInner(const framework::ExecutionContext& context,
const LoDTensor& input, const LoDTensor& index, LoDTensor* input, const LoDTensor& index,
LoDTensor* output, int dim) { LoDTensor* output, int dim) {
auto input_dim = input.dims(); auto input_dim = input->dims();
auto input_dim_size = input_dim.size(); auto input_dim_size = input_dim.size();
auto output_dim = output->dims(); auto output_dim = output->dims();
auto index_size = index.dims()[0];
LoDTensor index_cpu_copy;
if (!platform::is_cpu_place(index.place())) {
framework::TensorCopySync(index, platform::CPUPlace(), &index_cpu_copy);
}
const IndexT* index_data = platform::is_cpu_place(index.place())
? index.data<IndexT>()
: index_cpu_copy.data<IndexT>();
output->mutable_data<T>(context.GetPlace());
auto slice_size = 1; auto slice_size = 1;
for (auto i = dim + 1; i < input_dim_size; i++) { for (auto i = dim + 1; i < input_dim_size; i++) {
slice_size *= input_dim[i]; slice_size *= input_dim[i];
} }
auto input_width = slice_size * input_dim[dim];
auto output_width = slice_size * output_dim[dim];
auto outer_nums = 1; auto outer_nums = 1;
for (auto i = 0; i < dim; i++) { for (auto i = 0; i < dim; i++) {
outer_nums *= input_dim[i]; outer_nums *= input_dim[i];
} }
auto index_size = index.dims()[0];
std::vector<T> input_vec;
std::vector<IndexT> index_vec;
TensorToVector(input, context.device_context(), &input_vec);
TensorToVector(index, context.device_context(), &index_vec);
std::vector<T> out_vec(output->numel());
for (int i = 0; i < index_size; i++) { for (int i = 0; i < index_size; i++) {
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
index_vec[i], 0, index_data[i], 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Variable value (index) of OP(index_select) " "Variable value (index) of OP(index_select) "
"expected >= 0 and < %ld, but got %ld. Please check input " "expected >= 0 and < %ld, but got %ld. Please check input "
"value.", "value.",
input_dim[dim], index_vec[i])); input_dim[dim], index_data[i]));
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
index_vec[i], input_dim[dim], index_data[i], input_dim[dim],
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Variable value (index) of OP(index_select) " "Variable value (index) of OP(index_select) "
"expected >= 0 and < %ld, but got %ld. Please check input " "expected >= 0 and < %ld, but got %ld. Please check input "
"value.", "value.",
input_dim[dim], index_vec[i])); input_dim[dim], index_data[i]));
} }
VLOG(3) << "Index_Select_Debug; outer_nums: " << outer_nums VLOG(3) << "Index_Select_Debug; outer_nums: " << outer_nums
<< "; slice_size: " << slice_size << "; input_width: " << input_width << "; slice_size: " << slice_size << "; index_size: " << index_size;
<< "; output_width: " << output_width
<< "; index_size: " << index_size;
for (auto i = 0; i < outer_nums; i++) { input->Resize(framework::make_ddim({outer_nums, input_dim[dim], slice_size}));
auto input_start_offset = i * input_width; output->Resize(framework::make_ddim({outer_nums, index_size, slice_size}));
auto output_start_offset = i * output_width;
for (auto j = 0; j < index_size; j++) { auto input_tensor = framework::EigenTensor<T, 3>::From(*input);
IndexT index_value = index_vec[j]; auto output_tensor = framework::EigenTensor<T, 3>::From(*output);
for (auto k = 0; k < slice_size; k++) {
out_vec[output_start_offset + j * slice_size + k] = auto& place =
input_vec[input_start_offset + index_value * slice_size + k]; *context.template device_context<DeviceContext>().eigen_device();
}
} for (auto j = 0; j < index_size; j++) {
IndexT index_value = index_data[j];
auto output_t = output_tensor.chip(j, 1);
output_t.device(place) = input_tensor.chip(index_value, 1);
} }
output->mutable_data<T>(context.GetPlace()); input->Resize(input_dim);
framework::TensorFromVector(out_vec, context.device_context(), output);
output->Resize(output_dim); output->Resize(output_dim);
} }
...@@ -97,19 +95,15 @@ template <typename DeviceContext, typename T> ...@@ -97,19 +95,15 @@ template <typename DeviceContext, typename T>
class IndexSelectKernel : public framework::OpKernel<T> { class IndexSelectKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* inputs_var = context.InputVar("X"); auto inputs = *context.Input<framework::LoDTensor>("X");
auto* index_var = context.InputVar("Index"); auto* index = context.Input<framework::LoDTensor>("Index");
auto* output_var = context.OutputVar("Out"); auto* output = context.Output<framework::LoDTensor>("Out");
auto& inputs = inputs_var->Get<LoDTensor>();
auto& index = index_var->Get<LoDTensor>();
auto* output = output_var->GetMutable<framework::LoDTensor>();
int dim = context.Attr<int>("dim"); int dim = context.Attr<int>("dim");
if (dim < 0) { if (dim < 0) {
dim += inputs.dims().size(); dim += inputs.dims().size();
} }
const auto& index_type = index.type(); const auto& index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 || bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64; index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true, PADDLE_ENFORCE_EQ(index_type_match, true,
...@@ -123,9 +117,11 @@ class IndexSelectKernel : public framework::OpKernel<T> { ...@@ -123,9 +117,11 @@ class IndexSelectKernel : public framework::OpKernel<T> {
framework::proto::VarType::INT64))); framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
IndexSelectInner<T, int>(context, inputs, index, output, dim); IndexSelectInner<DeviceContext, T, int>(context, &inputs, *index, output,
dim);
} else if (index_type == framework::proto::VarType::INT64) { } else if (index_type == framework::proto::VarType::INT64) {
IndexSelectInner<T, int64_t>(context, inputs, index, output, dim); IndexSelectInner<DeviceContext, T, int64_t>(context, &inputs, *index,
output, dim);
} }
} }
}; };
...@@ -152,13 +148,13 @@ struct IndexSelectAdd< ...@@ -152,13 +148,13 @@ struct IndexSelectAdd<
template <typename DeviceContext, typename T, typename IndexT = int> template <typename DeviceContext, typename T, typename IndexT = int>
void IndexSelectGradInner(const framework::ExecutionContext& context, void IndexSelectGradInner(const framework::ExecutionContext& context,
const LoDTensor* out_grad, const LoDTensor* index, const LoDTensor& out_grad, const LoDTensor& index,
LoDTensor* x_grad, int dim) { LoDTensor* x_grad, int dim) {
const T* input_data = out_grad->data<T>(); const T* input_data = out_grad.data<T>();
const IndexT* index_data = index->data<IndexT>(); const IndexT* index_data = index.data<IndexT>();
const T* p_output = x_grad->mutable_data<T>(context.GetPlace()); const T* p_output = x_grad->mutable_data<T>(context.GetPlace());
T* out_data = x_grad->mutable_data<T>(context.GetPlace()); T* out_data = x_grad->mutable_data<T>(context.GetPlace());
auto input_dim = out_grad->dims(); auto input_dim = out_grad.dims();
auto input_dim_size = input_dim.size(); auto input_dim_size = input_dim.size();
auto output_dim = x_grad->dims(); auto output_dim = x_grad->dims();
...@@ -179,7 +175,7 @@ void IndexSelectGradInner(const framework::ExecutionContext& context, ...@@ -179,7 +175,7 @@ void IndexSelectGradInner(const framework::ExecutionContext& context,
outer_nums *= input_dim[i]; outer_nums *= input_dim[i];
} }
auto index_size = index->dims()[0]; auto index_size = index.dims()[0];
VLOG(3) << "Index_Select_Grad_Debug; outer_nums: " << outer_nums VLOG(3) << "Index_Select_Grad_Debug; outer_nums: " << outer_nums
<< "; slice_size: " << slice_size << "; input_width: " << input_width << "; slice_size: " << slice_size << "; input_width: " << input_width
<< "; output_width: " << output_width << "; output_width: " << output_width
...@@ -230,11 +226,11 @@ class IndexSelectGradKernel : public framework::OpKernel<T> { ...@@ -230,11 +226,11 @@ class IndexSelectGradKernel : public framework::OpKernel<T> {
framework::proto::VarType::INT64))); framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
IndexSelectGradInner<DeviceContext, T, int>(context, out_grad, index, IndexSelectGradInner<DeviceContext, T, int>(context, *out_grad, *index,
x_grad, dim); x_grad, dim);
} else if (index_type == framework::proto::VarType::INT64) { } else if (index_type == framework::proto::VarType::INT64) {
IndexSelectGradInner<DeviceContext, T, int64_t>(context, out_grad, index, IndexSelectGradInner<DeviceContext, T, int64_t>(context, *out_grad,
x_grad, dim); *index, x_grad, dim);
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册