From 2553af4f41ca27ae33acb137755b2eb1c0686bc6 Mon Sep 17 00:00:00 2001 From: furnace <34057289+windstamp@users.noreply.github.com> Date: Fri, 25 Feb 2022 11:08:41 +0800 Subject: [PATCH] [Phi] mv kernel (#39861) [Phi] mv kernel --- paddle/fluid/operators/mv_op.cc | 15 ++-- paddle/fluid/operators/mv_op.cu | 94 -------------------- paddle/fluid/operators/mv_op.h | 105 ----------------------- paddle/phi/kernels/cpu/mv_grad_kernel.cc | 72 ++++++++++++++++ paddle/phi/kernels/cpu/mv_kernel.cc | 22 +++++ paddle/phi/kernels/gpu/mv_grad_kernel.cu | 83 ++++++++++++++++++ paddle/phi/kernels/gpu/mv_kernel.cu | 22 +++++ paddle/phi/kernels/impl/mv_kernel_impl.h | 45 ++++++++++ paddle/phi/kernels/mv_grad_kernel.h | 35 ++++++++ paddle/phi/kernels/mv_kernel.h | 27 ++++++ paddle/phi/ops/compat/mv_sig.cc | 33 +++++++ 11 files changed, 346 insertions(+), 207 deletions(-) delete mode 100644 paddle/fluid/operators/mv_op.cu delete mode 100644 paddle/fluid/operators/mv_op.h create mode 100644 paddle/phi/kernels/cpu/mv_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/mv_kernel.cc create mode 100644 paddle/phi/kernels/gpu/mv_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/mv_kernel.cu create mode 100644 paddle/phi/kernels/impl/mv_kernel_impl.h create mode 100644 paddle/phi/kernels/mv_grad_kernel.h create mode 100644 paddle/phi/kernels/mv_kernel.h create mode 100644 paddle/phi/ops/compat/mv_sig.cc diff --git a/paddle/fluid/operators/mv_op.cc b/paddle/fluid/operators/mv_op.cc index 01135bab6d..ab9f10070f 100644 --- a/paddle/fluid/operators/mv_op.cc +++ b/paddle/fluid/operators/mv_op.cc @@ -12,7 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/mv_op.h" +#include +#include +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" + namespace paddle { namespace operators { @@ -116,10 +122,3 @@ REGISTER_OPERATOR(mv, ops::MVOp, ops::MVOpMaker, ops::MVOpGradMaker, ops::MVOpGradMaker); REGISTER_OPERATOR(mv_grad, ops::MVOpGrad); - -REGISTER_OP_CPU_KERNEL( - mv, ops::MVKernel, - ops::MVKernel); -REGISTER_OP_CPU_KERNEL( - mv_grad, ops::MVGradKernel, - ops::MVGradKernel); diff --git a/paddle/fluid/operators/mv_op.cu b/paddle/fluid/operators/mv_op.cu deleted file mode 100644 index b8b61ae490..0000000000 --- a/paddle/fluid/operators/mv_op.cu +++ /dev/null @@ -1,94 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/mv_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" - -namespace paddle { -namespace operators { - -template -__global__ void MVGradDxCUDAKernel(const int m, const int n, const T *dout, - const T *vec, T *dx) { - int idx = blockDim.x * blockIdx.x + threadIdx.x; - for (; idx < m * n; idx += blockDim.x * gridDim.x) { - int i = idx / n; - int j = idx % n; - dx[idx] = dout[i] * vec[j]; - } -} - -// Using dimensional constraints on matrix multiplication, it is -// straight-forward to check the following table for when X and Y -// are both matrices. -// -// dX = | dOut Vec^T -// dVec = | X^T dOut -template -class MVGradKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - auto *x = context.Input("X"); - auto *vec = context.Input("Vec"); - auto *dout = - context.Input(framework::GradVarName("Out")); - auto *dx = context.Output(framework::GradVarName("X")); - auto *dvec = - context.Output(framework::GradVarName("Vec")); - - auto dim_x = x->dims(); - int m = dim_x[0]; - int n = dim_x[1]; - - // get data ptr - const T *x_data = x->data(); - const T *vec_data = vec->data(); - const T *dout_data = dout->data(); - - auto &dev_ctx = - context.template device_context(); - auto blas = phi::funcs::GetBlas(dev_ctx); - auto stream = context.cuda_device_context().stream(); - auto config = GetGpuLaunchConfig1D(dev_ctx, m * n); - - if (dx) { - T *dx_data = dx->mutable_data(context.GetPlace()); - - MVGradDxCUDAKernel< - T><<>>( - m, n, dout_data, vec_data, dx_data); - } - - if (dvec) { - T *dvec_data = dvec->mutable_data(context.GetPlace()); - - blas.GEMV(true, dim_x[0], dim_x[1], static_cast(1), x_data, dout_data, - static_cast(0), dvec_data); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_CUDA_KERNEL( - mv, ops::MVKernel, - ops::MVKernel); -REGISTER_OP_CUDA_KERNEL( - mv_grad, ops::MVGradKernel, - ops::MVGradKernel); diff --git a/paddle/fluid/operators/mv_op.h b/paddle/fluid/operators/mv_op.h deleted file mode 100644 index c0a2172af3..0000000000 --- a/paddle/fluid/operators/mv_op.h +++ /dev/null @@ -1,105 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#ifdef PADDLE_WITH_MKLDNN -#include "paddle/fluid/platform/mkldnn_helper.h" -#endif - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class MVKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - auto *x = context.Input("X"); - auto *vec = context.Input("Vec"); - - auto *out = context.Output("Out"); - - auto dim_x = x->dims(); - - // get data ptr - const T *x_data = x->data(); - const T *vec_data = vec->data(); - T *out_data = out->mutable_data(context.GetPlace()); - - auto &dev_ctx = context.template device_context(); - auto blas = phi::funcs::GetBlas(dev_ctx); - - blas.GEMV(false, dim_x[0], dim_x[1], static_cast(1), x_data, vec_data, - static_cast(0), out_data); - } -}; - -// Using dimensional constraints on matrix multiplication, it is -// straight-forward to check the following table for when X and Y -// are both matrices. -// -// dX = | dOut vec^T -// dVec = | X^T dOut -template -class MVGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - auto *x = context.Input("X"); - auto *vec = context.Input("Vec"); - auto *dout = - context.Input(framework::GradVarName("Out")); - auto *dx = context.Output(framework::GradVarName("X")); - auto *dvec = - context.Output(framework::GradVarName("Vec")); - - auto dim_x = x->dims(); - int m = dim_x[0]; - int n = dim_x[1]; - - // get data ptr - const T *x_data = x->data(); - const T *vec_data = vec->data(); - const T *dout_data = dout->data(); - - if (dx) { - T *dx_data = dx->mutable_data(context.GetPlace()); - - for (int i = 0; i < m; ++i) { - for (int j = 0; j < n; ++j) { - dx_data[i * n + j] = dout_data[i] * vec_data[j]; - } - } - } - - if (dvec) { - T *dvec_data = dvec->mutable_data(context.GetPlace()); - - auto &dev_ctx = context.template device_context(); - auto blas = phi::funcs::GetBlas(dev_ctx); - - blas.GEMV(true, dim_x[0], dim_x[1], static_cast(1), x_data, dout_data, - static_cast(0), dvec_data); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/kernels/cpu/mv_grad_kernel.cc b/paddle/phi/kernels/cpu/mv_grad_kernel.cc new file mode 100644 index 0000000000..c3b7f94be4 --- /dev/null +++ b/paddle/phi/kernels/cpu/mv_grad_kernel.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/mv_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" + +namespace phi { + +template +void MvGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& vec, + const DenseTensor& out_grad, + DenseTensor* x_grad, + DenseTensor* vec_grad) { + auto dout = out_grad; + auto dx = x_grad; + auto dvec = vec_grad; + + auto dim_x = x.dims(); + int m = dim_x[0]; + int n = dim_x[1]; + + // get data ptr + const T* x_data = x.data(); + const T* vec_data = vec.data(); + const T* dout_data = dout.data(); + + if (dx) { + T* dx_data = dev_ctx.template Alloc(dx); + + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + dx_data[i * n + j] = dout_data[i] * vec_data[j]; + } + } + } + + if (dvec) { + T* dvec_data = dev_ctx.template Alloc(dvec); + + auto blas = phi::funcs::GetBlas(dev_ctx); + + blas.GEMV(true, + dim_x[0], + dim_x[1], + static_cast(1), + x_data, + dout_data, + static_cast(0), + dvec_data); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(mv_grad, CPU, ALL_LAYOUT, phi::MvGradKernel, float, double) { +} diff --git a/paddle/phi/kernels/cpu/mv_kernel.cc b/paddle/phi/kernels/cpu/mv_kernel.cc new file mode 100644 index 0000000000..7f76ddda6d --- /dev/null +++ b/paddle/phi/kernels/cpu/mv_kernel.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/mv_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/impl/mv_kernel_impl.h" + +PD_REGISTER_KERNEL(mv, CPU, ALL_LAYOUT, phi::MvKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/mv_grad_kernel.cu b/paddle/phi/kernels/gpu/mv_grad_kernel.cu new file mode 100644 index 0000000000..9eb8cd375e --- /dev/null +++ b/paddle/phi/kernels/gpu/mv_grad_kernel.cu @@ -0,0 +1,83 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/mv_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" + +namespace phi { + +template +__global__ void MVGradDxCUDAKernel( + const int m, const int n, const T *dout, const T *vec, T *dx) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + for (; idx < m * n; idx += blockDim.x * gridDim.x) { + int i = idx / n; + int j = idx % n; + dx[idx] = dout[i] * vec[j]; + } +} + +template +void MvGradKernel(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &vec, + const DenseTensor &out_grad, + DenseTensor *x_grad, + DenseTensor *vec_grad) { + auto dout = out_grad; + auto dx = x_grad; + auto dvec = vec_grad; + + auto dim_x = x.dims(); + int m = dim_x[0]; + int n = dim_x[1]; + + // get data ptr + const T *x_data = x.data(); + const T *vec_data = vec.data(); + const T *dout_data = dout.data(); + + auto blas = phi::funcs::GetBlas(dev_ctx); + auto stream = dev_ctx.stream(); + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, m * n); + + if (dx) { + T *dx_data = dev_ctx.template Alloc(dx); + + MVGradDxCUDAKernel< + T><<>>( + m, n, dout_data, vec_data, dx_data); + } + + if (dvec) { + T *dvec_data = dev_ctx.template Alloc(dvec); + + blas.GEMV(true, + dim_x[0], + dim_x[1], + static_cast(1), + x_data, + dout_data, + static_cast(0), + dvec_data); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(mv_grad, GPU, ALL_LAYOUT, phi::MvGradKernel, float, double) { +} diff --git a/paddle/phi/kernels/gpu/mv_kernel.cu b/paddle/phi/kernels/gpu/mv_kernel.cu new file mode 100644 index 0000000000..1faba5a62d --- /dev/null +++ b/paddle/phi/kernels/gpu/mv_kernel.cu @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/mv_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/impl/mv_kernel_impl.h" + +PD_REGISTER_KERNEL(mv, GPU, ALL_LAYOUT, phi::MvKernel, float, double) {} diff --git a/paddle/phi/kernels/impl/mv_kernel_impl.h b/paddle/phi/kernels/impl/mv_kernel_impl.h new file mode 100644 index 0000000000..1754ea323c --- /dev/null +++ b/paddle/phi/kernels/impl/mv_kernel_impl.h @@ -0,0 +1,45 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/funcs/blas/blas.h" + +namespace phi { + +template +void MvKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& vec, + DenseTensor* out) { + auto dim_x = x.dims(); + + // get data ptr + const T* x_data = x.data(); + const T* vec_data = vec.data(); + T* out_data = dev_ctx.template Alloc(out); + + auto blas = phi::funcs::GetBlas(dev_ctx); + + blas.GEMV(false, + dim_x[0], + dim_x[1], + static_cast(1), + x_data, + vec_data, + static_cast(0), + out_data); +} + +} // namespace phi diff --git a/paddle/phi/kernels/mv_grad_kernel.h b/paddle/phi/kernels/mv_grad_kernel.h new file mode 100644 index 0000000000..edc73d8936 --- /dev/null +++ b/paddle/phi/kernels/mv_grad_kernel.h @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +// Using dimensional constraints on matrix multiplication, it is +// straight-forward to check the following table for when X and Y +// are both matrices. +// +// dX = | dOut vec^T +// dVec = | X^T dOut +template +void MvGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& vec, + const DenseTensor& out_grad, + DenseTensor* x_grad, + DenseTensor* vec_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/mv_kernel.h b/paddle/phi/kernels/mv_kernel.h new file mode 100644 index 0000000000..ab4f0b8279 --- /dev/null +++ b/paddle/phi/kernels/mv_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void MvKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& vec, + DenseTensor* out); + +} // namepsace phi diff --git a/paddle/phi/ops/compat/mv_sig.cc b/paddle/phi/ops/compat/mv_sig.cc new file mode 100644 index 0000000000..ab0d31ee31 --- /dev/null +++ b/paddle/phi/ops/compat/mv_sig.cc @@ -0,0 +1,33 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature MvOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("mv", {"X", "Vec"}, {}, {"Out"}); +} + +KernelSignature MvGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("mv_grad", + {"X", "Vec", GradVarName("Out")}, + {}, + {GradVarName("X"), GradVarName("Vec")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(mv, phi::MvOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(mv_grad, phi::MvGradOpArgumentMapping); -- GitLab