未验证 提交 6883403f 编写于 作者: C crystal 提交者: GitHub

optimization of index_select op backward (#32955)

上级 3a5f1f22
...@@ -15,6 +15,10 @@ ...@@ -15,6 +15,10 @@
#pragma once #pragma once
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/jit/macro.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -38,7 +42,6 @@ void IndexSelectInner(const framework::ExecutionContext& context, ...@@ -38,7 +42,6 @@ void IndexSelectInner(const framework::ExecutionContext& context,
auto input_width = slice_size * input_dim[dim]; auto input_width = slice_size * input_dim[dim];
auto output_width = slice_size * output_dim[dim]; auto output_width = slice_size * output_dim[dim];
auto outer_nums = 1; auto outer_nums = 1;
for (auto i = 0; i < dim; i++) { for (auto i = 0; i < dim; i++) {
outer_nums *= input_dim[i]; outer_nums *= input_dim[i];
...@@ -77,7 +80,6 @@ void IndexSelectInner(const framework::ExecutionContext& context, ...@@ -77,7 +80,6 @@ void IndexSelectInner(const framework::ExecutionContext& context,
for (auto i = 0; i < outer_nums; i++) { for (auto i = 0; i < outer_nums; i++) {
auto input_start_offset = i * input_width; auto input_start_offset = i * input_width;
auto output_start_offset = i * output_width; auto output_start_offset = i * output_width;
for (auto j = 0; j < index_size; j++) { for (auto j = 0; j < index_size; j++) {
IndexT index_value = index_vec[j]; IndexT index_value = index_vec[j];
for (auto k = 0; k < slice_size; k++) { for (auto k = 0; k < slice_size; k++) {
...@@ -98,7 +100,6 @@ class IndexSelectKernel : public framework::OpKernel<T> { ...@@ -98,7 +100,6 @@ class IndexSelectKernel : public framework::OpKernel<T> {
auto* inputs_var = context.InputVar("X"); auto* inputs_var = context.InputVar("X");
auto* index_var = context.InputVar("Index"); auto* index_var = context.InputVar("Index");
auto* output_var = context.OutputVar("Out"); auto* output_var = context.OutputVar("Out");
auto& inputs = inputs_var->Get<LoDTensor>(); auto& inputs = inputs_var->Get<LoDTensor>();
auto& index = index_var->Get<LoDTensor>(); auto& index = index_var->Get<LoDTensor>();
auto* output = output_var->GetMutable<framework::LoDTensor>(); auto* output = output_var->GetMutable<framework::LoDTensor>();
...@@ -107,8 +108,8 @@ class IndexSelectKernel : public framework::OpKernel<T> { ...@@ -107,8 +108,8 @@ class IndexSelectKernel : public framework::OpKernel<T> {
if (dim < 0) { if (dim < 0) {
dim += inputs.dims().size(); dim += inputs.dims().size();
} }
const auto& index_type = index.type(); const auto& index_type = index.type();
bool index_type_match = index_type == framework::proto::VarType::INT32 || bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64; index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true, PADDLE_ENFORCE_EQ(index_type_match, true,
...@@ -129,19 +130,41 @@ class IndexSelectKernel : public framework::OpKernel<T> { ...@@ -129,19 +130,41 @@ class IndexSelectKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T, typename IndexT = int> template <typename DeviceContext, typename T, class Enable = void>
struct IndexSelectAdd {
void operator()(const framework::ExecutionContext& ctx, int slice_size,
const T* src_pointer, const T* p_pointer, T* dist_pointer) {
for (int i = 0; i < slice_size; i++) {
dist_pointer[i] = src_pointer[i] + p_pointer[i];
}
}
};
template <typename DeviceContext, typename T>
struct IndexSelectAdd<
DeviceContext, T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
void operator()(const framework::ExecutionContext& ctx, int slice_size,
const T* src_pointer, const T* p_pointer, T* dist_pointer) {
auto blas = math::GetBlas<DeviceContext, T>(ctx);
blas.VADD(slice_size, src_pointer, p_pointer, dist_pointer);
}
};
template <typename DeviceContext, typename T, typename IndexT = int>
void IndexSelectGradInner(const framework::ExecutionContext& context, void IndexSelectGradInner(const framework::ExecutionContext& context,
const LoDTensor& out_grad, const LoDTensor& index, const LoDTensor* out_grad, const LoDTensor* index,
LoDTensor* x_grad, int dim) { LoDTensor* x_grad, int dim) {
std::vector<T> input_vec; const T* input_data = out_grad->data<T>();
std::vector<IndexT> index_vec; const IndexT* index_data = index->data<IndexT>();
TensorToVector(out_grad, context.device_context(), &input_vec); const T* p_output = x_grad->mutable_data<T>(context.GetPlace());
TensorToVector(index, context.device_context(), &index_vec); T* out_data = x_grad->mutable_data<T>(context.GetPlace());
auto input_dim = out_grad->dims();
auto input_dim = out_grad.dims();
auto input_dim_size = input_dim.size(); auto input_dim_size = input_dim.size();
auto output_dim = x_grad->dims(); auto output_dim = x_grad->dims();
std::vector<T> out_vec(x_grad->numel(), 0);
auto& dev_ctx = context.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> set_constant;
set_constant(dev_ctx, x_grad, static_cast<T>(0.0));
auto slice_size = 1; auto slice_size = 1;
for (auto i = dim + 1; i < input_dim_size; i++) { for (auto i = dim + 1; i < input_dim_size; i++) {
...@@ -156,7 +179,7 @@ void IndexSelectGradInner(const framework::ExecutionContext& context, ...@@ -156,7 +179,7 @@ void IndexSelectGradInner(const framework::ExecutionContext& context,
outer_nums *= input_dim[i]; outer_nums *= input_dim[i];
} }
auto index_size = index.dims()[0]; auto index_size = index->dims()[0];
VLOG(3) << "Index_Select_Grad_Debug; outer_nums: " << outer_nums VLOG(3) << "Index_Select_Grad_Debug; outer_nums: " << outer_nums
<< "; slice_size: " << slice_size << "; input_width: " << input_width << "; slice_size: " << slice_size << "; input_width: " << input_width
<< "; output_width: " << output_width << "; output_width: " << output_width
...@@ -167,15 +190,14 @@ void IndexSelectGradInner(const framework::ExecutionContext& context, ...@@ -167,15 +190,14 @@ void IndexSelectGradInner(const framework::ExecutionContext& context,
auto output_start_offset = i * output_width; auto output_start_offset = i * output_width;
for (auto j = 0; j < index_size; j++) { for (auto j = 0; j < index_size; j++) {
IndexT index_value = index_vec[j]; IndexT index_value = index_data[j];
for (auto k = 0; k < slice_size; k++) { auto src = input_data + input_start_offset + j * slice_size;
out_vec[output_start_offset + index_value * slice_size + k] += auto p_out = p_output + output_start_offset + index_value * slice_size;
input_vec[input_start_offset + j * slice_size + k]; auto dst = out_data + output_start_offset + index_value * slice_size;
} IndexSelectAdd<DeviceContext, T> index_select_add;
index_select_add(context, slice_size, src, p_out, dst);
} }
} }
x_grad->mutable_data<T>(context.GetPlace());
framework::TensorFromVector(out_vec, context.device_context(), x_grad);
x_grad->Resize(output_dim); x_grad->Resize(output_dim);
} }
...@@ -183,19 +205,18 @@ template <typename DeviceContext, typename T> ...@@ -183,19 +205,18 @@ template <typename DeviceContext, typename T>
class IndexSelectGradKernel : public framework::OpKernel<T> { class IndexSelectGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* index_var = context.InputVar("Index"); auto* x_grad =
auto* x_grad_var = context.OutputVar(framework::GradVarName("X")); context.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto* out_grad_var = context.InputVar(framework::GradVarName("Out")); auto* index = context.Input<framework::LoDTensor>("Index");
auto* out_grad =
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto& index = index_var->Get<LoDTensor>();
auto& out_grad = out_grad_var->Get<LoDTensor>();
auto* x_grad = x_grad_var->GetMutable<framework::LoDTensor>();
int dim = context.Attr<int>("dim"); int dim = context.Attr<int>("dim");
if (dim < 0) { if (dim < 0) {
dim += out_grad.dims().size(); dim += out_grad->dims().size();
} }
const auto& index_type = index->type();
const auto& index_type = index.type();
bool index_type_match = index_type == framework::proto::VarType::INT32 || bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64; index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true, PADDLE_ENFORCE_EQ(index_type_match, true,
...@@ -209,9 +230,11 @@ class IndexSelectGradKernel : public framework::OpKernel<T> { ...@@ -209,9 +230,11 @@ class IndexSelectGradKernel : public framework::OpKernel<T> {
framework::proto::VarType::INT64))); framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
IndexSelectGradInner<T, int>(context, out_grad, index, x_grad, dim); IndexSelectGradInner<DeviceContext, T, int>(context, out_grad, index,
x_grad, dim);
} else if (index_type == framework::proto::VarType::INT64) { } else if (index_type == framework::proto::VarType::INT64) {
IndexSelectGradInner<T, int64_t>(context, out_grad, index, x_grad, dim); IndexSelectGradInner<DeviceContext, T, int64_t>(context, out_grad, index,
x_grad, dim);
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册