diff --git a/paddle/fluid/operators/index_select_op.h b/paddle/fluid/operators/index_select_op.h index 70714b7f3a0644e55c9bac27a88edb0b8a9921a4..04775107033adc4f74d43f8110db33d889f2c28a 100644 --- a/paddle/fluid/operators/index_select_op.h +++ b/paddle/fluid/operators/index_select_op.h @@ -15,6 +15,10 @@ #pragma once #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/jit/macro.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/cpu_info.h" namespace paddle { namespace operators { @@ -38,7 +42,6 @@ void IndexSelectInner(const framework::ExecutionContext& context, auto input_width = slice_size * input_dim[dim]; auto output_width = slice_size * output_dim[dim]; - auto outer_nums = 1; for (auto i = 0; i < dim; i++) { outer_nums *= input_dim[i]; @@ -77,7 +80,6 @@ void IndexSelectInner(const framework::ExecutionContext& context, for (auto i = 0; i < outer_nums; i++) { auto input_start_offset = i * input_width; auto output_start_offset = i * output_width; - for (auto j = 0; j < index_size; j++) { IndexT index_value = index_vec[j]; for (auto k = 0; k < slice_size; k++) { @@ -98,7 +100,6 @@ class IndexSelectKernel : public framework::OpKernel { auto* inputs_var = context.InputVar("X"); auto* index_var = context.InputVar("Index"); auto* output_var = context.OutputVar("Out"); - auto& inputs = inputs_var->Get(); auto& index = index_var->Get(); auto* output = output_var->GetMutable(); @@ -107,8 +108,8 @@ class IndexSelectKernel : public framework::OpKernel { if (dim < 0) { dim += inputs.dims().size(); } - const auto& index_type = index.type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || index_type == framework::proto::VarType::INT64; PADDLE_ENFORCE_EQ(index_type_match, true, @@ -129,19 +130,41 @@ class IndexSelectKernel : public framework::OpKernel { } }; -template +template +struct IndexSelectAdd { + void operator()(const framework::ExecutionContext& ctx, int slice_size, + const T* src_pointer, const T* p_pointer, T* dist_pointer) { + for (int i = 0; i < slice_size; i++) { + dist_pointer[i] = src_pointer[i] + p_pointer[i]; + } + } +}; +template +struct IndexSelectAdd< + DeviceContext, T, + typename std::enable_if::value>::type> { + void operator()(const framework::ExecutionContext& ctx, int slice_size, + const T* src_pointer, const T* p_pointer, T* dist_pointer) { + auto blas = math::GetBlas(ctx); + blas.VADD(slice_size, src_pointer, p_pointer, dist_pointer); + } +}; + +template void IndexSelectGradInner(const framework::ExecutionContext& context, - const LoDTensor& out_grad, const LoDTensor& index, + const LoDTensor* out_grad, const LoDTensor* index, LoDTensor* x_grad, int dim) { - std::vector input_vec; - std::vector index_vec; - TensorToVector(out_grad, context.device_context(), &input_vec); - TensorToVector(index, context.device_context(), &index_vec); - - auto input_dim = out_grad.dims(); + const T* input_data = out_grad->data(); + const IndexT* index_data = index->data(); + const T* p_output = x_grad->mutable_data(context.GetPlace()); + T* out_data = x_grad->mutable_data(context.GetPlace()); + auto input_dim = out_grad->dims(); auto input_dim_size = input_dim.size(); auto output_dim = x_grad->dims(); - std::vector out_vec(x_grad->numel(), 0); + + auto& dev_ctx = context.template device_context(); + math::SetConstant set_constant; + set_constant(dev_ctx, x_grad, static_cast(0.0)); auto slice_size = 1; for (auto i = dim + 1; i < input_dim_size; i++) { @@ -156,7 +179,7 @@ void IndexSelectGradInner(const framework::ExecutionContext& context, outer_nums *= input_dim[i]; } - auto index_size = index.dims()[0]; + auto index_size = index->dims()[0]; VLOG(3) << "Index_Select_Grad_Debug; outer_nums: " << outer_nums << "; slice_size: " << slice_size << "; input_width: " << input_width << "; output_width: " << output_width @@ -167,15 +190,14 @@ void IndexSelectGradInner(const framework::ExecutionContext& context, auto output_start_offset = i * output_width; for (auto j = 0; j < index_size; j++) { - IndexT index_value = index_vec[j]; - for (auto k = 0; k < slice_size; k++) { - out_vec[output_start_offset + index_value * slice_size + k] += - input_vec[input_start_offset + j * slice_size + k]; - } + IndexT index_value = index_data[j]; + auto src = input_data + input_start_offset + j * slice_size; + auto p_out = p_output + output_start_offset + index_value * slice_size; + auto dst = out_data + output_start_offset + index_value * slice_size; + IndexSelectAdd index_select_add; + index_select_add(context, slice_size, src, p_out, dst); } } - x_grad->mutable_data(context.GetPlace()); - framework::TensorFromVector(out_vec, context.device_context(), x_grad); x_grad->Resize(output_dim); } @@ -183,19 +205,18 @@ template class IndexSelectGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* index_var = context.InputVar("Index"); - auto* x_grad_var = context.OutputVar(framework::GradVarName("X")); - auto* out_grad_var = context.InputVar(framework::GradVarName("Out")); + auto* x_grad = + context.Output(framework::GradVarName("X")); + auto* index = context.Input("Index"); + auto* out_grad = + context.Input(framework::GradVarName("Out")); - auto& index = index_var->Get(); - auto& out_grad = out_grad_var->Get(); - auto* x_grad = x_grad_var->GetMutable(); int dim = context.Attr("dim"); if (dim < 0) { - dim += out_grad.dims().size(); + dim += out_grad->dims().size(); } + const auto& index_type = index->type(); - const auto& index_type = index.type(); bool index_type_match = index_type == framework::proto::VarType::INT32 || index_type == framework::proto::VarType::INT64; PADDLE_ENFORCE_EQ(index_type_match, true, @@ -209,9 +230,11 @@ class IndexSelectGradKernel : public framework::OpKernel { framework::proto::VarType::INT64))); if (index_type == framework::proto::VarType::INT32) { - IndexSelectGradInner(context, out_grad, index, x_grad, dim); + IndexSelectGradInner(context, out_grad, index, + x_grad, dim); } else if (index_type == framework::proto::VarType::INT64) { - IndexSelectGradInner(context, out_grad, index, x_grad, dim); + IndexSelectGradInner(context, out_grad, index, + x_grad, dim); } } };