// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/kernels/view_kernel.h" #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/strided_reshape_utils.h" namespace phi { template void ViewShapeKernel(const Context& dev_ctx, const DenseTensor& input, const std::vector& dims, DenseTensor* out) { DDim new_dims = DDim(dims.data(), static_cast(dims.size())); DDim stride; if (ReshapeStride(input.dims(), input.strides(), new_dims, stride)) { auto meta = input.meta(); meta.dims = new_dims; meta.strides = stride; meta.offset = input.offset(); out->set_meta(meta); out->ResetHolder(input.Holder()); } else { PADDLE_THROW(phi::errors::InvalidArgument( "The Tensor can not be viewed, please call reshape.")); } } template void ViewDtypeKernel(const Context& dev_ctx, const DenseTensor& input, DataType dtype, DenseTensor* out) { size_t input_dtype_size = phi::SizeOf(input.dtype()); size_t output_dtype_size = phi::SizeOf(dtype); if (input_dtype_size == output_dtype_size) { auto meta = input.meta(); meta.dtype = dtype; out->set_meta(meta); out->ResetHolder(input.Holder()); } else if (input_dtype_size == 0) { PADDLE_THROW(phi::errors::InvalidArgument( "The Tensor's shape is [] can not be viewed.")); } else if (input_dtype_size > output_dtype_size) { PADDLE_ENFORCE_EQ( input.strides()[input.strides().size() - 1], 1, phi::errors::InvalidArgument( "input.strides[-1] must be 1 to view %s as %s, but got %d", input.dtype(), dtype, input.strides()[input.strides().size() - 1])); size_t times = input_dtype_size / output_dtype_size; DDim output_dims = input.dims(); output_dims[output_dims.size() - 1] = output_dims[output_dims.size() - 1] * times; // NOLINT DDim output_stride = input.strides(); for (int i = 0; i < output_stride.size(); i++) { output_stride[i] = output_stride[i] * times; // NOLINT } output_stride[output_stride.size() - 1] = 1; auto meta = input.meta(); meta.dtype = dtype; meta.dims = output_dims; meta.strides = output_stride; meta.offset = input.offset() * times; out->set_meta(meta); out->ResetHolder(input.Holder()); } else { PADDLE_ENFORCE_EQ( input.strides()[input.strides().size() - 1], 1, phi::errors::InvalidArgument( "input.strides[%d] must be 1 to view %s as %s, but got %d", input.strides().size() - 1, input.dtype(), dtype, input.strides()[input.strides().size() - 1])); size_t times = input_dtype_size / output_dtype_size; PADDLE_ENFORCE_EQ( input.dims()[input.dims().size() - 1] % times, 0, phi::errors::InvalidArgument( "input.shape[%d](%d) must be multiple of %d to view %s as %s", input.dims().size() - 1, input.dims()[input.dims().size() - 1], times, input.dtype(), dtype)); PADDLE_ENFORCE_EQ( input.offset() % times, 0, phi::errors::InvalidArgument( "input.offset(%d) must be multiple of %d to view %s as %s", input.offset(), times, input.dtype(), dtype)); DDim output_dims = input.dims(); output_dims[output_dims.size() - 1] = output_dims[output_dims.size() - 1] / times; // NOLINT DDim output_stride = input.strides(); for (int i = 0; i < output_stride.size(); i++) { PADDLE_ENFORCE_EQ( output_stride[i] % times, 0, phi::errors::InvalidArgument("input.strides[%d](%d) must be be " "multiple of %d to view %s as %s", i, output_stride[i], times, input.dtype(), dtype)); output_stride[i] = output_stride[i] / times; // NOLINT } output_stride[output_stride.size() - 1] = 1; auto meta = input.meta(); meta.dtype = dtype; meta.dims = output_dims; meta.strides = output_stride; meta.offset = input.offset() / times; out->set_meta(meta); out->ResetHolder(input.Holder()); } } } // namespace phi PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM(view_shape, STRIDED, phi::ViewShapeKernel) {} PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM(view_dtype, STRIDED, phi::ViewDtypeKernel) {}