From 783c4aba035542e66f187b1d504d49f8154f2a3f Mon Sep 17 00:00:00 2001 From: Linjie Chen <40840292+linjieccc@users.noreply.github.com> Date: Fri, 25 Feb 2022 14:46:43 +0800 Subject: [PATCH] move diag_v2 to phi (#39914) --- paddle/fluid/operators/diag_v2_op.cc | 96 ++--------------- paddle/fluid/operators/diag_v2_op.cu | 128 ---------------------- paddle/fluid/operators/diag_v2_op.h | 34 ------ paddle/phi/core/compat/op_utils.h | 3 +- paddle/phi/infermeta/binary.cc | 1 + paddle/phi/infermeta/unary.cc | 40 +++++++ paddle/phi/infermeta/unary.h | 5 + paddle/phi/kernels/cpu/diag_kernel.cc | 66 ++++++++++++ paddle/phi/kernels/diag_kernel.h | 28 +++++ paddle/phi/kernels/funcs/diag_functor.h | 29 +++++ paddle/phi/kernels/gpu/diag_kernel.cu | 134 ++++++++++++++++++++++++ paddle/phi/ops/compat/diag_sig.cc | 27 +++++ 12 files changed, 340 insertions(+), 251 deletions(-) delete mode 100644 paddle/fluid/operators/diag_v2_op.cu delete mode 100644 paddle/fluid/operators/diag_v2_op.h create mode 100644 paddle/phi/kernels/cpu/diag_kernel.cc create mode 100644 paddle/phi/kernels/diag_kernel.h create mode 100644 paddle/phi/kernels/funcs/diag_functor.h create mode 100644 paddle/phi/kernels/gpu/diag_kernel.cu create mode 100644 paddle/phi/ops/compat/diag_sig.cc diff --git a/paddle/fluid/operators/diag_v2_op.cc b/paddle/fluid/operators/diag_v2_op.cc index 30ea323733..0160277dc7 100644 --- a/paddle/fluid/operators/diag_v2_op.cc +++ b/paddle/fluid/operators/diag_v2_op.cc @@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/diag_v2_op.h" #include + +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { @@ -23,44 +25,6 @@ namespace operators { class DiagV2Op : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "diag_v2"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "diag_v2"); - - auto x_dims = ctx->GetInputDim("X"); - auto offset = ctx->Attrs().Get("offset"); - - if (x_dims.size() == 1UL) { - int64_t size_ = x_dims[0] + std::abs(offset); - ctx->SetOutputDim("Out", {size_, size_}); - } else if (x_dims.size() == 2UL) { - int64_t size_ = 0; - if (offset >= 0) { - // Note(LutaoChu): Do not use std::min here, otherwise the calculation - // of `size_` will have unexpected result on Windows Python3.8 - if (x_dims[0] < x_dims[1] - offset) { - size_ = x_dims[0]; - } else { - size_ = x_dims[1] - offset; - } - } else { - // Note(LutaoChu): Do not use std::min here, otherwise the calculation - // of `size_` will have unexpected result on Windows Python3.8 - if (x_dims[0] + offset < x_dims[1]) { - size_ = x_dims[0] + offset; - } else { - size_ = x_dims[1]; - } - } - ctx->SetOutputDim("Out", {size_}); - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "The input tensor X's dimensions of DiagV2Op should be either 1 or " - "2, but received %d.", - x_dims.size())); - } - } }; class DiagV2OpMaker : public framework::OpProtoAndCheckerMaker { @@ -94,59 +58,15 @@ class DiagV2OpMaker : public framework::OpProtoAndCheckerMaker { } }; -template -class DiagV2Kernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* X = context.Input("X"); - auto* x_data = X->data(); - auto x_dims = X->dims(); - int offset = context.Attr("offset"); - auto* out = context.Output("Out"); - T* out_data = out->mutable_data(context.GetPlace()); - auto out_dims = out->dims(); - - int64_t i; - if (x_dims.size() == 1) { - float padding_value = context.Attr("padding_value"); - phi::funcs::SetConstant set_padding_value; - auto& dev_ctx = context.template device_context(); - set_padding_value(dev_ctx, out, static_cast(padding_value)); - - auto x_length = x_dims[0]; - const int& x_stride = ComputeStride(0, x_dims); - - auto out_stride_0 = ComputeStride(0, out_dims); - auto out_stride_1 = ComputeStride(1, out_dims); - out_data += - (offset >= 0 ? offset * out_stride_1 : -offset * out_stride_0); - - for (i = 0; i < x_length; i++) { - out_data[i * (out_stride_0 + out_stride_1)] = x_data[i * x_stride]; - } - } else { - auto out_length = out_dims[0]; - const int& x_stride_0 = ComputeStride(0, x_dims); - const int& x_stride_1 = ComputeStride(1, x_dims); - - auto out_stride_0 = ComputeStride(0, out_dims); - x_data += (offset >= 0 ? offset * x_stride_1 : -offset * x_stride_0); - for (i = 0; i < out_length; i++) { - out_data[i * out_stride_0] = x_data[i * (x_stride_0 + x_stride_1)]; - } - } - } -}; } // namespace operators } // namespace paddle namespace ops = paddle::operators; +DELCARE_INFER_SHAPE_FUNCTOR(diag_v2, DiagInferShapeFunctor, + PT_INFER_META(phi::DiagInferMeta)); + REGISTER_OPERATOR( diag_v2, ops::DiagV2Op, ops::DiagV2OpMaker, paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL( - diag_v2, ops::DiagV2Kernel, - ops::DiagV2Kernel, - ops::DiagV2Kernel, - ops::DiagV2Kernel); + paddle::framework::EmptyGradOpMaker, + DiagInferShapeFunctor); diff --git a/paddle/fluid/operators/diag_v2_op.cu b/paddle/fluid/operators/diag_v2_op.cu deleted file mode 100644 index 9b83b68bea..0000000000 --- a/paddle/fluid/operators/diag_v2_op.cu +++ /dev/null @@ -1,128 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/diag_v2_op.h" - -namespace paddle { -namespace operators { - -// Extract the diagonal of a matrix 'x' to a vector 'out'. -template -__global__ void ExtractDiagonalKernel(T* out, const T* x, std::ptrdiff_t start, - std::ptrdiff_t size, - const std::ptrdiff_t sumStride, - const std::ptrdiff_t outStride) { - for (std::ptrdiff_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; - idx += gridDim.x * blockDim.x) { - const std::ptrdiff_t xOffset = start + sumStride * idx; - out[outStride * idx] = x[xOffset]; - } -} - -// Paste a vector 'x' to the diagonal of a matrix 'out' -template -__global__ void PasteDiagonalKernel(T* out, const T* x, std::ptrdiff_t start, - std::ptrdiff_t x_length, - const std::ptrdiff_t sumStride, - const std::ptrdiff_t xStride) { - for (std::ptrdiff_t idx = blockIdx.x * blockDim.x + threadIdx.x; - idx < x_length; idx += gridDim.x * blockDim.x) { - const std::ptrdiff_t outOffset = start + sumStride * idx; - out[outOffset] = x[xStride * idx]; - } -} - -template -class DiagV2CUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* X = context.Input("X"); - auto* x_data = X->data(); - auto x_dims = X->dims(); - int offset = context.Attr("offset"); - auto* out = context.Output("Out"); - T* out_data = out->mutable_data(context.GetPlace()); - auto out_dims = out->dims(); - auto& dev_ctx = context.template device_context(); - - auto GetBlockGridSize = [&dev_ctx](int64_t size) { - const int64_t block_size = - std::min(size, static_cast(dev_ctx.GetMaxThreadsPerBlock())); - int64_t max_threads = dev_ctx.GetMaxPhysicalThreadCount(); - const int64_t max_blocks = std::max(((max_threads - 1) / block_size + 1), - static_cast(1)); - const int64_t grid_size = - std::min(max_blocks, (size + block_size - 1) / block_size); - return std::tuple{block_size, grid_size}; - }; - - if (x_dims.size() == 1) { - float padding_value = context.Attr("padding_value"); - phi::funcs::SetConstant set_padding_value; - set_padding_value(dev_ctx, out, static_cast(padding_value)); - - auto x_length = x_dims[0]; - auto size = (offset > 0) ? x_length + offset : x_length - offset; - const int& x_stride = ComputeStride(0, x_dims); - if (size > 0) { - const auto& out_stride_0 = ComputeStride(0, out_dims); - const auto& out_stride_1 = ComputeStride(1, out_dims); - auto start = - (offset >= 0 ? offset * out_stride_1 : -offset * out_stride_0); - - std::tuple block_grid_size = GetBlockGridSize(size); - - PasteDiagonalKernel< - T><<(block_grid_size), std::get<0>(block_grid_size), 0, - dev_ctx.stream()>>>(out_data, x_data, start, x_length, - out_stride_0 + out_stride_1, x_stride); - } - } else { - const int& x_stride_0 = ComputeStride(0, x_dims); - const int& x_stride_1 = ComputeStride(1, x_dims); - - int64_t size; - if (offset > 0) { - size = std::min(x_dims[0], x_dims[1] - offset); - } else { - size = std::min(x_dims[0] + offset, x_dims[1]); - } - - if (size > 0) { - auto start = (offset >= 0 ? offset * x_stride_1 : -offset * x_stride_0); - const auto& out_stride_0 = ComputeStride(0, out_dims); - - std::tuple block_grid_size = GetBlockGridSize(size); - - ExtractDiagonalKernel< - T><<(block_grid_size), std::get<0>(block_grid_size), 0, - dev_ctx.stream()>>>(out_data, x_data, start, size, - x_stride_0 + x_stride_1, out_stride_0); - } - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - diag_v2, ops::DiagV2CUDAKernel, - ops::DiagV2CUDAKernel, - ops::DiagV2CUDAKernel, - ops::DiagV2CUDAKernel); diff --git a/paddle/fluid/operators/diag_v2_op.h b/paddle/fluid/operators/diag_v2_op.h deleted file mode 100644 index f0bf04bada..0000000000 --- a/paddle/fluid/operators/diag_v2_op.h +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -using DDim = framework::DDim; - -static inline int ComputeStride(int axis, DDim dims) { - int size = 1; - for (int i = axis + 1; i < dims.size(); i++) { - size *= dims[i]; - } - return size; -} - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/core/compat/op_utils.h b/paddle/phi/core/compat/op_utils.h index ec810d4e16..bbf634b4b0 100644 --- a/paddle/phi/core/compat/op_utils.h +++ b/paddle/phi/core/compat/op_utils.h @@ -37,7 +37,8 @@ const std::unordered_set standard_kernel_suffixs({ * after 2.0, and can no longer be occupied by the previously abandoned ops. * They are marked here uniformly. */ -const std::unordered_set deprecated_op_names({"flatten", +const std::unordered_set deprecated_op_names({"diag", + "flatten", "flatten_grad", "matmul", "matmul_grad", diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 58cd43998b..dfaabf7cae 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -310,6 +310,7 @@ void BCELossInferMeta(const MetaTensor& input, } out->set_dims(input_dims); + out->set_dtype(input.dtype()); out->share_lod(input); } diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index ca71d6a56d..72b88f537f 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/phi/infermeta/unary.h" +#include #include #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/enforce.h" @@ -715,6 +716,45 @@ void UnfoldInferMeta(const MetaTensor& x, out->set_dims(phi::make_ddim(out_dims)); } +void DiagInferMeta(const MetaTensor& x, + int offset, + float padding_value, + MetaTensor* out) { + auto x_dims = x.dims(); + + if (x_dims.size() == 1UL) { + int64_t size_ = x_dims[0] + std::abs(offset); + out->set_dims({size_, size_}); + out->set_dtype(x.dtype()); + } else if (x_dims.size() == 2UL) { + int64_t size_ = 0; + if (offset >= 0) { + // Note(LutaoChu): Do not use std::min here, otherwise the calculation + // of `size_` will have unexpected result on Windows Python3.8 + if (x_dims[0] < x_dims[1] - offset) { + size_ = x_dims[0]; + } else { + size_ = x_dims[1] - offset; + } + } else { + // Note(LutaoChu): Do not use std::min here, otherwise the calculation + // of `size_` will have unexpected result on Windows Python3.8 + if (x_dims[0] + offset < x_dims[1]) { + size_ = x_dims[0] + offset; + } else { + size_ = x_dims[1]; + } + } + out->set_dims({size_}); + out->set_dtype(x.dtype()); + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "The input tensor X's dimensions of DiagV2Op should be either 1 or " + "2, but received %d.", + x_dims.size())); + } +} + } // namespace phi PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 21cbe76bb1..1a1605bb1c 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -104,4 +104,9 @@ void UnfoldInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void DiagInferMeta(const MetaTensor& x, + int offset, + float padding_value, + MetaTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/cpu/diag_kernel.cc b/paddle/phi/kernels/cpu/diag_kernel.cc new file mode 100644 index 0000000000..d1e0b8e31e --- /dev/null +++ b/paddle/phi/kernels/cpu/diag_kernel.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/diag_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/diag_functor.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void DiagKernel(const Context& dev_ctx, + const DenseTensor& x, + int offset, + float padding_value, + DenseTensor* out) { + auto* x_data = x.data(); + auto x_dims = x.dims(); + T* out_data = dev_ctx.template Alloc(out); + auto out_dims = out->dims(); + + int64_t i; + if (x_dims.size() == 1) { + phi::funcs::SetConstant set_padding_value; + set_padding_value(dev_ctx, out, static_cast(padding_value)); + + auto x_length = x_dims[0]; + const int& x_stride = phi::funcs::ComputeStride(0, x_dims); + + auto out_stride_0 = phi::funcs::ComputeStride(0, out_dims); + auto out_stride_1 = phi::funcs::ComputeStride(1, out_dims); + out_data += (offset >= 0 ? offset * out_stride_1 : -offset * out_stride_0); + + for (i = 0; i < x_length; i++) { + out_data[i * (out_stride_0 + out_stride_1)] = x_data[i * x_stride]; + } + } else { + auto out_length = out_dims[0]; + const int& x_stride_0 = phi::funcs::ComputeStride(0, x_dims); + const int& x_stride_1 = phi::funcs::ComputeStride(1, x_dims); + + auto out_stride_0 = phi::funcs::ComputeStride(0, out_dims); + x_data += (offset >= 0 ? offset * x_stride_1 : -offset * x_stride_0); + for (i = 0; i < out_length; i++) { + out_data[i * out_stride_0] = x_data[i * (x_stride_0 + x_stride_1)]; + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + diag, CPU, ALL_LAYOUT, phi::DiagKernel, int, float, double, int64_t) {} diff --git a/paddle/phi/kernels/diag_kernel.h b/paddle/phi/kernels/diag_kernel.h new file mode 100644 index 0000000000..8dc919fa63 --- /dev/null +++ b/paddle/phi/kernels/diag_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void DiagKernel(const Context& dev_ctx, + const DenseTensor& x, + int offset, + float padding_value, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/funcs/diag_functor.h b/paddle/phi/kernels/funcs/diag_functor.h new file mode 100644 index 0000000000..a806d1583a --- /dev/null +++ b/paddle/phi/kernels/funcs/diag_functor.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace phi { +namespace funcs { + +inline int ComputeStride(int axis, phi::DDim dims) { + int size = 1; + for (int i = axis + 1; i < dims.size(); i++) { + size *= dims[i]; + } + return size; +} + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/diag_kernel.cu b/paddle/phi/kernels/gpu/diag_kernel.cu new file mode 100644 index 0000000000..fc70639787 --- /dev/null +++ b/paddle/phi/kernels/gpu/diag_kernel.cu @@ -0,0 +1,134 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/diag_kernel.h" + +#include +#include + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/diag_functor.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +// Extract the diagonal of a matrix 'x' to a vector 'out'. +template +__global__ void ExtractDiagonalKernel(T* out, + const T* x, + std::ptrdiff_t start, + std::ptrdiff_t size, + const std::ptrdiff_t sumStride, + const std::ptrdiff_t outStride) { + for (std::ptrdiff_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; + idx += gridDim.x * blockDim.x) { + const std::ptrdiff_t xOffset = start + sumStride * idx; + out[outStride * idx] = x[xOffset]; + } +} + +// Paste a vector 'x' to the diagonal of a matrix 'out' +template +__global__ void PasteDiagonalKernel(T* out, + const T* x, + std::ptrdiff_t start, + std::ptrdiff_t x_length, + const std::ptrdiff_t sumStride, + const std::ptrdiff_t xStride) { + for (std::ptrdiff_t idx = blockIdx.x * blockDim.x + threadIdx.x; + idx < x_length; + idx += gridDim.x * blockDim.x) { + const std::ptrdiff_t outOffset = start + sumStride * idx; + out[outOffset] = x[xStride * idx]; + } +} + +template +void DiagKernel(const Context& dev_ctx, + const DenseTensor& x, + int offset, + float padding_value, + DenseTensor* out) { + auto* x_data = x.data(); + auto x_dims = x.dims(); + T* out_data = dev_ctx.template Alloc(out); + auto out_dims = out->dims(); + + auto GetBlockGridSize = [&dev_ctx](int64_t size) { + const int64_t block_size = + std::min(size, static_cast(dev_ctx.GetMaxThreadsPerBlock())); + int64_t max_threads = dev_ctx.GetMaxPhysicalThreadCount(); + const int64_t max_blocks = + std::max(((max_threads - 1) / block_size + 1), static_cast(1)); + const int64_t grid_size = + std::min(max_blocks, (size + block_size - 1) / block_size); + return std::tuple{block_size, grid_size}; + }; + + if (x_dims.size() == 1) { + phi::funcs::SetConstant set_padding_value; + set_padding_value(dev_ctx, out, static_cast(padding_value)); + + auto x_length = x_dims[0]; + auto size = (offset > 0) ? x_length + offset : x_length - offset; + const int& x_stride = phi::funcs::ComputeStride(0, x_dims); + if (size > 0) { + const auto& out_stride_0 = phi::funcs::ComputeStride(0, out_dims); + const auto& out_stride_1 = phi::funcs::ComputeStride(1, out_dims); + auto start = + (offset >= 0 ? offset * out_stride_1 : -offset * out_stride_0); + + std::tuple block_grid_size = GetBlockGridSize(size); + + PasteDiagonalKernel<<(block_grid_size), + std::get<0>(block_grid_size), + 0, + dev_ctx.stream()>>>(out_data, + x_data, + start, + x_length, + out_stride_0 + out_stride_1, + x_stride); + } + } else { + const int& x_stride_0 = phi::funcs::ComputeStride(0, x_dims); + const int& x_stride_1 = phi::funcs::ComputeStride(1, x_dims); + + int64_t size; + if (offset > 0) { + size = std::min(x_dims[0], x_dims[1] - offset); + } else { + size = std::min(x_dims[0] + offset, x_dims[1]); + } + + if (size > 0) { + auto start = (offset >= 0 ? offset * x_stride_1 : -offset * x_stride_0); + const auto& out_stride_0 = phi::funcs::ComputeStride(0, out_dims); + + std::tuple block_grid_size = GetBlockGridSize(size); + + ExtractDiagonalKernel<<(block_grid_size), + std::get<0>(block_grid_size), + 0, + dev_ctx.stream()>>>( + out_data, x_data, start, size, x_stride_0 + x_stride_1, out_stride_0); + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + diag, GPU, ALL_LAYOUT, phi::DiagKernel, int, int64_t, float, double) {} diff --git a/paddle/phi/ops/compat/diag_sig.cc b/paddle/phi/ops/compat/diag_sig.cc new file mode 100644 index 0000000000..0a14b9095c --- /dev/null +++ b/paddle/phi/ops/compat/diag_sig.cc @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature DiagOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("diag", {"X"}, {"offset", "padding_value"}, {"Out"}); +} + +} // namespace phi + +PD_REGISTER_BASE_KERNEL_NAME(diag_v2, diag); + +PD_REGISTER_ARG_MAPPING_FN(diag_v2, phi::DiagOpArgumentMapping); -- GitLab