diff --git a/paddle/fluid/operators/cross_op.cc b/paddle/fluid/operators/cross_op.cc index 977d84e1e47c885e1174dc16ba32770c9ac3f808..bdef52d3dca3c1c3c5675f7c66b288932f363342 100644 --- a/paddle/fluid/operators/cross_op.cc +++ b/paddle/fluid/operators/cross_op.cc @@ -61,6 +61,18 @@ class CrossGradOp : public framework::OperatorWithKernel { ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("Y"), ctx->GetInputDim("Y")); + + auto x_dims = ctx->GetInputsDim("X"); + auto y_dims = ctx->GetInputsDim("Y"); + for (size_t i = 0; i < x_dims.size(); ++i) { + PADDLE_ENFORCE_EQ(x_dims[i], y_dims[i], + phi::errors::InvalidArgument( + "The 'shape' of Input(X) should be equal to " + "the 'shape' of Input(Y). But received " + "Input(X).dimensions = [%s], " + "Input(Y).dimensions = [%s]", + x_dims[i], y_dims[i])); + } } protected: diff --git a/paddle/phi/kernels/cpu/cross_grad_kernel.cc b/paddle/phi/kernels/cpu/cross_grad_kernel.cc index 8dddc6f6e4e95a1585e7375730d4e59216ffab44..af573cfacf3df59d66314adaeb4ca58e41ec6b59 100644 --- a/paddle/phi/kernels/cpu/cross_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/cross_grad_kernel.cc @@ -14,10 +14,105 @@ #include "paddle/phi/kernels/cross_grad_kernel.h" +#include "paddle/fluid/framework/tensor_util.h" #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/cross_grad_kernel_impl.h" +namespace phi { + +template +void CrossGradKernel(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &y, + const DenseTensor &out_grad, + int axis, + DenseTensor *x_grad, + DenseTensor *y_grad) { + auto &input_x = x; + auto &input_y = y; + auto &input_out_grad = out_grad; + auto *output_x_grad = x_grad; + auto *output_y_grad = y_grad; + int dim = axis; + auto input_x_dims = input_x.dims(); + if (dim != DDim::kMaxRank) { + PADDLE_ENFORCE_EQ( + dim < input_x_dims.size() && dim >= (0 - input_x_dims.size()), + true, + errors::OutOfRange( + "Attr(dim) is out of range, It's expected " + "to be in range of [-%d, %d]. But received Attr(dim) = %d.", + input_x_dims.size(), + input_x_dims.size() - 1, + dim)); + if (dim < 0) { + dim += input_x_dims.size(); + } + + PADDLE_ENFORCE_EQ( + input_x_dims[dim] == 3, + true, + errors::InvalidArgument( + "Input(X/Y).dims[dim] must be equal to 3. But received: " + "Input(X/Y).dims[dim] = [%d].", + input_x_dims[dim])); + } else { + for (auto i = 0; i < input_x_dims.size(); i++) { + if (input_x_dims[i] == 3) { + dim = i; + break; + } + } + PADDLE_ENFORCE_EQ( + dim == DDim::kMaxRank, + false, + errors::InvalidArgument("There must be at least one dimension 'd' " + "so that Input(X/Y).dims()[d] is equal to 3. " + "But received: Input(X/Y).dims() == [%s].", + input_x_dims)); + } + auto outer_loops = 1; + for (auto i = 0; i < dim; i++) { + outer_loops *= input_x_dims[i]; + } + auto slice_size = 1; + for (auto i = dim + 1; i < input_x_dims.size(); i++) { + slice_size *= input_x_dims[i]; + } + + std::vector input_x_vec, input_y_vec, input_dout_vec; + paddle::framework::TensorToVector(input_x, dev_ctx, &input_x_vec); + paddle::framework::TensorToVector(input_y, dev_ctx, &input_y_vec); + paddle::framework::TensorToVector(input_out_grad, dev_ctx, &input_dout_vec); + std::vector out_dx_vec(output_x_grad->numel()); + std::vector out_dy_vec(output_y_grad->numel()); + + dev_ctx.template Alloc(output_x_grad); + dev_ctx.template Alloc(output_y_grad); + + for (auto i = 0; i < outer_loops; i++) { + for (auto j = 0; j < 3; j++) { + auto dst_pos = (3 * i + j) * slice_size; + auto in_pos1 = (3 * i + ((j + 1) % 3)) * slice_size; + auto in_pos2 = (3 * i + ((j + 2) % 3)) * slice_size; + for (auto k = 0; k < slice_size; k++) { + out_dx_vec[dst_pos + k] = + input_dout_vec[in_pos2 + k] * input_y_vec[in_pos1 + k] - + input_dout_vec[in_pos1 + k] * input_y_vec[in_pos2 + k]; + out_dy_vec[dst_pos + k] = + input_dout_vec[in_pos1 + k] * input_x_vec[in_pos2 + k] - + input_dout_vec[in_pos2 + k] * input_x_vec[in_pos1 + k]; + } + } + } + paddle::framework::TensorFromVector(out_dx_vec, dev_ctx, output_x_grad); + paddle::framework::TensorFromVector(out_dy_vec, dev_ctx, output_y_grad); + output_x_grad->Resize(input_x_dims); + output_y_grad->Resize(input_x_dims); +} + +} // namespace phi PD_REGISTER_KERNEL(cross_grad, CPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/cpu/cross_kernel.cc b/paddle/phi/kernels/cpu/cross_kernel.cc index 1f3a8fe5a38790fb88c980af887e06f73d0c6ffc..a321617deab30455d60aea678f978bb14370ba0f 100644 --- a/paddle/phi/kernels/cpu/cross_kernel.cc +++ b/paddle/phi/kernels/cpu/cross_kernel.cc @@ -14,9 +14,97 @@ #include "paddle/phi/kernels/cross_kernel.h" +#include "paddle/fluid/framework/tensor_util.h" #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/cross_kernel_impl.h" +#include "paddle/phi/kernels/funcs/common_shape.h" + +namespace phi { + +template +void CrossKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out) { + auto& input_x = x; + auto& input_y = y; + auto* output = out; + int dim = axis; + + auto input_x_dims = input_x.dims(); + + if (dim != DDim::kMaxRank) { + PADDLE_ENFORCE_EQ( + dim < input_x_dims.size() && dim >= (0 - input_x_dims.size()), + true, + phi::errors::OutOfRange( + "Attr(dim) is out of range, It's expected " + "to be in range of [-%d, %d]. But received Attr(dim) = %d.", + input_x_dims.size(), + input_x_dims.size() - 1, + dim)); + if (dim < 0) { + dim += input_x_dims.size(); + } + + PADDLE_ENFORCE_EQ( + input_x_dims[dim] == 3, + true, + phi::errors::InvalidArgument( + "Input(X/Y).dims[dim] must be equal to 3. But received: " + "Input(X/Y).dims[dim] = [%d].", + input_x_dims[dim])); + } else { + for (auto i = 0; i < input_x_dims.size(); i++) { + if (input_x_dims[i] == 3) { + dim = i; + break; + } + } + PADDLE_ENFORCE_EQ(dim == DDim::kMaxRank, + false, + phi::errors::InvalidArgument( + "There must be at least one dimension 'd' so that " + "Input(X/Y).dims()[d] is equal to 3. " + "But received: Input(X/Y).dims() == [%s].", + input_x_dims)); + } + auto outer_loops = 1; + for (auto i = 0; i < dim; i++) { + outer_loops *= input_x_dims[i]; + } + auto slice_size = 1; + for (auto i = dim + 1; i < input_x_dims.size(); i++) { + slice_size *= input_x_dims[i]; + } + + std::vector input_x_vec, input_y_vec; + paddle::framework::TensorToVector(input_x, dev_ctx, &input_x_vec); + paddle::framework::TensorToVector(input_y, dev_ctx, &input_y_vec); + std::vector out_vec(output->numel()); + + dev_ctx.template Alloc(output); + + for (auto i = 0; i < outer_loops; i++) { + for (auto j = 0; j < 3; j++) { + auto dst_pos = (3 * i + j) * slice_size; + auto in_pos1 = (3 * i + ((j + 1) % 3)) * slice_size; + auto in_pos2 = (3 * i + ((j + 2) % 3)) * slice_size; + + for (auto k = 0; k < slice_size; k++) { + out_vec[dst_pos + k] = + input_x_vec[in_pos1 + k] * input_y_vec[in_pos2 + k] - + input_x_vec[in_pos2 + k] * input_y_vec[in_pos1 + k]; + } + } + } + paddle::framework::TensorFromVector(out_vec, dev_ctx, output); + output->Resize(input_x_dims); +} + +} // namespace phi PD_REGISTER_KERNEL( cross, CPU, ALL_LAYOUT, phi::CrossKernel, float, double, int, int64_t) {} diff --git a/paddle/phi/kernels/gpu/cross_grad_kernel.cu b/paddle/phi/kernels/gpu/cross_grad_kernel.cu index 1f83f05f81c77560f628fd20cdb7a33fdb12a057..97d6d6849ae0048feef3f0f0422b73321e8c635c 100644 --- a/paddle/phi/kernels/gpu/cross_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/cross_grad_kernel.cu @@ -13,9 +13,141 @@ // limitations under the License. #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/cross_grad_kernel.h" -#include "paddle/phi/kernels/impl/cross_grad_kernel_impl.h" +#include "paddle/phi/kernels/funcs/reduce_function.h" + +namespace phi { + +using funcs::IndexCalculator; + +template +__global__ void CrossGrad(const T* x, + const T* y, + const T* out, + T* out_dx, + T* out_dy, + const int stride, + const int N, + IndexCalculator index_calculator) { + CUDA_KERNEL_LOOP(i, N) { + int offset = index_calculator(i); + + auto pos0 = offset + 0 * stride; + auto pos1 = offset + 1 * stride; + auto pos2 = offset + 2 * stride; + + out_dx[pos0] = out[pos2] * y[pos1] - out[pos1] * y[pos2]; + out_dy[pos0] = out[pos1] * x[pos2] - out[pos2] * x[pos1]; + + out_dx[pos1] = out[pos0] * y[pos2] - out[pos2] * y[pos0]; + out_dy[pos1] = out[pos2] * x[pos0] - out[pos0] * x[pos2]; + + out_dx[pos2] = out[pos1] * y[pos0] - out[pos0] * y[pos1]; + out_dy[pos2] = out[pos0] * x[pos1] - out[pos1] * x[pos0]; + } +} + +template +void CrossGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + int axis, + DenseTensor* x_grad, + DenseTensor* y_grad) { + auto& input_x = x; + auto& input_y = y; + auto& input_out_grad = out_grad; + auto* output_x_grad = x_grad; + auto* output_y_grad = y_grad; + int dim = axis; + + auto input_x_dims = input_x.dims(); + if (dim != DDim::kMaxRank) { + PADDLE_ENFORCE_EQ( + dim < input_x_dims.size() && dim >= (0 - input_x_dims.size()), + true, + errors::OutOfRange( + "Attr(dim) is out of range, It's expected " + "to be in range of [-%d, %d]. But received Attr(dim) = %d.", + input_x_dims.size(), + input_x_dims.size() - 1, + dim)); + if (dim < 0) { + dim += input_x_dims.size(); + } + + PADDLE_ENFORCE_EQ( + input_x_dims[dim] == 3, + true, + errors::InvalidArgument( + "Input(X/Y).dims[dim] must be equal to 3. But received: " + "Input(X/Y).dims[dim] = [%d].", + input_x_dims[dim])); + } else { + for (auto i = 0; i < input_x_dims.size(); i++) { + if (input_x_dims[i] == 3) { + dim = i; + break; + } + } + PADDLE_ENFORCE_EQ( + dim == DDim::kMaxRank, + false, + errors::InvalidArgument("There must be at least one dimension 'd' " + "so that Input(X/Y).dims()[d] is equal to 3. " + "But received: Input(X/Y).dims() == [%s].", + input_x_dims)); + } + + std::vector cal_dims; + std::vector left_strides; + std::vector full_strides; + + int full_dim = 1; + int left_dim = 1; + for (auto i = 0; i < input_x_dims.size(); i++) { + full_strides.insert(full_strides.begin(), full_dim); + full_dim *= input_x_dims[input_x_dims.size() - i - 1]; + if (i == dim) { + continue; + } + cal_dims.push_back(i); + left_strides.insert(left_strides.begin(), left_dim); + left_dim *= input_x_dims[input_x_dims.size() - i - 1]; + } + + const auto* input_x_data = input_x.data(); + const auto* input_y_data = input_y.data(); + const auto* input_out_grad_data = input_out_grad.data(); + + auto* output_x_grad_data = dev_ctx.template Alloc(x_grad); + auto* output_y_grad_data = dev_ctx.template Alloc(y_grad); + + auto index_calculator = IndexCalculator( + input_x_dims.size() - 1, cal_dims, left_strides, full_strides); + + int64_t numel = x.numel(); + + backends::gpu::GpuLaunchConfig config = + backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel / 3); + + CrossGrad<<>>(input_x_data, + input_y_data, + input_out_grad_data, + output_x_grad_data, + output_y_grad_data, + full_strides[dim], + numel / 3, + index_calculator); +} +} // namespace phi PD_REGISTER_KERNEL(cross_grad, GPU, diff --git a/paddle/phi/kernels/gpu/cross_kernel.cu b/paddle/phi/kernels/gpu/cross_kernel.cu index 4f3e5f0ca8c5d9b1ed87a8ecdacb30bc2388671f..4d4588f9b51c91cf1194b44e9bb1a2e955caeaa8 100644 --- a/paddle/phi/kernels/gpu/cross_kernel.cu +++ b/paddle/phi/kernels/gpu/cross_kernel.cu @@ -13,9 +13,126 @@ // limitations under the License. #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/cross_kernel.h" -#include "paddle/phi/kernels/impl/cross_kernel_impl.h" +#include "paddle/phi/kernels/funcs/reduce_function.h" + +namespace phi { + +using funcs::IndexCalculator; + +template +__global__ void Cross(const T* x, + const T* y, + T* out, + const int stride, + const int N, + IndexCalculator index_calculator) { + CUDA_KERNEL_LOOP(i, N) { + int offset = index_calculator(i); + + auto pos0 = offset + 0 * stride; + auto pos1 = offset + 1 * stride; + auto pos2 = offset + 2 * stride; + + out[pos0] = x[pos1] * y[pos2] - x[pos2] * y[pos1]; + out[pos1] = x[pos2] * y[pos0] - x[pos0] * y[pos2]; + out[pos2] = x[pos0] * y[pos1] - x[pos1] * y[pos0]; + } +} + +template +void CrossKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out) { + auto& input_x = x; + auto& input_y = y; + auto* output = out; + int dim = axis; + + auto input_x_dims = input_x.dims(); + if (dim != DDim::kMaxRank) { + PADDLE_ENFORCE_EQ( + dim < input_x_dims.size() && dim >= (0 - input_x_dims.size()), + true, + phi::errors::OutOfRange( + "Attr(dim) is out of range, It's expected " + "to be in range of [-%d, %d]. But received Attr(dim) = %d.", + input_x_dims.size(), + input_x_dims.size() - 1, + dim)); + if (dim < 0) { + dim += input_x_dims.size(); + } + + PADDLE_ENFORCE_EQ( + input_x_dims[dim] == 3, + true, + phi::errors::InvalidArgument( + "Input(X/Y).dims[dim] must be equal to 3. But received: " + "Input(X/Y).dims[dim] = [%d].", + input_x_dims[dim])); + } else { + for (auto i = 0; i < input_x_dims.size(); i++) { + if (input_x_dims[i] == 3) { + dim = i; + break; + } + } + PADDLE_ENFORCE_EQ(dim == DDim::kMaxRank, + false, + phi::errors::InvalidArgument( + "There must be at least one dimension 'd' so that " + "Input(X/Y).dims()[d] is equal to 3. " + "But received: Input(X/Y).dims() == [%s].", + input_x_dims)); + } + + std::vector cal_dims; + std::vector left_strides; + std::vector full_strides; + + int dims0 = 1; + int dims1 = 1; + for (auto i = 0; i < input_x_dims.size(); i++) { + full_strides.insert(full_strides.begin(), dims0); + dims0 *= input_x_dims[input_x_dims.size() - i - 1]; + if (i == dim) { + continue; + } + cal_dims.push_back(i); + left_strides.insert(left_strides.begin(), dims1); + dims1 *= input_x_dims[input_x_dims.size() - i - 1]; + } + + const auto* input_x_data = input_x.data(); + const auto* input_y_data = input_y.data(); + + auto* out_data = dev_ctx.template Alloc(out); + + auto index_calculator = IndexCalculator( + input_x_dims.size() - 1, cal_dims, left_strides, full_strides); + + int64_t numel = x.numel(); + + backends::gpu::GpuLaunchConfig config = + backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel / 3); + + Cross<<>>(input_x_data, + input_y_data, + out_data, + full_strides[dim], + numel / 3, + index_calculator); +} +} // namespace phi PD_REGISTER_KERNEL( cross, GPU, ALL_LAYOUT, phi::CrossKernel, float, double, int, int64_t) {} diff --git a/paddle/phi/kernels/impl/cross_grad_kernel_impl.h b/paddle/phi/kernels/impl/cross_grad_kernel_impl.h deleted file mode 100644 index 99a79dc15c049d826c2bfb9a50efa866bc1e176d..0000000000000000000000000000000000000000 --- a/paddle/phi/kernels/impl/cross_grad_kernel_impl.h +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/fluid/framework/tensor_util.h" -#include "paddle/phi/core/dense_tensor.h" - -namespace phi { - -template -void CrossGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - const DenseTensor& out_grad, - int axis, - DenseTensor* x_grad, - DenseTensor* y_grad) { - auto& input_x = x; - auto& input_y = y; - auto& input_out_grad = out_grad; - auto* output_x_grad = x_grad; - auto* output_y_grad = y_grad; - int dim = axis; - auto input_x_dims = input_x.dims(); - if (dim != DDim::kMaxRank) { - PADDLE_ENFORCE_EQ( - dim < input_x_dims.size() && dim >= (0 - input_x_dims.size()), - true, - errors::OutOfRange( - "Attr(dim) is out of range, It's expected " - "to be in range of [-%d, %d]. But received Attr(dim) = %d.", - input_x_dims.size(), - input_x_dims.size() - 1, - dim)); - if (dim < 0) { - dim += input_x_dims.size(); - } - - PADDLE_ENFORCE_EQ( - input_x_dims[dim] == 3, - true, - errors::InvalidArgument( - "Input(X/Y).dims[dim] must be equal to 3. But received: " - "Input(X/Y).dims[dim] = [%d].", - input_x_dims[dim])); - } else { - for (auto i = 0; i < input_x_dims.size(); i++) { - if (input_x_dims[i] == 3) { - dim = i; - break; - } - } - PADDLE_ENFORCE_EQ( - dim == DDim::kMaxRank, - false, - errors::InvalidArgument("There must be at least one dimension 'd' " - "so that Input(X/Y).dims()[d] is equal to 3. " - "But received: Input(X/Y).dims() == [%s].", - input_x_dims)); - } - auto outer_loops = 1; - for (auto i = 0; i < dim; i++) { - outer_loops *= input_x_dims[i]; - } - auto slice_size = 1; - for (auto i = dim + 1; i < input_x_dims.size(); i++) { - slice_size *= input_x_dims[i]; - } - - std::vector input_x_vec, input_y_vec, input_dout_vec; - paddle::framework::TensorToVector(input_x, dev_ctx, &input_x_vec); - paddle::framework::TensorToVector(input_y, dev_ctx, &input_y_vec); - paddle::framework::TensorToVector(input_out_grad, dev_ctx, &input_dout_vec); - std::vector out_dx_vec(output_x_grad->numel()); - std::vector out_dy_vec(output_y_grad->numel()); - - dev_ctx.template Alloc(output_x_grad); - dev_ctx.template Alloc(output_y_grad); - - for (auto i = 0; i < outer_loops; i++) { - for (auto j = 0; j < 3; j++) { - auto dst_pos = (3 * i + j) * slice_size; - auto in_pos1 = (3 * i + ((j + 1) % 3)) * slice_size; - auto in_pos2 = (3 * i + ((j + 2) % 3)) * slice_size; - for (auto k = 0; k < slice_size; k++) { - out_dx_vec[dst_pos + k] = - input_dout_vec[in_pos2 + k] * input_y_vec[in_pos1 + k] - - input_dout_vec[in_pos1 + k] * input_y_vec[in_pos2 + k]; - out_dy_vec[dst_pos + k] = - input_dout_vec[in_pos1 + k] * input_x_vec[in_pos2 + k] - - input_dout_vec[in_pos2 + k] * input_x_vec[in_pos1 + k]; - } - } - } - paddle::framework::TensorFromVector(out_dx_vec, dev_ctx, output_x_grad); - paddle::framework::TensorFromVector(out_dy_vec, dev_ctx, output_y_grad); - output_x_grad->Resize(input_x_dims); - output_y_grad->Resize(input_x_dims); -} - -} // namespace phi diff --git a/paddle/phi/kernels/impl/cross_kernel_impl.h b/paddle/phi/kernels/impl/cross_kernel_impl.h deleted file mode 100644 index 6427d7f87193f2d952d838e1fcafe8b532d08598..0000000000000000000000000000000000000000 --- a/paddle/phi/kernels/impl/cross_kernel_impl.h +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/fluid/framework/tensor_util.h" -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/funcs/common_shape.h" - -namespace phi { - -template -void CrossKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int axis, - DenseTensor* out) { - auto& input_x = x; - auto& input_y = y; - auto* output = out; - int dim = axis; - - auto input_x_dims = input_x.dims(); - auto input_y_dims = input_y.dims(); - bool dims_match = phi::funcs::CheckDims(input_x_dims, input_y_dims); - PADDLE_ENFORCE_EQ( - dims_match, - true, - phi::errors::InvalidArgument("The 'shape' of Input(X) should be equal to " - "the 'shape' of Input(Y). But received " - "Input(X).dimensions = [%s], " - "Input(Y).dimensions = [%s]", - input_x_dims, - input_x_dims)); - - if (dim != DDim::kMaxRank) { - PADDLE_ENFORCE_EQ( - dim < input_x_dims.size() && dim >= (0 - input_x_dims.size()), - true, - phi::errors::OutOfRange( - "Attr(dim) is out of range, It's expected " - "to be in range of [-%d, %d]. But received Attr(dim) = %d.", - input_x_dims.size(), - input_x_dims.size() - 1, - dim)); - if (dim < 0) { - dim += input_x_dims.size(); - } - - PADDLE_ENFORCE_EQ( - input_x_dims[dim] == 3, - true, - phi::errors::InvalidArgument( - "Input(X/Y).dims[dim] must be equal to 3. But received: " - "Input(X/Y).dims[dim] = [%d].", - input_x_dims[dim])); - } else { - for (auto i = 0; i < input_x_dims.size(); i++) { - if (input_x_dims[i] == 3) { - dim = i; - break; - } - } - PADDLE_ENFORCE_EQ(dim == DDim::kMaxRank, - false, - phi::errors::InvalidArgument( - "There must be at least one dimension 'd' so that " - "Input(X/Y).dims()[d] is equal to 3. " - "But received: Input(X/Y).dims() == [%s].", - input_x_dims)); - } - auto outer_loops = 1; - for (auto i = 0; i < dim; i++) { - outer_loops *= input_x_dims[i]; - } - auto slice_size = 1; - for (auto i = dim + 1; i < input_x_dims.size(); i++) { - slice_size *= input_x_dims[i]; - } - - std::vector input_x_vec, input_y_vec; - paddle::framework::TensorToVector(input_x, dev_ctx, &input_x_vec); - paddle::framework::TensorToVector(input_y, dev_ctx, &input_y_vec); - std::vector out_vec(output->numel()); - - dev_ctx.template Alloc(output); - - for (auto i = 0; i < outer_loops; i++) { - for (auto j = 0; j < 3; j++) { - auto dst_pos = (3 * i + j) * slice_size; - auto in_pos1 = (3 * i + ((j + 1) % 3)) * slice_size; - auto in_pos2 = (3 * i + ((j + 2) % 3)) * slice_size; - - for (auto k = 0; k < slice_size; k++) { - out_vec[dst_pos + k] = - input_x_vec[in_pos1 + k] * input_y_vec[in_pos2 + k] - - input_x_vec[in_pos2 + k] * input_y_vec[in_pos1 + k]; - } - } - } - paddle::framework::TensorFromVector(out_vec, dev_ctx, output); - output->Resize(input_x_dims); -} - -} // namespace phi