未验证 提交 2d3cbb49 编写于 作者: L levi131 提交者: GitHub

Add lgamma_op kernel and unittest (#32913)

* run pre-commit

* use HOST or DEVICE instead of HOSTDEVICE in implementation of lgamma op

* add test for fp32

* add lgamma to op_threshold_white_list

* add cuda kernel for lgamma kernel

* modify numeric grad delta

* fix small English issue

* change LaunchElementwiseCudaKernel to LaunchSameDimsElementwiseCudaKernel
上级 5756d3e5
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/lgamma_op.h"
namespace paddle {
namespace operators {
class LgammaOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of lgamma op.");
AddOutput("Out", "(Tensor), The output tensor of lgamma op.");
AddComment(R"DOC(
Lgamma Operator.
This operator performs elementwise lgamma for input $X$.
$$out = log\Gamma(x)$$
)DOC");
}
};
class LgammaOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Lgamma");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Lgamma");
auto in_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", in_dims);
ctx->ShareLoD("X", "Out");
}
};
template <typename T>
class LgammaGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("lgamma_grad");
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetInput("X", this->Input("X"));
retv->SetAttrMap(this->Attrs());
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
class LgammaGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@Grad", "LgammaGrad");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "LgammaGrad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
"X@Grad", "LgammaGrad");
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
ctx->SetOutputDim(framework::GradVarName("X"), dout_dims);
ctx->ShareLoD(framework::GradVarName("Out"), framework::GradVarName("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(lgamma, ops::LgammaOp, ops::LgammaOpMaker,
ops::LgammaGradMaker<paddle::framework::OpDesc>,
ops::LgammaGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(lgamma_grad, ops::LgammaGradOp);
REGISTER_OP_CPU_KERNEL(
lgamma, ops::LgammaKernel<paddle::platform::CPUDeviceContext, float>,
ops::LgammaKernel<paddle::platform::CPUDeviceContext, double>)
REGISTER_OP_CPU_KERNEL(
lgamma_grad,
ops::LgammaGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::LgammaGradKernel<paddle::platform::CPUDeviceContext, double>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <unsupported/Eigen/SpecialFunctions>
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/lgamma_op.h"
#include "paddle/fluid/operators/math/complex_functors.h"
namespace paddle {
namespace operators {
template <typename T, typename Enable = void>
struct CudaLgammaFunctor;
template <typename T>
struct CudaLgammaFunctor<T, math::NoComplex<T, math::Real<T>>> {
__device__ __forceinline__ T operator()(const T* args) const {
return Eigen::numext::lgamma(args[0]);
}
};
template <typename T>
class LgammaKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
out->mutable_data<math::Real<T>>(context.GetPlace());
auto& dev_ctx = context.device_context<platform::CUDADeviceContext>();
std::vector<const framework::Tensor*> ins = {x};
std::vector<framework::Tensor*> outs = {out};
auto functor = CudaLgammaFunctor<T>();
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, T,
math::Real<T>>(dev_ctx, ins, &outs,
functor);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
lgamma, ops::LgammaKernel<paddle::platform::CUDADeviceContext, float>,
ops::LgammaKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
lgamma_grad,
ops::LgammaGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::LgammaGradKernel<paddle::platform::CUDADeviceContext, double>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <unsupported/Eigen/SpecialFunctions>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
template <typename T>
struct LgammaFunctor {
LgammaFunctor(const T* input, T* output, int64_t numel)
: input_(input), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
output_[idx] = Eigen::numext::lgamma(input_[idx]);
}
private:
const T* input_;
T* output_;
int64_t numel_;
};
template <typename T>
struct LgammaGradFunctor {
LgammaGradFunctor(const T* dout, const T* x, T* output, int64_t numel)
: dout_(dout), x_(x), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
output_[idx] = dout_[idx] * Eigen::numext::digamma(x_[idx]);
}
private:
const T* dout_;
const T* x_;
T* output_;
int64_t numel_;
};
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class LgammaKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
auto numel = x->numel();
auto* x_data = x->data<T>();
auto* out_data = out->mutable_data<T>(context.GetPlace(),
size_t(x->numel() * sizeof(T)));
auto& dev_ctx = context.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
LgammaFunctor<T> functor(x_data, out_data, numel);
for_range(functor);
}
};
template <typename DeviceContext, typename T>
class LgammaGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
const framework::Tensor* d_out =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
const framework::Tensor* x = ctx.Input<framework::Tensor>("X");
framework::Tensor* d_x =
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto numel = d_out->numel();
auto* dout_data = d_out->data<T>();
auto* x_data = x->data<T>();
auto* dx_data = d_x->mutable_data<T>(
ctx.GetPlace(), static_cast<size_t>(numel * sizeof(T)));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
LgammaGradFunctor<T> functor(dout_data, x_data, dx_data, numel);
for_range(functor);
}
};
} // namespace operators
} // namespace paddle
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import math
import numpy as np
import paddle
from op_test import OpTest
paddle.enable_static()
class TestLgammaOp(OpTest):
def setUp(self):
self.op_type = 'lgamma'
self.init_dtype_type()
shape = (5, 20)
data = np.random.random(shape).astype(self.dtype) + 1
self.inputs = {'X': data}
result = np.ones(shape).astype(self.dtype)
for i in range(shape[0]):
for j in range(shape[1]):
result[i][j] = math.lgamma(data[i][j])
self.outputs = {'Out': result}
def init_dtype_type(self):
self.dtype = np.float64
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out', numeric_grad_delta=1e-7)
class TestLgammaOpFp32(TestLgammaOp):
def init_dtype_type(self):
self.dtype = np.float32
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out', numeric_grad_delta=0.005)
if __name__ == "__main__":
unittest.main()
......@@ -45,6 +45,7 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [
'bilateral_slice',\
'cudnn_lstm', \
'rnn', \
'lgamma', \
]
NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = ['bilinear_interp',\
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册