From 5c66338f4e9678d1a1254c6f1adb5d124a15512c Mon Sep 17 00:00:00 2001 From: xiongkun Date: Fri, 18 Feb 2022 15:01:36 +0800 Subject: [PATCH] [pten] trans diagonal kernel into pten (#39575) * trans diagonal kernel into pten * fix by code review --- paddle/fluid/operators/diagonal_op.cc | 12 +- paddle/fluid/operators/diagonal_op.cu | 273 ------------------ paddle/fluid/operators/diagonal_op.h | 163 ----------- .../pten/kernels/cpu/diagonal_grad_kernel.cc | 92 ++++++ paddle/pten/kernels/cpu/diagonal_kernel.cc | 90 ++++++ paddle/pten/kernels/diagonal_grad_kernel.h | 29 ++ paddle/pten/kernels/diagonal_kernel.h | 28 ++ paddle/pten/kernels/funcs/diagonal.h | 85 ++++++ .../pten/kernels/gpu/diagonal_grad_kernel.cu | 168 +++++++++++ paddle/pten/kernels/gpu/diagonal_kernel.cu | 165 +++++++++++ paddle/pten/ops/compat/diagonal_sig.cc | 28 ++ 11 files changed, 687 insertions(+), 446 deletions(-) delete mode 100644 paddle/fluid/operators/diagonal_op.cu delete mode 100644 paddle/fluid/operators/diagonal_op.h create mode 100644 paddle/pten/kernels/cpu/diagonal_grad_kernel.cc create mode 100644 paddle/pten/kernels/cpu/diagonal_kernel.cc create mode 100644 paddle/pten/kernels/diagonal_grad_kernel.h create mode 100644 paddle/pten/kernels/diagonal_kernel.h create mode 100644 paddle/pten/kernels/gpu/diagonal_grad_kernel.cu create mode 100644 paddle/pten/kernels/gpu/diagonal_kernel.cu create mode 100644 paddle/pten/ops/compat/diagonal_sig.cc diff --git a/paddle/fluid/operators/diagonal_op.cc b/paddle/fluid/operators/diagonal_op.cc index dd5a84ade59..7a0c256c37b 100644 --- a/paddle/fluid/operators/diagonal_op.cc +++ b/paddle/fluid/operators/diagonal_op.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/diagonal_op.h" +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { @@ -169,18 +169,10 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(DiagonalGradNoNeedBufferVarsInferer, } // namespace paddle namespace ops = paddle::operators; + REGISTER_OPERATOR(diagonal, ops::DiagonalOp, ops::DiagonalOpMaker, ops::DiagonalGradOpMaker, ops::DiagonalGradOpMaker); REGISTER_OPERATOR(diagonal_grad, ops::DiagonalGradOp, ops::DiagonalGradNoNeedBufferVarsInferer) - -REGISTER_OP_CPU_KERNEL(diagonal, ops::DiagonalKernel, - ops::DiagonalKernel, ops::DiagonalKernel, - ops::DiagonalKernel, ops::DiagonalKernel); - -REGISTER_OP_CPU_KERNEL(diagonal_grad, ops::DiagonalGradKernel, - ops::DiagonalGradKernel, - ops::DiagonalGradKernel, - ops::DiagonalGradKernel); diff --git a/paddle/fluid/operators/diagonal_op.cu b/paddle/fluid/operators/diagonal_op.cu deleted file mode 100644 index b1268e903df..00000000000 --- a/paddle/fluid/operators/diagonal_op.cu +++ /dev/null @@ -1,273 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/diagonal_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" - -namespace paddle { -namespace operators { - -using platform::PADDLE_CUDA_NUM_THREADS; - -template -__global__ void Diagonal(const T* data1, T* data2, const int64_t offset_, - int64_t axis1_, int64_t axis2_, int64_t* x_stride, - int64_t* out_stride, int64_t numel, bool is_grad) { - CUDA_KERNEL_LOOP(idx, numel) { - int64_t idx_dim[X_DIM_SIZE] = {0}; - int64_t temp = 0; - for (size_t i = 0; i < X_DIM_SIZE - 1; i++) { - idx_dim[i] = (idx - temp) / x_stride[i]; - temp = temp + idx_dim[i] * x_stride[i]; - } - idx_dim[X_DIM_SIZE - 1] = idx - temp; - - int64_t axis1_dim = idx_dim[axis1_]; - int64_t axis2_dim = idx_dim[axis2_]; - - int64_t out_dim[OUT_DIM_SIZE] = {0}; - int temp_pos = 0; - for (int i = 0; i < X_DIM_SIZE; i++) { - if (i != axis1_ && i != axis2_) { - out_dim[temp_pos] = idx_dim[i]; - temp_pos++; - } - } - bool flag = false; - if (offset_ == 0 && axis1_dim == axis2_dim) { - out_dim[temp_pos] = axis1_dim; - flag = true; - } else if (offset_ > 0 && (axis1_dim + offset_) == axis2_dim) { - out_dim[temp_pos] = axis1_dim; - flag = true; - } else if (offset_ < 0 && (axis1_dim + offset_) == axis2_dim) { - out_dim[temp_pos] = axis2_dim; - flag = true; - } - if (!is_grad) { - if (flag) { - int64_t idx_output = 0; - for (size_t i = 0; i < OUT_DIM_SIZE - 1; i++) { - idx_output = idx_output + out_dim[i] * out_stride[i]; - } - idx_output = idx_output + out_dim[OUT_DIM_SIZE - 1]; - data2[idx_output] = data1[idx]; - } - } else { - if (flag) { - int64_t idx_output = 0; - for (size_t i = 0; i < OUT_DIM_SIZE - 1; i++) { - idx_output = idx_output + out_dim[i] * out_stride[i]; - } - idx_output = idx_output + out_dim[OUT_DIM_SIZE - 1]; - data2[idx] = data1[idx_output]; - } else { - data2[idx] = static_cast(0); - } - } - } -} - -template -class DiagonalCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* input = context.Input("Input"); - const auto* input_data = input->data(); - auto input_dim = input->dims().Get(); - auto input_dim_size = input->dims().size(); - - std::vector res_in = vectorize(framework::stride(input->dims())); - paddle::framework::Tensor input_stride_tensor; - framework::TensorFromVector(res_in, context.device_context(), - &input_stride_tensor); - int64_t* input_stride = input_stride_tensor.data(); - - auto* output = context.Output("Out"); - auto* output_data = output->mutable_data(context.GetPlace()); - auto output_dim = output->dims().Get(); - auto output_dim_size = output->dims().size(); - - std::vector res_out = vectorize(framework::stride(output->dims())); - paddle::framework::Tensor output_stride_tensor; - framework::TensorFromVector(res_out, context.device_context(), - &output_stride_tensor); - int64_t* output_stride = output_stride_tensor.data(); - - const int64_t offset_ = context.Attr("offset"); - const int64_t axis1 = context.Attr("axis1"); - int64_t axis1_ = axis1 < 0 ? input_dim_size + axis1 : axis1; - const int64_t axis2 = context.Attr("axis2"); - int64_t axis2_ = axis2 < 0 ? input_dim_size + axis2 : axis2; - int64_t numel = input->numel(); - - int threads = PADDLE_CUDA_NUM_THREADS; - int blocks = (numel + threads - 1) / threads; - - switch (input_dim_size) { - case 2: - Diagonal<<>>(input_data, output_data, offset_, - axis1_, axis2_, input_stride, - output_stride, numel, false); - break; - case 3: - Diagonal<<>>(input_data, output_data, offset_, - axis1_, axis2_, input_stride, - output_stride, numel, false); - break; - case 4: - Diagonal<<>>(input_data, output_data, offset_, - axis1_, axis2_, input_stride, - output_stride, numel, false); - break; - case 5: - Diagonal<<>>(input_data, output_data, offset_, - axis1_, axis2_, input_stride, - output_stride, numel, false); - break; - case 6: - Diagonal<<>>(input_data, output_data, offset_, - axis1_, axis2_, input_stride, - output_stride, numel, false); - break; - case 7: - Diagonal<<>>(input_data, output_data, offset_, - axis1_, axis2_, input_stride, - output_stride, numel, false); - break; - case 8: - Diagonal<<>>(input_data, output_data, offset_, - axis1_, axis2_, input_stride, - output_stride, numel, false); - break; - case 9: - Diagonal<<>>(input_data, output_data, offset_, - axis1_, axis2_, input_stride, - output_stride, numel, false); - break; - default: - PADDLE_THROW(platform::errors::InvalidArgument( - "The rank of input should be less than 10, but received %d.", - input_dim_size)); - } - } -}; - -template -class DiagonalGradCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const auto* dout = - context.Input(framework::GradVarName("Out")); - const auto* dout_data = dout->data(); - auto dout_dim = dout->dims().Get(); - auto dout_dim_size = dout->dims().size(); - - std::vector res_dout = vectorize(framework::stride(dout->dims())); - paddle::framework::Tensor dout_stride_tensor; - framework::TensorFromVector(res_dout, context.device_context(), - &dout_stride_tensor); - int64_t* dout_stride = dout_stride_tensor.data(); - - auto* dx = - context.Output(framework::GradVarName("Input")); - auto* dx_data = dx->mutable_data(context.GetPlace()); - auto dx_dim = dx->dims().Get(); - auto dx_dim_size = dx->dims().size(); - - std::vector res_dx = vectorize(framework::stride(dx->dims())); - paddle::framework::Tensor dx_stride_tensor; - framework::TensorFromVector(res_dx, context.device_context(), - &dx_stride_tensor); - int64_t* dx_stride = dx_stride_tensor.data(); - - const int64_t offset_ = context.Attr("offset"); - const int64_t axis1 = context.Attr("axis1"); - int64_t axis1_ = axis1 < 0 ? dx_dim_size + axis1 : axis1; - const int64_t axis2 = context.Attr("axis2"); - int64_t axis2_ = axis2 < 0 ? dx_dim_size + axis2 : axis2; - - int64_t numel = dx->numel(); - - int threads = PADDLE_CUDA_NUM_THREADS; - int blocks = (numel + threads - 1) / threads; - - switch (dx_dim_size) { - case 2: - Diagonal<<>>(dout_data, dx_data, offset_, - axis1_, axis2_, dx_stride, - dout_stride, numel, true); - break; - case 3: - Diagonal<<>>(dout_data, dx_data, offset_, - axis1_, axis2_, dx_stride, - dout_stride, numel, true); - break; - case 4: - Diagonal<<>>(dout_data, dx_data, offset_, - axis1_, axis2_, dx_stride, - dout_stride, numel, true); - break; - case 5: - Diagonal<<>>(dout_data, dx_data, offset_, - axis1_, axis2_, dx_stride, - dout_stride, numel, true); - break; - case 6: - Diagonal<<>>(dout_data, dx_data, offset_, - axis1_, axis2_, dx_stride, - dout_stride, numel, true); - break; - case 7: - Diagonal<<>>(dout_data, dx_data, offset_, - axis1_, axis2_, dx_stride, - dout_stride, numel, true); - break; - case 8: - Diagonal<<>>(dout_data, dx_data, offset_, - axis1_, axis2_, dx_stride, - dout_stride, numel, true); - break; - case 9: - Diagonal<<>>(dout_data, dx_data, offset_, - axis1_, axis2_, dx_stride, - dout_stride, numel, true); - break; - default: - PADDLE_THROW(platform::errors::InvalidArgument( - "The rank of output(input@Grad) should be less than 10, but " - "received %d.", - dx_dim_size)); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL(diagonal, ops::DiagonalCUDAKernel, - ops::DiagonalCUDAKernel, - ops::DiagonalCUDAKernel, - ops::DiagonalCUDAKernel, - ops::DiagonalCUDAKernel, - ops::DiagonalCUDAKernel); - -REGISTER_OP_CUDA_KERNEL(diagonal_grad, ops::DiagonalGradCUDAKernel, - ops::DiagonalGradCUDAKernel, - ops::DiagonalGradCUDAKernel, - ops::DiagonalGradCUDAKernel, - ops::DiagonalGradCUDAKernel); diff --git a/paddle/fluid/operators/diagonal_op.h b/paddle/fluid/operators/diagonal_op.h deleted file mode 100644 index a0380e9e52c..00000000000 --- a/paddle/fluid/operators/diagonal_op.h +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { -template - -std::vector ComputeDimStride(const std::vector dim) { - size_t dim_size = dim.size(); - std::vector dim_strides; - dim_strides.resize(dim_size); - for (size_t i = 0; i < dim_size - 1; i++) { - size_t temp_stride = 1; - for (size_t j = i + 1; j < dim_size; j++) { - temp_stride = temp_stride * dim[j]; - } - dim_strides[i] = temp_stride; - } - dim_strides[dim_size - 1] = 1; - return dim_strides; -} -template -class DiagonalKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* input = context.Input("Input"); - const T* input_data = input->data(); - auto input_dim = vectorize(input->dims()); - auto input_dim_size = input_dim.size(); - - auto* output = context.Output("Out"); - T* output_data = output->mutable_data(context.GetPlace()); - auto output_dim = vectorize(output->dims()); - - const int64_t offset_ = context.Attr("offset"); - const int64_t axis1 = context.Attr("axis1"); - int64_t axis1_ = axis1 < 0 ? input_dim_size + axis1 : axis1; - const int64_t axis2 = context.Attr("axis2"); - int64_t axis2_ = axis2 < 0 ? input_dim_size + axis2 : axis2; - - std::vector input_stride = ComputeDimStride(input_dim); - std::vector output_stride = ComputeDimStride(output_dim); - - int64_t numel = input->numel(); - - for (int64_t idx = 0; idx < numel; idx++) { - std::vector idx_dim(input_dim_size); - int64_t temp = 0; - for (size_t i = 0; i < input_dim_size; i++) { - idx_dim[i] = (idx - temp) / input_stride[i]; - temp = temp + idx_dim[i] * input_stride[i]; - } - - int64_t axis1_dim = idx_dim[axis1_]; - int64_t axis2_dim = idx_dim[axis2_]; - - idx_dim.erase(idx_dim.begin() + std::max(axis1_, axis2_)); - idx_dim.erase(idx_dim.begin() + std::min(axis1_, axis2_)); - - bool flag = false; - if (offset_ == 0 && axis1_dim == axis2_dim) { - idx_dim.push_back(axis1_dim); - flag = true; - } else if (offset_ > 0 && (axis1_dim + offset_) == axis2_dim) { - idx_dim.push_back(axis1_dim); - flag = true; - } else if (offset_ < 0 && (axis1_dim + offset_) == axis2_dim) { - idx_dim.push_back(axis2_dim); - flag = true; - } - if (flag) { - int64_t idx_output = 0; - for (size_t i = 0; i < idx_dim.size(); i++) { - idx_output = idx_output + idx_dim[i] * output_stride[i]; - } - output_data[idx_output] = input_data[idx]; - } - } - } -}; - -template -class DiagonalGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const auto* dout = - context.Input(framework::GradVarName("Out")); - const T* dout_data = dout->data(); - auto dout_dim = vectorize(dout->dims()); - - auto* dx = - context.Output(framework::GradVarName("Input")); - T* dx_data = dx->mutable_data(context.GetPlace()); - auto dx_dim = vectorize(dx->dims()); - auto dx_dim_size = dx_dim.size(); - - const int64_t offset_ = context.Attr("offset"); - const int64_t axis1 = context.Attr("axis1"); - int64_t axis1_ = axis1 < 0 ? dx_dim_size + axis1 : axis1; - const int64_t axis2 = context.Attr("axis2"); - int64_t axis2_ = axis2 < 0 ? dx_dim_size + axis2 : axis2; - - std::vector dout_stride = ComputeDimStride(dout_dim); - std::vector dx_stride = ComputeDimStride(dx_dim); - - int64_t numel = dx->numel(); - - for (int64_t idx = 0; idx < numel; idx++) { - std::vector idx_dim(dx_dim_size); - int64_t temp = 0; - for (size_t i = 0; i < dx_dim_size; i++) { - idx_dim[i] = (idx - temp) / dx_stride[i]; - temp = temp + idx_dim[i] * dx_stride[i]; - } - - int64_t axis1_dim = idx_dim[axis1_]; - int64_t axis2_dim = idx_dim[axis2_]; - - idx_dim.erase(idx_dim.begin() + std::max(axis1_, axis2_)); - idx_dim.erase(idx_dim.begin() + std::min(axis1_, axis2_)); - - bool flag = false; - if (offset_ == 0 && axis1_dim == axis2_dim) { - idx_dim.push_back(axis1_dim); - flag = true; - } else if (offset_ > 0 && (axis1_dim + offset_) == axis2_dim) { - idx_dim.push_back(axis1_dim); - flag = true; - } else if (offset_ < 0 && (axis1_dim + offset_) == axis2_dim) { - idx_dim.push_back(axis2_dim); - flag = true; - } - if (flag) { - int64_t idx_output = 0; - for (size_t i = 0; i < idx_dim.size(); i++) { - idx_output = idx_output + idx_dim[i] * dout_stride[i]; - } - dx_data[idx] = dout_data[idx_output]; - } else { - dx_data[idx] = static_cast(0); - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/pten/kernels/cpu/diagonal_grad_kernel.cc b/paddle/pten/kernels/cpu/diagonal_grad_kernel.cc new file mode 100644 index 00000000000..5d47f6b679e --- /dev/null +++ b/paddle/pten/kernels/cpu/diagonal_grad_kernel.cc @@ -0,0 +1,92 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/diagonal_grad_kernel.h" +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/funcs/diagonal.h" + +namespace pten { + +template +void DiagonalGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + int offset, + int axis1, + int axis2, + DenseTensor* in_grad) { + const auto* dout = &out_grad; + const T* dout_data = dout->data(); + auto dout_dim = vectorize(dout->dims()); + + auto* dx = in_grad; + T* dx_data = dev_ctx.template Alloc(dx); + auto dx_dim = vectorize(dx->dims()); + auto dx_dim_size = dx_dim.size(); + + const int64_t offset_ = offset; + int64_t axis1_ = axis1 < 0 ? dx_dim_size + axis1 : axis1; + int64_t axis2_ = axis2 < 0 ? dx_dim_size + axis2 : axis2; + + std::vector dout_stride = funcs::ComputeDimStride(dout_dim); + std::vector dx_stride = funcs::ComputeDimStride(dx_dim); + + int64_t numel = dx->numel(); + + for (int64_t idx = 0; idx < numel; idx++) { + std::vector idx_dim(dx_dim_size); + int64_t temp = 0; + for (size_t i = 0; i < dx_dim_size; i++) { + idx_dim[i] = (idx - temp) / dx_stride[i]; + temp = temp + idx_dim[i] * dx_stride[i]; + } + + int64_t axis1_dim = idx_dim[axis1_]; + int64_t axis2_dim = idx_dim[axis2_]; + + idx_dim.erase(idx_dim.begin() + std::max(axis1_, axis2_)); + idx_dim.erase(idx_dim.begin() + std::min(axis1_, axis2_)); + + bool flag = false; + if (offset_ == 0 && axis1_dim == axis2_dim) { + idx_dim.push_back(axis1_dim); + flag = true; + } else if (offset_ > 0 && (axis1_dim + offset_) == axis2_dim) { + idx_dim.push_back(axis1_dim); + flag = true; + } else if (offset_ < 0 && (axis1_dim + offset_) == axis2_dim) { + idx_dim.push_back(axis2_dim); + flag = true; + } + if (flag) { + int64_t idx_output = 0; + for (size_t i = 0; i < idx_dim.size(); i++) { + idx_output = idx_output + idx_dim[i] * dout_stride[i]; + } + dx_data[idx] = dout_data[idx_output]; + } else { + dx_data[idx] = static_cast(0); + } + } +} +} // namespace pten +PT_REGISTER_KERNEL(diagonal_grad, + CPU, + ALL_LAYOUT, + pten::DiagonalGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/pten/kernels/cpu/diagonal_kernel.cc b/paddle/pten/kernels/cpu/diagonal_kernel.cc new file mode 100644 index 00000000000..1b794a64d29 --- /dev/null +++ b/paddle/pten/kernels/cpu/diagonal_kernel.cc @@ -0,0 +1,90 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/diagonal_kernel.h" +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/funcs/diagonal.h" + +namespace pten { + +template +void DiagonalKernel(const Context& dev_ctx, + const DenseTensor& x, + int offset, + int axis1, + int axis2, + DenseTensor* out) { + auto* input = &x; + const T* input_data = input->data(); + auto input_dim = vectorize(input->dims()); + auto input_dim_size = input_dim.size(); + + auto* output = out; + T* output_data = dev_ctx.template Alloc(output); + auto output_dim = vectorize(output->dims()); + + const int64_t offset_ = offset; + int64_t axis1_ = axis1 < 0 ? input_dim_size + axis1 : axis1; + int64_t axis2_ = axis2 < 0 ? input_dim_size + axis2 : axis2; + + std::vector input_stride = funcs::ComputeDimStride(input_dim); + std::vector output_stride = funcs::ComputeDimStride(output_dim); + + int64_t numel = input->numel(); + + for (int64_t idx = 0; idx < numel; idx++) { + std::vector idx_dim(input_dim_size); + int64_t temp = 0; + for (size_t i = 0; i < input_dim_size; i++) { + idx_dim[i] = (idx - temp) / input_stride[i]; + temp = temp + idx_dim[i] * input_stride[i]; + } + + int64_t axis1_dim = idx_dim[axis1_]; + int64_t axis2_dim = idx_dim[axis2_]; + + idx_dim.erase(idx_dim.begin() + std::max(axis1_, axis2_)); + idx_dim.erase(idx_dim.begin() + std::min(axis1_, axis2_)); + + bool flag = false; + if (offset_ == 0 && axis1_dim == axis2_dim) { + idx_dim.push_back(axis1_dim); + flag = true; + } else if (offset_ > 0 && (axis1_dim + offset_) == axis2_dim) { + idx_dim.push_back(axis1_dim); + flag = true; + } else if (offset_ < 0 && (axis1_dim + offset_) == axis2_dim) { + idx_dim.push_back(axis2_dim); + flag = true; + } + if (flag) { + int64_t idx_output = 0; + for (size_t i = 0; i < idx_dim.size(); i++) { + idx_output = idx_output + idx_dim[i] * output_stride[i]; + } + output_data[idx_output] = input_data[idx]; + } + } +} +} // namespace pten +PT_REGISTER_KERNEL(diagonal, + CPU, + ALL_LAYOUT, + pten::DiagonalKernel, + float, + double, + int, + int64_t, + bool) {} diff --git a/paddle/pten/kernels/diagonal_grad_kernel.h b/paddle/pten/kernels/diagonal_grad_kernel.h new file mode 100644 index 00000000000..8a1073dec32 --- /dev/null +++ b/paddle/pten/kernels/diagonal_grad_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" + +namespace pten { + +template +void DiagonalGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + int offset, + int axis1, + int axis2, + DenseTensor* in_grad); +} // namespace pten diff --git a/paddle/pten/kernels/diagonal_kernel.h b/paddle/pten/kernels/diagonal_kernel.h new file mode 100644 index 00000000000..78a2f1ea8a2 --- /dev/null +++ b/paddle/pten/kernels/diagonal_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" + +namespace pten { + +template +void DiagonalKernel(const Context& dev_ctx, + const DenseTensor& x, + int offset, + int axis1, + int axis2, + DenseTensor* out); +} // pten diff --git a/paddle/pten/kernels/funcs/diagonal.h b/paddle/pten/kernels/funcs/diagonal.h index 533010527f4..f36da55f924 100644 --- a/paddle/pten/kernels/funcs/diagonal.h +++ b/paddle/pten/kernels/funcs/diagonal.h @@ -17,11 +17,13 @@ #if defined(__NVCC__) || defined(__HIPCC__) #include #include +#include "paddle/pten/kernels/primitive/kernel_primitives.h" #endif #include #include "paddle/fluid/platform/for_range.h" +#include "paddle/pten/core/dense_tensor.h" namespace pten { namespace funcs { @@ -126,5 +128,88 @@ DenseTensor Diagonal(const DeviceContext& context, } } +template +std::vector ComputeDimStride(const std::vector dim) { + size_t dim_size = dim.size(); + std::vector dim_strides; + dim_strides.resize(dim_size); + for (size_t i = 0; i < dim_size - 1; i++) { + size_t temp_stride = 1; + for (size_t j = i + 1; j < dim_size; j++) { + temp_stride = temp_stride * dim[j]; + } + dim_strides[i] = temp_stride; + } + dim_strides[dim_size - 1] = 1; + return dim_strides; +} + +#if defined(__NVCC__) || defined(__HIPCC__) +template +__global__ void DiagonalCuda(const T* data1, + T* data2, + const int64_t offset_, + int64_t axis1_, + int64_t axis2_, + int64_t* x_stride, + int64_t* out_stride, + int64_t numel, + bool is_grad) { + CUDA_KERNEL_LOOP(idx, numel) { + int64_t idx_dim[X_DIM_SIZE] = {0}; + int64_t temp = 0; + for (size_t i = 0; i < X_DIM_SIZE - 1; i++) { + idx_dim[i] = (idx - temp) / x_stride[i]; + temp = temp + idx_dim[i] * x_stride[i]; + } + idx_dim[X_DIM_SIZE - 1] = idx - temp; + + int64_t axis1_dim = idx_dim[axis1_]; + int64_t axis2_dim = idx_dim[axis2_]; + + int64_t out_dim[OUT_DIM_SIZE] = {0}; + int temp_pos = 0; + for (int i = 0; i < X_DIM_SIZE; i++) { + if (i != axis1_ && i != axis2_) { + out_dim[temp_pos] = idx_dim[i]; + temp_pos++; + } + } + bool flag = false; + if (offset_ == 0 && axis1_dim == axis2_dim) { + out_dim[temp_pos] = axis1_dim; + flag = true; + } else if (offset_ > 0 && (axis1_dim + offset_) == axis2_dim) { + out_dim[temp_pos] = axis1_dim; + flag = true; + } else if (offset_ < 0 && (axis1_dim + offset_) == axis2_dim) { + out_dim[temp_pos] = axis2_dim; + flag = true; + } + if (!is_grad) { + if (flag) { + int64_t idx_output = 0; + for (size_t i = 0; i < OUT_DIM_SIZE - 1; i++) { + idx_output = idx_output + out_dim[i] * out_stride[i]; + } + idx_output = idx_output + out_dim[OUT_DIM_SIZE - 1]; + data2[idx_output] = data1[idx]; + } + } else { + if (flag) { + int64_t idx_output = 0; + for (size_t i = 0; i < OUT_DIM_SIZE - 1; i++) { + idx_output = idx_output + out_dim[i] * out_stride[i]; + } + idx_output = idx_output + out_dim[OUT_DIM_SIZE - 1]; + data2[idx] = data1[idx_output]; + } else { + data2[idx] = static_cast(0); + } + } + } +} +#endif + } // namespace funcs } // namespace pten diff --git a/paddle/pten/kernels/gpu/diagonal_grad_kernel.cu b/paddle/pten/kernels/gpu/diagonal_grad_kernel.cu new file mode 100644 index 00000000000..02813ea760d --- /dev/null +++ b/paddle/pten/kernels/gpu/diagonal_grad_kernel.cu @@ -0,0 +1,168 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/diagonal_grad_kernel.h" +#include "paddle/pten/kernels/funcs/diagonal.h" + +namespace pten { + +using paddle::platform::PADDLE_CUDA_NUM_THREADS; + +template +void DiagonalGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + int offset, + int axis1, + int axis2, + DenseTensor* in_grad) { + const auto* dout = &out_grad; + const auto* dout_data = dout->data(); + auto dout_dim = dout->dims().Get(); + auto dout_dim_size = dout->dims().size(); + + std::vector res_dout = vectorize(framework::stride(dout->dims())); + DenseTensor dout_stride_tensor; + paddle::framework::TensorFromVector( + res_dout, dev_ctx, &dout_stride_tensor); + int64_t* dout_stride = dout_stride_tensor.data(); + + auto* dx = in_grad; + auto* dx_data = dev_ctx.template Alloc(dx); + auto dx_dim = dx->dims().Get(); + auto dx_dim_size = dx->dims().size(); + + std::vector res_dx = vectorize(framework::stride(dx->dims())); + DenseTensor dx_stride_tensor; + paddle::framework::TensorFromVector( + res_dx, dev_ctx, &dx_stride_tensor); + int64_t* dx_stride = dx_stride_tensor.data(); + + const int64_t offset_ = offset; + int64_t axis1_ = axis1 < 0 ? dx_dim_size + axis1 : axis1; + int64_t axis2_ = axis2 < 0 ? dx_dim_size + axis2 : axis2; + + int64_t numel = dx->numel(); + + int threads = PADDLE_CUDA_NUM_THREADS; + int blocks = (numel + threads - 1) / threads; + + switch (dx_dim_size) { + case 2: + funcs::DiagonalCuda<<>>(dout_data, + dx_data, + offset_, + axis1_, + axis2_, + dx_stride, + dout_stride, + numel, + true); + break; + case 3: + funcs::DiagonalCuda<<>>(dout_data, + dx_data, + offset_, + axis1_, + axis2_, + dx_stride, + dout_stride, + numel, + true); + break; + case 4: + funcs::DiagonalCuda<<>>(dout_data, + dx_data, + offset_, + axis1_, + axis2_, + dx_stride, + dout_stride, + numel, + true); + break; + case 5: + funcs::DiagonalCuda<<>>(dout_data, + dx_data, + offset_, + axis1_, + axis2_, + dx_stride, + dout_stride, + numel, + true); + break; + case 6: + funcs::DiagonalCuda<<>>(dout_data, + dx_data, + offset_, + axis1_, + axis2_, + dx_stride, + dout_stride, + numel, + true); + break; + case 7: + funcs::DiagonalCuda<<>>(dout_data, + dx_data, + offset_, + axis1_, + axis2_, + dx_stride, + dout_stride, + numel, + true); + break; + case 8: + funcs::DiagonalCuda<<>>(dout_data, + dx_data, + offset_, + axis1_, + axis2_, + dx_stride, + dout_stride, + numel, + true); + break; + case 9: + funcs::DiagonalCuda<<>>(dout_data, + dx_data, + offset_, + axis1_, + axis2_, + dx_stride, + dout_stride, + numel, + true); + break; + default: + PADDLE_THROW(errors::InvalidArgument( + "The rank of output(input@Grad) should be less than 10, but " + "received %d.", + dx_dim_size)); + } +} +} // namespace pten +PT_REGISTER_KERNEL(diagonal_grad, + GPU, + ALL_LAYOUT, + pten::DiagonalGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/pten/kernels/gpu/diagonal_kernel.cu b/paddle/pten/kernels/gpu/diagonal_kernel.cu new file mode 100644 index 00000000000..293a1b340b2 --- /dev/null +++ b/paddle/pten/kernels/gpu/diagonal_kernel.cu @@ -0,0 +1,165 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/diagonal_kernel.h" +#include "paddle/pten/kernels/funcs/diagonal.h" + +namespace pten { +using paddle::platform::PADDLE_CUDA_NUM_THREADS; +template +void DiagonalKernel(const Context& dev_ctx, + const DenseTensor& x, + int offset, + int axis1, + int axis2, + DenseTensor* out) { + auto* input = &x; + const auto* input_data = input->data(); + auto input_dim = input->dims().Get(); + auto input_dim_size = input->dims().size(); + + std::vector res_in = vectorize(framework::stride(input->dims())); + DenseTensor input_stride_tensor; + paddle::framework::TensorFromVector( + res_in, dev_ctx, &input_stride_tensor); + int64_t* input_stride = input_stride_tensor.data(); + + auto* output = out; + auto* output_data = dev_ctx.template Alloc(out); + auto output_dim = output->dims().Get(); + auto output_dim_size = output->dims().size(); + + std::vector res_out = vectorize(framework::stride(output->dims())); + DenseTensor output_stride_tensor; + paddle::framework::TensorFromVector( + res_out, dev_ctx, &output_stride_tensor); + int64_t* output_stride = output_stride_tensor.data(); + + const int64_t offset_ = offset; + int64_t axis1_ = axis1 < 0 ? input_dim_size + axis1 : axis1; + int64_t axis2_ = axis2 < 0 ? input_dim_size + axis2 : axis2; + int64_t numel = input->numel(); + + int threads = PADDLE_CUDA_NUM_THREADS; + int blocks = (numel + threads - 1) / threads; + + switch (input_dim_size) { + case 2: + funcs::DiagonalCuda<<>>(input_data, + output_data, + offset_, + axis1_, + axis2_, + input_stride, + output_stride, + numel, + false); + break; + case 3: + funcs::DiagonalCuda<<>>(input_data, + output_data, + offset_, + axis1_, + axis2_, + input_stride, + output_stride, + numel, + false); + break; + case 4: + funcs::DiagonalCuda<<>>(input_data, + output_data, + offset_, + axis1_, + axis2_, + input_stride, + output_stride, + numel, + false); + break; + case 5: + funcs::DiagonalCuda<<>>(input_data, + output_data, + offset_, + axis1_, + axis2_, + input_stride, + output_stride, + numel, + false); + break; + case 6: + funcs::DiagonalCuda<<>>(input_data, + output_data, + offset_, + axis1_, + axis2_, + input_stride, + output_stride, + numel, + false); + break; + case 7: + funcs::DiagonalCuda<<>>(input_data, + output_data, + offset_, + axis1_, + axis2_, + input_stride, + output_stride, + numel, + false); + break; + case 8: + funcs::DiagonalCuda<<>>(input_data, + output_data, + offset_, + axis1_, + axis2_, + input_stride, + output_stride, + numel, + false); + break; + case 9: + funcs::DiagonalCuda<<>>(input_data, + output_data, + offset_, + axis1_, + axis2_, + input_stride, + output_stride, + numel, + false); + break; + default: + PADDLE_THROW(errors::InvalidArgument( + "The rank of input should be less than 10, but received %d.", + input_dim_size)); + } +} +} // namespace pten + +PT_REGISTER_KERNEL(diagonal, + GPU, + ALL_LAYOUT, + pten::DiagonalKernel, + float, + double, + int, + int64_t, + bool) {} diff --git a/paddle/pten/ops/compat/diagonal_sig.cc b/paddle/pten/ops/compat/diagonal_sig.cc new file mode 100644 index 00000000000..7354d3b9223 --- /dev/null +++ b/paddle/pten/ops/compat/diagonal_sig.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +KernelSignature DiagonalGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("diagonal_grad", + {"Input", GradVarName("Out")}, + {"offset", "axis1", "axis2"}, + {GradVarName("Input")}); +} + +} // namespace pten +PT_REGISTER_ARG_MAPPING_FN(diagonal_grad, pten::DiagonalGradOpArgumentMapping); -- GitLab