From 6c358a7c22fe9acd35cae13f7debc8200350d0ee Mon Sep 17 00:00:00 2001 From: 0x45f <23097963+0x45f@users.noreply.github.com> Date: Thu, 24 Feb 2022 15:34:31 +0800 Subject: [PATCH] [Phi]Move cross OP to phi (#39829) * move cross forward OP * move cross grad op to phi * move infershape * refine infershape * rename ctx * set dtype and layout in InferMeta * refine code --- paddle/fluid/operators/cross_op.cc | 69 +----- paddle/fluid/operators/cross_op.cu | 28 --- paddle/fluid/operators/cross_op.h | 222 ------------------ paddle/phi/infermeta/binary.cc | 45 ++++ paddle/phi/infermeta/binary.h | 5 + paddle/phi/kernels/cpu/cross_grad_kernel.cc | 28 +++ paddle/phi/kernels/cpu/cross_kernel.cc | 22 ++ paddle/phi/kernels/cross_grad_kernel.h | 30 +++ paddle/phi/kernels/cross_kernel.h | 28 +++ paddle/phi/kernels/funcs/common_shape.h | 12 + paddle/phi/kernels/gpu/cross_grad_kernel.cu | 28 +++ paddle/phi/kernels/gpu/cross_kernel.cu | 22 ++ .../phi/kernels/impl/cross_grad_kernel_impl.h | 113 +++++++++ paddle/phi/kernels/impl/cross_kernel_impl.h | 116 +++++++++ paddle/phi/ops/compat/cross_sig.cc | 33 +++ 15 files changed, 491 insertions(+), 310 deletions(-) delete mode 100644 paddle/fluid/operators/cross_op.cu delete mode 100644 paddle/fluid/operators/cross_op.h create mode 100644 paddle/phi/kernels/cpu/cross_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/cross_kernel.cc create mode 100644 paddle/phi/kernels/cross_grad_kernel.h create mode 100644 paddle/phi/kernels/cross_kernel.h create mode 100644 paddle/phi/kernels/gpu/cross_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/cross_kernel.cu create mode 100644 paddle/phi/kernels/impl/cross_grad_kernel_impl.h create mode 100644 paddle/phi/kernels/impl/cross_kernel_impl.h create mode 100644 paddle/phi/ops/compat/cross_sig.cc diff --git a/paddle/fluid/operators/cross_op.cc b/paddle/fluid/operators/cross_op.cc index e6b30ba42f..fe00ee0660 100644 --- a/paddle/fluid/operators/cross_op.cc +++ b/paddle/fluid/operators/cross_op.cc @@ -12,67 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/cross_op.h" #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/binary.h" namespace paddle { namespace operators { using framework::Tensor; using framework::DDim; +const int kDefaultDim = framework::DDim::kMaxRank; class CrossOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, - platform::errors::InvalidArgument( - "Input(X) of CrossOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true, - platform::errors::InvalidArgument( - "Input(Index) of CrossOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, - platform::errors::InvalidArgument( - "Output(Out) of CrossOp should not be null.")); - - auto x_dim = ctx->GetInputDim("X"); - auto y_dim = ctx->GetInputDim("Y"); - auto dim = ctx->Attrs().Get("dim"); - - bool dims_match = CheckDims(x_dim, y_dim); - PADDLE_ENFORCE_EQ(dims_match, true, - platform::errors::InvalidArgument( - "The 'shape' of Input(X) should be equal to " - "the 'shape' of Input(Y). But received " - "Input(X).dimensions = [%s], " - "Input(Y).dimensions = [%s]", - x_dim, y_dim)); - - if (dim != kDefaultDim) { - PADDLE_ENFORCE_EQ( - dim < x_dim.size() && dim >= (0 - x_dim.size()), true, - platform::errors::OutOfRange( - "Attr(dim) is out of range, It's expected " - "to be in range of [-%d, %d]. But received Attr(dim) = %d.", - x_dim.size(), x_dim.size() - 1, dim)); - if (dim < 0) { - dim += x_dim.size(); - } - PADDLE_ENFORCE_EQ(x_dim[dim] == 3 && y_dim[dim] == 3, true, - platform::errors::InvalidArgument( - "Input(X/Y).dims()[dim] should be equal to 3." - "But received Input(X/Y).dims()[dim] = %d.", - x_dim[dim])); - } - - ctx->SetOutputDim("Out", x_dim); - auto type = ctx->GetInputsVarType("X")[0]; - if (type == framework::proto::VarType::LOD_TENSOR) { - ctx->ShareLoD("X", /*->*/ "Out"); - } - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -153,17 +109,10 @@ class CrossGradMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; +DELCARE_INFER_SHAPE_FUNCTOR(cross, CrossInferShapeFunctor, + PT_INFER_META(phi::CrossInferMeta)); REGISTER_OPERATOR(cross, ops::CrossOp, ops::CrossOpMaker, ops::CrossGradMaker, - ops::CrossGradMaker); + ops::CrossGradMaker, + CrossInferShapeFunctor); REGISTER_OPERATOR(cross_grad, ops::CrossGradOp); -REGISTER_OP_CPU_KERNEL( - cross, ops::CrossKernel, - ops::CrossKernel, - ops::CrossKernel, - ops::CrossKernel); -REGISTER_OP_CPU_KERNEL( - cross_grad, ops::CrossGradKernel, - ops::CrossGradKernel, - ops::CrossGradKernel, - ops::CrossGradKernel); diff --git a/paddle/fluid/operators/cross_op.cu b/paddle/fluid/operators/cross_op.cu deleted file mode 100644 index 78bbb3ea56..0000000000 --- a/paddle/fluid/operators/cross_op.cu +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/cross_op.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - cross, ops::CrossKernel, - ops::CrossKernel, - ops::CrossKernel, - ops::CrossKernel); -REGISTER_OP_CUDA_KERNEL( - cross_grad, - ops::CrossGradKernel, - ops::CrossGradKernel, - ops::CrossGradKernel, - ops::CrossGradKernel); diff --git a/paddle/fluid/operators/cross_op.h b/paddle/fluid/operators/cross_op.h deleted file mode 100644 index b1c5eb62fd..0000000000 --- a/paddle/fluid/operators/cross_op.h +++ /dev/null @@ -1,222 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; -using DDim = framework::DDim; -const int kDefaultDim = framework::DDim::kMaxRank; - -inline bool CheckDims(const DDim& dims_x, const DDim& dims_y) { - if (dims_x.size() != dims_y.size()) { - return false; - } - for (int i = 0; i < dims_x.size(); i++) { - if (dims_x[i] != dims_y[i]) { - return false; - } - } - return true; -} - -template -class CrossKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* input_x_var = context.InputVar("X"); - auto* input_y_var = context.InputVar("Y"); - auto* output_var = context.OutputVar("Out"); - - auto& input_x = input_x_var->Get(); - auto& input_y = input_y_var->Get(); - auto* output = output_var->GetMutable(); - int dim = context.Attr("dim"); - - auto input_x_dims = input_x.dims(); - auto input_y_dims = input_y.dims(); - bool dims_match = CheckDims(input_x_dims, input_y_dims); - PADDLE_ENFORCE_EQ(dims_match, true, - platform::errors::InvalidArgument( - "The 'shape' of Input(X) should be equal to " - "the 'shape' of Input(Y). But received " - "Input(X).dimensions = [%s], " - "Input(Y).dimensions = [%s]", - input_x_dims, input_x_dims)); - - if (dim != kDefaultDim) { - PADDLE_ENFORCE_EQ( - dim < input_x_dims.size() && dim >= (0 - input_x_dims.size()), true, - platform::errors::OutOfRange( - "Attr(dim) is out of range, It's expected " - "to be in range of [-%d, %d]. But received Attr(dim) = %d.", - input_x_dims.size(), input_x_dims.size() - 1, dim)); - if (dim < 0) { - dim += input_x_dims.size(); - } - - PADDLE_ENFORCE_EQ( - input_x_dims[dim] == 3, true, - platform::errors::InvalidArgument( - "Input(X/Y).dims[dim] must be equal to 3. But received: " - "Input(X/Y).dims[dim] = [%d].", - input_x_dims[dim])); - } else { - for (auto i = 0; i < input_x_dims.size(); i++) { - if (input_x_dims[i] == 3) { - dim = i; - break; - } - } - PADDLE_ENFORCE_EQ(dim == kDefaultDim, false, - platform::errors::InvalidArgument( - "There must be at least one dimension 'd' so that " - "Input(X/Y).dims()[d] is equal to 3. " - "But received: Input(X/Y).dims() == [%s].", - input_x_dims)); - } - auto outer_loops = 1; - for (auto i = 0; i < dim; i++) { - outer_loops *= input_x_dims[i]; - } - auto slice_size = 1; - for (auto i = dim + 1; i < input_x_dims.size(); i++) { - slice_size *= input_x_dims[i]; - } - - std::vector input_x_vec, input_y_vec; - framework::TensorToVector(input_x, context.device_context(), &input_x_vec); - framework::TensorToVector(input_y, context.device_context(), &input_y_vec); - std::vector out_vec(output->numel()); - - output->mutable_data(context.GetPlace()); - - for (auto i = 0; i < outer_loops; i++) { - for (auto j = 0; j < 3; j++) { - auto dst_pos = (3 * i + j) * slice_size; - auto in_pos1 = (3 * i + ((j + 1) % 3)) * slice_size; - auto in_pos2 = (3 * i + ((j + 2) % 3)) * slice_size; - - for (auto k = 0; k < slice_size; k++) { - out_vec[dst_pos + k] = - input_x_vec[in_pos1 + k] * input_y_vec[in_pos2 + k] - - input_x_vec[in_pos2 + k] * input_y_vec[in_pos1 + k]; - } - } - } - framework::TensorFromVector(out_vec, context.device_context(), output); - output->Resize(input_x_dims); - } -}; - -template -class CrossGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* input_x_var = context.InputVar("X"); - auto* input_y_var = context.InputVar("Y"); - auto* input_out_grad_var = context.InputVar(framework::GradVarName("Out")); - auto* output_x_grad_var = context.OutputVar(framework::GradVarName("X")); - auto* output_y_grad_var = context.OutputVar(framework::GradVarName("Y")); - - auto& input_x = input_x_var->Get(); - auto& input_y = input_y_var->Get(); - auto& input_out_grad = input_out_grad_var->Get(); - auto* output_x_grad = output_x_grad_var->GetMutable(); - auto* output_y_grad = output_y_grad_var->GetMutable(); - - int dim = context.Attr("dim"); - auto input_x_dims = input_x.dims(); - if (dim != kDefaultDim) { - PADDLE_ENFORCE_EQ( - dim < input_x_dims.size() && dim >= (0 - input_x_dims.size()), true, - platform::errors::OutOfRange( - "Attr(dim) is out of range, It's expected " - "to be in range of [-%d, %d]. But received Attr(dim) = %d.", - input_x_dims.size(), input_x_dims.size() - 1, dim)); - if (dim < 0) { - dim += input_x_dims.size(); - } - - PADDLE_ENFORCE_EQ( - input_x_dims[dim] == 3, true, - platform::errors::InvalidArgument( - "Input(X/Y).dims[dim] must be equal to 3. But received: " - "Input(X/Y).dims[dim] = [%d].", - input_x_dims[dim])); - } else { - for (auto i = 0; i < input_x_dims.size(); i++) { - if (input_x_dims[i] == 3) { - dim = i; - break; - } - } - PADDLE_ENFORCE_EQ(dim == kDefaultDim, false, - platform::errors::InvalidArgument( - "There must be at least one dimension 'd' " - "so that Input(X/Y).dims()[d] is equal to 3. " - "But received: Input(X/Y).dims() == [%s].", - input_x_dims)); - } - auto outer_loops = 1; - for (auto i = 0; i < dim; i++) { - outer_loops *= input_x_dims[i]; - } - auto slice_size = 1; - for (auto i = dim + 1; i < input_x_dims.size(); i++) { - slice_size *= input_x_dims[i]; - } - - std::vector input_x_vec, input_y_vec, input_dout_vec; - framework::TensorToVector(input_x, context.device_context(), &input_x_vec); - framework::TensorToVector(input_y, context.device_context(), &input_y_vec); - framework::TensorToVector(input_out_grad, context.device_context(), - &input_dout_vec); - std::vector out_dx_vec(output_x_grad->numel()); - std::vector out_dy_vec(output_y_grad->numel()); - - output_x_grad->mutable_data(context.GetPlace()); - output_y_grad->mutable_data(context.GetPlace()); - - for (auto i = 0; i < outer_loops; i++) { - for (auto j = 0; j < 3; j++) { - auto dst_pos = (3 * i + j) * slice_size; - auto in_pos1 = (3 * i + ((j + 1) % 3)) * slice_size; - auto in_pos2 = (3 * i + ((j + 2) % 3)) * slice_size; - for (auto k = 0; k < slice_size; k++) { - out_dx_vec[dst_pos + k] = - input_dout_vec[in_pos2 + k] * input_y_vec[in_pos1 + k] - - input_dout_vec[in_pos1 + k] * input_y_vec[in_pos2 + k]; - out_dy_vec[dst_pos + k] = - input_dout_vec[in_pos1 + k] * input_x_vec[in_pos2 + k] - - input_dout_vec[in_pos2 + k] * input_x_vec[in_pos1 + k]; - } - } - } - framework::TensorFromVector(out_dx_vec, context.device_context(), - output_x_grad); - framework::TensorFromVector(out_dy_vec, context.device_context(), - output_y_grad); - output_x_grad->Resize(input_x_dims); - output_y_grad->Resize(input_x_dims); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index ab1fe5433f..58cd43998b 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -225,6 +225,51 @@ void HuberLossInferMeta(const MetaTensor& input, out->share_lod(input); } +void CrossInferMeta(const MetaTensor& x, + const MetaTensor& y, + int axis, + MetaTensor* out) { + auto x_dim = x.dims(); + auto y_dim = y.dims(); + auto dim = axis; + + bool dims_match = phi::funcs::CheckDims(x_dim, y_dim); + PADDLE_ENFORCE_EQ( + dims_match, + true, + phi::errors::InvalidArgument("The 'shape' of Input(X) should be equal to " + "the 'shape' of Input(Y). But received " + "Input(X).dimensions = [%s], " + "Input(Y).dimensions = [%s]", + x_dim, + y_dim)); + + if (dim != DDim::kMaxRank) { + PADDLE_ENFORCE_EQ( + dim < x_dim.size() && dim >= (0 - x_dim.size()), + true, + phi::errors::OutOfRange( + "Attr(dim) is out of range, It's expected " + "to be in range of [-%d, %d]. But received Attr(dim) = %d.", + x_dim.size(), + x_dim.size() - 1, + dim)); + if (dim < 0) { + dim += x_dim.size(); + } + PADDLE_ENFORCE_EQ(x_dim[dim] == 3 && y_dim[dim] == 3, + true, + phi::errors::InvalidArgument( + "Input(X/Y).dims()[dim] should be equal to 3." + "But received Input(X/Y).dims()[dim] = %d.", + x_dim[dim])); + } + out->set_dims(x_dim); + out->set_dtype(x.dtype()); + out->set_layout(x.layout()); + out->share_lod(x); +} + void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { auto in_dims = x.dims(); out->set_dims(in_dims); diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index effa18c567..02750482dc 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -53,6 +53,11 @@ void HuberLossInferMeta(const MetaTensor& input_meta, MetaTensor* residual, MetaConfig config = MetaConfig()); +void CrossInferMeta(const MetaTensor& x, + const MetaTensor& y, + int axis, + MetaTensor* out); + void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out); void BCELossInferMeta(const MetaTensor& input, const MetaTensor& label, diff --git a/paddle/phi/kernels/cpu/cross_grad_kernel.cc b/paddle/phi/kernels/cpu/cross_grad_kernel.cc new file mode 100644 index 0000000000..390420008e --- /dev/null +++ b/paddle/phi/kernels/cpu/cross_grad_kernel.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/cross_grad_kernel.h" +#include "paddle/phi/kernels/impl/cross_grad_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(cross_grad, + CPU, + ALL_LAYOUT, + phi::CrossGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/cross_kernel.cc b/paddle/phi/kernels/cpu/cross_kernel.cc new file mode 100644 index 0000000000..a63f33174e --- /dev/null +++ b/paddle/phi/kernels/cpu/cross_kernel.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/cross_kernel.h" +#include "paddle/phi/kernels/impl/cross_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL( + cross, CPU, ALL_LAYOUT, phi::CrossKernel, float, double, int, int64_t) {} diff --git a/paddle/phi/kernels/cross_grad_kernel.h b/paddle/phi/kernels/cross_grad_kernel.h new file mode 100644 index 0000000000..9ea0804a94 --- /dev/null +++ b/paddle/phi/kernels/cross_grad_kernel.h @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void CrossGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + int axis, + DenseTensor* x_grad, + DenseTensor* y_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/cross_kernel.h b/paddle/phi/kernels/cross_kernel.h new file mode 100644 index 0000000000..567889e078 --- /dev/null +++ b/paddle/phi/kernels/cross_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void CrossKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/funcs/common_shape.h b/paddle/phi/kernels/funcs/common_shape.h index 8bd9867f39..d5289dcc22 100644 --- a/paddle/phi/kernels/funcs/common_shape.h +++ b/paddle/phi/kernels/funcs/common_shape.h @@ -128,5 +128,17 @@ static void GetBroadcastDims(const DDim &in_dims, } } +inline bool CheckDims(const DDim &dims_x, const DDim &dims_y) { + if (dims_x.size() != dims_y.size()) { + return false; + } + for (int i = 0; i < dims_x.size(); i++) { + if (dims_x[i] != dims_y[i]) { + return false; + } + } + return true; +} + } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/gpu/cross_grad_kernel.cu b/paddle/phi/kernels/gpu/cross_grad_kernel.cu new file mode 100644 index 0000000000..1bb0d42dad --- /dev/null +++ b/paddle/phi/kernels/gpu/cross_grad_kernel.cu @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/cross_grad_kernel.h" +#include "paddle/phi/kernels/impl/cross_grad_kernel_impl.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(cross_grad, + GPU, + ALL_LAYOUT, + phi::CrossGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/cross_kernel.cu b/paddle/phi/kernels/gpu/cross_kernel.cu new file mode 100644 index 0000000000..aa944f8291 --- /dev/null +++ b/paddle/phi/kernels/gpu/cross_kernel.cu @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/cross_kernel.h" +#include "paddle/phi/kernels/impl/cross_kernel_impl.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL( + cross, GPU, ALL_LAYOUT, phi::CrossKernel, float, double, int, int64_t) {} diff --git a/paddle/phi/kernels/impl/cross_grad_kernel_impl.h b/paddle/phi/kernels/impl/cross_grad_kernel_impl.h new file mode 100644 index 0000000000..99a79dc15c --- /dev/null +++ b/paddle/phi/kernels/impl/cross_grad_kernel_impl.h @@ -0,0 +1,113 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void CrossGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + int axis, + DenseTensor* x_grad, + DenseTensor* y_grad) { + auto& input_x = x; + auto& input_y = y; + auto& input_out_grad = out_grad; + auto* output_x_grad = x_grad; + auto* output_y_grad = y_grad; + int dim = axis; + auto input_x_dims = input_x.dims(); + if (dim != DDim::kMaxRank) { + PADDLE_ENFORCE_EQ( + dim < input_x_dims.size() && dim >= (0 - input_x_dims.size()), + true, + errors::OutOfRange( + "Attr(dim) is out of range, It's expected " + "to be in range of [-%d, %d]. But received Attr(dim) = %d.", + input_x_dims.size(), + input_x_dims.size() - 1, + dim)); + if (dim < 0) { + dim += input_x_dims.size(); + } + + PADDLE_ENFORCE_EQ( + input_x_dims[dim] == 3, + true, + errors::InvalidArgument( + "Input(X/Y).dims[dim] must be equal to 3. But received: " + "Input(X/Y).dims[dim] = [%d].", + input_x_dims[dim])); + } else { + for (auto i = 0; i < input_x_dims.size(); i++) { + if (input_x_dims[i] == 3) { + dim = i; + break; + } + } + PADDLE_ENFORCE_EQ( + dim == DDim::kMaxRank, + false, + errors::InvalidArgument("There must be at least one dimension 'd' " + "so that Input(X/Y).dims()[d] is equal to 3. " + "But received: Input(X/Y).dims() == [%s].", + input_x_dims)); + } + auto outer_loops = 1; + for (auto i = 0; i < dim; i++) { + outer_loops *= input_x_dims[i]; + } + auto slice_size = 1; + for (auto i = dim + 1; i < input_x_dims.size(); i++) { + slice_size *= input_x_dims[i]; + } + + std::vector input_x_vec, input_y_vec, input_dout_vec; + paddle::framework::TensorToVector(input_x, dev_ctx, &input_x_vec); + paddle::framework::TensorToVector(input_y, dev_ctx, &input_y_vec); + paddle::framework::TensorToVector(input_out_grad, dev_ctx, &input_dout_vec); + std::vector out_dx_vec(output_x_grad->numel()); + std::vector out_dy_vec(output_y_grad->numel()); + + dev_ctx.template Alloc(output_x_grad); + dev_ctx.template Alloc(output_y_grad); + + for (auto i = 0; i < outer_loops; i++) { + for (auto j = 0; j < 3; j++) { + auto dst_pos = (3 * i + j) * slice_size; + auto in_pos1 = (3 * i + ((j + 1) % 3)) * slice_size; + auto in_pos2 = (3 * i + ((j + 2) % 3)) * slice_size; + for (auto k = 0; k < slice_size; k++) { + out_dx_vec[dst_pos + k] = + input_dout_vec[in_pos2 + k] * input_y_vec[in_pos1 + k] - + input_dout_vec[in_pos1 + k] * input_y_vec[in_pos2 + k]; + out_dy_vec[dst_pos + k] = + input_dout_vec[in_pos1 + k] * input_x_vec[in_pos2 + k] - + input_dout_vec[in_pos2 + k] * input_x_vec[in_pos1 + k]; + } + } + } + paddle::framework::TensorFromVector(out_dx_vec, dev_ctx, output_x_grad); + paddle::framework::TensorFromVector(out_dy_vec, dev_ctx, output_y_grad); + output_x_grad->Resize(input_x_dims); + output_y_grad->Resize(input_x_dims); +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/cross_kernel_impl.h b/paddle/phi/kernels/impl/cross_kernel_impl.h new file mode 100644 index 0000000000..6427d7f871 --- /dev/null +++ b/paddle/phi/kernels/impl/cross_kernel_impl.h @@ -0,0 +1,116 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/common_shape.h" + +namespace phi { + +template +void CrossKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out) { + auto& input_x = x; + auto& input_y = y; + auto* output = out; + int dim = axis; + + auto input_x_dims = input_x.dims(); + auto input_y_dims = input_y.dims(); + bool dims_match = phi::funcs::CheckDims(input_x_dims, input_y_dims); + PADDLE_ENFORCE_EQ( + dims_match, + true, + phi::errors::InvalidArgument("The 'shape' of Input(X) should be equal to " + "the 'shape' of Input(Y). But received " + "Input(X).dimensions = [%s], " + "Input(Y).dimensions = [%s]", + input_x_dims, + input_x_dims)); + + if (dim != DDim::kMaxRank) { + PADDLE_ENFORCE_EQ( + dim < input_x_dims.size() && dim >= (0 - input_x_dims.size()), + true, + phi::errors::OutOfRange( + "Attr(dim) is out of range, It's expected " + "to be in range of [-%d, %d]. But received Attr(dim) = %d.", + input_x_dims.size(), + input_x_dims.size() - 1, + dim)); + if (dim < 0) { + dim += input_x_dims.size(); + } + + PADDLE_ENFORCE_EQ( + input_x_dims[dim] == 3, + true, + phi::errors::InvalidArgument( + "Input(X/Y).dims[dim] must be equal to 3. But received: " + "Input(X/Y).dims[dim] = [%d].", + input_x_dims[dim])); + } else { + for (auto i = 0; i < input_x_dims.size(); i++) { + if (input_x_dims[i] == 3) { + dim = i; + break; + } + } + PADDLE_ENFORCE_EQ(dim == DDim::kMaxRank, + false, + phi::errors::InvalidArgument( + "There must be at least one dimension 'd' so that " + "Input(X/Y).dims()[d] is equal to 3. " + "But received: Input(X/Y).dims() == [%s].", + input_x_dims)); + } + auto outer_loops = 1; + for (auto i = 0; i < dim; i++) { + outer_loops *= input_x_dims[i]; + } + auto slice_size = 1; + for (auto i = dim + 1; i < input_x_dims.size(); i++) { + slice_size *= input_x_dims[i]; + } + + std::vector input_x_vec, input_y_vec; + paddle::framework::TensorToVector(input_x, dev_ctx, &input_x_vec); + paddle::framework::TensorToVector(input_y, dev_ctx, &input_y_vec); + std::vector out_vec(output->numel()); + + dev_ctx.template Alloc(output); + + for (auto i = 0; i < outer_loops; i++) { + for (auto j = 0; j < 3; j++) { + auto dst_pos = (3 * i + j) * slice_size; + auto in_pos1 = (3 * i + ((j + 1) % 3)) * slice_size; + auto in_pos2 = (3 * i + ((j + 2) % 3)) * slice_size; + + for (auto k = 0; k < slice_size; k++) { + out_vec[dst_pos + k] = + input_x_vec[in_pos1 + k] * input_y_vec[in_pos2 + k] - + input_x_vec[in_pos2 + k] * input_y_vec[in_pos1 + k]; + } + } + } + paddle::framework::TensorFromVector(out_vec, dev_ctx, output); + output->Resize(input_x_dims); +} + +} // namespace phi diff --git a/paddle/phi/ops/compat/cross_sig.cc b/paddle/phi/ops/compat/cross_sig.cc new file mode 100644 index 0000000000..307c2ac516 --- /dev/null +++ b/paddle/phi/ops/compat/cross_sig.cc @@ -0,0 +1,33 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature CrossOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("cross", {"X", "Y"}, {"dim"}, {"Out"}); +} + +KernelSignature CrossGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("cross_grad", + {"X", "Y", GradVarName("Out")}, + {"dim"}, + {GradVarName("X"), GradVarName("Y")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(cross, phi::CrossOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(cross_grad, phi::CrossGradOpArgumentMapping); -- GitLab