// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; using DDim = framework::DDim; template void IndexSelectInner(const framework::ExecutionContext& context, LoDTensor* input, const LoDTensor& index, LoDTensor* output, int dim) { auto input_dim = input->dims(); auto input_dim_size = input_dim.size(); auto output_dim = output->dims(); auto index_size = index.dims()[0]; LoDTensor index_cpu_copy; if (!platform::is_cpu_place(index.place())) { framework::TensorCopySync(index, platform::CPUPlace(), &index_cpu_copy); } const IndexT* index_data = platform::is_cpu_place(index.place()) ? index.data() : index_cpu_copy.data(); output->mutable_data(context.GetPlace()); auto slice_size = 1; for (auto i = dim + 1; i < input_dim_size; i++) { slice_size *= input_dim[i]; } auto outer_nums = 1; for (auto i = 0; i < dim; i++) { outer_nums *= input_dim[i]; } for (int i = 0; i < index_size; i++) { PADDLE_ENFORCE_GE( index_data[i], 0, platform::errors::InvalidArgument( "Variable value (index) of OP(index_select) " "expected >= 0 and < %ld, but got %ld. Please check input " "value.", input_dim[dim], index_data[i])); PADDLE_ENFORCE_LT( index_data[i], input_dim[dim], platform::errors::InvalidArgument( "Variable value (index) of OP(index_select) " "expected >= 0 and < %ld, but got %ld. Please check input " "value.", input_dim[dim], index_data[i])); } VLOG(3) << "Index_Select_Debug; outer_nums: " << outer_nums << "; slice_size: " << slice_size << "; index_size: " << index_size; input->Resize(framework::make_ddim({outer_nums, input_dim[dim], slice_size})); output->Resize(framework::make_ddim({outer_nums, index_size, slice_size})); auto input_tensor = framework::EigenTensor::From(*input); auto output_tensor = framework::EigenTensor::From(*output); auto& place = *context.template device_context().eigen_device(); for (auto j = 0; j < index_size; j++) { IndexT index_value = index_data[j]; auto output_t = output_tensor.chip(j, 1); output_t.device(place) = input_tensor.chip(index_value, 1); } input->Resize(input_dim); output->Resize(output_dim); } template class IndexSelectKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto inputs = *context.Input("X"); auto* index = context.Input("Index"); auto* output = context.Output("Out"); int dim = context.Attr("dim"); if (dim < 0) { dim += inputs.dims().size(); } const auto& index_type = framework::TransToProtoVarType(index->dtype()); bool index_type_match = index_type == framework::proto::VarType::INT32 || index_type == framework::proto::VarType::INT64; PADDLE_ENFORCE_EQ(index_type_match, true, platform::errors::InvalidArgument( "Input(Index) holds the wrong type, it holds %s, but " "desires to be %s or %s", paddle::framework::DataTypeToString(index_type), paddle::framework::DataTypeToString( framework::proto::VarType::INT32), paddle::framework::DataTypeToString( framework::proto::VarType::INT64))); if (index_type == framework::proto::VarType::INT32) { IndexSelectInner(context, &inputs, *index, output, dim); } else if (index_type == framework::proto::VarType::INT64) { IndexSelectInner(context, &inputs, *index, output, dim); } } }; template struct IndexSelectAdd { void operator()(const framework::ExecutionContext& ctx, int slice_size, const T* src_pointer, const T* p_pointer, T* dist_pointer) { for (int i = 0; i < slice_size; i++) { dist_pointer[i] = src_pointer[i] + p_pointer[i]; } } }; template struct IndexSelectAdd< DeviceContext, T, typename std::enable_if::value>::type> { void operator()(const framework::ExecutionContext& ctx, int slice_size, const T* src_pointer, const T* p_pointer, T* dist_pointer) { auto blas = math::GetBlas(ctx); blas.VADD(slice_size, src_pointer, p_pointer, dist_pointer); } }; template void IndexSelectGradInner(const framework::ExecutionContext& context, const LoDTensor& out_grad, const LoDTensor& index, LoDTensor* x_grad, int dim) { const T* input_data = out_grad.data(); const IndexT* index_data = index.data(); const T* p_output = x_grad->mutable_data(context.GetPlace()); T* out_data = x_grad->mutable_data(context.GetPlace()); auto input_dim = out_grad.dims(); auto input_dim_size = input_dim.size(); auto output_dim = x_grad->dims(); auto& dev_ctx = context.template device_context(); pten::funcs::SetConstant set_constant; set_constant(dev_ctx, x_grad, static_cast(0.0)); auto slice_size = 1; for (auto i = dim + 1; i < input_dim_size; i++) { slice_size *= input_dim[i]; } auto input_width = slice_size * input_dim[dim]; auto output_width = slice_size * output_dim[dim]; auto outer_nums = 1; for (auto i = 0; i < dim; i++) { outer_nums *= input_dim[i]; } auto index_size = index.dims()[0]; VLOG(3) << "Index_Select_Grad_Debug; outer_nums: " << outer_nums << "; slice_size: " << slice_size << "; input_width: " << input_width << "; output_width: " << output_width << "; index_size: " << index_size; for (auto i = 0; i < outer_nums; i++) { auto input_start_offset = i * input_width; auto output_start_offset = i * output_width; for (auto j = 0; j < index_size; j++) { IndexT index_value = index_data[j]; auto src = input_data + input_start_offset + j * slice_size; auto p_out = p_output + output_start_offset + index_value * slice_size; auto dst = out_data + output_start_offset + index_value * slice_size; IndexSelectAdd index_select_add; index_select_add(context, slice_size, src, p_out, dst); } } x_grad->Resize(output_dim); } template class IndexSelectGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* x_grad = context.Output(framework::GradVarName("X")); auto* index = context.Input("Index"); auto* out_grad = context.Input(framework::GradVarName("Out")); int dim = context.Attr("dim"); if (dim < 0) { dim += out_grad->dims().size(); } const auto& index_type = framework::TransToProtoVarType(index->dtype()); bool index_type_match = index_type == framework::proto::VarType::INT32 || index_type == framework::proto::VarType::INT64; PADDLE_ENFORCE_EQ(index_type_match, true, platform::errors::InvalidArgument( "Input(Index) holds the wrong type, it holds %s, but " "desires to be %s or %s", paddle::framework::DataTypeToString(index_type), paddle::framework::DataTypeToString( framework::proto::VarType::INT32), paddle::framework::DataTypeToString( framework::proto::VarType::INT64))); if (index_type == framework::proto::VarType::INT32) { IndexSelectGradInner(context, *out_grad, *index, x_grad, dim); } else if (index_type == framework::proto::VarType::INT64) { IndexSelectGradInner(context, *out_grad, *index, x_grad, dim); } } }; } // namespace operators } // namespace paddle