diff --git a/paddle/fluid/operators/index_select_op.h b/paddle/fluid/operators/index_select_op.h index 04775107033adc4f74d43f8110db33d889f2c28a..be76a66ef7c964836d5c1742827f976526c937dd 100644 --- a/paddle/fluid/operators/index_select_op.h +++ b/paddle/fluid/operators/index_select_op.h @@ -15,10 +15,8 @@ #pragma once #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/jit/macro.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/platform/cpu_info.h" namespace paddle { namespace operators { @@ -27,69 +25,69 @@ using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; using DDim = framework::DDim; -template +template void IndexSelectInner(const framework::ExecutionContext& context, - const LoDTensor& input, const LoDTensor& index, + LoDTensor* input, const LoDTensor& index, LoDTensor* output, int dim) { - auto input_dim = input.dims(); + auto input_dim = input->dims(); auto input_dim_size = input_dim.size(); auto output_dim = output->dims(); + auto index_size = index.dims()[0]; + + LoDTensor index_cpu_copy; + if (!platform::is_cpu_place(index.place())) { + framework::TensorCopySync(index, platform::CPUPlace(), &index_cpu_copy); + } + const IndexT* index_data = platform::is_cpu_place(index.place()) + ? index.data() + : index_cpu_copy.data(); + output->mutable_data(context.GetPlace()); auto slice_size = 1; for (auto i = dim + 1; i < input_dim_size; i++) { slice_size *= input_dim[i]; } - auto input_width = slice_size * input_dim[dim]; - auto output_width = slice_size * output_dim[dim]; auto outer_nums = 1; for (auto i = 0; i < dim; i++) { outer_nums *= input_dim[i]; } - auto index_size = index.dims()[0]; - - std::vector input_vec; - std::vector index_vec; - TensorToVector(input, context.device_context(), &input_vec); - TensorToVector(index, context.device_context(), &index_vec); - std::vector out_vec(output->numel()); - for (int i = 0; i < index_size; i++) { PADDLE_ENFORCE_GE( - index_vec[i], 0, + index_data[i], 0, platform::errors::InvalidArgument( "Variable value (index) of OP(index_select) " "expected >= 0 and < %ld, but got %ld. Please check input " "value.", - input_dim[dim], index_vec[i])); + input_dim[dim], index_data[i])); PADDLE_ENFORCE_LT( - index_vec[i], input_dim[dim], + index_data[i], input_dim[dim], platform::errors::InvalidArgument( "Variable value (index) of OP(index_select) " "expected >= 0 and < %ld, but got %ld. Please check input " "value.", - input_dim[dim], index_vec[i])); + input_dim[dim], index_data[i])); } VLOG(3) << "Index_Select_Debug; outer_nums: " << outer_nums - << "; slice_size: " << slice_size << "; input_width: " << input_width - << "; output_width: " << output_width - << "; index_size: " << index_size; + << "; slice_size: " << slice_size << "; index_size: " << index_size; - for (auto i = 0; i < outer_nums; i++) { - auto input_start_offset = i * input_width; - auto output_start_offset = i * output_width; - for (auto j = 0; j < index_size; j++) { - IndexT index_value = index_vec[j]; - for (auto k = 0; k < slice_size; k++) { - out_vec[output_start_offset + j * slice_size + k] = - input_vec[input_start_offset + index_value * slice_size + k]; - } - } + input->Resize(framework::make_ddim({outer_nums, input_dim[dim], slice_size})); + output->Resize(framework::make_ddim({outer_nums, index_size, slice_size})); + + auto input_tensor = framework::EigenTensor::From(*input); + auto output_tensor = framework::EigenTensor::From(*output); + + auto& place = + *context.template device_context().eigen_device(); + + for (auto j = 0; j < index_size; j++) { + IndexT index_value = index_data[j]; + auto output_t = output_tensor.chip(j, 1); + output_t.device(place) = input_tensor.chip(index_value, 1); } - output->mutable_data(context.GetPlace()); - framework::TensorFromVector(out_vec, context.device_context(), output); + input->Resize(input_dim); output->Resize(output_dim); } @@ -97,19 +95,15 @@ template class IndexSelectKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* inputs_var = context.InputVar("X"); - auto* index_var = context.InputVar("Index"); - auto* output_var = context.OutputVar("Out"); - auto& inputs = inputs_var->Get(); - auto& index = index_var->Get(); - auto* output = output_var->GetMutable(); + auto inputs = *context.Input("X"); + auto* index = context.Input("Index"); + auto* output = context.Output("Out"); int dim = context.Attr("dim"); if (dim < 0) { dim += inputs.dims().size(); } - const auto& index_type = index.type(); - + const auto& index_type = index->type(); bool index_type_match = index_type == framework::proto::VarType::INT32 || index_type == framework::proto::VarType::INT64; PADDLE_ENFORCE_EQ(index_type_match, true, @@ -123,9 +117,11 @@ class IndexSelectKernel : public framework::OpKernel { framework::proto::VarType::INT64))); if (index_type == framework::proto::VarType::INT32) { - IndexSelectInner(context, inputs, index, output, dim); + IndexSelectInner(context, &inputs, *index, output, + dim); } else if (index_type == framework::proto::VarType::INT64) { - IndexSelectInner(context, inputs, index, output, dim); + IndexSelectInner(context, &inputs, *index, + output, dim); } } }; @@ -152,13 +148,13 @@ struct IndexSelectAdd< template void IndexSelectGradInner(const framework::ExecutionContext& context, - const LoDTensor* out_grad, const LoDTensor* index, + const LoDTensor& out_grad, const LoDTensor& index, LoDTensor* x_grad, int dim) { - const T* input_data = out_grad->data(); - const IndexT* index_data = index->data(); + const T* input_data = out_grad.data(); + const IndexT* index_data = index.data(); const T* p_output = x_grad->mutable_data(context.GetPlace()); T* out_data = x_grad->mutable_data(context.GetPlace()); - auto input_dim = out_grad->dims(); + auto input_dim = out_grad.dims(); auto input_dim_size = input_dim.size(); auto output_dim = x_grad->dims(); @@ -179,7 +175,7 @@ void IndexSelectGradInner(const framework::ExecutionContext& context, outer_nums *= input_dim[i]; } - auto index_size = index->dims()[0]; + auto index_size = index.dims()[0]; VLOG(3) << "Index_Select_Grad_Debug; outer_nums: " << outer_nums << "; slice_size: " << slice_size << "; input_width: " << input_width << "; output_width: " << output_width @@ -230,11 +226,11 @@ class IndexSelectGradKernel : public framework::OpKernel { framework::proto::VarType::INT64))); if (index_type == framework::proto::VarType::INT32) { - IndexSelectGradInner(context, out_grad, index, + IndexSelectGradInner(context, *out_grad, *index, x_grad, dim); } else if (index_type == framework::proto::VarType::INT64) { - IndexSelectGradInner(context, out_grad, index, - x_grad, dim); + IndexSelectGradInner(context, *out_grad, + *index, x_grad, dim); } } };