// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #define MAX_RANK_SUPPORTED 6 namespace phi { using Tensor = DenseTensor; template void Expand(const Context& ctx, const DenseTensor& x, const IntArray& shape, DenseTensor* out) { auto in_dims = x.dims(); auto expand_shape = shape.GetData(); auto vec_in_dims = phi::vectorize(in_dims); auto diff = expand_shape.size() - vec_in_dims.size(); vec_in_dims.insert(vec_in_dims.begin(), diff, 1); std::vector repeat_times(vec_in_dims.size()); for (size_t i = 0; i < vec_in_dims.size(); ++i) { PADDLE_ENFORCE_NE( expand_shape[i], 0, phi::errors::InvalidArgument("The expanded size cannot be zero.")); if (i < diff) { PADDLE_ENFORCE_GT( expand_shape[i], 0, phi::errors::InvalidArgument( "The expanded size (%d) for non-existing dimensions must be " "positive for expand_v2 op.", expand_shape[i])); repeat_times[i] = expand_shape[i]; } else if (expand_shape[i] > 0) { if (vec_in_dims[i] != 1) { PADDLE_ENFORCE_EQ( vec_in_dims[i], expand_shape[i], phi::errors::InvalidArgument( "The value (%d) of the non-singleton dimension does not match" " the corresponding value (%d) in shape for expand_v2 op.", vec_in_dims[i], expand_shape[i])); repeat_times[i] = 1; } else { repeat_times[i] = expand_shape[i]; } } else { PADDLE_ENFORCE_EQ( expand_shape[i], -1, phi::errors::InvalidArgument( "When the value in shape is negative for expand_v2 op, " "only -1 is supported, but the value received is %d.", expand_shape[i])); repeat_times[i] = 1; } } Eigen::DSizes bcast_dims; for (size_t i = 0; i < repeat_times.size(); ++i) { bcast_dims[i] = repeat_times[i]; } DDim new_in_dims = phi::make_ddim(vec_in_dims); DDim out_dims(new_in_dims); for (size_t i = 0; i < repeat_times.size(); ++i) { out_dims[i] *= repeat_times[i]; } out->Resize(out_dims); auto x0 = EigenTensor::From(x, new_in_dims); ctx.template Alloc(out); out->data(); auto y = EigenTensor::From(*out, out_dims); auto& place = *ctx.eigen_device(); // use 32-bit index to speed up bool use_32bit_index = y.size() < Eigen::NumTraits::highest(); if (use_32bit_index) { phi::funcs::EigenBroadcast, T, Rank>::Eval( place, To32BitIndex(y), To32BitIndex(x0), bcast_dims); } else { phi::funcs::EigenBroadcast, T, Rank>::Eval( place, y, x0, bcast_dims); } } template void ExpandKernel(const Context& ctx, const DenseTensor& x, const IntArray& shape, DenseTensor* out) { auto rank = x.dims().size(); PADDLE_ENFORCE_GE( rank, 1, phi::errors::InvalidArgument( "The rank of the input 'X' for expand_v2 op must be positive, " "but the value received is %d.", rank)); PADDLE_ENFORCE_LE( rank, MAX_RANK_SUPPORTED, phi::errors::InvalidArgument( "The rank of the input 'X' for expand_v2 op must be less than " "or equal to %d, but the value received is %d.", MAX_RANK_SUPPORTED, rank)); auto expand_shape = shape.GetData(); auto shape_size = expand_shape.size(); PADDLE_ENFORCE_GE( shape_size, rank, phi::errors::InvalidArgument( "The number (%d) of elements of 'shape' for expand_v2 op must be " "greater than or equal to the rank (%d) of the input 'X'.", shape_size, rank)); PADDLE_ENFORCE_LE( shape_size, MAX_RANK_SUPPORTED, phi::errors::InvalidArgument( "The number (%d) of elements of 'shape' for expand_v2 op must be " "less than or equal to %d.", shape_size, MAX_RANK_SUPPORTED)); rank = std::max(rank, static_cast(shape_size)); switch (rank) { case 1: Expand(ctx, x, shape, out); break; case 2: Expand(ctx, x, shape, out); break; case 3: Expand(ctx, x, shape, out); break; case 4: Expand(ctx, x, shape, out); break; case 5: Expand(ctx, x, shape, out); break; case 6: Expand(ctx, x, shape, out); break; } } } // namespace phi