diff --git a/paddle/fluid/operators/abs_op.cc b/paddle/fluid/operators/abs_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..5c431ce77dc76ae08c70cd54989f323a230d47f7 --- /dev/null +++ b/paddle/fluid/operators/abs_op.cc @@ -0,0 +1,192 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/abs_op.h" + +#include +#include +#include +#include +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif + +namespace paddle { +namespace operators { + +class AbsOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "abs"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "abs"); + + auto in_dims = ctx->GetInputDim("X"); + + ctx->SetOutputDim("Out", in_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class AbsOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of abs op."); + AddOutput("Out", "(Tensor), The output tensor of abs op."); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr("use_cudnn", + "(bool, default false) Only used in cudnn kernel, need " + "install cudnn") + .SetDefault(false); + AddComment(R"DOC( +Abs Operator. + +This operator is used to perform elementwise abs for input $X$. +$$out = |x|$$ + +)DOC"); + } +}; + +class AbsGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@Grad", "AbsGrad"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", + "X@Grad", "AbsGrad"); + + auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); + ctx->SetOutputDim(framework::GradVarName("X"), dout_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(dtype, ctx.GetPlace()); + } +}; + +template +class AbsGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr retv) const override { + retv->SetType("abs_grad"); + retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + retv->SetInput("X", this->Input("X")); + retv->SetAttrMap(this->Attrs()); + retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + +// AbsGrad: dx=dy if x >=0 else -dy +// AbsDoubleGrad: ddy = ddx if x >=0 else -ddx +template +class AbsDoubleGradMaker : public framework::SingleGradOpMaker { + public: + using ::paddle::framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("abs_grad_grad"); + // input1: x + op->SetInput("X", this->Input("X")); + // input2: ddx + op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X"))); + op->SetAttrMap(this->Attrs()); + // output: ddy + op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out"))); + } +}; + +class AbsDoubleGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + if (ctx->HasOutput("DDOut")) { + ctx->ShareDim("X", "DDOut"); + ctx->ShareLoD("X", "DDOut"); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "DDX"); + return framework::OpKernelType(dtype, ctx.GetPlace()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const { + return framework::OpKernelType(tensor.type(), tensor.place(), + tensor.layout()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(abs, ops::AbsOp, ops::AbsOpMaker, + ops::AbsGradMaker, + ops::AbsGradMaker); + +REGISTER_OPERATOR(abs_grad, ops::AbsGradOp, + ops::AbsDoubleGradMaker, + ops::AbsDoubleGradMaker); + +REGISTER_OPERATOR(abs_grad_grad, ops::AbsDoubleGradOp); + +REGISTER_OP_CPU_KERNEL( + abs, ops::AbsKernel, + ops::AbsKernel, + ops::AbsKernel, + ops::AbsKernel, + ops::AbsKernel, + ops::AbsKernel); + +REGISTER_OP_CPU_KERNEL( + abs_grad, ops::AbsGradKernel, + ops::AbsGradKernel, + ops::AbsGradKernel, + ops::AbsGradKernel, + ops::AbsGradKernel, + ops::AbsGradKernel); + +REGISTER_OP_CPU_KERNEL( + abs_grad_grad, + ops::AbsDoubleGradKernel, + ops::AbsDoubleGradKernel, + ops::AbsDoubleGradKernel, + ops::AbsDoubleGradKernel, + ops::AbsDoubleGradKernel, + ops::AbsDoubleGradKernel, + ops::AbsDoubleGradKernel); diff --git a/paddle/fluid/operators/abs_op.cu b/paddle/fluid/operators/abs_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..e373d628f6cbd6b5ee48edc984a68d2767ce0593 --- /dev/null +++ b/paddle/fluid/operators/abs_op.cu @@ -0,0 +1,56 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/abs_op.h" +#include "paddle/fluid/platform/complex128.h" +#include "paddle/fluid/platform/complex64.h" +#include "paddle/fluid/platform/float16.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + abs, ops::AbsKernel, + ops::AbsKernel, + ops::AbsKernel, + ops::AbsKernel, + ops::AbsKernel, + ops::AbsKernel, + ops::AbsKernel); + +REGISTER_OP_CUDA_KERNEL( + abs_grad, ops::AbsGradKernel, + ops::AbsGradKernel, + ops::AbsGradKernel, + ops::AbsGradKernel, + ops::AbsGradKernel, + ops::AbsGradKernel, + ops::AbsGradKernel); + +REGISTER_OP_CUDA_KERNEL( + abs_grad_grad, + ops::AbsDoubleGradKernel, + ops::AbsDoubleGradKernel, + ops::AbsDoubleGradKernel, + ops::AbsDoubleGradKernel, + ops::AbsDoubleGradKernel, + ops::AbsDoubleGradKernel, + ops::AbsDoubleGradKernel); diff --git a/paddle/fluid/operators/abs_op.h b/paddle/fluid/operators/abs_op.h new file mode 100644 index 0000000000000000000000000000000000000000..c79e83314f3bd39dcf6736e66c0b12956a2b0e81 --- /dev/null +++ b/paddle/fluid/operators/abs_op.h @@ -0,0 +1,90 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +template +class AbsKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* x = context.Input("X"); + Tensor* out = context.Output("Out"); + + auto numel = x->numel(); + auto* x_data = x->data(); + auto* out_data = out->mutable_data>( + context.GetPlace(), size_t(x->numel() * sizeof(math::Real))); + + auto& dev_ctx = context.template device_context(); + platform::ForRange for_range(dev_ctx, numel); + math::AbsFunctor functor(x_data, out_data, numel); + for_range(functor); + } +}; + +template +class AbsGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + const framework::Tensor* d_out = + ctx.Input(framework::GradVarName("Out")); + const framework::Tensor* x = ctx.Input("X"); + framework::Tensor* d_x = + ctx.Output(framework::GradVarName("X")); + + auto numel = d_out->numel(); + auto* dout_data = d_out->data>(); + auto* x_data = x->data(); + auto* dx_data = d_x->mutable_data( + ctx.GetPlace(), static_cast(numel * sizeof(T))); + + auto& dev_ctx = ctx.template device_context(); + platform::ForRange for_range(dev_ctx, numel); + math::AbsGradFunctor functor(dout_data, x_data, dx_data, numel); + for_range(functor); + } +}; + +template +class AbsDoubleGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + const framework::Tensor* ddx = ctx.Input("DDX"); + const framework::Tensor* x = ctx.Input("X"); + framework::Tensor* ddout = ctx.Output("DDOut"); + + auto numel = ddx->numel(); + auto* ddx_data = ddx->data(); + auto* x_data = x->data(); + auto* ddout_data = ddout->mutable_data( + ctx.GetPlace(), static_cast(numel * sizeof(T))); + + auto& dev_ctx = ctx.template device_context(); + platform::ForRange for_range(dev_ctx, numel); + math::AbsGradGradFunctor functor(ddx_data, x_data, ddout_data, numel); + for_range(functor); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 19e5902e74318de30500380a171585aab0afdbe2..696606441642c91e5dabacaa1af7e28a575e0557 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -219,13 +219,6 @@ $$out = \\frac{1}{\\sqrt{x}}$$ )DOC"; -UNUSED constexpr char AbsDoc[] = R"DOC( -Abs Operator. - -$$out = |x|$$ - -)DOC"; - UNUSED constexpr char CeilDoc[] = R"DOC( Ceil Operator. Computes ceil of x element-wise. @@ -714,7 +707,6 @@ REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc); REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc); REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc); REGISTER_ACTIVATION_OP_MAKER(Rsqrt, RsqrtDoc); -REGISTER_ACTIVATION_OP_MAKER(Abs, AbsDoc); REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc); REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc); REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc); @@ -793,26 +785,6 @@ class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel { } }; -// AbsGrad: dx=dy if x >=0 else -dy -// AbsDoubleGrad: ddy = ddx if x >=0 else -ddx -template -class AbsDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker { - public: - using ::paddle::framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("abs_grad_grad"); - // input1: x - op->SetInput("X", this->Input("X")); - // input2: ddx - op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X"))); - op->SetAttrMap(this->Attrs()); - // output: ddy - op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out"))); - } -}; - // ReluGrad: dx = dy if y >= 0 else 0 // ReluGradGrad: ddy = ddx if y >= 0 else 0 template @@ -1322,56 +1294,6 @@ REGISTER_OP_CPU_KERNEL( ops::ExpGradFunctor>); /* ========================================================================== */ -/* ========================== abs register ============================ */ -REGISTER_OPERATOR( - abs, ops::ActivationOp, ops::AbsOpMaker, ops::ActivationOpInferVarType, - ops::ActivationGradOpMaker::FwdDeps(), - paddle::framework::OpDesc>, - ops::ActivationGradOpMaker::FwdDeps(), - paddle::imperative::OpBase>, - std::conditional>(), - ops::ActFwdInplaceInferer, void>::type); -REGISTER_OPERATOR(abs_grad, ops::ActivationOpGrad, - ops::ActivationGradOpInplaceInferer, - ops::AbsDoubleGradMaker, - ops::AbsDoubleGradMaker); -REGISTER_OPERATOR( - abs_grad_grad, - ops::ActivationOpDoubleGrad::FwdDeps()>, - ops::ActivationDoubleGradOpInplaceInferer); - -REGISTER_OP_CPU_KERNEL(abs, - ops::ActivationKernel>, - ops::ActivationKernel>, - ops::ActivationKernel>, - ops::ActivationKernel>); -REGISTER_OP_CPU_KERNEL( - abs_grad, ops::ActivationGradKernel>, - ops::ActivationGradKernel>, - ops::ActivationGradKernel>, - ops::ActivationGradKernel>); -REGISTER_OP_CPU_KERNEL( - abs_grad_grad, - ops::ActivationDoubleGradKernel>, - ops::ActivationDoubleGradKernel>, - ops::ActivationDoubleGradKernel>, - ops::ActivationDoubleGradKernel>, - ops::ActivationDoubleGradKernel>); -/* ========================================================================== */ - /* ========================== Log register ==================================*/ REGISTER_OPERATOR( log, ops::ActivationOp, ops::LogOpMaker, ops::ActivationOpInferVarType, diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index 1a6d5de18ec47c8627d351150bfd6bbf808d019b..36777399174f5d2619fbcd40ebf91be1ed29feec 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -174,40 +174,6 @@ REGISTER_OP_CUDA_KERNEL( ops::ExpGradFunctor>); /* ========================================================================== */ -/* ========================== abs register ============================ */ - -REGISTER_OP_CUDA_KERNEL( - abs, ops::ActivationKernel>, - ops::ActivationKernel>, - ops::ActivationKernel>, - ops::ActivationKernel>, - ops::ActivationKernel>); -REGISTER_OP_CUDA_KERNEL( - abs_grad, ops::ActivationGradKernel>, - ops::ActivationGradKernel>, - ops::ActivationGradKernel>, - ops::ActivationGradKernel>, - ops::ActivationGradKernel>); -REGISTER_OP_CUDA_KERNEL( - abs_grad_grad, - ops::ActivationDoubleGradKernel>, - ops::ActivationDoubleGradKernel>, - ops::ActivationDoubleGradKernel>, - ops::ActivationDoubleGradKernel>, - ops::ActivationDoubleGradKernel>); -/* ========================================================================== */ - /* ========================== Log register ==================================*/ REGISTER_ACTIVATION_CUDA_KERNEL(log, Log, LogFunctor, LogGradFunctor); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 6e906d734e1ac630fe4b5a43695a987942df3063..483f5cc2e5cc267b1e0ca3856b32f37acded8c43 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -793,26 +793,6 @@ struct RoundFunctor : public BaseActivationFunctor { } }; -// abs(x) = |x| -template -struct AbsFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { - out.device(d) = x.abs(); - } -}; - -template -struct AbsGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = dout * x.sign(); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; - // reciprocal(x) = 1 / x template struct ReciprocalFunctor : public BaseActivationFunctor { diff --git a/paddle/fluid/operators/math/complex_functors.h b/paddle/fluid/operators/math/complex_functors.h index 18a003d5c9a504f00d6382bb51028481d5d9dcbe..2e9e72eac12aaabe0ac658a2f1c8711267e75936 100644 --- a/paddle/fluid/operators/math/complex_functors.h +++ b/paddle/fluid/operators/math/complex_functors.h @@ -48,6 +48,18 @@ struct select { using type = eval_if_t>; }; +template +struct select { + using type = T; +}; + +template +struct select> { + // last one had better be true! + static_assert(B, "No match select type!"); + using type = T; +}; + template using select_t = typename select::type; @@ -63,6 +75,16 @@ using Complex = typename std::enable_if::value>::type; template using NoComplex = typename std::enable_if::value>::type; +template +using EnableComplex = + typename std::enable_if::value || + std::is_same::value>::type; + +template +using DisableComplex = typename std::enable_if< + !std::is_same::value && + !std::is_same::value>::type; + template struct RealFunctor; @@ -99,6 +121,76 @@ struct ImagFunctor>> { int64_t numel_; }; +template +struct AbsFunctor; + +template +struct AbsFunctor>> { + AbsFunctor(const T* input, Real* output, int64_t numel) + : input_(input), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + output_[idx] = abs(input_[idx]); + } + + const T* input_; + Real* output_; + int64_t numel_; +}; + +template +struct AbsFunctor>> { + AbsFunctor(const T* input, T* output, int64_t numel) + : input_(input), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + output_[idx] = abs(input_[idx]); + } + + const T* input_; + T* output_; + int64_t numel_; +}; + +template +struct AbsGradFunctor { + AbsGradFunctor(const math::Real* dout, const T* x, T* output, + int64_t numel) + : dout_(dout), x_(x), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + if (x_[idx] == T(0)) { + output_[idx] = T(0); + } else { + output_[idx] = T(dout_[idx]) * (x_[idx] / T(abs(x_[idx]))); + } + } + + const math::Real* dout_; + const T* x_; + T* output_; + int64_t numel_; +}; + +template +struct AbsGradGradFunctor { + AbsGradGradFunctor(const T* ddx, const T* x, T* output, int64_t numel) + : ddx_(ddx), x_(x), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + if (x_[idx] == T(0)) { + output_[idx] = T(0); + } else { + output_[idx] = T(ddx_[idx]) * x_[idx] / T(abs(x_[idx])); + } + } + + const T* ddx_; + const T* x_; + T* output_; + int64_t numel_; +}; + template struct RealToComplexFunctor; @@ -135,16 +227,6 @@ struct ImagToComplexFunctor>> { int64_t numel_; }; -template -using EnableComplex = - typename std::enable_if::value || - std::is_same::value>::type; - -template -using DisableComplex = typename std::enable_if< - !std::is_same::value && - !std::is_same::value>::type; - template struct ConjFunctor; diff --git a/paddle/fluid/platform/complex128.h b/paddle/fluid/platform/complex128.h index 2a2cd3b7be266120086149bba3324929b050f114..58753527c0405e955be158ead2549234a3725f11 100644 --- a/paddle/fluid/platform/complex128.h +++ b/paddle/fluid/platform/complex128.h @@ -361,7 +361,7 @@ HOSTDEVICE inline double(abs)(const complex128& a) { #if defined(__CUDA_ARCH__) return thrust::abs(thrust::complex(a.real, a.imag)); #else - return std::abs(std::complex(a)); + return std::abs(std::complex(a.real, a.imag)); #endif } diff --git a/paddle/fluid/platform/complex64.h b/paddle/fluid/platform/complex64.h index 7da11cfe5ed761b257ea70bc5f1f99063b016666..5f9b3c1118d3fe26b724ac56ffaafc3be502271e 100644 --- a/paddle/fluid/platform/complex64.h +++ b/paddle/fluid/platform/complex64.h @@ -363,7 +363,7 @@ HOSTDEVICE inline float(abs)(const complex64& a) { #if defined(__CUDA_ARCH__) return complex64(thrust::abs(thrust::complex(a.real, a.imag))); #else - return std::abs(std::complex(a)); + return std::abs(std::complex(a.real, a.imag)); #endif } diff --git a/paddle/fluid/platform/float16.h b/paddle/fluid/platform/float16.h index 753f0d398c20447809d30a55b595a2d06f368abe..6f0b44f6af60298cca1a65445ef77ba6b1810396 100644 --- a/paddle/fluid/platform/float16.h +++ b/paddle/fluid/platform/float16.h @@ -899,6 +899,16 @@ HOSTDEVICE inline bool(isfinite)(const float16& a) { return !((isnan)(a)) && !((isinf)(a)); } +HOSTDEVICE inline float16(abs)(const float16& a) { +#if (defined(PADDLE_CUDA_FP16) && \ + ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ + (defined(__HIP_DEVICE_COMPILE__)))) + return float16(::fabs(static_cast(a))); +#else + return float16(std::abs(static_cast(a))); +#endif +} + inline std::ostream& operator<<(std::ostream& os, const float16& a) { os << static_cast(a); return os; diff --git a/python/paddle/fluid/tests/unittests/test_complex_abs.py b/python/paddle/fluid/tests/unittests/test_complex_abs.py new file mode 100644 index 0000000000000000000000000000000000000000..f9bce91e46d91057c3b879d32c43fac21b1482a2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_complex_abs.py @@ -0,0 +1,89 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function, division + +import unittest +import numpy as np + +import paddle +from op_test import OpTest + + +class TestComplexAbsOp(OpTest): + def setUp(self): + paddle.enable_static() + self.op_type = "abs" + self.dtype = np.float64 + self.shape = (2, 3, 4, 5) + self.init_input_output() + self.init_grad_input_output() + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(self.x)} + self.outputs = {'Out': self.out} + + def init_input_output(self): + self.x = np.random.random(self.shape).astype( + self.dtype) + 1J * np.random.random(self.shape).astype(self.dtype) + self.out = np.abs(self.x) + + def init_grad_input_output(self): + self.grad_out = np.ones(self.shape, self.dtype) + self.grad_x = self.grad_out * (self.x / np.abs(self.x)) + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad( + ['X'], + 'Out', + user_defined_grads=[self.grad_x], + user_defined_grad_outputs=[self.grad_out]) + + +class TestComplexAbsOpZeroValues(OpTest): + def setUp(self): + paddle.enable_static() + self.op_type = "abs" + self.dtype = np.float64 + self.shape = (2, 3, 4, 5) + self.init_input_output() + self.init_grad_input_output() + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(self.x)} + self.outputs = {'Out': self.out} + + def init_input_output(self): + self.x = np.zeros(self.shape).astype(self.dtype) + 1J * np.zeros( + self.shape).astype(self.dtype) + self.out = np.abs(self.x) + + def init_grad_input_output(self): + self.grad_out = np.ones(self.shape, self.dtype) + self.grad_x = np.zeros(self.shape, self.dtype) + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad( + ['X'], + 'Out', + user_defined_grads=[self.grad_x], + user_defined_grad_outputs=[self.grad_out]) + + +if __name__ == '__main__': + unittest.main()