diff --git a/paddle/fluid/operators/erf_op.cc b/paddle/fluid/operators/erf_op.cc index f68f670394871114369f8b05b7f958c03d5508d0..64274d098c0585c28196743c09d5e6c78c3fe37d 100644 --- a/paddle/fluid/operators/erf_op.cc +++ b/paddle/fluid/operators/erf_op.cc @@ -16,8 +16,10 @@ limitations under the License. */ #include #include -#include "paddle/fluid/operators/erf_op.h" -#include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -29,18 +31,6 @@ class ErfOp : public framework::OperatorWithKernel { const framework::AttributeMap &attrs) : OperatorWithKernel(type, inputs, outputs, attrs) {} - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, - platform::errors::InvalidArgument( - "Input(%s) of ErfOp should not be null.", "X")); - PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, - platform::errors::InvalidArgument( - "Output(%s) of ErfOp should not be null.", "Out")); - - ctx->ShareDim("X", /*->*/ "Out"); - ctx->ShareLoD("X", /*->*/ "Out"); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { @@ -116,28 +106,10 @@ class ErfGradOpMaker : public framework::SingleGradOpMaker { namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(erf, ErfInferShapeFunctor, + PD_INFER_META(phi::UnchangedInferMeta)); REGISTER_OPERATOR(erf, ops::ErfOp, ops::ErfOpMaker, ops::ErfGradOpMaker, - ops::ErfGradOpMaker); + ops::ErfGradOpMaker, + ErfInferShapeFunctor); REGISTER_OPERATOR(erf_grad, ops::ErfGradOp); -REGISTER_OP_CPU_KERNEL( - erf, ops::ErfKernel, - ops::ErfKernel, - ops::ErfKernel); -REGISTER_OP_CPU_KERNEL( - erf_grad, ops::ErfGradKernel, - ops::ErfGradKernel, - ops::ErfGradKernel); - -REGISTER_OP_CUDA_KERNEL( - erf, ops::ErfKernel, - ops::ErfKernel, - ops::ErfKernel); -REGISTER_OP_CUDA_KERNEL( - erf_grad, ops::ErfGradKernel, - ops::ErfGradKernel, - ops::ErfGradKernel); diff --git a/paddle/fluid/operators/erf_op.h b/paddle/fluid/operators/erf_op.h deleted file mode 100644 index 4780b2e7f5b28d4a743f6d35046891b30cbefd00..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/erf_op.h +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#ifndef _USE_MATH_DEFINES -#define _USE_MATH_DEFINES -#endif -#include -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/eigen/eigen_function.h" - -namespace paddle { -namespace operators { - -template -class ErfKernel : public framework::OpKernel { - public: - virtual void Compute(const framework::ExecutionContext& context) const { - auto* out = context.Output("Out"); - auto* in = context.Input("X"); - out->mutable_data(in->place()); - - auto eigen_out = framework::EigenVector::Flatten(*out); - auto eigen_in = framework::EigenVector::Flatten(*in); - auto& place = - *context.template device_context().eigen_device(); - EigenErf, T>::Eval(place, eigen_out, - eigen_in); - } -}; - -template -class ErfGradKernel : public framework::OpKernel { - public: - virtual void Compute(const framework::ExecutionContext& context) const { - auto* x = context.Input("X"); - auto* dout = - context.Input(framework::GradVarName("Out")); - auto* dx = context.Output(framework::GradVarName("X")); - - dx->mutable_data(dout->place()); - - auto eigen_x = framework::EigenVector::Flatten(*x); - auto eigen_dout = framework::EigenVector::Flatten(*dout); - auto eigen_dx = framework::EigenVector::Flatten(*dx); - auto& place = - *context.template device_context().eigen_device(); - EigenErfGrad, T>::Eval(place, eigen_dx, - eigen_x, eigen_dout); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/kernels/cpu/erf_grad_kernel.cc b/paddle/phi/kernels/cpu/erf_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..3c1cd0df1531a50524958b527ee39b09460f042c --- /dev/null +++ b/paddle/phi/kernels/cpu/erf_grad_kernel.cc @@ -0,0 +1,27 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/erf_grad_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/erf_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(erf_grad, + CPU, + ALL_LAYOUT, + phi::ErfGradKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/cpu/erf_kernel.cc b/paddle/phi/kernels/cpu/erf_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..05ce4cab7fcef4d1da7d378093f9f3d04827acd2 --- /dev/null +++ b/paddle/phi/kernels/cpu/erf_kernel.cc @@ -0,0 +1,22 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/erf_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/erf_kernel_impl.h" + +PD_REGISTER_KERNEL( + erf, CPU, ALL_LAYOUT, phi::ErfKernel, float, double, phi::dtype::float16) {} diff --git a/paddle/phi/kernels/erf_grad_kernel.h b/paddle/phi/kernels/erf_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..8957fcaf79b9a0a83323682ffa6efaf28f5362bf --- /dev/null +++ b/paddle/phi/kernels/erf_grad_kernel.h @@ -0,0 +1,27 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void ErfGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/erf_kernel.h b/paddle/phi/kernels/erf_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..1d5c57d2201c749f541cabe74b112073c2dace06 --- /dev/null +++ b/paddle/phi/kernels/erf_kernel.h @@ -0,0 +1,24 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void ErfKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/erf_grad_kernel.cu b/paddle/phi/kernels/gpu/erf_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..a06863b0a877714bc241d714429a8114f2b0d6f7 --- /dev/null +++ b/paddle/phi/kernels/gpu/erf_grad_kernel.cu @@ -0,0 +1,27 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/erf_grad_kernel.h" +#include "paddle/phi/kernels/impl/erf_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(erf_grad, + GPU, + ALL_LAYOUT, + phi::ErfGradKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/erf_kernel.cu b/paddle/phi/kernels/gpu/erf_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..8e741be3345fc1d487e1a4148d96f0430f5978f0 --- /dev/null +++ b/paddle/phi/kernels/gpu/erf_kernel.cu @@ -0,0 +1,22 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/erf_kernel.h" +#include "paddle/phi/kernels/impl/erf_kernel_impl.h" + +PD_REGISTER_KERNEL( + erf, GPU, ALL_LAYOUT, phi::ErfKernel, float, double, phi::dtype::float16) {} diff --git a/paddle/phi/kernels/impl/erf_grad_kernel_impl.h b/paddle/phi/kernels/impl/erf_grad_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..5908d9d7dcb50d617d6796f26e2e8793d06b8409 --- /dev/null +++ b/paddle/phi/kernels/impl/erf_grad_kernel_impl.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/erf_grad_kernel.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" + +namespace phi { + +template +void ErfGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + DenseTensor* x_grad) { + dev_ctx.template Alloc(x_grad); + + auto eigen_x = EigenVector::Flatten(x); + auto eigen_dout = EigenVector::Flatten(out_grad); + auto eigen_dx = EigenVector::Flatten(*x_grad); + auto& place = *dev_ctx.eigen_device(); + phi::funcs::EigenErfGrad, T>::Eval( + place, eigen_dx, eigen_x, eigen_dout); +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/erf_kernel_impl.h b/paddle/phi/kernels/impl/erf_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..aa1f4d349ab71ee0b19a3a2d3ad3d3a47553464c --- /dev/null +++ b/paddle/phi/kernels/impl/erf_kernel_impl.h @@ -0,0 +1,36 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/erf_kernel.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" + +namespace phi { + +template +void ErfKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { + dev_ctx.template Alloc(out); + + auto eigen_out = EigenVector::Flatten(*out); + auto eigen_in = EigenVector::Flatten(x); + auto& place = *dev_ctx.eigen_device(); + phi::funcs::EigenErf, T>::Eval( + place, eigen_out, eigen_in); +} + +} // namespace phi diff --git a/paddle/phi/ops/compat/erf_sig.cc b/paddle/phi/ops/compat/erf_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..784727a98042db820b4f83bac84014ebd0e1302e --- /dev/null +++ b/paddle/phi/ops/compat/erf_sig.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature ErfGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature( + "erf_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(erf_grad, phi::ErfGradOpArgumentMapping);