/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/kernels/functions/general/elementwise_base.h" namespace pten { inline void UpdateElementwiseIndexArray(const int *out_dims_array, const int max_dim, int *index_array) { for (int i = max_dim - 1; i >= 0; --i) { ++index_array[i]; if (index_array[i] >= out_dims_array[i]) { index_array[i] -= out_dims_array[i]; } else { break; } } } inline int GetElementwiseIndex(const int *x_dims_array, const int max_dim, const int *index_array) { int index_ = 0; for (int i = 0; i < max_dim; i++) { if (x_dims_array[i] > 1) { index_ = index_ * x_dims_array[i] + index_array[i]; } } return index_; } template void CommonForwardBroadcastCPU(const DenseTensor &x, const DenseTensor &y, DenseTensor *z, int *x_dims_array, int *y_dims_array, int *out_dims_array, int max_dim, const paddle::platform::CPUDeviceContext &ctx, Functor func, const bool is_xsize_larger = true) { std::vector index_array(max_dim, 0); const T *x_data = x.data(); const T *y_data = y.data(); PADDLE_ENFORCE_NOT_NULL(x_data, paddle::platform::errors::InvalidArgument( "The input X should not be empty.")); PADDLE_ENFORCE_NOT_NULL(y_data, paddle::platform::errors::InvalidArgument( "The input Y should not be empty.")); OutType *out_data = z->mutable_data(); const int out_size = std::accumulate( out_dims_array, out_dims_array + max_dim, 1, std::multiplies()); int x_index, y_index; for (int out_index = 0; out_index < out_size; ++out_index) { x_index = GetElementwiseIndex(x_dims_array, max_dim, index_array.data()); y_index = GetElementwiseIndex(y_dims_array, max_dim, index_array.data()); if (is_xsize_larger) { out_data[out_index] = func(x_data[x_index], y_data[y_index]); } else { out_data[out_index] = func(y_data[y_index], x_data[x_index]); } UpdateElementwiseIndexArray(out_dims_array, max_dim, index_array.data()); } } template void CommonElementwiseBroadcastForward( const paddle::platform::CPUDeviceContext &dev_ctx, const DenseTensor &x, const DenseTensor &y, DenseTensor *z, const DDim &x_dims, const DDim &y_dims, Functor func, int axis, const bool is_xsize_larger = true) { int max_dim = (std::max)(x_dims.size(), y_dims.size()); axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); PADDLE_ENFORCE_GE( axis, 0, paddle::platform::errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); PADDLE_ENFORCE_LT(axis, max_dim, paddle::platform::errors::InvalidArgument( "Axis should be less than %d, but received axis is %d.", max_dim, axis)); std::vector x_dims_array(max_dim); std::vector y_dims_array(max_dim); std::vector out_dims_array(max_dim); general::GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(), y_dims_array.data(), out_dims_array.data(), max_dim, axis); CommonForwardBroadcastCPU(x, y, z, x_dims_array.data(), y_dims_array.data(), out_dims_array.data(), max_dim, dev_ctx, func, is_xsize_larger); } // It is a common implementation to compute binary calculation with the support // of broadcast, supporting both CPU and GPU. // - CPU implementation cannot support the case when x needs broadcast, thus // this function need to be called with XxxFunctor and XxxInverseFunctor, // like paddle/fluid/operators/elementwise/elementwise_add_op.h#L49 - L55. // - GPU implementation supports all the broadcast cases, thus there is no need // to define and call with XxxInverseFunctor. // TODO(liuyiqun): optimize the CPU implementation to support all broadcast // cases and avoid the need of XxxInverseFunctor. template void ElementwiseCompute(const paddle::platform::CPUDeviceContext &dev_ctx, const DenseTensor &x, const DenseTensor &y, int axis, Functor func, DenseTensor *z) { z->mutable_data(); auto x_dims = x.dims(); auto y_dims = y.dims(); bool is_xsize_larger = true; int max_dim = x_dims.size(); if (x_dims.size() < y_dims.size()) { is_xsize_larger = false; max_dim = y_dims.size(); } general:: TransformFunctor functor(x, y, z, dev_ctx, func, is_xsize_larger); if (x_dims == y_dims) { functor.Run(); return; } axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); PADDLE_ENFORCE_GE( axis, 0, paddle::platform::errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); PADDLE_ENFORCE_LT(axis, max_dim, paddle::platform::errors::InvalidArgument( "Axis should be less than %d, but received axis is %d.", max_dim, axis)); int pre, n, post, is_run_common_broadcast, axis_trim = 0; if (is_xsize_larger) { auto y_dims_trimed = general::trim_trailing_singular_dims(y_dims); axis_trim = (y_dims_trimed.size() == 0) ? x_dims.size() : axis; general::get_mid_dims(x_dims, y_dims_trimed, axis_trim, &pre, &n, &post, &is_run_common_broadcast); } else { auto x_dims_trimed = general::trim_trailing_singular_dims(x_dims); axis_trim = (x_dims_trimed.size() == 0) ? y_dims.size() : axis; general::get_mid_dims(y_dims, x_dims_trimed, axis_trim, &pre, &n, &post, &is_run_common_broadcast); } // special case for common implementation. // case 1: x=[2,3,1,5], y=[2,1,4,1] // case 2: x=[2,3,4], y=[1,1,4] if (is_run_common_broadcast == 1) { CommonElementwiseBroadcastForward( dev_ctx, x, y, z, x_dims, y_dims, func, axis, is_xsize_larger); return; } if (post == 1) { functor.RunRowWise(n, pre); return; } else { functor.RunMidWise(n, pre, post); return; } } template struct SameDimsElementwiseCompute { void operator()(const paddle::platform::CPUDeviceContext &dev_ctx, const DenseTensor &x, const DenseTensor &y, DenseTensor *z) { Functor()(dev_ctx, x, y, z); } }; } // namespace pten