提交 c552d1ac 编写于 作者: P phlrain

add forward case

上级 4e23ac69
......@@ -1650,9 +1650,6 @@ REGISTER_OPERATOR(logit, ops::LogitOp, ops::LogitOpMaker,
ops::LogitGradOpMaker<paddle::framework::OpDesc>,
ops::LogitGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(logit_grad, ops::LogitGradOp);
REGISTER_OP_CPU_KERNEL(
logit, ops::LogitKernel<paddle::platform::CPUDeviceContext, float>,
ops::LogitKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
logit_grad, ops::LogitGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::LogitGradKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -1830,24 +1827,6 @@ REGISTER_OPERATOR(
REGISTER_OPERATOR(exp_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(exp,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::ExpFunctor<float>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::ExpFunctor<double>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::ExpFunctor<int>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::ExpFunctor<int64_t>>);
REGISTER_OP_CPU_KERNEL(
exp_grad, ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::ExpGradFunctor<float>>,
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::ExpGradFunctor<double>>,
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::ExpGradFunctor<int>>,
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::ExpGradFunctor<int64_t>>);
/* ========================================================================== */
/* ========================== expm1 register ============================ */
......@@ -1862,13 +1841,6 @@ REGISTER_OPERATOR(
REGISTER_OPERATOR(expm1_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(expm1,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::Expm1Functor<float>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::Expm1Functor<double>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::Expm1Functor<plat::float16>>);
REGISTER_OP_CPU_KERNEL(
expm1_grad, ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::Expm1GradFunctor<float>>,
......
......@@ -273,6 +273,7 @@ USE_PHI_FUNCTOR(Asinh)
USE_PHI_FUNCTOR(Acosh)
USE_PHI_FUNCTOR(Atanh)
USE_PHI_FUNCTOR(Tanh)
USE_PHI_FUNCTOR(Exp)
USE_PHI_DOUBLE_GRAD_FUNCTOR(Tanh)
USE_PHI_TRIPLE_GRAD_FUNCTOR(Tanh)
USE_PHI_FUNCTOR(BRelu)
......@@ -455,37 +456,6 @@ struct LogSigmoidGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
// exp(x) = e^x
template <typename T>
struct ExpFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.exp();
}
};
template <typename T>
struct ExpGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * out;
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
// expm1(x) = e^x - 1
template <typename T>
struct Expm1Functor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.expm1();
}
};
template <typename T>
struct Expm1GradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
......@@ -605,15 +575,6 @@ struct SoftShrinkGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
// sqrt(x) = x^(1/2)
template <typename T>
struct SqrtFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.sqrt();
}
};
template <typename T>
struct SqrtGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
......@@ -627,15 +588,6 @@ struct SqrtGradFunctor : public BaseActivationFunctor<T> {
}
};
// rsqrt(x) = x^(-1/2)
template <typename T>
struct RsqrtFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.rsqrt();
}
};
template <typename T>
struct RsqrtGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
......@@ -689,15 +641,6 @@ struct RoundFunctor : public BaseActivationFunctor<T> {
}
};
// reciprocal(x) = 1 / x
template <typename T>
struct ReciprocalFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = static_cast<T>(1) / x;
}
};
template <typename T>
struct ReciprocalGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
......@@ -793,15 +736,6 @@ struct Log1pGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
// square(x) = x^2
template <typename T>
struct SquareFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.square();
}
};
template <typename T>
struct SquareGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
......@@ -894,27 +828,6 @@ struct HardSwishGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
// For numerical stability, using the following formula instead of softplus(x) =
// log(1 + exp(x))
// softplus(x) = log(1 + exp(beta * x)) / beta when beta * x <= threshold(beta =
// 1, threshold = 20 by default), otherwise x
template <typename T>
struct SoftplusFunctor : public BaseActivationFunctor<T> {
float beta;
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}, {"threshold", &threshold}};
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) {
auto x_beta = static_cast<T>(beta) * x;
out.device(d) = (x_beta > static_cast<T>(threshold))
.select(x, (static_cast<T>(1) + x_beta.exp()).log() /
static_cast<T>(beta));
}
};
// For numerical stability, using the following formula instead of
// d(softplus(x))/dx = 1 / (1 + exp(-x))
// d(softplus(x))/dx = 1 / (1 + exp(-beta * x)) when beta * x <= threshold(beta
......@@ -939,24 +852,6 @@ struct SoftplusGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
// mish(x) = x * tanh(softplus(x))
// softplus(x) = x, if x > threshold
// = ln(1 + exp(x)), otherwise
template <typename T>
struct MishFunctor : public BaseActivationFunctor<T> {
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) {
auto sp = (x > static_cast<T>(threshold))
.select(x, (static_cast<T>(1) + x.exp()).log());
out.device(d) = x * sp.tanh();
}
};
// dx = dout * (tanh(sp) + x * (1 - tanh(sp) ** 2) * (1 - exp(-sp)))
// sp = softplus(x)
template <typename T>
......@@ -979,15 +874,6 @@ struct MishGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
// softsign(x) = x / (1 + |x|)
template <typename T>
struct SoftsignFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) {
out.device(d) = x / (static_cast<T>(1) + x.abs());
}
};
// d(softsign(x))/dx = 1 / (1 + |x|)^2
// Taken from https://en.wikipedia.org/wiki/Activation_function
template <typename T>
......@@ -1198,24 +1084,6 @@ struct PowGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct LogitFunctor {
template <typename Device, typename X, typename Out, typename P>
void operator()(Device d, X x, Out out, P p, float eps) const {
// logit(x) = ln(x/(1-x))
auto tmp_x =
(x.cwiseMin(static_cast<T>(1.0 - eps))).cwiseMax(static_cast<T>(eps));
if (!eps) {
out.device(d) = (x < static_cast<T>(0.0) || x > static_cast<T>(1.0))
.select(p.constant(static_cast<T>(NAN)),
(tmp_x / (static_cast<T>(1) - tmp_x)).log());
} else {
out.device(d) = (tmp_x / (static_cast<T>(1) - tmp_x)).log();
}
}
};
template <typename T>
struct LogitGradFunctor {
template <typename Device, typename X, typename dOut, typename dX, typename P>
......@@ -1228,21 +1096,6 @@ struct LogitGradFunctor {
}
};
template <typename T>
struct STanhFunctor : public BaseActivationFunctor<T> {
float scale_a;
float scale_b;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) =
static_cast<T>(scale_b) * (static_cast<T>(scale_a) * x).tanh();
}
};
template <typename T>
struct STanhGradFunctor : public BaseActivationFunctor<T> {
float scale_a;
......@@ -2075,26 +1928,6 @@ class PowGradKernel
}
};
template <typename DeviceContext, typename T>
class LogitKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out = context.Output<framework::Tensor>("Out");
auto* in = context.Input<framework::Tensor>("X");
auto eps = context.Attr<float>("eps");
out->mutable_data<T>(in->place());
auto eigen_out = framework::EigenVector<T>::Flatten(*out);
auto eigen_in = framework::EigenVector<T>::Flatten(*in);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto eigen_p = framework::EigenVector<T>::Flatten(*out);
LogitFunctor<T> functor;
functor(place, eigen_in, eigen_out, eigen_p, eps);
}
};
template <typename DeviceContext, typename T>
class LogitGradKernel : public framework::OpKernel<T> {
public:
......
......@@ -192,14 +192,6 @@ struct CudaZeroGradFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct CudaReciprocalFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);
// reciprocal(x) = 1 / x
__device__ __forceinline__ T operator()(const T x) const { return one / x; }
};
template <typename T>
struct CudaReciprocalGradFunctor : public BaseActivationFunctor<T> {
// dx = -dout * out^2
......@@ -212,40 +204,6 @@ struct CudaReciprocalGradFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct CudaExpFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// exp(x) = exp(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(exp(x));
}
};
template <typename T>
struct CudaExpGradFunctor : public BaseActivationFunctor<T> {
// dx = dout * out
__device__ __forceinline__ T operator()(const T dout, const T out) const {
return dout * out;
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct CudaExpm1Functor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// expm1(x) = expm1(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(expm1(x));
}
};
template <typename T>
struct CudaExpm1GradFunctor : public BaseActivationFunctor<T> {
// dx = dout * out
......@@ -279,12 +237,6 @@ struct CudaLogGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaSquareFunctor : public BaseActivationFunctor<T> {
// square(x) = x * x
__device__ __forceinline__ T operator()(const T x) const { return x * x; }
};
template <typename T>
struct CudaSquareGradFunctor : public BaseActivationFunctor<T> {
T two = static_cast<T>(2.0f);
......@@ -297,17 +249,6 @@ struct CudaSquareGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaSqrtFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// sqrt(x) = sqrt(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(sqrt(x));
}
};
template <typename T>
struct CudaSqrtGradFunctor : public BaseActivationFunctor<T> {
T one_half = static_cast<T>(0.5f);
......@@ -322,17 +263,6 @@ struct CudaSqrtGradFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct CudaRsqrtFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// rsqrt(x) = rsqrt(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(rsqrt(x));
}
};
template <typename T>
struct CudaRsqrtGradFunctor : public BaseActivationFunctor<T> {
T minus_one_half = static_cast<T>(-0.5f);
......@@ -466,25 +396,6 @@ struct CudaSoftReluGradFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct CudaSTanhFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
float scale_a;
float scale_b;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
}
// stanh(x) = b * tanh(a * x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
MPType a = static_cast<MPType>(scale_a);
MPType b = static_cast<MPType>(scale_b);
return static_cast<T>(b * tanh(a * x));
}
};
template <typename T>
struct CudaSTanhGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
......@@ -510,27 +421,6 @@ struct CudaSTanhGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaSoftplusFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float beta;
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}, {"threshold", &threshold}};
}
// softplus(x) = beta * x > threshold ? x : log(1 + exp(beta * x)) / beta
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
MPType b = static_cast<MPType>(beta);
MPType t = static_cast<MPType>(threshold);
MPType x_beta = x * beta;
return static_cast<T>(x_beta > t ? x : log(one + exp(x_beta)) / b);
}
};
template <typename T>
struct CudaSoftplusGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
......@@ -556,16 +446,6 @@ struct CudaSoftplusGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaSoftsignFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);
// softsign(x) = x / (1 + abs(x))
__device__ __forceinline__ T operator()(const T x) const {
return x / (one + abs(x));
}
};
template <typename T>
struct CudaSoftsignGradFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);
......@@ -762,27 +642,6 @@ struct CudaSwishGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaMishFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
// mish(x) = x * tanh(softplus(x))
// softplus(x) = x, if x > threshold
// = ln(1 + exp(x)), otherwise
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
MPType sp = (x > static_cast<MPType>(threshold)) ? x : log(one + exp(x));
return static_cast<T>(x * tanh(sp));
}
};
template <typename T>
struct CudaMishGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
......@@ -1292,11 +1151,6 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================== logit register ============================ */
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
logit, ops::LogitKernel<paddle::platform::CUDADeviceContext, float>,
ops::LogitKernel<paddle::platform::CUDADeviceContext, double>,
ops::LogitKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
logit_grad,
ops::LogitGradKernel<paddle::platform::CUDADeviceContext, float>,
......
......@@ -279,6 +279,46 @@ struct SelectedRowsAddToTensor<platform::CPUDeviceContext, T> {
}
};
template <typename T>
struct SelectedRowsAddToTensor<phi::CPUContext, T> {
void operator()(const phi::CPUContext& context,
const phi::SelectedRows& input1, framework::Tensor* input2) {
if (UNLIKELY(input1.rows().size() == 0)) {
LOG(WARNING) << "input selected rows is empty!";
return;
}
auto in1_height = input1.height();
auto in2_dims = input2->dims();
PADDLE_ENFORCE_EQ(
in1_height, in2_dims[0],
platform::errors::InvalidArgument("The two inputs height must be equal."
"But recieved first input height = "
"[%d], second input height = [%d]",
in1_height, in2_dims[0]));
auto& in1_value = input1.value();
auto& in1_rows = input1.rows();
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(
in1_row_numel, input2->numel() / in1_height,
platform::errors::InvalidArgument(
"The two inputs width must be equal."
"But recieved first input width = [%d], second input width = [%d]",
in1_row_numel, input2->numel() / in1_height));
auto* in1_data = in1_value.data<T>();
auto* input2_data = input2->data<T>();
for (size_t i = 0; i < in1_rows.size(); i++) {
for (int64_t j = 0; j < in1_row_numel; j++) {
input2_data[in1_rows[i] * in1_row_numel + j] +=
in1_data[i * in1_row_numel + j];
}
}
}
};
template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, float>;
template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, double>;
template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, int>;
......@@ -286,6 +326,11 @@ template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, int64_t>;
template struct SelectedRowsAddToTensor<platform::CPUDeviceContext,
platform::bfloat16>;
template struct SelectedRowsAddToTensor<phi::CPUContext, float>;
template struct SelectedRowsAddToTensor<phi::CPUContext, double>;
template struct SelectedRowsAddToTensor<phi::CPUContext, int>;
template struct SelectedRowsAddToTensor<phi::CPUContext, int64_t>;
template struct SelectedRowsAddToTensor<phi::CPUContext, platform::bfloat16>;
// This is a separated namespace for manipulate SelectedRows typed
// data. Like merge duplicated rows, adding two SelectedRows etc.
//
......@@ -294,30 +339,30 @@ template struct SelectedRowsAddToTensor<platform::CPUDeviceContext,
// add or mul.
namespace scatter {
template <typename T>
template <typename T, typename DeviceContext>
typename std::enable_if<!std::is_integral<T>::value>::type elementwise_add_to(
phi::funcs::BlasT<platform::CPUDeviceContext, T>* blas, size_t data_len,
const T* in, T* out) {
phi::funcs::BlasT<DeviceContext, T>* blas, size_t data_len, const T* in,
T* out) {
blas->AXPY(data_len, T(1.f), in, out);
}
template <typename T>
template <typename T, typename DeviceContext>
typename std::enable_if<std::is_integral<T>::value>::type elementwise_add_to(
phi::funcs::BlasT<platform::CPUDeviceContext, T>* blas, size_t data_len,
const T* in, T* out) {
phi::funcs::BlasT<DeviceContext, T>* blas, size_t data_len, const T* in,
T* out) {
for (size_t i = 0; i < data_len; i++) {
out[i] += in[i];
}
}
template <typename T>
template <typename T, typename DeviceContext>
typename std::enable_if<std::is_same<T, platform::bfloat16>::value>::type
add_sparse_inputs(const std::vector<const phi::SelectedRows*>& inputs,
const std::unordered_map<int64_t, size_t>& rows_to_id,
int64_t input_width,
const platform::CPUDeviceContext& context, T* out_data) {
int64_t input_width, const DeviceContext& context,
T* out_data) {
#ifndef PADDLE_WITH_MKLDNN
auto blas = phi::funcs::GetBlas<platform::CPUDeviceContext, T>(context);
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
#endif
for (auto* input : inputs) {
if (input->rows().size() == 0) {
......@@ -336,22 +381,22 @@ add_sparse_inputs(const std::vector<const phi::SelectedRows*>& inputs,
#else
for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = rows_to_id.at(input_rows[i]);
elementwise_add_to<T>(&blas, static_cast<size_t>(input_width),
&input_data[i * input_width],
&out_data[out_i * input_width]);
elementwise_add_to<T, DeviceContext>(
&blas, static_cast<size_t>(input_width), &input_data[i * input_width],
&out_data[out_i * input_width]);
}
#endif
}
}
template <typename T>
template <typename T, typename DeviceContext>
typename std::enable_if<!std::is_same<T, platform::bfloat16>::value>::type
add_sparse_inputs(const std::vector<const phi::SelectedRows*>& inputs,
const std::unordered_map<int64_t, size_t>& rows_to_id,
int64_t input_width,
const platform::CPUDeviceContext& context, T* out_data) {
int64_t input_width, const DeviceContext& context,
T* out_data) {
VLOG(4) << "[CPU] add_sparse_inputs <" << typeid(T).name();
auto blas = phi::funcs::GetBlas<platform::CPUDeviceContext, T>(context);
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
for (auto* input : inputs) {
if (input->rows().size() == 0) {
continue;
......@@ -361,16 +406,16 @@ add_sparse_inputs(const std::vector<const phi::SelectedRows*>& inputs,
for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = rows_to_id.at(input_rows[i]);
elementwise_add_to<T>(&blas, static_cast<size_t>(input_width),
&input_data[i * input_width],
&out_data[out_i * input_width]);
elementwise_add_to<T, DeviceContext>(
&blas, static_cast<size_t>(input_width), &input_data[i * input_width],
&out_data[out_i * input_width]);
}
}
}
template <typename T>
struct MergeAdd<platform::CPUDeviceContext, T> {
phi::SelectedRows operator()(const platform::CPUDeviceContext& context,
template <typename DeviceContext, typename T>
struct MergeAddImpl {
phi::SelectedRows operator()(const DeviceContext& context,
const phi::SelectedRows& input,
const bool sorted_result = false) {
phi::SelectedRows out;
......@@ -378,15 +423,14 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
return out;
}
void operator()(const platform::CPUDeviceContext& context,
const phi::SelectedRows& input, phi::SelectedRows* output,
const bool sorted_result = false) {
void operator()(const DeviceContext& context, const phi::SelectedRows& input,
phi::SelectedRows* output, const bool sorted_result = false) {
std::vector<const phi::SelectedRows*> inputs;
inputs.push_back(&input);
(*this)(context, inputs, output, sorted_result);
}
void operator()(const platform::CPUDeviceContext& context,
void operator()(const DeviceContext& context,
const std::vector<const phi::SelectedRows*>& inputs,
phi::SelectedRows* output, const bool sorted_result = false) {
if (inputs.size() == 0) {
......@@ -461,7 +505,7 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
out.set_rows(merge_rows);
phi::funcs::SetConstant<platform::CPUDeviceContext, T> constant_functor;
phi::funcs::SetConstant<DeviceContext, T> constant_functor;
constant_functor(context, out.mutable_value(), static_cast<T>(0.f));
std::unordered_map<int64_t, size_t> rows_to_id;
......@@ -469,11 +513,75 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
rows_to_id[merge_rows[i]] = i;
}
add_sparse_inputs<T>(inputs, rows_to_id, input_width, context, out_data);
add_sparse_inputs<T, DeviceContext>(inputs, rows_to_id, input_width,
context, out_data);
}
}
};
template <typename T>
struct MergeAdd<platform::CPUDeviceContext, T> {
// unary functor, merge by adding duplicated rows in
// the input SelectedRows object.
phi::SelectedRows operator()(const platform::CPUDeviceContext& context,
const phi::SelectedRows& input,
const bool sorted_result) {
return MergeAddImpl<platform::CPUDeviceContext, T>()(context, input,
sorted_result);
}
void operator()(const platform::CPUDeviceContext& context,
const phi::SelectedRows& input, phi::SelectedRows* output,
const bool sorted_result) {
MergeAddImpl<platform::CPUDeviceContext, T>()(context, input, output,
sorted_result);
}
void operator()(const platform::CPUDeviceContext& context,
const std::vector<const phi::SelectedRows*>& inputs,
phi::SelectedRows* output, const bool sorted_result) {
MergeAddImpl<platform::CPUDeviceContext, T>()(context, inputs, output,
sorted_result);
}
};
template <typename T>
struct MergeAdd<phi::CPUContext, T> {
// unary functor, merge by adding duplicated rows in
// the input SelectedRows object.
phi::SelectedRows operator()(const phi::CPUContext& context,
const phi::SelectedRows& input,
const bool sorted_result) {
return MergeAddImpl<phi::CPUContext, T>()(context, input, sorted_result);
}
void operator()(const phi::CPUContext& context,
const phi::SelectedRows& input, phi::SelectedRows* output,
const bool sorted_result) {
MergeAddImpl<phi::CPUContext, T>()(context, input, output, sorted_result);
}
void operator()(const phi::CPUContext& context,
const std::vector<const phi::SelectedRows*>& inputs,
phi::SelectedRows* output, const bool sorted_result) {
MergeAddImpl<phi::CPUContext, T>()(context, inputs, output, sorted_result);
}
};
#define TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(dtype) \
template struct MergeAddImpl<platform::CPUDeviceContext, dtype>; \
template struct MergeAddImpl<phi::CPUContext, dtype>; \
template struct MergeAdd<platform::CPUDeviceContext, dtype>; \
template struct MergeAdd<phi::CPUContext, dtype>;
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(float)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(double)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(int)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(int64_t)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(platform::bfloat16)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(platform::complex<float>)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(platform::complex<double>)
#ifdef PADDLE_WITH_XPU
template <typename T>
struct MergeAdd<platform::XPUDeviceContext, T> {
......@@ -714,17 +822,6 @@ struct MergeAverage<platform::CPUDeviceContext, T> {
}
};
template struct MergeAdd<platform::CPUDeviceContext, int>;
template struct MergeAdd<platform::CPUDeviceContext, int64_t>;
template struct MergeAdd<platform::CPUDeviceContext, float>;
template struct MergeAdd<platform::CPUDeviceContext, double>;
template struct MergeAdd<platform::CPUDeviceContext,
paddle::platform::complex<float>>;
template struct MergeAdd<platform::CPUDeviceContext,
paddle::platform::complex<double>>;
template struct MergeAdd<platform::CPUDeviceContext,
paddle::platform::bfloat16>;
#ifdef PADDLE_WITH_XPU
template struct MergeAdd<platform::XPUDeviceContext, float>;
#endif
......
......@@ -174,12 +174,77 @@ struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {
}
};
template <typename T>
struct SelectedRowsAddTensor<phi::GPUContext, T> {
void operator()(const phi::GPUContext& context,
const phi::SelectedRows& input1,
const framework::Tensor& input2, framework::Tensor* output) {
auto in1_height = input1.height();
auto in2_dims = input2.dims();
auto out_dims = output->dims();
PADDLE_ENFORCE_EQ(
in1_height, in2_dims[0],
platform::errors::InvalidArgument(
"The two inputs height must be equal."
"But recieved first input height = [%d], first input height = [%d]",
in1_height, in2_dims[0]));
PADDLE_ENFORCE_EQ(
in1_height, out_dims[0],
platform::errors::InvalidArgument(
"The input and output height must be equal."
"But recieved input height = [%d], output height = [%d]",
in1_height, out_dims[0]));
auto& in1_value = input1.value();
auto& in1_rows = input1.rows();
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(
in1_row_numel, input2.numel() / in1_height,
platform::errors::InvalidArgument(
"The two inputs width must be equal."
"But recieved first input width = [%d], second input width = [%d]",
in1_row_numel, input2.numel() / in1_height));
PADDLE_ENFORCE_EQ(
in1_row_numel, output->numel() / in1_height,
platform::errors::InvalidArgument(
"The input and output width must be equal."
"But recieved input width = [%d], output width = [%d]",
in1_row_numel, output->numel() / in1_height));
auto* in1_data = in1_value.data<T>();
auto* in2_data = input2.data<T>();
auto* out_data = output->data<T>();
phi::funcs::SetConstant<phi::GPUContext, T> functor;
functor(context, output, static_cast<T>(0));
const int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid(in1_rows.size(), 1);
paddle::framework::MixVector<int64_t> mixv_in1_rows(&in1_rows);
SelectedRowsAddTensorKernel<
T, block_size><<<grid, threads, 0, context.stream()>>>(
in1_data, mixv_in1_rows.CUDAData(context.GetPlace()), out_data,
in1_row_numel);
auto out_eigen = framework::EigenVector<T>::Flatten(*output);
auto in2_eigen = framework::EigenVector<T>::Flatten(input2);
out_eigen.device(*context.eigen_device()) = out_eigen + in2_eigen;
}
};
template struct SelectedRowsAddTensor<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddTensor<platform::CUDADeviceContext, double>;
template struct SelectedRowsAdd<platform::CUDADeviceContext, platform::float16>;
template struct SelectedRowsAddTensor<platform::CUDADeviceContext,
platform::float16>;
template struct SelectedRowsAddTensor<phi::GPUContext, float>;
template struct SelectedRowsAddTensor<phi::GPUContext, double>;
template struct SelectedRowsAdd<phi::GPUContext, platform::float16>;
template struct SelectedRowsAddTensor<phi::GPUContext, platform::float16>;
template <typename T>
struct SelectedRowsAddTo<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context,
......@@ -285,12 +350,54 @@ struct SelectedRowsAddToTensor<platform::CUDADeviceContext, T> {
}
};
template <typename T>
struct SelectedRowsAddToTensor<phi::GPUContext, T> {
void operator()(const phi::GPUContext& context,
const phi::SelectedRows& input1, framework::Tensor* input2) {
auto in1_height = input1.height();
auto in2_dims = input2->dims();
PADDLE_ENFORCE_EQ(
in1_height, in2_dims[0],
platform::errors::InvalidArgument("The two inputs height must be equal."
"But recieved first input height = "
"[%d], second input height = [%d]",
in1_height, in2_dims[0]));
auto& in1_value = input1.value();
auto& in1_rows = input1.rows();
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(
in1_row_numel, input2->numel() / in1_height,
platform::errors::InvalidArgument(
"The two inputs width must be equal."
"But recieved first input width = [%d], second input width = [%d]",
in1_row_numel, input2->numel() / in1_height));
auto* in1_data = in1_value.data<T>();
auto* in2_data = input2->data<T>();
const int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid(in1_rows.size(), 1);
paddle::framework::MixVector<int64_t> mixv_in1_rows(&in1_rows);
SelectedRowsAddToTensorKernel<
T, block_size><<<grid, threads, 0, context.stream()>>>(
in1_data, mixv_in1_rows.CUDAData(context.GetPlace()), in2_data,
in1_row_numel);
}
};
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, double>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int64_t>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext,
platform::float16>;
template struct SelectedRowsAddToTensor<phi::GPUContext, float>;
template struct SelectedRowsAddToTensor<phi::GPUContext, double>;
template struct SelectedRowsAddToTensor<phi::GPUContext, int>;
template struct SelectedRowsAddToTensor<phi::GPUContext, int64_t>;
template struct SelectedRowsAddToTensor<phi::GPUContext, platform::float16>;
namespace scatter {
......@@ -319,9 +426,9 @@ __global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
}
}
template <typename T>
struct MergeAdd<platform::CUDADeviceContext, T> {
phi::SelectedRows operator()(const platform::CUDADeviceContext& context,
template <typename DeviceContext, typename T>
struct MergeAddImpl {
phi::SelectedRows operator()(const DeviceContext& context,
const phi::SelectedRows& input,
const bool sorted_result = false) {
phi::SelectedRows out;
......@@ -329,9 +436,8 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
return out;
}
void operator()(const platform::CUDADeviceContext& context,
const phi::SelectedRows& input, phi::SelectedRows* output,
const bool sorted_result = false) {
void operator()(const DeviceContext& context, const phi::SelectedRows& input,
phi::SelectedRows* output, const bool sorted_result = false) {
framework::Vector<int64_t> input_rows(input.rows());
if (input_rows.size() == 0) {
return;
......@@ -350,7 +456,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
phi::make_ddim({static_cast<int64_t>(merge_rows.size()), input_width}),
context.GetPlace());
phi::funcs::SetConstant<platform::CUDADeviceContext, T> constant_functor;
phi::funcs::SetConstant<DeviceContext, T> constant_functor;
constant_functor(context, out.mutable_value(), static_cast<T>(0));
auto* out_data = out.mutable_value()->data<T>();
......@@ -369,7 +475,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
mix_vector_out.CopyToCPU();
}
void operator()(const platform::CUDADeviceContext& context,
void operator()(const DeviceContext& context,
const std::vector<const phi::SelectedRows*>& inputs,
phi::SelectedRows* output, const bool sorted_result = false) {
if (inputs.size() == 0) {
......@@ -414,7 +520,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
phi::make_ddim({static_cast<int64_t>(merge_rows.size()), input_width}),
context.GetPlace());
phi::funcs::SetConstant<platform::CUDADeviceContext, T> constant_functor;
phi::funcs::SetConstant<DeviceContext, T> constant_functor;
constant_functor(context, out.mutable_value(), static_cast<T>(0));
auto* out_data = out.mutable_value()->data<T>();
......@@ -441,15 +547,69 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
}
};
template struct MergeAdd<platform::CUDADeviceContext, float>;
template struct MergeAdd<platform::CUDADeviceContext, double>;
template struct MergeAdd<platform::CUDADeviceContext, int>;
template struct MergeAdd<platform::CUDADeviceContext, int64_t>;
template struct MergeAdd<platform::CUDADeviceContext, platform::float16>;
template struct MergeAdd<platform::CUDADeviceContext, platform::bfloat16>;
template struct MergeAdd<platform::CUDADeviceContext, platform::complex<float>>;
template struct MergeAdd<platform::CUDADeviceContext,
platform::complex<double>>;
template <typename T>
struct MergeAdd<platform::CUDADeviceContext, T> {
// unary functor, merge by adding duplicated rows in
// the input SelectedRows object.
phi::SelectedRows operator()(const platform::CUDADeviceContext& context,
const phi::SelectedRows& input,
const bool sorted_result) {
return MergeAddImpl<platform::CUDADeviceContext, T>()(context, input,
sorted_result);
}
void operator()(const platform::CUDADeviceContext& context,
const phi::SelectedRows& input, phi::SelectedRows* output,
const bool sorted_result) {
MergeAddImpl<platform::CUDADeviceContext, T>()(context, input, output,
sorted_result);
}
void operator()(const platform::CUDADeviceContext& context,
const std::vector<const phi::SelectedRows*>& inputs,
phi::SelectedRows* output, const bool sorted_result) {
MergeAddImpl<platform::CUDADeviceContext, T>()(context, inputs, output,
sorted_result);
}
};
template <typename T>
struct MergeAdd<phi::GPUContext, T> {
// unary functor, merge by adding duplicated rows in
// the input SelectedRows object.
phi::SelectedRows operator()(const phi::GPUContext& context,
const phi::SelectedRows& input,
const bool sorted_result) {
return MergeAddImpl<phi::GPUContext, T>()(context, input, sorted_result);
}
void operator()(const phi::GPUContext& context,
const phi::SelectedRows& input, phi::SelectedRows* output,
const bool sorted_result) {
MergeAddImpl<phi::GPUContext, T>()(context, input, output, sorted_result);
}
void operator()(const phi::GPUContext& context,
const std::vector<const phi::SelectedRows*>& inputs,
phi::SelectedRows* output, const bool sorted_result) {
MergeAddImpl<phi::GPUContext, T>()(context, inputs, output, sorted_result);
}
};
#define TEMPLATE_SPECIALIZED_FOR_MERGEADD(dtype) \
template struct MergeAddImpl<platform::CUDADeviceContext, dtype>; \
template struct MergeAddImpl<phi::GPUContext, dtype>; \
template struct MergeAdd<platform::CUDADeviceContext, dtype>; \
template struct MergeAdd<phi::GPUContext, dtype>;
TEMPLATE_SPECIALIZED_FOR_MERGEADD(float)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(double)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(int)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(int64_t)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(platform::float16)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(platform::bfloat16)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(platform::complex<float>)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(platform::complex<double>)
template <typename T, int block_size>
__global__ void UpdateToTensorKernel(const T* selected_rows,
......
......@@ -11,7 +11,7 @@ set_property(GLOBAL PROPERTY PHI_KERNELS "")
# [ 1. Common kernel compilation dependencies ]
set(COMMON_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils custom_kernel)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col vol2col concat_and_split_functor)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col vol2col concat_and_split_functor selected_rows_functor)
# remove this dep after removing fluid deps on tensor creation
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} phi_api_utils)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta)
......
......@@ -100,5 +100,6 @@ DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Acosh);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Atanh);
DECLARE_ACTIVATION_GRAD_KERNEL_DepOut(Relu);
DECLARE_ACTIVATION_GRAD_KERNEL_DepOut(Tanh);
DECLARE_ACTIVATION_GRAD_KERNEL_DepOut(Exp);
} // namespace phi
......@@ -37,6 +37,8 @@ DECLARE_ACTIVATION_KERNEL(Acosh)
DECLARE_ACTIVATION_KERNEL(Atanh)
DECLARE_ACTIVATION_KERNEL(Relu)
DECLARE_ACTIVATION_KERNEL(Tanh)
DECLARE_ACTIVATION_KERNEL(Exp)
DECLARE_ACTIVATION_KERNEL(Expm1)
template <typename T, typename Context>
void BReluKernel(const Context& dev_ctx,
......@@ -57,4 +59,16 @@ void ThresholdedReluKernel(const Context& dev_ctx,
float threshold,
DenseTensor* out);
template <typename T, typename Context>
void LogitKernel(const Context& dev_ctx,
const DenseTensor& x,
float eps,
DenseTensor* out);
template <typename T, typename Context>
void MishKernel(const Context& dev_ctx,
const DenseTensor& x,
float threshold,
DenseTensor* out);
} // namespace phi
......@@ -104,6 +104,7 @@ DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Atanh, funcs::AtanhGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepOut(Relu, funcs::ReluGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepOut(Tanh, funcs::TanhGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepOut(Exp, funcs::ExpGradFunctor);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX(LeakyRelu,
funcs::LeakyReluGradFunctor,
......@@ -159,3 +160,12 @@ PD_REGISTER_KERNEL(tanh_triple_grad,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(exp_grad,
CPU,
ALL_LAYOUT,
phi::ExpGradKernel,
float,
double,
int,
int64_t) {}
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/phi/kernels/activation_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/phi/kernels/impl/activation_impl.h"
namespace phi {
......@@ -67,11 +68,27 @@ DEFINE_CPU_ACTIVATION_KERNEL(Acosh, funcs::AcoshFunctor<T>)
DEFINE_CPU_ACTIVATION_KERNEL(Atanh, funcs::AtanhFunctor<T>)
DEFINE_CPU_ACTIVATION_KERNEL(Relu, funcs::ReluCPUFunctor<T>)
DEFINE_CPU_ACTIVATION_KERNEL(Tanh, funcs::TanhFunctor<T>)
DEFINE_CPU_ACTIVATION_KERNEL(Exp, funcs::ExpFunctor<T>)
DEFINE_CPU_ACTIVATION_KERNEL(Expm1, funcs::Expm1Functor<T>)
DEFINE_CPU_ACTIVATION_KERNEL(Reciprocal, funcs::ReciprocalFunctor<T>)
DEFINE_CPU_ACTIVATION_KERNEL(Square, funcs::SquareFunctor<T>)
DEFINE_CPU_ACTIVATION_KERNEL(Sqrt, funcs::SqrtFunctor<T>)
DEFINE_CPU_ACTIVATION_KERNEL(Rsqrt, funcs::RsqrtFunctor<T>)
DEFINE_CPU_ACTIVATION_KERNEL(Softsign, funcs::SoftsignFunctor<T>)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, funcs::LeakyReluFunctor, alpha)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu,
funcs::ThresholdedReluFunctor,
threshold)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Mish, funcs::MishFunctor, threshold)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, funcs::BReluFunctor, t_min, t_max)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(STanh,
funcs::STanhFunctor,
scale_a,
scale_b)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(Softplus,
funcs::SoftplusFunctor,
beta,
threshold)
} // namespace phi
PD_REGISTER_KERNEL(relu, CPU, ALL_LAYOUT, phi::ReluKernel, float, double) {}
......@@ -94,3 +111,23 @@ PD_REGISTER_ACTIVATION_KERNEL(tanh, Tanh)
PD_REGISTER_ACTIVATION_KERNEL(brelu, BRelu)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyRelu)
PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedRelu)
PD_REGISTER_ACTIVATION_KERNEL(mish, Mish)
PD_REGISTER_ACTIVATION_KERNEL(stanh, STanh)
PD_REGISTER_ACTIVATION_KERNEL(reciprocal, Reciprocal)
PD_REGISTER_ACTIVATION_KERNEL(sqrt, Sqrt)
PD_REGISTER_ACTIVATION_KERNEL(rsqrt, Rsqrt)
PD_REGISTER_ACTIVATION_KERNEL(softplus, Softplus)
PD_REGISTER_ACTIVATION_KERNEL(softsign, Softsign)
PD_REGISTER_KERNEL(
exp, CPU, ALL_LAYOUT, phi::ExpKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(expm1,
CPU,
ALL_LAYOUT,
phi::Expm1Kernel,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(logit, CPU, ALL_LAYOUT, phi::LogitKernel, float, double) {}
PD_REGISTER_KERNEL(
square, CPU, ALL_LAYOUT, phi::SquareKernel, float, double, int, int64_t) {}
......@@ -100,6 +100,15 @@ struct SinFunctor : public BaseActivationFunctor<T> {
}
};
// reciprocal(x) = 1 / x
template <typename T>
struct ReciprocalFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = static_cast<T>(1) / x;
}
};
// cosine'(x) = -sin(x)
template <typename T>
struct CosGradFunctor : public BaseActivationFunctor<T> {
......@@ -124,6 +133,57 @@ struct CosFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct LogitFunctor {
template <typename Device, typename X, typename Out, typename P>
void operator()(Device d, X x, Out out, P p, float eps) const {
// logit(x) = ln(x/(1-x))
auto tmp_x =
(x.cwiseMin(static_cast<T>(1.0 - eps))).cwiseMax(static_cast<T>(eps));
if (!eps) {
out.device(d) = (x < static_cast<T>(0.0) || x > static_cast<T>(1.0))
.select(p.constant(static_cast<T>(NAN)),
(tmp_x / (static_cast<T>(1) - tmp_x)).log());
} else {
out.device(d) = (tmp_x / (static_cast<T>(1) - tmp_x)).log();
}
}
};
// mish(x) = x * tanh(softplus(x))
// softplus(x) = x, if x > threshold
// = ln(1 + exp(x)), otherwise
template <typename T>
struct MishFunctor : public BaseActivationFunctor<T> {
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) {
auto sp = (x > static_cast<T>(threshold))
.select(x, (static_cast<T>(1) + x.exp()).log());
out.device(d) = x * sp.tanh();
}
};
template <typename T>
struct STanhFunctor : public BaseActivationFunctor<T> {
float scale_a;
float scale_b;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) =
static_cast<T>(scale_b) * (static_cast<T>(scale_a) * x).tanh();
}
};
template <typename T>
struct Tangent {
HOSTDEVICE T operator()(const T& val) const { return tan(val); }
......@@ -151,6 +211,55 @@ struct TanGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// square(x) = x^2
template <typename T>
struct SquareFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.square();
}
};
// sqrt(x) = x^(1/2)
template <typename T>
struct SqrtFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.sqrt();
}
};
// rsqrt(x) = x^(-1/2)
template <typename T>
struct RsqrtFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.rsqrt();
}
};
// For numerical stability, using the following formula instead of softplus(x) =
// log(1 + exp(x))
// softplus(x) = log(1 + exp(beta * x)) / beta when beta * x <= threshold(beta =
// 1, threshold = 20 by default), otherwise x
template <typename T>
struct SoftplusFunctor : public BaseActivationFunctor<T> {
float beta;
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}, {"threshold", &threshold}};
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) {
auto x_beta = static_cast<T>(beta) * x;
out.device(d) = (x_beta > static_cast<T>(threshold))
.select(x,
(static_cast<T>(1) + x_beta.exp()).log() /
static_cast<T>(beta));
}
};
// Tangent(x) = tan(x)
template <typename T>
struct TanFunctor : public BaseActivationFunctor<T> {
......@@ -452,6 +561,41 @@ struct AtanhGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// exp functor
// exp(x) = e^x
template <typename T>
struct ExpFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.exp();
}
};
template <typename T>
struct ExpGradFunctor : public BaseActivationFunctor<T> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * out;
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
// expm1(x) = e^x - 1
template <typename T>
struct Expm1Functor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.expm1();
}
};
// relu(x) = max(x, 0)
template <typename T>
struct ReluCPUFunctor : public BaseActivationFunctor<T> {
......@@ -672,6 +816,15 @@ struct BReluGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
// softsign(x) = x / (1 + |x|)
template <typename T>
struct SoftsignFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) {
out.device(d) = x / (static_cast<T>(1) + x.abs());
}
};
template <typename T>
struct LeakyReluFunctor : public BaseActivationFunctor<T> {
float alpha;
......@@ -827,6 +980,54 @@ struct CudaCosGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaExpFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
// exp(x) = exp(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(exp(x));
}
};
template <typename T>
struct CudaSquareFunctor : public BaseActivationFunctor<T> {
// square(x) = x * x
__device__ __forceinline__ T operator()(const T x) const { return x * x; }
};
template <typename T>
struct CudaExpGradFunctor : public BaseActivationFunctor<T> {
// dx = dout * out
__device__ __forceinline__ T operator()(const T dout, const T out) const {
return dout * out;
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct CudaReciprocalFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);
// reciprocal(x) = 1 / x
__device__ __forceinline__ T operator()(const T x) const { return one / x; }
};
template <typename T>
struct CudaExpm1Functor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
// expm1(x) = expm1(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(expm1(x));
}
};
template <typename T>
struct CudaSinFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
......@@ -838,6 +1039,16 @@ struct CudaSinFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct CudaSoftsignFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);
// softsign(x) = x / (1 + abs(x))
__device__ __forceinline__ T operator()(const T x) const {
return x / (one + abs(x));
}
};
template <typename T>
struct CudaSinGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
......@@ -1049,6 +1260,46 @@ struct CudaAtanhFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct CudaSTanhFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
float scale_a;
float scale_b;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
}
// stanh(x) = b * tanh(a * x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
MPType a = static_cast<MPType>(scale_a);
MPType b = static_cast<MPType>(scale_b);
return static_cast<T>(b * tanh(a * x));
}
};
template <typename T>
struct CudaSoftplusFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float beta;
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}, {"threshold", &threshold}};
}
// softplus(x) = beta * x > threshold ? x : log(1 + exp(beta * x)) / beta
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
MPType b = static_cast<MPType>(beta);
MPType t = static_cast<MPType>(threshold);
MPType x_beta = x * beta;
return static_cast<T>(x_beta > t ? x : log(one + exp(x_beta)) / b);
}
};
template <typename T>
struct CudaAtanhGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
......@@ -1064,6 +1315,28 @@ struct CudaAtanhGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaSqrtFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
// sqrt(x) = sqrt(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(sqrt(x));
}
};
template <typename T>
struct CudaRsqrtFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
// rsqrt(x) = rsqrt(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(rsqrt(x));
}
};
template <typename T>
struct CudaAtanFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
......@@ -1131,6 +1404,27 @@ struct CudaBReluFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct CudaMishFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
// mish(x) = x * tanh(softplus(x))
// softplus(x) = x, if x > threshold
// = ln(1 + exp(x)), otherwise
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
MPType sp = (x > static_cast<MPType>(threshold)) ? x : log(one + exp(x));
return static_cast<T>(x * tanh(sp));
}
};
template <typename T>
struct CudaBReluGradFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
......
......@@ -155,6 +155,7 @@ DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Cosh, CudaCoshGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Asinh, CudaAsinhGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Acosh, CudaAcoshGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Atanh, CudaAtanhGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepOut(Exp, CudaExpGradFunctor);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX(LeakyRelu,
CudaLeakyReluGradFunctor,
......@@ -234,3 +235,12 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_double_grad,
LeakyReluDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(thresholded_relu_grad,
ThresholdedReluGradKernel)
PD_REGISTER_KERNEL(exp_grad,
GPU,
ALL_LAYOUT,
phi::ExpGradKernel,
float,
double,
int,
int64_t) {}
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/impl/activation_grad_impl.h"
#include "paddle/phi/kernels/impl/activation_impl.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
......@@ -88,13 +89,27 @@ DEFINE_GPU_ACTIVATION_KERNEL(Acosh, funcs::CudaAcoshFunctor<T>)
DEFINE_GPU_ACTIVATION_KERNEL(Atanh, funcs::CudaAtanhFunctor<T>)
DEFINE_GPU_ACTIVATION_KERNEL(Relu, funcs::CudaReluFunctor<T>)
DEFINE_GPU_ACTIVATION_KERNEL(Tanh, funcs::CudaTanhFunctor<T>)
DEFINE_GPU_ACTIVATION_KERNEL(Exp, funcs::CudaExpFunctor<T>)
DEFINE_GPU_ACTIVATION_KERNEL(Expm1, funcs::CudaExpm1Functor<T>)
DEFINE_GPU_ACTIVATION_KERNEL(Reciprocal, funcs::CudaReciprocalFunctor<T>)
DEFINE_GPU_ACTIVATION_KERNEL(Square, funcs::CudaSquareFunctor<T>)
DEFINE_GPU_ACTIVATION_KERNEL(Sqrt, funcs::CudaSqrtFunctor<T>)
DEFINE_GPU_ACTIVATION_KERNEL(Rsqrt, funcs::CudaRsqrtFunctor<T>)
DEFINE_GPU_ACTIVATION_KERNEL(Softsign, funcs::CudaSoftsignFunctor<T>)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, CudaLeakyReluFunctor, alpha)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu,
CudaThresholdedReluFunctor,
threshold)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Mish, CudaMishFunctor, threshold)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, CudaBReluFunctor, t_min, t_max)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(Stanh, CudaSTanhFunctor, scale_a, scale_b)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(Softplus,
CudaSoftplusFunctor,
beta,
threshold)
} // namespace phi
......@@ -142,3 +157,23 @@ PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(brelu, BReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL(stanh, StanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel)
PD_REGISTER_ACTIVATION_KERNEL(softsign, SoftsignKernel)
PD_REGISTER_KERNEL(
exp, GPU, ALL_LAYOUT, phi::ExpKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(expm1,
GPU,
ALL_LAYOUT,
phi::Expm1Kernel,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(logit, GPU, ALL_LAYOUT, phi::LogitKernel, float, double) {}
PD_REGISTER_KERNEL(
square, GPU, ALL_LAYOUT, phi::SquareKernel, float, double, int, int64_t) {}
......@@ -40,16 +40,16 @@ void ClipByNormKernel<phi::dtype::float16, phi::GPUContext>(
DenseTensor tmp;
tmp.Resize({1});
dev_ctx.template Alloc<float>(&tmp);
kernels::TensorReduceImpl<dtype::float16,
float,
kps::AddFunctor,
kps::SquareFunctor<dtype::float16, float>>(
phi::funcs::ReduceKernel<dtype::float16,
float,
kps::AddFunctor,
kps::SquareFunctor<dtype::float16, float>>(
dev_ctx,
x_in,
&tmp,
kps::SquareFunctor<dtype::float16, float>(),
reduce_dims,
dev_ctx.stream());
reduce_dims);
auto tmp_eigen = EigenVector<float>::Flatten(tmp);
auto x_norm = tmp_eigen.sqrt();
......
......@@ -47,4 +47,20 @@ void ActivationImpl(const Context& dev_ctx,
}
}
template <typename T, typename Context>
void LogitKernel(const Context& dev_ctx,
const DenseTensor& x,
float eps,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
auto eigen_out = EigenVector<T>::Flatten(*out);
auto eigen_in = EigenVector<T>::Flatten(x);
auto& place = *dev_ctx.eigen_device();
auto eigen_p = EigenVector<T>::Flatten(*out);
funcs::LogitFunctor<T> functor;
functor(place, eigen_in, eigen_out, eigen_p, eps);
}
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册