未验证 提交 358106fc 编写于 作者: C chentianyu03 提交者: GitHub

make abs op support complex types (#30375)

* rewrite abs op

* rewrite abs op and remove abs in activation

* remove abs register in old codes

* fix abs_grad type error

* fix abs double_grad output name error

* modify abs_grad, abs_grad_grad functor for windows building

* format code style

* fix the bug of result is nan when the divisor is zero

* add missing abs attr and add abs for float16
上级 13862008
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/abs_op.h"
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle {
namespace operators {
class AbsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "abs");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "abs");
auto in_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", in_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class AbsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of abs op.");
AddOutput("Out", "(Tensor), The output tensor of abs op.");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>("use_cudnn",
"(bool, default false) Only used in cudnn kernel, need "
"install cudnn")
.SetDefault(false);
AddComment(R"DOC(
Abs Operator.
This operator is used to perform elementwise abs for input $X$.
$$out = |x|$$
)DOC");
}
};
class AbsGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@Grad", "AbsGrad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
"X@Grad", "AbsGrad");
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
ctx->SetOutputDim(framework::GradVarName("X"), dout_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(dtype, ctx.GetPlace());
}
};
template <typename T>
class AbsGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("abs_grad");
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetInput("X", this->Input("X"));
retv->SetAttrMap(this->Attrs());
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
// AbsGrad: dx=dy if x >=0 else -dy
// AbsDoubleGrad: ddy = ddx if x >=0 else -ddx
template <typename T>
class AbsDoubleGradMaker : public framework::SingleGradOpMaker<T> {
public:
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("abs_grad_grad");
// input1: x
op->SetInput("X", this->Input("X"));
// input2: ddx
op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
op->SetAttrMap(this->Attrs());
// output: ddy
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
}
};
class AbsDoubleGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("X", "DDOut");
ctx->ShareLoD("X", "DDOut");
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "DDX");
return framework::OpKernelType(dtype, ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
return framework::OpKernelType(tensor.type(), tensor.place(),
tensor.layout());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(abs, ops::AbsOp, ops::AbsOpMaker,
ops::AbsGradMaker<paddle::framework::OpDesc>,
ops::AbsGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(abs_grad, ops::AbsGradOp,
ops::AbsDoubleGradMaker<paddle::framework::OpDesc>,
ops::AbsDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(abs_grad_grad, ops::AbsDoubleGradOp);
REGISTER_OP_CPU_KERNEL(
abs, ops::AbsKernel<paddle::platform::CPUDeviceContext, float>,
ops::AbsKernel<paddle::platform::CPUDeviceContext, double>,
ops::AbsKernel<paddle::platform::CPUDeviceContext, int>,
ops::AbsKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::AbsKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::AbsKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
abs_grad, ops::AbsGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::AbsGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::AbsGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::AbsGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::AbsGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::AbsGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
abs_grad_grad,
ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/abs_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
abs, ops::AbsKernel<paddle::platform::CUDADeviceContext, float>,
ops::AbsKernel<paddle::platform::CUDADeviceContext, double>,
ops::AbsKernel<paddle::platform::CUDADeviceContext, int>,
ops::AbsKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::AbsKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::AbsKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
ops::AbsKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CUDA_KERNEL(
abs_grad, ops::AbsGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::AbsGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::AbsGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::AbsGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::AbsGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::AbsGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
ops::AbsGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CUDA_KERNEL(
abs_grad_grad,
ops::AbsDoubleGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::AbsDoubleGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::AbsDoubleGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::AbsDoubleGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::AbsDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::AbsDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
ops::AbsDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class AbsKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
auto numel = x->numel();
auto* x_data = x->data<T>();
auto* out_data = out->mutable_data<math::Real<T>>(
context.GetPlace(), size_t(x->numel() * sizeof(math::Real<T>)));
auto& dev_ctx = context.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
math::AbsFunctor<T> functor(x_data, out_data, numel);
for_range(functor);
}
};
template <typename DeviceContext, typename T>
class AbsGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
const framework::Tensor* d_out =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
const framework::Tensor* x = ctx.Input<framework::Tensor>("X");
framework::Tensor* d_x =
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto numel = d_out->numel();
auto* dout_data = d_out->data<math::Real<T>>();
auto* x_data = x->data<T>();
auto* dx_data = d_x->mutable_data<T>(
ctx.GetPlace(), static_cast<size_t>(numel * sizeof(T)));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
math::AbsGradFunctor<T> functor(dout_data, x_data, dx_data, numel);
for_range(functor);
}
};
template <typename DeviceContext, typename T>
class AbsDoubleGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
const framework::Tensor* ddx = ctx.Input<framework::Tensor>("DDX");
const framework::Tensor* x = ctx.Input<framework::Tensor>("X");
framework::Tensor* ddout = ctx.Output<framework::Tensor>("DDOut");
auto numel = ddx->numel();
auto* ddx_data = ddx->data<T>();
auto* x_data = x->data<T>();
auto* ddout_data = ddout->mutable_data<T>(
ctx.GetPlace(), static_cast<size_t>(numel * sizeof(T)));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
math::AbsGradGradFunctor<T> functor(ddx_data, x_data, ddout_data, numel);
for_range(functor);
}
};
} // namespace operators
} // namespace paddle
...@@ -219,13 +219,6 @@ $$out = \\frac{1}{\\sqrt{x}}$$ ...@@ -219,13 +219,6 @@ $$out = \\frac{1}{\\sqrt{x}}$$
)DOC"; )DOC";
UNUSED constexpr char AbsDoc[] = R"DOC(
Abs Operator.
$$out = |x|$$
)DOC";
UNUSED constexpr char CeilDoc[] = R"DOC( UNUSED constexpr char CeilDoc[] = R"DOC(
Ceil Operator. Computes ceil of x element-wise. Ceil Operator. Computes ceil of x element-wise.
...@@ -714,7 +707,6 @@ REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc); ...@@ -714,7 +707,6 @@ REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc);
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc); REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc); REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
REGISTER_ACTIVATION_OP_MAKER(Rsqrt, RsqrtDoc); REGISTER_ACTIVATION_OP_MAKER(Rsqrt, RsqrtDoc);
REGISTER_ACTIVATION_OP_MAKER(Abs, AbsDoc);
REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc); REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc);
REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc); REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc);
REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc); REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc);
...@@ -793,26 +785,6 @@ class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel { ...@@ -793,26 +785,6 @@ class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
} }
}; };
// AbsGrad: dx=dy if x >=0 else -dy
// AbsDoubleGrad: ddy = ddx if x >=0 else -ddx
template <typename T>
class AbsDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
public:
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("abs_grad_grad");
// input1: x
op->SetInput("X", this->Input("X"));
// input2: ddx
op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
op->SetAttrMap(this->Attrs());
// output: ddy
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
}
};
// ReluGrad: dx = dy if y >= 0 else 0 // ReluGrad: dx = dy if y >= 0 else 0
// ReluGradGrad: ddy = ddx if y >= 0 else 0 // ReluGradGrad: ddy = ddx if y >= 0 else 0
template <typename T> template <typename T>
...@@ -1322,56 +1294,6 @@ REGISTER_OP_CPU_KERNEL( ...@@ -1322,56 +1294,6 @@ REGISTER_OP_CPU_KERNEL(
ops::ExpGradFunctor<int64_t>>); ops::ExpGradFunctor<int64_t>>);
/* ========================================================================== */ /* ========================================================================== */
/* ========================== abs register ============================ */
REGISTER_OPERATOR(
abs, ops::ActivationOp, ops::AbsOpMaker, ops::ActivationOpInferVarType,
ops::ActivationGradOpMaker<ops::AbsGradFunctor<float>::FwdDeps(),
paddle::framework::OpDesc>,
ops::ActivationGradOpMaker<ops::AbsGradFunctor<float>::FwdDeps(),
paddle::imperative::OpBase>,
std::conditional<ops::CanInplaceAct<ops::AbsGradFunctor<float>>(),
ops::ActFwdInplaceInferer, void>::type);
REGISTER_OPERATOR(abs_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInferer,
ops::AbsDoubleGradMaker<paddle::framework::OpDesc>,
ops::AbsDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
abs_grad_grad,
ops::ActivationOpDoubleGrad<ops::AbsGradGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(abs,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::AbsFunctor<float>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::AbsFunctor<double>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::AbsFunctor<int>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::AbsFunctor<int64_t>>);
REGISTER_OP_CPU_KERNEL(
abs_grad, ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::AbsGradFunctor<float>>,
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::AbsGradFunctor<double>>,
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::AbsGradFunctor<int>>,
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::AbsGradFunctor<int64_t>>);
REGISTER_OP_CPU_KERNEL(
abs_grad_grad,
ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
ops::AbsGradGradFunctor<float>>,
ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
ops::AbsGradGradFunctor<double>>,
ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
ops::AbsGradGradFunctor<plat::float16>>,
ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
ops::AbsGradGradFunctor<int>>,
ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
ops::AbsGradGradFunctor<int64_t>>);
/* ========================================================================== */
/* ========================== Log register ==================================*/ /* ========================== Log register ==================================*/
REGISTER_OPERATOR( REGISTER_OPERATOR(
log, ops::ActivationOp, ops::LogOpMaker, ops::ActivationOpInferVarType, log, ops::ActivationOp, ops::LogOpMaker, ops::ActivationOpInferVarType,
......
...@@ -174,40 +174,6 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -174,40 +174,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::ExpGradFunctor<plat::float16>>); ops::ExpGradFunctor<plat::float16>>);
/* ========================================================================== */ /* ========================================================================== */
/* ========================== abs register ============================ */
REGISTER_OP_CUDA_KERNEL(
abs, ops::ActivationKernel<plat::CUDADeviceContext, ops::AbsFunctor<float>>,
ops::ActivationKernel<plat::CUDADeviceContext, ops::AbsFunctor<double>>,
ops::ActivationKernel<plat::CUDADeviceContext, ops::AbsFunctor<int>>,
ops::ActivationKernel<plat::CUDADeviceContext, ops::AbsFunctor<int64_t>>,
ops::ActivationKernel<plat::CUDADeviceContext,
ops::AbsFunctor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
abs_grad, ops::ActivationGradKernel<plat::CUDADeviceContext,
ops::AbsGradFunctor<float>>,
ops::ActivationGradKernel<plat::CUDADeviceContext,
ops::AbsGradFunctor<double>>,
ops::ActivationGradKernel<plat::CUDADeviceContext,
ops::AbsGradFunctor<int>>,
ops::ActivationGradKernel<plat::CUDADeviceContext,
ops::AbsGradFunctor<int64_t>>,
ops::ActivationGradKernel<plat::CUDADeviceContext,
ops::AbsGradFunctor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
abs_grad_grad,
ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::AbsGradGradFunctor<float>>,
ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::AbsGradGradFunctor<double>>,
ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
ops::AbsGradGradFunctor<plat::float16>>,
ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::AbsGradGradFunctor<int>>,
ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::AbsGradGradFunctor<int64_t>>);
/* ========================================================================== */
/* ========================== Log register ==================================*/ /* ========================== Log register ==================================*/
REGISTER_ACTIVATION_CUDA_KERNEL(log, Log, LogFunctor, LogGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(log, Log, LogFunctor, LogGradFunctor);
......
...@@ -793,26 +793,6 @@ struct RoundFunctor : public BaseActivationFunctor<T> { ...@@ -793,26 +793,6 @@ struct RoundFunctor : public BaseActivationFunctor<T> {
} }
}; };
// abs(x) = |x|
template <typename T>
struct AbsFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.abs();
}
};
template <typename T>
struct AbsGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * x.sign();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// reciprocal(x) = 1 / x // reciprocal(x) = 1 / x
template <typename T> template <typename T>
struct ReciprocalFunctor : public BaseActivationFunctor<T> { struct ReciprocalFunctor : public BaseActivationFunctor<T> {
......
...@@ -48,6 +48,18 @@ struct select { ...@@ -48,6 +48,18 @@ struct select {
using type = eval_if_t<Head::value, Head, select<Tail...>>; using type = eval_if_t<Head::value, Head, select<Tail...>>;
}; };
template <typename T>
struct select<T> {
using type = T;
};
template <bool B, typename T>
struct select<cond<B, T>> {
// last one had better be true!
static_assert(B, "No match select type!");
using type = T;
};
template <typename Head, typename... Tail> template <typename Head, typename... Tail>
using select_t = typename select<Head, Tail...>::type; using select_t = typename select<Head, Tail...>::type;
...@@ -63,6 +75,16 @@ using Complex = typename std::enable_if<!std::is_same<T, RealT>::value>::type; ...@@ -63,6 +75,16 @@ using Complex = typename std::enable_if<!std::is_same<T, RealT>::value>::type;
template <typename T, typename RealT> template <typename T, typename RealT>
using NoComplex = typename std::enable_if<std::is_same<T, RealT>::value>::type; using NoComplex = typename std::enable_if<std::is_same<T, RealT>::value>::type;
template <typename T>
using EnableComplex =
typename std::enable_if<std::is_same<T, platform::complex64>::value ||
std::is_same<T, platform::complex128>::value>::type;
template <typename T>
using DisableComplex = typename std::enable_if<
!std::is_same<T, platform::complex64>::value &&
!std::is_same<T, platform::complex128>::value>::type;
template <typename T, typename Enable = void> template <typename T, typename Enable = void>
struct RealFunctor; struct RealFunctor;
...@@ -99,6 +121,76 @@ struct ImagFunctor<T, Complex<T, Real<T>>> { ...@@ -99,6 +121,76 @@ struct ImagFunctor<T, Complex<T, Real<T>>> {
int64_t numel_; int64_t numel_;
}; };
template <typename T, typename Enable = void>
struct AbsFunctor;
template <typename T>
struct AbsFunctor<T, Complex<T, Real<T>>> {
AbsFunctor(const T* input, Real<T>* output, int64_t numel)
: input_(input), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
output_[idx] = abs(input_[idx]);
}
const T* input_;
Real<T>* output_;
int64_t numel_;
};
template <typename T>
struct AbsFunctor<T, NoComplex<T, Real<T>>> {
AbsFunctor(const T* input, T* output, int64_t numel)
: input_(input), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
output_[idx] = abs(input_[idx]);
}
const T* input_;
T* output_;
int64_t numel_;
};
template <typename T>
struct AbsGradFunctor {
AbsGradFunctor(const math::Real<T>* dout, const T* x, T* output,
int64_t numel)
: dout_(dout), x_(x), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
if (x_[idx] == T(0)) {
output_[idx] = T(0);
} else {
output_[idx] = T(dout_[idx]) * (x_[idx] / T(abs(x_[idx])));
}
}
const math::Real<T>* dout_;
const T* x_;
T* output_;
int64_t numel_;
};
template <typename T>
struct AbsGradGradFunctor {
AbsGradGradFunctor(const T* ddx, const T* x, T* output, int64_t numel)
: ddx_(ddx), x_(x), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
if (x_[idx] == T(0)) {
output_[idx] = T(0);
} else {
output_[idx] = T(ddx_[idx]) * x_[idx] / T(abs(x_[idx]));
}
}
const T* ddx_;
const T* x_;
T* output_;
int64_t numel_;
};
template <typename T, typename Enable = void> template <typename T, typename Enable = void>
struct RealToComplexFunctor; struct RealToComplexFunctor;
...@@ -135,16 +227,6 @@ struct ImagToComplexFunctor<T, Complex<T, Real<T>>> { ...@@ -135,16 +227,6 @@ struct ImagToComplexFunctor<T, Complex<T, Real<T>>> {
int64_t numel_; int64_t numel_;
}; };
template <typename T>
using EnableComplex =
typename std::enable_if<std::is_same<T, platform::complex64>::value ||
std::is_same<T, platform::complex128>::value>::type;
template <typename T>
using DisableComplex = typename std::enable_if<
!std::is_same<T, platform::complex64>::value &&
!std::is_same<T, platform::complex128>::value>::type;
template <typename T, typename Enable = void> template <typename T, typename Enable = void>
struct ConjFunctor; struct ConjFunctor;
......
...@@ -361,7 +361,7 @@ HOSTDEVICE inline double(abs)(const complex128& a) { ...@@ -361,7 +361,7 @@ HOSTDEVICE inline double(abs)(const complex128& a) {
#if defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__)
return thrust::abs(thrust::complex<double>(a.real, a.imag)); return thrust::abs(thrust::complex<double>(a.real, a.imag));
#else #else
return std::abs(std::complex<double>(a)); return std::abs(std::complex<double>(a.real, a.imag));
#endif #endif
} }
......
...@@ -363,7 +363,7 @@ HOSTDEVICE inline float(abs)(const complex64& a) { ...@@ -363,7 +363,7 @@ HOSTDEVICE inline float(abs)(const complex64& a) {
#if defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__)
return complex64(thrust::abs(thrust::complex<float>(a.real, a.imag))); return complex64(thrust::abs(thrust::complex<float>(a.real, a.imag)));
#else #else
return std::abs(std::complex<float>(a)); return std::abs(std::complex<float>(a.real, a.imag));
#endif #endif
} }
......
...@@ -899,6 +899,16 @@ HOSTDEVICE inline bool(isfinite)(const float16& a) { ...@@ -899,6 +899,16 @@ HOSTDEVICE inline bool(isfinite)(const float16& a) {
return !((isnan)(a)) && !((isinf)(a)); return !((isnan)(a)) && !((isinf)(a));
} }
HOSTDEVICE inline float16(abs)(const float16& a) {
#if (defined(PADDLE_CUDA_FP16) && \
((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
(defined(__HIP_DEVICE_COMPILE__))))
return float16(::fabs(static_cast<float>(a)));
#else
return float16(std::abs(static_cast<float>(a)));
#endif
}
inline std::ostream& operator<<(std::ostream& os, const float16& a) { inline std::ostream& operator<<(std::ostream& os, const float16& a) {
os << static_cast<float>(a); os << static_cast<float>(a);
return os; return os;
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function, division
import unittest
import numpy as np
import paddle
from op_test import OpTest
class TestComplexAbsOp(OpTest):
def setUp(self):
paddle.enable_static()
self.op_type = "abs"
self.dtype = np.float64
self.shape = (2, 3, 4, 5)
self.init_input_output()
self.init_grad_input_output()
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(self.x)}
self.outputs = {'Out': self.out}
def init_input_output(self):
self.x = np.random.random(self.shape).astype(
self.dtype) + 1J * np.random.random(self.shape).astype(self.dtype)
self.out = np.abs(self.x)
def init_grad_input_output(self):
self.grad_out = np.ones(self.shape, self.dtype)
self.grad_x = self.grad_out * (self.x / np.abs(self.x))
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(
['X'],
'Out',
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out])
class TestComplexAbsOpZeroValues(OpTest):
def setUp(self):
paddle.enable_static()
self.op_type = "abs"
self.dtype = np.float64
self.shape = (2, 3, 4, 5)
self.init_input_output()
self.init_grad_input_output()
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(self.x)}
self.outputs = {'Out': self.out}
def init_input_output(self):
self.x = np.zeros(self.shape).astype(self.dtype) + 1J * np.zeros(
self.shape).astype(self.dtype)
self.out = np.abs(self.x)
def init_grad_input_output(self):
self.grad_out = np.ones(self.shape, self.dtype)
self.grad_x = np.zeros(self.shape, self.dtype)
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(
['X'],
'Out',
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册