From 48b4366c707ab570d7012e213d3eccef73ac40a4 Mon Sep 17 00:00:00 2001 From: Yang <3349368+m3ngyang@users.noreply.github.com> Date: Tue, 8 Mar 2022 16:51:44 +0800 Subject: [PATCH] [Phi] move ops: maxout/take_along_axis/put_along_axis (#39959) * [Phi] move put_along_axis/take_along_axis/maxout * use phi::Copy --- paddle/fluid/operators/math/maxouting.cc | 151 +++++++++--------- paddle/fluid/operators/math/maxouting.cu | 107 +++++++------ paddle/fluid/operators/math/maxouting.h | 2 +- paddle/fluid/operators/maxout_op.cc | 13 +- paddle/fluid/operators/maxout_op.cu.cc | 24 --- paddle/fluid/operators/maxout_op.h | 72 --------- paddle/fluid/operators/put_along_axis_op.cc | 16 +- paddle/fluid/operators/put_along_axis_op.cu | 134 ---------------- paddle/fluid/operators/put_along_axis_op.h | 124 -------------- paddle/fluid/operators/take_along_axis_op.cc | 16 +- paddle/fluid/operators/take_along_axis_op.cu | 97 ----------- paddle/fluid/operators/take_along_axis_op.h | 92 ----------- paddle/phi/kernels/CMakeLists.txt | 8 +- paddle/phi/kernels/cpu/maxout_grad_kernel.cc | 20 +++ paddle/phi/kernels/cpu/maxout_kernel.cc | 19 +++ .../kernels/cpu/put_along_axis_grad_kernel.cc | 83 ++++++++++ .../phi/kernels/cpu/put_along_axis_kernel.cc | 87 ++++++++++ .../cpu/take_along_axis_grad_kernel.cc | 71 ++++++++ .../phi/kernels/cpu/take_along_axis_kernel.cc | 60 +++++++ paddle/phi/kernels/gpu/maxout_grad_kernel.cu | 20 +++ paddle/phi/kernels/gpu/maxout_kernel.cu | 19 +++ .../kernels/gpu/put_along_axis_grad_kernel.cu | 79 +++++++++ .../phi/kernels/gpu/put_along_axis_kernel.cu | 86 ++++++++++ .../gpu/take_along_axis_grad_kernel.cu | 72 +++++++++ .../phi/kernels/gpu/take_along_axis_kernel.cu | 59 +++++++ .../kernels/impl/maxout_grad_kernel_impl.h | 45 ++++++ paddle/phi/kernels/impl/maxout_kernel_impl.h | 37 +++++ paddle/phi/kernels/maxout_grad_kernel.h | 30 ++++ paddle/phi/kernels/maxout_kernel.h | 28 ++++ .../phi/kernels/put_along_axis_grad_kernel.h | 33 ++++ paddle/phi/kernels/put_along_axis_kernel.h | 32 ++++ .../phi/kernels/take_along_axis_grad_kernel.h | 29 ++++ paddle/phi/kernels/take_along_axis_kernel.h | 28 ++++ paddle/phi/ops/compat/maxout_sig.cc | 33 ++++ paddle/phi/ops/compat/put_along_axis_sig.cc | 38 +++++ paddle/phi/ops/compat/take_along_axis_sig.cc | 37 +++++ 36 files changed, 1191 insertions(+), 710 deletions(-) delete mode 100644 paddle/fluid/operators/maxout_op.cu.cc delete mode 100644 paddle/fluid/operators/maxout_op.h delete mode 100644 paddle/fluid/operators/put_along_axis_op.cu delete mode 100644 paddle/fluid/operators/put_along_axis_op.h delete mode 100644 paddle/fluid/operators/take_along_axis_op.cu delete mode 100644 paddle/fluid/operators/take_along_axis_op.h create mode 100644 paddle/phi/kernels/cpu/maxout_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/maxout_kernel.cc create mode 100644 paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/put_along_axis_kernel.cc create mode 100644 paddle/phi/kernels/cpu/take_along_axis_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/take_along_axis_kernel.cc create mode 100644 paddle/phi/kernels/gpu/maxout_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/maxout_kernel.cu create mode 100644 paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/put_along_axis_kernel.cu create mode 100644 paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/take_along_axis_kernel.cu create mode 100644 paddle/phi/kernels/impl/maxout_grad_kernel_impl.h create mode 100644 paddle/phi/kernels/impl/maxout_kernel_impl.h create mode 100644 paddle/phi/kernels/maxout_grad_kernel.h create mode 100644 paddle/phi/kernels/maxout_kernel.h create mode 100644 paddle/phi/kernels/put_along_axis_grad_kernel.h create mode 100644 paddle/phi/kernels/put_along_axis_kernel.h create mode 100644 paddle/phi/kernels/take_along_axis_grad_kernel.h create mode 100644 paddle/phi/kernels/take_along_axis_kernel.h create mode 100644 paddle/phi/ops/compat/maxout_sig.cc create mode 100644 paddle/phi/ops/compat/put_along_axis_sig.cc create mode 100644 paddle/phi/ops/compat/take_along_axis_sig.cc diff --git a/paddle/fluid/operators/math/maxouting.cc b/paddle/fluid/operators/math/maxouting.cc index 45556e97d1..28ec3a8710 100644 --- a/paddle/fluid/operators/math/maxouting.cc +++ b/paddle/fluid/operators/math/maxouting.cc @@ -14,106 +14,107 @@ limitations under the License. */ #include "paddle/fluid/operators/math/maxouting.h" +#include "paddle/phi/backends/cpu/cpu_context.h" + namespace paddle { namespace operators { namespace math { // All tensors are in NCHW or NHWC format, and the groups must be greater than 1 -template -class MaxOutFunctor { - public: - void operator()(const platform::CPUDeviceContext& context, - const framework::Tensor& input, framework::Tensor* output, - const int groups, const int axis) { - const int batch_size = input.dims()[0]; - const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]); - const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]); - const int output_channels = output->dims()[axis]; - int fea_size = input_height * input_width; - // c_size means the output size of each sample - int c_size = fea_size * output_channels; - const T* input_data = input.data(); - T* output_data = output->mutable_data(context.GetPlace()); - for (int i = 0; i < batch_size; ++i) { - int new_bindex = c_size * i; - for (int c = 0; c < output_channels; ++c) { - int new_cindex = fea_size * c; - for (int f = 0; f < fea_size; ++f) { - T ele = static_cast(-FLT_MAX); - int input_idx, output_idx; - for (int ph = 0; ph < groups; ++ph) { - if (axis == 1) { - input_idx = - (new_bindex + new_cindex) * groups + ph * fea_size + f; - } else { - input_idx = (new_bindex + f * output_channels + c) * groups + ph; - } - T x = input_data[input_idx]; - ele = ele > x ? ele : x; - } +template +void MaxOutFunctor::operator()(const DeviceContext& context, + const framework::Tensor& input, + framework::Tensor* output, + const int groups, + const int axis) { + const int batch_size = input.dims()[0]; + const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]); + const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]); + const int output_channels = output->dims()[axis]; + int fea_size = input_height * input_width; + // c_size means the output size of each sample + int c_size = fea_size * output_channels; + const T* input_data = input.data(); + T* output_data = output->mutable_data(context.GetPlace()); + for (int i = 0; i < batch_size; ++i) { + int new_bindex = c_size * i; + for (int c = 0; c < output_channels; ++c) { + int new_cindex = fea_size * c; + for (int f = 0; f < fea_size; ++f) { + T ele = static_cast(-FLT_MAX); + int input_idx, output_idx; + for (int ph = 0; ph < groups; ++ph) { if (axis == 1) { - output_idx = new_bindex + new_cindex + f; + input_idx = (new_bindex + new_cindex) * groups + ph * fea_size + f; } else { - output_idx = new_bindex + f * output_channels + c; + input_idx = (new_bindex + f * output_channels + c) * groups + ph; } - output_data[output_idx] = ele; + T x = input_data[input_idx]; + ele = ele > x ? ele : x; } + if (axis == 1) { + output_idx = new_bindex + new_cindex + f; + } else { + output_idx = new_bindex + f * output_channels + c; + } + output_data[output_idx] = ele; } } } -}; +} -template -class MaxOutGradFunctor { - public: - void operator()(const platform::CPUDeviceContext& context, - const framework::Tensor& input, framework::Tensor* input_grad, - const framework::Tensor& output, - const framework::Tensor& output_grad, const int groups, - const int axis) { - const int batch_size = input.dims()[0]; - const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]); - const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]); - const int output_channels = output.dims()[axis]; - int fea_size = input_height * input_width; - const T* input_data = input.data(); - const T* output_data = output.data(); - const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad->mutable_data(context.GetPlace()); +template +void MaxOutGradFunctor::operator()( + const DeviceContext& context, const framework::Tensor& input, + framework::Tensor* input_grad, const framework::Tensor& output, + const framework::Tensor& output_grad, const int groups, const int axis) { + const int batch_size = input.dims()[0]; + const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]); + const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]); + const int output_channels = output.dims()[axis]; + int fea_size = input_height * input_width; + const T* input_data = input.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); - for (int i = 0; i < batch_size; ++i) { - int blen = fea_size * output_channels * i; - for (int c = 0; c < output_channels; ++c) { - int clen = fea_size * c; - for (int f = 0; f < fea_size; ++f) { - int input_idx0, output_idx; - bool continue_match = true; - if (axis == 1) { - input_idx0 = (blen + clen) * groups + f; - output_idx = blen + clen + f; - } else { - input_idx0 = (blen + f * output_channels + c) * groups; - output_idx = blen + f * output_channels + c; - } - for (int g = 0; g < groups && continue_match; ++g) { - int idx_offset = (axis == 1 ? fea_size * g : g); - int input_idx = input_idx0 + idx_offset; - if (input_data[input_idx] == output_data[output_idx]) { - input_grad_data[input_idx] += output_grad_data[output_idx]; - continue_match = false; - } + for (int i = 0; i < batch_size; ++i) { + int blen = fea_size * output_channels * i; + for (int c = 0; c < output_channels; ++c) { + int clen = fea_size * c; + for (int f = 0; f < fea_size; ++f) { + int input_idx0, output_idx; + bool continue_match = true; + if (axis == 1) { + input_idx0 = (blen + clen) * groups + f; + output_idx = blen + clen + f; + } else { + input_idx0 = (blen + f * output_channels + c) * groups; + output_idx = blen + f * output_channels + c; + } + for (int g = 0; g < groups && continue_match; ++g) { + int idx_offset = (axis == 1 ? fea_size * g : g); + int input_idx = input_idx0 + idx_offset; + if (input_data[input_idx] == output_data[output_idx]) { + input_grad_data[input_idx] += output_grad_data[output_idx]; + continue_match = false; } } } } } -}; +} template class MaxOutGradFunctor; template class MaxOutGradFunctor; template class MaxOutFunctor; template class MaxOutFunctor; +template class MaxOutGradFunctor; +template class MaxOutGradFunctor; +template class MaxOutFunctor; +template class MaxOutFunctor; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/maxouting.cu b/paddle/fluid/operators/math/maxouting.cu index 1856fb4eb4..1d0478db5e 100644 --- a/paddle/fluid/operators/math/maxouting.cu +++ b/paddle/fluid/operators/math/maxouting.cu @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/maxouting.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_context.h" namespace paddle { namespace operators { @@ -95,61 +96,57 @@ __global__ void KernelMaxoutGrad(const int nthreads, const T* input_data, /* * All tensors are in NCHW or NHWC format. */ -template -class MaxOutFunctor { - public: - void operator()(const platform::CUDADeviceContext& context, - const framework::Tensor& input, framework::Tensor* output, - const int groups, const int axis) { - const int batch_size = input.dims()[0]; - const int input_channels = input.dims()[axis]; - const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]); - const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]); - const int output_channels = output->dims()[axis]; - - const T* input_data = input.data(); - T* output_data = output->mutable_data(context.GetPlace()); - int nthreads = output->numel(); - int blocks = (nthreads + 1024 - 1) / 1024; - dim3 threads(1024, 1); - dim3 grid(blocks, 1); - - KernelMaxOut<<>>( - nthreads, input_data, input_channels, input_height, input_width, groups, - axis, output_data); - } -}; +template +void MaxOutFunctor::operator()(const DeviceContext& context, + const framework::Tensor& input, + framework::Tensor* output, + const int groups, + const int axis) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[axis]; + const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]); + const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]); + const int output_channels = output->dims()[axis]; + + const T* input_data = input.data(); + T* output_data = output->mutable_data(context.GetPlace()); + int nthreads = output->numel(); + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxOut<<>>( + nthreads, input_data, input_channels, input_height, input_width, groups, + axis, output_data); +} + /* * All tensors are in NCHW or NHWC format. */ -template -class MaxOutGradFunctor { - public: - void operator()(const platform::CUDADeviceContext& context, - const framework::Tensor& input, framework::Tensor* input_grad, - const framework::Tensor& output, - const framework::Tensor& output_grad, const int groups, - const int axis) { - const int batch_size = input.dims()[0]; - const int input_channels = input.dims()[axis]; - const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]); - const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]); - const int output_channels = output.dims()[axis]; - - const T* input_data = input.data(); - const T* output_data = output.data(); - const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad->mutable_data(context.GetPlace()); - int nthreads = output.numel(); - int blocks = (nthreads + 1024 - 1) / 1024; - dim3 threads(1024, 1); - dim3 grid(blocks, 1); - - KernelMaxoutGrad<<>>( - nthreads, input_data, output_data, output_grad_data, input_grad_data, - input_channels, input_height, input_width, groups, axis); - } -}; +template +void MaxOutGradFunctor::operator()( + const DeviceContext& context, const framework::Tensor& input, + framework::Tensor* input_grad, const framework::Tensor& output, + const framework::Tensor& output_grad, const int groups, const int axis) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[axis]; + const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]); + const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]); + const int output_channels = output.dims()[axis]; + + const T* input_data = input.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); + int nthreads = output.numel(); + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxoutGrad<<>>( + nthreads, input_data, output_data, output_grad_data, input_grad_data, + input_channels, input_height, input_width, groups, axis); +} template class MaxOutGradFunctor; template class MaxOutGradFunctor; @@ -157,6 +154,12 @@ template class MaxOutGradFunctor; template class MaxOutFunctor; template class MaxOutFunctor; +template class MaxOutGradFunctor; +template class MaxOutGradFunctor; + +template class MaxOutFunctor; +template class MaxOutFunctor; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/maxouting.h b/paddle/fluid/operators/math/maxouting.h index 0d8372df8a..1f4964f771 100644 --- a/paddle/fluid/operators/math/maxouting.h +++ b/paddle/fluid/operators/math/maxouting.h @@ -30,7 +30,7 @@ class MaxOutFunctor { const int axis = 1); }; -template +template class MaxOutGradFunctor { public: void operator()(const DeviceContext& context, const framework::Tensor& input, diff --git a/paddle/fluid/operators/maxout_op.cc b/paddle/fluid/operators/maxout_op.cc index bd9ebd2977..e55369e069 100644 --- a/paddle/fluid/operators/maxout_op.cc +++ b/paddle/fluid/operators/maxout_op.cc @@ -12,14 +12,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "paddle/fluid/operators/maxout_op.h" #include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" + namespace paddle { namespace operators { -using framework::Tensor; - class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -130,10 +130,3 @@ REGISTER_OPERATOR( paddle::framework::DefaultGradOpMaker, paddle::framework::DefaultGradOpMaker); REGISTER_OPERATOR(maxout_grad, ops::MaxOutOpGrad); -REGISTER_OP_CPU_KERNEL( - maxout, ops::MaxOutKernel, - ops::MaxOutKernel); -REGISTER_OP_CPU_KERNEL( - maxout_grad, - ops::MaxOutGradKernel, - ops::MaxOutGradKernel); diff --git a/paddle/fluid/operators/maxout_op.cu.cc b/paddle/fluid/operators/maxout_op.cu.cc deleted file mode 100644 index be1e81bb86..0000000000 --- a/paddle/fluid/operators/maxout_op.cu.cc +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/maxout_op.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - maxout, ops::MaxOutKernel, - ops::MaxOutKernel); -REGISTER_OP_CUDA_KERNEL( - maxout_grad, - ops::MaxOutGradKernel, - ops::MaxOutGradKernel); diff --git a/paddle/fluid/operators/maxout_op.h b/paddle/fluid/operators/maxout_op.h deleted file mode 100644 index 9229982939..0000000000 --- a/paddle/fluid/operators/maxout_op.h +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/maxouting.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class MaxOutKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* in_x = context.Input("X"); - Tensor* out = context.Output("Out"); - int groups = context.template Attr("groups"); - int axis = context.template Attr("axis"); - if (axis < 0) { - axis += in_x->dims().size(); - } - - math::MaxOutFunctor maxout_forward; - maxout_forward(context.template device_context(), *in_x, out, - groups, axis); - } -}; - -template -class MaxOutGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* in_x = context.Input("X"); - const Tensor* out = context.Input("Out"); - const Tensor* out_grad = - context.Input(framework::GradVarName("Out")); - Tensor* in_x_grad = context.Output(framework::GradVarName("X")); - int groups = context.template Attr("groups"); - int axis = context.template Attr("axis"); - if (axis < 0) { - axis += in_x->dims().size(); - } - - auto& device_ctx = context.template device_context(); - phi::funcs::SetConstant zero; - if (in_x_grad) { - in_x_grad->mutable_data(context.GetPlace()); - zero(device_ctx, in_x_grad, static_cast(0.0)); - math::MaxOutGradFunctor maxout_backward; - maxout_backward(device_ctx, *in_x, in_x_grad, *out, *out_grad, groups, - axis); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/put_along_axis_op.cc b/paddle/fluid/operators/put_along_axis_op.cc index 6b0d6f332b..54e31845ad 100644 --- a/paddle/fluid/operators/put_along_axis_op.cc +++ b/paddle/fluid/operators/put_along_axis_op.cc @@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/put_along_axis_op.h" #include #include #include + +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/phi/core/ddim.h" @@ -123,16 +124,3 @@ REGISTER_OPERATOR(put_along_axis, ops::PutAlongAxisOp, ops::PutAlongAxisOpMaker, paddle::operators::PutAlongAxisInplaceInferer); REGISTER_OPERATOR(put_along_axis_grad, ops::PutAlongAxisGradOp); - -REGISTER_OP_CPU_KERNEL(put_along_axis, ops::PutAlongAxisOpKernel, - ops::PutAlongAxisOpKernel, - ops::PutAlongAxisOpKernel, - ops::PutAlongAxisOpKernel, - ops::PutAlongAxisOpKernel); - -REGISTER_OP_CPU_KERNEL(put_along_axis_grad, - ops::PutAlongAxisGradOpKernel, - ops::PutAlongAxisGradOpKernel, - ops::PutAlongAxisGradOpKernel, - ops::PutAlongAxisGradOpKernel, - ops::PutAlongAxisGradOpKernel); diff --git a/paddle/fluid/operators/put_along_axis_op.cu b/paddle/fluid/operators/put_along_axis_op.cu deleted file mode 100644 index 5508023efa..0000000000 --- a/paddle/fluid/operators/put_along_axis_op.cu +++ /dev/null @@ -1,134 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/operators/put_along_axis_op.h" -#include "paddle/phi/core/ddim.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -template -class PutAlongAxisCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, - platform::errors::PreconditionNotMet( - "PutAlongAxisCUDAKernel only runs on GPU device.")); - auto input = ctx.Input("Input"); - auto axis = ctx.Attr("Axis"); - auto value = ctx.Input("Value"); - auto index = ctx.Input("Index"); - auto reduce_op = ctx.Attr("Reduce"); - auto result = ctx.Output("Result"); - const platform::DeviceContext &device_ctx = ctx.device_context(); - - const auto &index_type = framework::TransToProtoVarType(index->dtype()); - - framework::TensorCopy(*input, ctx.GetPlace(), result); - if (reduce_op == "add") { - if (index_type == framework::proto::VarType::INT32) { - gpu_scatter_add_kernel(*result, axis, *index, *value, - device_ctx); - } else if (index_type == framework::proto::VarType::INT64) { - gpu_scatter_add_kernel(*result, axis, *index, *value, - device_ctx); - } - } else if (reduce_op == "multiply" || reduce_op == "mul") { - if (index_type == framework::proto::VarType::INT32) { - gpu_scatter_mul_kernel(*result, axis, *index, *value, - device_ctx); - } else if (index_type == framework::proto::VarType::INT64) { - gpu_scatter_mul_kernel(*result, axis, *index, *value, - device_ctx); - } - } else if (reduce_op == "assign") { - if (index_type == framework::proto::VarType::INT32) { - gpu_scatter_assign_kernel(*result, axis, *index, *value, - device_ctx); - } else if (index_type == framework::proto::VarType::INT64) { - gpu_scatter_assign_kernel(*result, axis, *index, *value, - device_ctx); - } - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "can not support reduce_op: '%s' for scatter kernel, only " - "support reduce op: 'add‘, 'assign', 'mul' and 'multiply', the " - "defalut reduce op is 'assign' ", - reduce_op)); - return; - } - } -}; - -template -class PutAlongAxisGradOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, - platform::errors::PreconditionNotMet( - "PutAlongAxisGradOpCUDAKernel only runs on GPU.")); - - auto input_grad = ctx.Output(framework::GradVarName("Input")); - auto value_grad = ctx.Output(framework::GradVarName("Value")); - auto index = ctx.Input("Index"); - auto result_grad = ctx.Input(framework::GradVarName("Result")); - auto axis = ctx.Attr("Axis"); - - const auto &index_type = framework::TransToProtoVarType(index->dtype()); - if (input_grad) { - framework::TensorCopy(*result_grad, ctx.GetPlace(), input_grad); - if (index_type == framework::proto::VarType::INT32) { - gpu_scatter_input_grad_kernel( - *result_grad, axis, *index, *input_grad, ctx.device_context()); - } else { - gpu_scatter_input_grad_kernel( - *result_grad, axis, *index, *input_grad, ctx.device_context()); - } - } - if (value_grad) { - value_grad->Resize(index->dims()); - value_grad->mutable_data(ctx.GetPlace()); - if (index_type == framework::proto::VarType::INT32) { - gpu_gather_kernel( - *result_grad, axis, *index, *value_grad, - ctx.device_context()); // the gradient of scatter is gather - } else if (index_type == framework::proto::VarType::INT64) { - gpu_gather_kernel(*result_grad, axis, *index, *value_grad, - ctx.device_context()); - } - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL(put_along_axis, ops::PutAlongAxisCUDAKernel, - ops::PutAlongAxisCUDAKernel, - ops::PutAlongAxisCUDAKernel, - ops::PutAlongAxisCUDAKernel, - ops::PutAlongAxisCUDAKernel); -REGISTER_OP_CUDA_KERNEL(put_along_axis_grad, - ops::PutAlongAxisGradOpCUDAKernel, - ops::PutAlongAxisGradOpCUDAKernel, - ops::PutAlongAxisGradOpCUDAKernel, - ops::PutAlongAxisGradOpCUDAKernel, - ops::PutAlongAxisGradOpCUDAKernel); diff --git a/paddle/fluid/operators/put_along_axis_op.h b/paddle/fluid/operators/put_along_axis_op.h deleted file mode 100644 index 38487f5ce2..0000000000 --- a/paddle/fluid/operators/put_along_axis_op.h +++ /dev/null @@ -1,124 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/gather_scatter_kernel.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class PutAlongAxisOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true, - platform::errors::PreconditionNotMet( - "PutAlongAxisOpKernel only runs on CPU.")); - - auto input = ctx.Input("Input"); - auto axis = ctx.Attr("Axis"); - auto value = ctx.Input("Value"); - auto index = ctx.Input("Index"); - auto reduce_op = ctx.Attr("Reduce"); - auto result = ctx.Output("Result"); - - framework::TensorCopy(*input, ctx.GetPlace(), result); - const platform::DeviceContext &device_ctx = ctx.device_context(); - const auto &index_type = framework::TransToProtoVarType(index->dtype()); - if (reduce_op == "add") { - if (index_type == framework::proto::VarType::INT32) { - cpu_scatter_add_kernel(*result, axis, *index, *value, - device_ctx); - } else if (index_type == framework::proto::VarType::INT64) { - cpu_scatter_add_kernel(*result, axis, *index, *value, - device_ctx); - } - } else if (reduce_op == "multiply" || reduce_op == "mul") { - if (index_type == framework::proto::VarType::INT32) { - cpu_scatter_mul_kernel(*result, axis, *index, *value, - device_ctx); - } else if (index_type == framework::proto::VarType::INT64) { - cpu_scatter_mul_kernel(*result, axis, *index, *value, - device_ctx); - } - } else if (reduce_op == "assign") { - if (index_type == framework::proto::VarType::INT32) { - cpu_scatter_assign_kernel(*result, axis, *index, *value, - device_ctx); - } else if (index_type == framework::proto::VarType::INT64) { - cpu_scatter_assign_kernel(*result, axis, *index, *value, - device_ctx); - } - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "can not support reduce_op: '%s' for scatter kernel, only " - "support reduce op: 'add‘, 'assign', 'mul' and 'multiply', the " - "defalut reduce " - "op is 'assign' ", - reduce_op)); - return; - } - } -}; - -template -class PutAlongAxisGradOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true, - platform::errors::PreconditionNotMet( - "PutAlongAxisGradOpKernel only runs on CPU.")); - - auto input_grad = ctx.Output(framework::GradVarName("Input")); - auto value_grad = ctx.Output(framework::GradVarName("Value")); - auto index = ctx.Input("Index"); - auto result_grad = ctx.Input(framework::GradVarName("Result")); - auto axis = ctx.Attr("Axis"); - const auto &index_type = framework::TransToProtoVarType(index->dtype()); - - if (input_grad) { - framework::TensorCopy(*result_grad, ctx.GetPlace(), input_grad); - if (index_type == framework::proto::VarType::INT32) { - cpu_scatter_input_grad_kernel( - // Here passing an unused argument *result_grad, because it's - // convenient to instantiate a bunch of template function with the - // same arguments list. - *result_grad, axis, *index, *input_grad, ctx.device_context()); - } else { - cpu_scatter_input_grad_kernel( - *result_grad, axis, *index, *input_grad, ctx.device_context()); - } - } - - if (value_grad) { - value_grad->Resize(index->dims()); - value_grad->mutable_data(ctx.GetPlace()); - if (index_type == framework::proto::VarType::INT32) { - cpu_gather_kernel(*result_grad, axis, *index, *value_grad, - ctx.device_context()); - } else if (index_type == framework::proto::VarType::INT64) { - cpu_gather_kernel(*result_grad, axis, *index, *value_grad, - ctx.device_context()); - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/take_along_axis_op.cc b/paddle/fluid/operators/take_along_axis_op.cc index 664f103191..fa8a5e9271 100644 --- a/paddle/fluid/operators/take_along_axis_op.cc +++ b/paddle/fluid/operators/take_along_axis_op.cc @@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/take_along_axis_op.h" #include #include #include + +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/phi/core/ddim.h" @@ -139,16 +140,3 @@ REGISTER_OPERATOR(take_along_axis, ops::TakeAlongAxisOp, ops::TakeAlongAxisGradOpMaker); REGISTER_OPERATOR(take_along_axis_grad, ops::TakeAlongAxisGradOp); - -REGISTER_OP_CPU_KERNEL(take_along_axis, ops::TakeAlongAxisOpKernel, - ops::TakeAlongAxisOpKernel, - ops::TakeAlongAxisOpKernel, - ops::TakeAlongAxisOpKernel, - ops::TakeAlongAxisOpKernel); - -REGISTER_OP_CPU_KERNEL(take_along_axis_grad, - ops::TakeAlongAxisGradOpKernel, - ops::TakeAlongAxisGradOpKernel, - ops::TakeAlongAxisGradOpKernel, - ops::TakeAlongAxisGradOpKernel, - ops::TakeAlongAxisGradOpKernel); diff --git a/paddle/fluid/operators/take_along_axis_op.cu b/paddle/fluid/operators/take_along_axis_op.cu deleted file mode 100644 index b6c62d497b..0000000000 --- a/paddle/fluid/operators/take_along_axis_op.cu +++ /dev/null @@ -1,97 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/operators/take_along_axis_op.h" -#include "paddle/phi/core/ddim.h" - -namespace paddle { -namespace operators { - -template -class TakeAlongAxisCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, - platform::errors::PreconditionNotMet( - "This kernel only runs on GPU device.")); - auto input = ctx.Input("Input"); - auto axis = ctx.Attr("Axis"); - auto index = ctx.Input("Index"); - auto result = ctx.Output("Result"); - result->Resize(index->dims()); - result->mutable_data(ctx.GetPlace()); - const auto &index_type = framework::TransToProtoVarType(index->dtype()); - if (index_type == framework::proto::VarType::INT32) { - gpu_gather_kernel(*input, axis, *index, *result, - ctx.device_context()); - } else if (index_type == framework::proto::VarType::INT64) { - gpu_gather_kernel(*input, axis, *index, *result, - ctx.device_context()); - } - } -}; - -template -class TakeAlongAxisGradOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ( - platform::is_gpu_place(ctx.GetPlace()), true, - platform::errors::PreconditionNotMet("This kernel only runs on GPU.")); - - auto input_grad = ctx.Output(framework::GradVarName("Input")); - auto index = ctx.Input("Index"); - auto result_grad = ctx.Input(framework::GradVarName("Result")); - auto axis = ctx.Attr("Axis"); - // We need to know the shape of input matrix to determine the shape of grad - // matrix of input. - auto input = ctx.Input("Input"); - input_grad->Resize(input->dims()); - input_grad->mutable_data(ctx.GetPlace()); - - // Set to zero tensor. - auto &dev_ctx = ctx.template device_context(); - phi::funcs::SetConstant functor; - functor(reinterpret_cast(dev_ctx), - input_grad, static_cast(0)); - const auto &index_type = framework::TransToProtoVarType(index->dtype()); - - if (index_type == framework::proto::VarType::INT32) { - gpu_scatter_add_kernel( - *input_grad, axis, *index, *result_grad, - ctx.device_context()); // the gradient of gather is scatter - } else if (index_type == framework::proto::VarType::INT64) { - gpu_scatter_add_kernel(*input_grad, axis, *index, - *result_grad, ctx.device_context()); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL(take_along_axis, ops::TakeAlongAxisCUDAKernel, - ops::TakeAlongAxisCUDAKernel, - ops::TakeAlongAxisCUDAKernel, - ops::TakeAlongAxisCUDAKernel, - ops::TakeAlongAxisCUDAKernel); -REGISTER_OP_CUDA_KERNEL(take_along_axis_grad, - ops::TakeAlongAxisGradOpCUDAKernel, - ops::TakeAlongAxisGradOpCUDAKernel, - ops::TakeAlongAxisGradOpCUDAKernel, - ops::TakeAlongAxisGradOpCUDAKernel, - ops::TakeAlongAxisGradOpCUDAKernel); diff --git a/paddle/fluid/operators/take_along_axis_op.h b/paddle/fluid/operators/take_along_axis_op.h deleted file mode 100644 index fc781dbddf..0000000000 --- a/paddle/fluid/operators/take_along_axis_op.h +++ /dev/null @@ -1,92 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/gather_scatter_kernel.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class TakeAlongAxisOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ( - platform::is_cpu_place(ctx.GetPlace()), true, - platform::errors::PreconditionNotMet("This kernel only runs on CPU.")); - - auto input = ctx.Input("Input"); - auto axis = ctx.Attr("Axis"); - auto index = ctx.Input("Index"); - auto result = ctx.Output("Result"); - result->Resize(index->dims()); - result->mutable_data(ctx.GetPlace()); - - const auto &index_type = framework::TransToProtoVarType(index->dtype()); - if (index_type == framework::proto::VarType::INT32) { - cpu_gather_kernel(*input, axis, *index, *result, - ctx.device_context()); - } else if (index_type == framework::proto::VarType::INT64) { - cpu_gather_kernel(*input, axis, *index, *result, - ctx.device_context()); - } - } -}; - -template -class TakeAlongAxisGradOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ( - platform::is_cpu_place(ctx.GetPlace()), true, - platform::errors::PreconditionNotMet("This kernel only runs on CPU.")); - - auto input_grad = ctx.Output(framework::GradVarName("Input")); - auto index = ctx.Input("Index"); - auto result_grad = ctx.Input(framework::GradVarName("Result")); - auto axis = ctx.Attr("Axis"); - // We need to know the shape of input matrix to determine the shape of grad - // matrix of input. - auto input = ctx.Input("Input"); - input_grad->Resize(input->dims()); - input_grad->mutable_data(ctx.GetPlace()); - - // Set to zero tensor. - auto &dev_ctx = ctx.template device_context(); - phi::funcs::SetConstant functor; - functor(reinterpret_cast(dev_ctx), - input_grad, static_cast(0)); - - const auto &index_type = framework::TransToProtoVarType(index->dtype()); - if (index_type == framework::proto::VarType::INT32) { - cpu_scatter_add_kernel( - *input_grad, axis, *index, *result_grad, - ctx.device_context()); // the gradient of gather is scatter - } else if (index_type == framework::proto::VarType::INT64) { - cpu_scatter_add_kernel(*input_grad, axis, *index, - *result_grad, ctx.device_context()); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 58ea231bee..de3b5b53f4 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -27,11 +27,17 @@ kernel_library(full_kernel DEPS ${COMMON_KERNEL_DEPS} empty_kernel) # Some kernels depend on some targets that are not commonly used. # These targets are not suitable for common dependencies. # In this case, you need to manually generate them here. -set(MANUAL_BUILD_KERNELS math_kernel softmax_kernel softmax_grad_kernel triangular_solve_grad_kernel) +set(MANUAL_BUILD_KERNELS math_kernel softmax_kernel softmax_grad_kernel triangular_solve_grad_kernel maxout_kernel maxout_grad_kernel put_along_axis_kernel put_along_axis_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel) kernel_library(math_kernel DEPS ${COMMON_KERNEL_DEPS} cast_kernel copy_kernel) kernel_library(softmax_kernel DEPS ${COMMON_KERNEL_DEPS} softmax) kernel_library(softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax) kernel_library(triangular_solve_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_reduce) +kernel_library(maxout_kernel DEPS ${COMMON_KERNEL_DEPS} maxouting) +kernel_library(maxout_grad_kernel DEPS ${COMMON_KERNEL_DEPS} maxouting) +kernel_library(put_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel) +kernel_library(put_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel) +kernel_library(take_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel) +kernel_library(take_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel) # 4. auto parse and build kernel targets by cmake register_kernels(EXCLUDES ${COMMON_BAISC_KERNELS} ${MANUAL_BUILD_KERNELS} DEPS ${COMMON_KERNEL_DEPS} ${COMMON_BAISC_KERNELS} ) diff --git a/paddle/phi/kernels/cpu/maxout_grad_kernel.cc b/paddle/phi/kernels/cpu/maxout_grad_kernel.cc new file mode 100644 index 0000000000..429344a362 --- /dev/null +++ b/paddle/phi/kernels/cpu/maxout_grad_kernel.cc @@ -0,0 +1,20 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/impl/maxout_grad_kernel_impl.h" + +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL( + maxout_grad, CPU, ALL_LAYOUT, phi::MaxOutGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/maxout_kernel.cc b/paddle/phi/kernels/cpu/maxout_kernel.cc new file mode 100644 index 0000000000..e7cd3ab07f --- /dev/null +++ b/paddle/phi/kernels/cpu/maxout_kernel.cc @@ -0,0 +1,19 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/impl/maxout_kernel_impl.h" + +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(maxout, CPU, ALL_LAYOUT, phi::MaxOutKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc b/paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc new file mode 100644 index 0000000000..e94d09e033 --- /dev/null +++ b/paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc @@ -0,0 +1,83 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/put_along_axis_grad_kernel.h" + +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/operators/gather_scatter_kernel.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" + +namespace phi { + +template +void PutAlongAxisGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& index, + const DenseTensor& out_grad, + int axis, + const std::string& reduce, + DenseTensor* x_grad, + DenseTensor* value_grad) { + PADDLE_ENFORCE_EQ( + paddle::platform::is_cpu_place(dev_ctx.GetPlace()), + true, + errors::PreconditionNotMet("PutAlongAxisGradOpKernel only runs on CPU.")); + + const auto& index_type = + paddle::framework::TransToProtoVarType(index.dtype()); + if (x_grad) { + phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); + if (index_type == paddle::framework::proto::VarType::INT32) { + paddle::operators::cpu_scatter_input_grad_kernel( + // Here passing an unused argument out_grad, because it's + // convenient to instantiate a bunch of template function with the + // same arguments list. + out_grad, + axis, + index, + *x_grad, + dev_ctx); + } else { + paddle::operators::cpu_scatter_input_grad_kernel( + out_grad, axis, index, *x_grad, dev_ctx); + } + } + + if (value_grad) { + value_grad->Resize(index.dims()); + value_grad->mutable_data(dev_ctx.GetPlace()); + if (index_type == paddle::framework::proto::VarType::INT32) { + paddle::operators::cpu_gather_kernel( + out_grad, axis, index, *value_grad, dev_ctx); + } else if (index_type == paddle::framework::proto::VarType::INT64) { + paddle::operators::cpu_gather_kernel( + out_grad, axis, index, *value_grad, dev_ctx); + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(put_along_axis_grad, + CPU, + ALL_LAYOUT, + phi::PutAlongAxisGradKernel, + float, + double, + int, + uint8_t, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/put_along_axis_kernel.cc b/paddle/phi/kernels/cpu/put_along_axis_kernel.cc new file mode 100644 index 0000000000..83c9a915ee --- /dev/null +++ b/paddle/phi/kernels/cpu/put_along_axis_kernel.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/put_along_axis_kernel.h" + +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/operators/gather_scatter_kernel.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" + +namespace phi { + +template +void PutAlongAxisKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& index, + const DenseTensor& value, + int axis, + const std::string& reduce, + DenseTensor* out) { + PADDLE_ENFORCE_EQ( + paddle::platform::is_cpu_place(dev_ctx.GetPlace()), + true, + errors::PreconditionNotMet("PutAlongAxisOpKernel only runs on CPU.")); + + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); + const auto& index_type = + paddle::framework::TransToProtoVarType(index.dtype()); + if (reduce == "add") { + if (index_type == paddle::framework::proto::VarType::INT32) { + paddle::operators::cpu_scatter_add_kernel( + *out, axis, index, value, dev_ctx); + } else if (index_type == paddle::framework::proto::VarType::INT64) { + paddle::operators::cpu_scatter_add_kernel( + *out, axis, index, value, dev_ctx); + } + } else if (reduce == "multiply" || reduce == "mul") { + if (index_type == paddle::framework::proto::VarType::INT32) { + paddle::operators::cpu_scatter_mul_kernel( + *out, axis, index, value, dev_ctx); + } else if (index_type == paddle::framework::proto::VarType::INT64) { + paddle::operators::cpu_scatter_mul_kernel( + *out, axis, index, value, dev_ctx); + } + } else if (reduce == "assign") { + if (index_type == paddle::framework::proto::VarType::INT32) { + paddle::operators::cpu_scatter_assign_kernel( + *out, axis, index, value, dev_ctx); + } else if (index_type == paddle::framework::proto::VarType::INT64) { + paddle::operators::cpu_scatter_assign_kernel( + *out, axis, index, value, dev_ctx); + } + } else { + PADDLE_THROW(errors::InvalidArgument( + "can not support reduce: '%s' for scatter kernel, only " + "support reduce op: 'add', 'assign', 'mul' and 'multiply', the " + "defalut reduce " + "op is 'assign' ", + reduce)); + return; + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(put_along_axis, + CPU, + ALL_LAYOUT, + phi::PutAlongAxisKernel, + float, + double, + int, + uint8_t, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/take_along_axis_grad_kernel.cc b/paddle/phi/kernels/cpu/take_along_axis_grad_kernel.cc new file mode 100644 index 0000000000..4443383f40 --- /dev/null +++ b/paddle/phi/kernels/cpu/take_along_axis_grad_kernel.cc @@ -0,0 +1,71 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/take_along_axis_grad_kernel.h" + +#include "paddle/fluid/operators/gather_scatter_kernel.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void TakeAlongAxisGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& index, + const DenseTensor& out_grad, + int axis, + DenseTensor* x_grad) { + PADDLE_ENFORCE_EQ( + paddle::platform::is_cpu_place(dev_ctx.GetPlace()), + true, + errors::PreconditionNotMet("This kernel only runs on CPU.")); + + // We need to know the shape of input matrix to determine the shape of grad + // matrix of input. + x_grad->Resize(x.dims()); + dev_ctx.template Alloc(x_grad); + + // Set to zero tensor. + phi::funcs::SetConstant functor; + functor(dev_ctx, x_grad, static_cast(0)); + + const auto& index_type = + paddle::framework::TransToProtoVarType(index.dtype()); + if (index_type == paddle::framework::proto::VarType::INT32) { + paddle::operators::cpu_scatter_add_kernel( + *x_grad, + axis, + index, + out_grad, + dev_ctx); // the gradient of gather is scatter + } else if (index_type == paddle::framework::proto::VarType::INT64) { + paddle::operators::cpu_scatter_add_kernel( + *x_grad, axis, index, out_grad, dev_ctx); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(take_along_axis_grad, + CPU, + ALL_LAYOUT, + phi::TakeAlongAxisGradKernel, + float, + double, + int, + uint8_t, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/take_along_axis_kernel.cc b/paddle/phi/kernels/cpu/take_along_axis_kernel.cc new file mode 100644 index 0000000000..502db8a22d --- /dev/null +++ b/paddle/phi/kernels/cpu/take_along_axis_kernel.cc @@ -0,0 +1,60 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/take_along_axis_kernel.h" + +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/operators/gather_scatter_kernel.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void TakeAlongAxisKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& index, + int axis, + DenseTensor* out) { + PADDLE_ENFORCE_EQ( + paddle::platform::is_cpu_place(dev_ctx.GetPlace()), + true, + errors::PreconditionNotMet("This kernel only runs on CPU.")); + + out->Resize(index.dims()); + dev_ctx.template Alloc(out); + + const auto& index_type = + paddle::framework::TransToProtoVarType(index.dtype()); + if (index_type == paddle::framework::proto::VarType::INT32) { + paddle::operators::cpu_gather_kernel( + x, axis, index, *out, dev_ctx); + } else if (index_type == paddle::framework::proto::VarType::INT64) { + paddle::operators::cpu_gather_kernel( + x, axis, index, *out, dev_ctx); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(take_along_axis, + CPU, + ALL_LAYOUT, + phi::TakeAlongAxisKernel, + float, + double, + int, + uint8_t, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/maxout_grad_kernel.cu b/paddle/phi/kernels/gpu/maxout_grad_kernel.cu new file mode 100644 index 0000000000..86ff09fd74 --- /dev/null +++ b/paddle/phi/kernels/gpu/maxout_grad_kernel.cu @@ -0,0 +1,20 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/impl/maxout_grad_kernel_impl.h" + +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL( + maxout_grad, GPU, ALL_LAYOUT, phi::MaxOutGradKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/maxout_kernel.cu b/paddle/phi/kernels/gpu/maxout_kernel.cu new file mode 100644 index 0000000000..88776a49f1 --- /dev/null +++ b/paddle/phi/kernels/gpu/maxout_kernel.cu @@ -0,0 +1,19 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/impl/maxout_kernel_impl.h" + +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(maxout, GPU, ALL_LAYOUT, phi::MaxOutKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu b/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu new file mode 100644 index 0000000000..f553da361f --- /dev/null +++ b/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu @@ -0,0 +1,79 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/put_along_axis_grad_kernel.h" + +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/operators/gather_scatter_kernel.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" + +namespace phi { + +template +void PutAlongAxisGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& index, + const DenseTensor& out_grad, + int axis, + const std::string& reduce, + DenseTensor* x_grad, + DenseTensor* value_grad) { + PADDLE_ENFORCE_EQ(paddle::platform::is_gpu_place(dev_ctx.GetPlace()), + true, + errors::PreconditionNotMet( + "PutAlongAxisGradOpCUDAKernel only runs on GPU.")); + + const auto& index_type = + paddle::framework::TransToProtoVarType(index.dtype()); + if (x_grad) { + phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); + if (index_type == paddle::framework::proto::VarType::INT32) { + paddle::operators::gpu_scatter_input_grad_kernel( + out_grad, axis, index, *x_grad, dev_ctx); + } else { + paddle::operators::gpu_scatter_input_grad_kernel( + out_grad, axis, index, *x_grad, dev_ctx); + } + } + if (value_grad) { + value_grad->Resize(index.dims()); + value_grad->mutable_data(dev_ctx.GetPlace()); + if (index_type == paddle::framework::proto::VarType::INT32) { + paddle::operators::gpu_gather_kernel( + out_grad, + axis, + index, + *value_grad, + dev_ctx); // the gradient of scatter is gather + } else if (index_type == paddle::framework::proto::VarType::INT64) { + paddle::operators::gpu_gather_kernel( + out_grad, axis, index, *value_grad, dev_ctx); + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(put_along_axis_grad, + GPU, + ALL_LAYOUT, + phi::PutAlongAxisGradKernel, + float, + double, + int64_t, + int, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/put_along_axis_kernel.cu b/paddle/phi/kernels/gpu/put_along_axis_kernel.cu new file mode 100644 index 0000000000..d363c0c283 --- /dev/null +++ b/paddle/phi/kernels/gpu/put_along_axis_kernel.cu @@ -0,0 +1,86 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/put_along_axis_kernel.h" + +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/operators/gather_scatter_kernel.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" + +namespace phi { + +template +void PutAlongAxisKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& index, + const DenseTensor& value, + int axis, + const std::string& reduce, + DenseTensor* out) { + PADDLE_ENFORCE_EQ(paddle::platform::is_gpu_place(dev_ctx.GetPlace()), + true, + errors::PreconditionNotMet( + "PutAlongAxisCUDAKernel only runs on GPU device.")); + + const auto& index_type = + paddle::framework::TransToProtoVarType(index.dtype()); + + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); + if (reduce == "add") { + if (index_type == paddle::framework::proto::VarType::INT32) { + paddle::operators::gpu_scatter_add_kernel( + *out, axis, index, value, dev_ctx); + } else if (index_type == paddle::framework::proto::VarType::INT64) { + paddle::operators::gpu_scatter_add_kernel( + *out, axis, index, value, dev_ctx); + } + } else if (reduce == "multiply" || reduce == "mul") { + if (index_type == paddle::framework::proto::VarType::INT32) { + paddle::operators::gpu_scatter_mul_kernel( + *out, axis, index, value, dev_ctx); + } else if (index_type == paddle::framework::proto::VarType::INT64) { + paddle::operators::gpu_scatter_mul_kernel( + *out, axis, index, value, dev_ctx); + } + } else if (reduce == "assign") { + if (index_type == paddle::framework::proto::VarType::INT32) { + paddle::operators::gpu_scatter_assign_kernel( + *out, axis, index, value, dev_ctx); + } else if (index_type == paddle::framework::proto::VarType::INT64) { + paddle::operators::gpu_scatter_assign_kernel( + *out, axis, index, value, dev_ctx); + } + } else { + PADDLE_THROW(errors::InvalidArgument( + "can not support reduce: '%s' for scatter kernel, only " + "support reduce op: 'add', 'assign', 'mul' and 'multiply', the " + "defalut reduce op is 'assign' ", + reduce)); + return; + } +} +} // namespace phi + +PD_REGISTER_KERNEL(put_along_axis, + GPU, + ALL_LAYOUT, + phi::PutAlongAxisKernel, + float, + double, + int64_t, + int, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu b/paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu new file mode 100644 index 0000000000..e09cfd370a --- /dev/null +++ b/paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu @@ -0,0 +1,72 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/take_along_axis_grad_kernel.h" + +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/operators/gather_scatter_kernel.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void TakeAlongAxisGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& index, + const DenseTensor& out_grad, + int axis, + DenseTensor* x_grad) { + PADDLE_ENFORCE_EQ( + paddle::platform::is_gpu_place(dev_ctx.GetPlace()), + true, + errors::PreconditionNotMet("This kernel only runs on GPU.")); + + // We need to know the shape of input matrix to determine the shape of grad + // matrix of input. + x_grad->Resize(x.dims()); + dev_ctx.template Alloc(x_grad); + + // Set to zero tensor. + phi::funcs::SetConstant functor; + functor(dev_ctx, x_grad, static_cast(0)); + const auto& index_type = + paddle::framework::TransToProtoVarType(index.dtype()); + + if (index_type == paddle::framework::proto::VarType::INT32) { + paddle::operators::gpu_scatter_add_kernel( + *x_grad, + axis, + index, + out_grad, + dev_ctx); // the gradient of gather is scatter + } else if (index_type == paddle::framework::proto::VarType::INT64) { + paddle::operators::gpu_scatter_add_kernel( + *x_grad, axis, index, out_grad, dev_ctx); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(take_along_axis_grad, + GPU, + ALL_LAYOUT, + phi::TakeAlongAxisGradKernel, + float, + double, + int64_t, + int, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/take_along_axis_kernel.cu b/paddle/phi/kernels/gpu/take_along_axis_kernel.cu new file mode 100644 index 0000000000..63113e3e67 --- /dev/null +++ b/paddle/phi/kernels/gpu/take_along_axis_kernel.cu @@ -0,0 +1,59 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/take_along_axis_kernel.h" + +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/operators/gather_scatter_kernel.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void TakeAlongAxisKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& index, + int axis, + DenseTensor* out) { + PADDLE_ENFORCE_EQ( + paddle::platform::is_gpu_place(dev_ctx.GetPlace()), + true, + errors::PreconditionNotMet("This kernel only runs on GPU device.")); + + out->Resize(index.dims()); + dev_ctx.template Alloc(out); + + const auto& index_type = + paddle::framework::TransToProtoVarType(index.dtype()); + if (index_type == paddle::framework::proto::VarType::INT32) { + paddle::operators::gpu_gather_kernel( + x, axis, index, *out, dev_ctx); + } else if (index_type == paddle::framework::proto::VarType::INT64) { + paddle::operators::gpu_gather_kernel( + x, axis, index, *out, dev_ctx); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(take_along_axis, + GPU, + ALL_LAYOUT, + phi::TakeAlongAxisKernel, + double, + int64_t, + int, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/impl/maxout_grad_kernel_impl.h b/paddle/phi/kernels/impl/maxout_grad_kernel_impl.h new file mode 100644 index 0000000000..546ea74674 --- /dev/null +++ b/paddle/phi/kernels/impl/maxout_grad_kernel_impl.h @@ -0,0 +1,45 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/maxout_grad_kernel.h" + +#include "paddle/fluid/operators/math/maxouting.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void MaxOutGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& out_grad, + int groups, + int axis, + DenseTensor* x_grad) { + if (axis < 0) { + axis += x.dims().size(); + } + + phi::funcs::SetConstant zero; + if (x_grad) { + dev_ctx.template Alloc(x_grad); + zero(dev_ctx, x_grad, static_cast(0.0)); + paddle::operators::math::MaxOutGradFunctor maxout_backward; + maxout_backward(dev_ctx, x, x_grad, out, out_grad, groups, axis); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/maxout_kernel_impl.h b/paddle/phi/kernels/impl/maxout_kernel_impl.h new file mode 100644 index 0000000000..da8c259ebf --- /dev/null +++ b/paddle/phi/kernels/impl/maxout_kernel_impl.h @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/maxout_kernel.h" + +#include "paddle/fluid/operators/math/maxouting.h" + +namespace phi { + +template +void MaxOutKernel(const Context& dev_ctx, + const DenseTensor& x, + int groups, + int axis, + DenseTensor* out) { + if (axis < 0) { + axis += x.dims().size(); + } + + paddle::operators::math::MaxOutFunctor maxout_forward; + maxout_forward(dev_ctx, x, out, groups, axis); +} + +} // namespace phi diff --git a/paddle/phi/kernels/maxout_grad_kernel.h b/paddle/phi/kernels/maxout_grad_kernel.h new file mode 100644 index 0000000000..1ee4e8cc89 --- /dev/null +++ b/paddle/phi/kernels/maxout_grad_kernel.h @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void MaxOutGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& out_grad, + int groups, + int axis, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/maxout_kernel.h b/paddle/phi/kernels/maxout_kernel.h new file mode 100644 index 0000000000..e582575678 --- /dev/null +++ b/paddle/phi/kernels/maxout_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void MaxOutKernel(const Context& dev_ctx, + const DenseTensor& x, + int groups, + int axis, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/put_along_axis_grad_kernel.h b/paddle/phi/kernels/put_along_axis_grad_kernel.h new file mode 100644 index 0000000000..2141443da7 --- /dev/null +++ b/paddle/phi/kernels/put_along_axis_grad_kernel.h @@ -0,0 +1,33 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void PutAlongAxisGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& index, + const DenseTensor& out_grad, + int axis, + const std::string& reduce, + DenseTensor* x_grad, + DenseTensor* value_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/put_along_axis_kernel.h b/paddle/phi/kernels/put_along_axis_kernel.h new file mode 100644 index 0000000000..797d0e364b --- /dev/null +++ b/paddle/phi/kernels/put_along_axis_kernel.h @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void PutAlongAxisKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& index, + const DenseTensor& value, + int axis, + const std::string& reduce, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/take_along_axis_grad_kernel.h b/paddle/phi/kernels/take_along_axis_grad_kernel.h new file mode 100644 index 0000000000..a312c235f6 --- /dev/null +++ b/paddle/phi/kernels/take_along_axis_grad_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void TakeAlongAxisGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& index, + const DenseTensor& out_grad, + int axis, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/take_along_axis_kernel.h b/paddle/phi/kernels/take_along_axis_kernel.h new file mode 100644 index 0000000000..e8fb78556d --- /dev/null +++ b/paddle/phi/kernels/take_along_axis_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void TakeAlongAxisKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& index, + int axis, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/maxout_sig.cc b/paddle/phi/ops/compat/maxout_sig.cc new file mode 100644 index 0000000000..d16dd1c861 --- /dev/null +++ b/paddle/phi/ops/compat/maxout_sig.cc @@ -0,0 +1,33 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature MaxoutArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("maxout", {"X"}, {"groups", "axis"}, {"Out"}); +} + +KernelSignature MaxoutGradArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("maxout_grad", + {"X", "Out", GradVarName("Out")}, + {"groups", "axis"}, + {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(maxout, phi::MaxoutArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(maxout_grad, phi::MaxoutGradArgumentMapping); diff --git a/paddle/phi/ops/compat/put_along_axis_sig.cc b/paddle/phi/ops/compat/put_along_axis_sig.cc new file mode 100644 index 0000000000..5f8dc1cf4c --- /dev/null +++ b/paddle/phi/ops/compat/put_along_axis_sig.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature PutAlongAxisArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("put_along_axis", + {"Input", "Index", "Value"}, + {"Axis", "Reduce"}, + {"Result"}); +} + +KernelSignature PutAlongAxisGradArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("put_along_axis_grad", + {"Input", "Index", GradVarName("Result")}, + {"Axis", "Reduce"}, + {GradVarName("Input"), GradVarName("Value")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(put_along_axis, phi::PutAlongAxisArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(put_along_axis_grad, + phi::PutAlongAxisGradArgumentMapping); diff --git a/paddle/phi/ops/compat/take_along_axis_sig.cc b/paddle/phi/ops/compat/take_along_axis_sig.cc new file mode 100644 index 0000000000..27a996a270 --- /dev/null +++ b/paddle/phi/ops/compat/take_along_axis_sig.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature TakeAlongAxisArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "take_along_axis", {"Input", "Index"}, {"Axis"}, {"Result"}); +} + +KernelSignature TakeAlongAxisGradArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("take_along_axis_grad", + {"Input", "Index", GradVarName("Result")}, + {"Axis"}, + {GradVarName("Input")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(take_along_axis, phi::TakeAlongAxisArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(take_along_axis_grad, + phi::TakeAlongAxisGradArgumentMapping); -- GitLab