diff --git a/paddle/fluid/operators/clip_op.cc b/paddle/fluid/operators/clip_op.cc index 436d1edcedf1e0526d58db75c15da0d5927a5678..6e898d31663fac73bc26d13ddb72acdbe4c6473c 100644 --- a/paddle/fluid/operators/clip_op.cc +++ b/paddle/fluid/operators/clip_op.cc @@ -1,21 +1,23 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/clip_op.h" #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -23,15 +25,6 @@ namespace operators { class ClipOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "clip"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "clip"); - auto x_dims = ctx->GetInputDim("X"); - ctx->SetOutputDim("Out", x_dims); - ctx->ShareLoD("X", /*->*/ "Out"); - } - framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = @@ -176,23 +169,15 @@ class ClipDoubleGradOpMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(clip, ClipInferShapeFunctor, + PD_INFER_META(phi::UnchangedInferMeta)); REGISTER_OPERATOR(clip, ops::ClipOp, ops::ClipOpMaker, ops::ClipGradOpMaker, ops::ClipGradOpMaker, - ops::ClipInplaceInferer); + ops::ClipInplaceInferer, ClipInferShapeFunctor); REGISTER_OPERATOR(clip_grad, ops::ClipOpGrad, ops::ClipGradInplaceInferer, ops::ClipDoubleGradOpMaker, ops::ClipDoubleGradOpMaker); -REGISTER_OP_CPU_KERNEL( - clip, ops::ClipKernel, - ops::ClipKernel, - ops::ClipKernel, - ops::ClipKernel); -REGISTER_OP_CPU_KERNEL( - clip_grad, ops::ClipGradKernel, - ops::ClipGradKernel, - ops::ClipGradKernel, - ops::ClipGradKernel); REGISTER_OP_VERSION(clip) .AddCheckpoint( diff --git a/paddle/fluid/operators/clip_op.cu b/paddle/fluid/operators/clip_op.cu deleted file mode 100644 index 846354fcb81c5f07580533c69a598df62e50ddaf..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/clip_op.cu +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/clip_op.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - clip, ops::ClipKernel, - ops::ClipKernel, - ops::ClipKernel, - ops::ClipKernel, - ops::ClipKernel); - -REGISTER_OP_CUDA_KERNEL( - clip_grad, ops::ClipGradKernel, - ops::ClipGradKernel, - ops::ClipGradKernel, - ops::ClipGradKernel, - ops::ClipGradKernel); diff --git a/paddle/fluid/operators/clip_op.h b/paddle/fluid/operators/clip_op.h deleted file mode 100644 index 3b815cd1fa74a6dc094fa48fd11db25393e4f977..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/clip_op.h +++ /dev/null @@ -1,196 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/selected_rows_functor.h" -#include "paddle/fluid/platform/transform.h" -#if defined(__NVCC__) || defined(__HIPCC__) -#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" -#endif - -namespace paddle { -namespace operators { - -using framework::Tensor; -using platform::Transform; - -template -class ClipFunctor { - public: - explicit ClipFunctor(const T min, const T max) : min_(min), max_(max) {} - HOSTDEVICE T operator()(const T x) const { - return x < min_ ? min_ : x > max_ ? max_ : x; - } - - private: - T min_; - T max_; -}; - -template -class ClipGradFunctor { - public: - explicit ClipGradFunctor(const T min, const T max) : min_(min), max_(max) {} - HOSTDEVICE T operator()(const T x, const T y) const { - return (y > min_ && y < max_) ? x : static_cast(0); - } - - private: - T min_; - T max_; -}; - -template -class ClipKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto max = static_cast(context.Attr("max")); - Tensor max_cpu; - if (context.HasInput("Max")) { - auto* max_t = context.Input("Max"); - auto* max_data = max_t->data(); - if (platform::is_gpu_place(max_t->place())) { - paddle::framework::TensorCopySync(*max_t, platform::CPUPlace(), - &max_cpu); - max_data = max_cpu.data(); - } - max = max_data[0]; - } - max = static_cast(max); - - auto min = static_cast(context.Attr("min")); - Tensor min_cpu; - if (context.HasInput("Min")) { - auto* min_t = context.Input("Min"); - auto* min_data = min_t->data(); - if (platform::is_gpu_place(min_t->place())) { - paddle::framework::TensorCopySync(*min_t, platform::CPUPlace(), - &min_cpu); - min_data = min_cpu.data(); - } - min = min_data[0]; - } - - PADDLE_ENFORCE_LE(min, max, - platform::errors::InvalidArgument( - "max should be greater than or equal to min. " - "But received min = %f, max = %f", - static_cast(min), static_cast(max))); - - auto* x_var = context.InputVar("X"); - if (x_var->IsType()) { - auto* x = context.Input("X"); - auto* out = context.Output("Out"); - T* out_data = out->mutable_data(context.GetPlace()); - const T* x_data = x->data(); - int64_t numel = x->numel(); - if (platform::is_gpu_place(context.GetPlace())) { -#if defined(__NVCC__) || defined(__HIPCC__) - std::vector ins = {x}; - std::vector outs = {out}; - auto functor = ClipFunctor(min, max); - paddle::operators::LaunchSameDimsElementwiseCudaKernel( - context.template device_context(), ins, - &outs, functor); -#endif - } else { - Transform trans; - trans(context.template device_context(), x_data, - x_data + numel, out_data, ClipFunctor(min, max)); - } - } else if (x_var->IsType()) { - auto* x = context.Input("X"); - auto* out = context.Output("Out"); - PADDLE_ENFORCE_NE(x, out, platform::errors::InvalidArgument( - "Inplace clip is not allowed " - "when x is SelectedRows")); - math::scatter::MergeAdd merge_func; - merge_func(context.template device_context(), *x, out); - auto* out_tensor = out->mutable_value(); - auto* out_data = out_tensor->data(); - int64_t numel = out_tensor->numel(); - Transform trans; - trans(context.template device_context(), out_data, - out_data + numel, out_data, ClipFunctor(min, max)); - } else { - PADDLE_THROW(platform::errors::Unavailable( - "ClipOp only supports LoDTensor and SelectedRows.")); - } - } -}; - -template -class ClipGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto max = static_cast(context.Attr("max")); - Tensor max_cpu; - if (context.HasInput("Max")) { - auto* max_t = context.Input("Max"); - auto* max_data = max_t->data(); - if (platform::is_gpu_place(max_t->place())) { - paddle::framework::TensorCopySync(*max_t, platform::CPUPlace(), - &max_cpu); - max_data = max_cpu.data(); - } - max = max_data[0]; - } - max = static_cast(max); - - auto min = static_cast(context.Attr("min")); - Tensor min_cpu; - if (context.HasInput("Min")) { - auto* min_t = context.Input("Min"); - auto* min_data = min_t->data(); - if (platform::is_gpu_place(min_t->place())) { - paddle::framework::TensorCopySync(*min_t, platform::CPUPlace(), - &min_cpu); - min_data = min_cpu.data(); - } - min = min_data[0]; - } - min = static_cast(min); - - auto* d_out = - context.Input(framework::GradVarName("Out")); - auto* d_x = - context.Output(framework::GradVarName("X")); - if (d_x != nullptr) { - auto* x = context.Input("X"); -#if defined(__NVCC__) || defined(__HIPCC__) - std::vector ins = {d_out, x}; - std::vector outs = {d_x}; - auto functor = ClipGradFunctor(min, max); - d_x->mutable_data(context.GetPlace()); - LaunchSameDimsElementwiseCudaKernel( - context.template device_context(), ins, - &outs, functor); -#else - int64_t numel = d_out->numel(); - auto* d_x_data = d_x->mutable_data(context.GetPlace()); - const T* d_out_data = d_out->data(); - const T* x_data = x->data(); - Transform trans; - trans(context.template device_context(), d_out_data, - d_out_data + numel, x_data, d_x_data, ClipGradFunctor(min, max)); -#endif - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/clip_op_npu.cc b/paddle/fluid/operators/clip_op_npu.cc index 372ba707329bb35cc525498b00c91c06278b3bc8..17d7ad97965040b360f872ec4017cd58024f23a1 100644 --- a/paddle/fluid/operators/clip_op_npu.cc +++ b/paddle/fluid/operators/clip_op_npu.cc @@ -1,18 +1,18 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/clip_op.h" +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { diff --git a/paddle/fluid/operators/clip_op_xpu.cc b/paddle/fluid/operators/clip_op_xpu.cc index c53bb2d9e4d0cbb1b2a21af8f217ce3340dee0fa..c551312837274fcc0df50c7150af97923f7008da 100644 --- a/paddle/fluid/operators/clip_op_xpu.cc +++ b/paddle/fluid/operators/clip_op_xpu.cc @@ -1,20 +1,19 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #ifdef PADDLE_WITH_XPU -#include "paddle/fluid/operators/clip_op.h" #include "paddle/fluid/framework/op_registry.h" namespace paddle { diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index 4544386718813c1344f7e3a933b036cae08a7df9..ac72f23d46ea84fbff854a2685d7b28b4a16d434 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -17,8 +17,8 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/operators/clip_op.h" #include "paddle/fluid/platform/transform.h" +#include "paddle/phi/kernels/impl/clip_kernel_impl.h" namespace paddle { namespace operators { @@ -91,7 +91,7 @@ struct ClipAndFakeQuantFunctor { T inv_s = inverse(s); platform::Transform trans; trans(ctx, in.data(), in.data() + in.numel(), - out->mutable_data(ctx.GetPlace()), ClipFunctor(-s, s)); + out->mutable_data(ctx.GetPlace()), phi::ClipFunctor(-s, s)); auto out_e = framework::EigenVector::Flatten(*out); out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round(); } @@ -109,7 +109,7 @@ struct ClipAndFakeQuantDequantFunctor { platform::Transform trans; trans(ctx, in.data(), in.data() + in.numel(), - out->mutable_data(ctx.GetPlace()), ClipFunctor(-s, s)); + out->mutable_data(ctx.GetPlace()), phi::ClipFunctor(-s, s)); auto out_e = framework::EigenVector::Flatten(*out); out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round() * s / static_cast(bin_cnt); @@ -144,7 +144,7 @@ struct ChannelClipAndFakeQuantFunctor { auto* start = in_data + i * channel_size; auto* end = in_data + (i + 1) * channel_size; trans(ctx, start, end, out_data + i * channel_size, - ClipFunctor(-s, s)); + phi::ClipFunctor(-s, s)); } for (int64_t i = 0; i < channel; i++) { T s = scale_data[i]; @@ -163,7 +163,7 @@ struct ChannelClipAndFakeQuantFunctor { auto* start = in_data + i * step_i + j * step_j; auto* end = in_data + i * step_i + (j + 1) * step_j; auto* cur_out_data = out_data + i * step_i + j * step_j; - trans(ctx, start, end, cur_out_data, ClipFunctor(-s, s)); + trans(ctx, start, end, cur_out_data, phi::ClipFunctor(-s, s)); for (int k = 0; k < step_j; k++) { cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]); } @@ -200,7 +200,7 @@ struct ChannelClipFakeQuantDequantFunctor { auto* start = in_data + i * channel_size; auto* end = in_data + (i + 1) * channel_size; trans(ctx, start, end, out_data + i * channel_size, - ClipFunctor(-s, s)); + phi::ClipFunctor(-s, s)); } for (int i = 0; i < channel; i++) { T s = scale_data[i]; @@ -220,7 +220,7 @@ struct ChannelClipFakeQuantDequantFunctor { auto* start = in_data + i * step_i + j * step_j; auto* end = in_data + i * step_i + (j + 1) * step_j; auto* cur_out_data = out_data + i * step_i + j * step_j; - trans(ctx, start, end, cur_out_data, ClipFunctor(-s, s)); + trans(ctx, start, end, cur_out_data, phi::ClipFunctor(-s, s)); for (int k = 0; k < step_j; k++) { cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]) * s / static_cast(bin_cnt); diff --git a/paddle/phi/kernels/clip_grad_kernel.h b/paddle/phi/kernels/clip_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..8a7e5b99fd9248e752dc22d33695fde533b78016 --- /dev/null +++ b/paddle/phi/kernels/clip_grad_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" + +namespace phi { + +template +void ClipGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const Scalar& min, + const Scalar& max, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/clip_kernel.h b/paddle/phi/kernels/clip_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..14ac8342e03bcfab4b43bbc80ec87ee5f6326755 --- /dev/null +++ b/paddle/phi/kernels/clip_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/selected_rows.h" + +namespace phi { + +template +void ClipKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& min, + const Scalar& max, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/clip_grad_kernel.cc b/paddle/phi/kernels/cpu/clip_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..bccdc0746d51ca63643ab8b5068618ee71ae8751 --- /dev/null +++ b/paddle/phi/kernels/cpu/clip_grad_kernel.cc @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/clip_grad_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/clip_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(clip_grad, + CPU, + ALL_LAYOUT, + phi::ClipGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/clip_kernel.cc b/paddle/phi/kernels/cpu/clip_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..5fd9aea966f8d24fa113ca88fbdf1bfc26791e01 --- /dev/null +++ b/paddle/phi/kernels/cpu/clip_kernel.cc @@ -0,0 +1,21 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/clip_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/clip_kernel_impl.h" + +PD_REGISTER_KERNEL( + clip, CPU, ALL_LAYOUT, phi::ClipKernel, float, double, int, int64_t) {} diff --git a/paddle/phi/kernels/cpu/hierarchical_sigmoid_kernel.cc b/paddle/phi/kernels/cpu/hierarchical_sigmoid_kernel.cc index 096a54f9fb263d3c153ab687d83bb61c63b117d7..4c4f1aa125a33916bf5790e4386ac65f6e84cd7f 100644 --- a/paddle/phi/kernels/cpu/hierarchical_sigmoid_kernel.cc +++ b/paddle/phi/kernels/cpu/hierarchical_sigmoid_kernel.cc @@ -14,7 +14,6 @@ #include "paddle/phi/kernels/hierarchical_sigmoid_kernel.h" -#include "paddle/fluid/operators/clip_op.h" #include "paddle/fluid/operators/math/matrix_bit_code.h" #include "paddle/fluid/platform/transform.h" #include "paddle/phi/backends/cpu/cpu_context.h" @@ -22,6 +21,7 @@ #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/math_function_impl.h" +#include "paddle/phi/kernels/impl/clip_kernel_impl.h" namespace phi { @@ -92,8 +92,7 @@ void HierarchicalSigmoidKernel(const Context& ctx, pre_out_data, pre_out_data + pre_out->numel(), pre_out_data, - paddle::operators::ClipFunctor(static_cast(-40.0), - static_cast(40.0))); + ClipFunctor(static_cast(-40.0), static_cast(40.0))); bit_code->Sum(*pre_out, out, static_cast(-1)); // use softrelu to calculate cross entropy pre_out_mat.device(place) = (static_cast(1.0) + pre_out_mat.exp()).log(); diff --git a/paddle/phi/kernels/gpu/clip_grad_kernel.cu b/paddle/phi/kernels/gpu/clip_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..b76086be6488774fde8a1d96a59fcd3a52a64330 --- /dev/null +++ b/paddle/phi/kernels/gpu/clip_grad_kernel.cu @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/clip_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/clip_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(clip_grad, + GPU, + ALL_LAYOUT, + phi::ClipGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/clip_kernel.cu b/paddle/phi/kernels/gpu/clip_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..9e0050db7fdbf178acb4fe5cf7174ebc951fc465 --- /dev/null +++ b/paddle/phi/kernels/gpu/clip_kernel.cu @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/clip_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/clip_kernel_impl.h" + +PD_REGISTER_KERNEL(clip, + GPU, + ALL_LAYOUT, + phi::ClipKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/impl/clip_grad_kernel_impl.h b/paddle/phi/kernels/impl/clip_grad_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..7ce86492327bac5094e87cef593769d3855dc2fb --- /dev/null +++ b/paddle/phi/kernels/impl/clip_grad_kernel_impl.h @@ -0,0 +1,74 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/clip_kernel.h" + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/transform.h" +#if defined(__NVCC__) || defined(__HIPCC__) +#include "paddle/phi/kernels/funcs/broadcast_function.h" +#endif + +namespace phi { + +template +class ClipGradFunctor { + public: + explicit ClipGradFunctor(const T min, const T max) : min_(min), max_(max) {} + HOSTDEVICE T operator()(const T x, const T y) const { + return (y > min_ && y < max_) ? x : static_cast(0); + } + + private: + T min_; + T max_; +}; + +template +void ClipGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const Scalar& min, + const Scalar& max, + DenseTensor* x_grad) { + auto max_ = max.to(); + auto min_ = min.to(); + +#if defined(__NVCC__) || defined(__HIPCC__) + std::vector ins = {&out_grad, &x}; + std::vector outs = {x_grad}; + auto functor = ClipGradFunctor(min_, max_); + dev_ctx.template Alloc(x_grad); + phi::funcs::ElementwiseKernel(dev_ctx, ins, &outs, functor); +#else + int64_t numel = out_grad.numel(); + auto* d_x_data = dev_ctx.template Alloc(x_grad); + const T* d_out_data = out_grad.data(); + const T* x_data = x.data(); + paddle::platform::Transform trans; + trans(dev_ctx, + d_out_data, + d_out_data + numel, + x_data, + d_x_data, + ClipGradFunctor(min_, max_)); +#endif +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/clip_kernel_impl.h b/paddle/phi/kernels/impl/clip_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..17c04c31a598af91cfec0f35e45d1556cafb1e5a --- /dev/null +++ b/paddle/phi/kernels/impl/clip_kernel_impl.h @@ -0,0 +1,79 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/clip_kernel.h" + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/platform/transform.h" +#if defined(__NVCC__) || defined(__HIPCC__) +#include "paddle/phi/kernels/funcs/broadcast_function.h" +#endif + +namespace phi { + +template +class ClipFunctor { + public: + explicit ClipFunctor(const T min, const T max) : min_(min), max_(max) {} + HOSTDEVICE T operator()(const T x) const { + return x < min_ ? min_ : x > max_ ? max_ : x; + } + + private: + T min_; + T max_; +}; + +template +void ClipKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& min, + const Scalar& max, + DenseTensor* out) { + auto max_ = max.to(); + auto min_ = min.to(); + + PADDLE_ENFORCE_LE( + min_, + max_, + errors::InvalidArgument("max should be greater than or equal to min. " + "But received min = %f, max = %f", + static_cast(min_), + static_cast(max_))); + + T* out_data = dev_ctx.template Alloc(out); + // const T* x_data = x->data(); + // int64_t numel = x->numel(); + const T* x_data = x.data(); + int64_t numel = x.numel(); + if (paddle::platform::is_gpu_place(dev_ctx.GetPlace())) { +#if defined(__NVCC__) || defined(__HIPCC__) + std::vector ins = {&x}; + std::vector outs = {out}; + auto functor = ClipFunctor(min_, max_); + phi::funcs::ElementwiseKernel(dev_ctx, ins, &outs, functor); +#endif + } else { + paddle::platform::Transform trans; + trans( + dev_ctx, x_data, x_data + numel, out_data, ClipFunctor(min_, max_)); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/selected_rows/clip_kernel.h b/paddle/phi/kernels/selected_rows/clip_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..ec56d92c513ea25897f63dd31a25f574df8c6fbc --- /dev/null +++ b/paddle/phi/kernels/selected_rows/clip_kernel.h @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/selected_rows.h" +#include "paddle/phi/kernels/impl/clip_kernel_impl.h" + +namespace phi { +namespace sr { + +template +void ClipSparseKernel(const Context& dev_ctx, + const SelectedRows& x, + const Scalar& min, + const Scalar& max, + SelectedRows* out); +} // namespace sr +} // namespace phi diff --git a/paddle/phi/kernels/selected_rows/cpu/clip_kernel.cc b/paddle/phi/kernels/selected_rows/cpu/clip_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..0098bf13f2b2f1eb2c8691a6973bfa04798f9560 --- /dev/null +++ b/paddle/phi/kernels/selected_rows/cpu/clip_kernel.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/selected_rows/clip_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/selected_rows/impl/clip_kernel_impl.h" + +PD_REGISTER_KERNEL(clip_sr, + CPU, + ALL_LAYOUT, + phi::sr::ClipSparseKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/selected_rows/gpu/clip_kernel.cu b/paddle/phi/kernels/selected_rows/gpu/clip_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..a8d659559e19e5c04c0145fbe918065958e6bb64 --- /dev/null +++ b/paddle/phi/kernels/selected_rows/gpu/clip_kernel.cu @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/selected_rows/clip_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/selected_rows/impl/clip_kernel_impl.h" + +PD_REGISTER_KERNEL(clip_sr, + GPU, + ALL_LAYOUT, + phi::sr::ClipSparseKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/selected_rows/impl/clip_kernel_impl.h b/paddle/phi/kernels/selected_rows/impl/clip_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..1d95e633b93a6e01fd29363311fc9242facf1a38 --- /dev/null +++ b/paddle/phi/kernels/selected_rows/impl/clip_kernel_impl.h @@ -0,0 +1,62 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/selected_rows/clip_kernel.h" + +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/selected_rows.h" + +namespace phi { +namespace sr { + +template +void ClipSparseKernel(const Context& dev_ctx, + const SelectedRows& x, + const Scalar& min, + const Scalar& max, + SelectedRows* out) { + auto max_ = max.to(); + auto min_ = min.to(); + + PADDLE_ENFORCE_LE( + min_, + max_, + errors::InvalidArgument("max should be greater than or equal to min. " + "But received min = %f, max = %f", + static_cast(min_), + static_cast(max_))); + + PADDLE_ENFORCE_NE(&x, + out, + errors::InvalidArgument("Inplace clip is not allowed " + "when x is SelectedRows")); + paddle::operators::math::scatter::MergeAdd merge_func; + merge_func(dev_ctx, x, out); + auto* out_tensor = out->mutable_value(); + auto* out_data = out_tensor->data(); + int64_t numel = out_tensor->numel(); + paddle::platform::Transform trans; + trans(dev_ctx, + out_data, + out_data + numel, + out_data, + ClipFunctor(min_, max_)); +} +} // namespace sr +} // namespace phi diff --git a/paddle/phi/ops/compat/clip_sig.cc b/paddle/phi/ops/compat/clip_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..78fa6c36a51492d9339958621c95dac8457e97bd --- /dev/null +++ b/paddle/phi/ops/compat/clip_sig.cc @@ -0,0 +1,88 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" +#include "paddle/utils/small_vector.h" + +namespace phi { + +KernelSignature ClipOpArgumentMapping(const ArgumentMappingContext& ctx) { + paddle::SmallVector attr_names; + attr_names.emplace_back(ctx.HasInput("Min") ? "Min" : "min"); + attr_names.emplace_back(ctx.HasInput("Max") ? "Max" : "max"); + if (ctx.IsDenseTensorInput("X")) { + if (ctx.HasInput("Min")) { + if (ctx.HasInput("Max")) { + return KernelSignature("clip", {"X"}, {"Min", "Max"}, {"Out"}); + } else { + return KernelSignature("clip", {"X"}, {"Min", "max"}, {"Out"}); + } + } else { + if (ctx.HasInput("Max")) { + return KernelSignature("clip", {"X"}, {"min", "Max"}, {"Out"}); + } else { + return KernelSignature("clip", {"X"}, {"min", "max"}, {"Out"}); + } + } + } else if (ctx.IsSelectedRowsInput("X")) { + if (ctx.HasInput("Min")) { + if (ctx.HasInput("Max")) { + return KernelSignature("clip_sr", {"X"}, {"Min", "Max"}, {"Out"}); + } else { + return KernelSignature("clip_sr", {"X"}, {"Min", "max"}, {"Out"}); + } + } else { + if (ctx.HasInput("Max")) { + return KernelSignature("clip_sr", {"X"}, {"min", "Max"}, {"Out"}); + } else { + return KernelSignature("clip_sr", {"X"}, {"min", "max"}, {"Out"}); + } + } + } + + return KernelSignature("unregistered", {}, {}, {}); +} + +KernelSignature ClipGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.HasInput("Min")) { + if (ctx.HasInput("Max")) { + return KernelSignature("clip_grad", + {"X", GradVarName("Out")}, + {"Min", "Max"}, + {GradVarName("X")}); + } else { + return KernelSignature("clip_grad", + {"X", GradVarName("Out")}, + {"Min", "max"}, + {GradVarName("X")}); + } + } else { + if (ctx.HasInput("Max")) { + return KernelSignature("clip_grad", + {"X", GradVarName("Out")}, + {"min", "Max"}, + {GradVarName("X")}); + } else { + return KernelSignature("clip_grad", + {"X", GradVarName("Out")}, + {"min", "max"}, + {GradVarName("X")}); + } + } +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(clip, phi::ClipOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(clip_grad, phi::ClipGradOpArgumentMapping);