提交 def33631 编写于 作者: P phlrain

update

上级 c552d1ac
...@@ -39,6 +39,7 @@ DECLARE_ACTIVATION_KERNEL(Relu) ...@@ -39,6 +39,7 @@ DECLARE_ACTIVATION_KERNEL(Relu)
DECLARE_ACTIVATION_KERNEL(Tanh) DECLARE_ACTIVATION_KERNEL(Tanh)
DECLARE_ACTIVATION_KERNEL(Exp) DECLARE_ACTIVATION_KERNEL(Exp)
DECLARE_ACTIVATION_KERNEL(Expm1) DECLARE_ACTIVATION_KERNEL(Expm1)
DECLARE_ACTIVATION_KERNEL(Softsign)
template <typename T, typename Context> template <typename T, typename Context>
void BReluKernel(const Context& dev_ctx, void BReluKernel(const Context& dev_ctx,
......
...@@ -74,21 +74,23 @@ DEFINE_CPU_ACTIVATION_KERNEL(Reciprocal, funcs::ReciprocalFunctor<T>) ...@@ -74,21 +74,23 @@ DEFINE_CPU_ACTIVATION_KERNEL(Reciprocal, funcs::ReciprocalFunctor<T>)
DEFINE_CPU_ACTIVATION_KERNEL(Square, funcs::SquareFunctor<T>) DEFINE_CPU_ACTIVATION_KERNEL(Square, funcs::SquareFunctor<T>)
DEFINE_CPU_ACTIVATION_KERNEL(Sqrt, funcs::SqrtFunctor<T>) DEFINE_CPU_ACTIVATION_KERNEL(Sqrt, funcs::SqrtFunctor<T>)
DEFINE_CPU_ACTIVATION_KERNEL(Rsqrt, funcs::RsqrtFunctor<T>) DEFINE_CPU_ACTIVATION_KERNEL(Rsqrt, funcs::RsqrtFunctor<T>)
DEFINE_CPU_ACTIVATION_KERNEL(Softsign, funcs::SoftsignFunctor<T>) DEFINE_CPU_ACTIVATION_KERNEL(Softsign, funcs::SoftsignFunctor<T>)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, funcs::LeakyReluFunctor, alpha) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, funcs::LeakyReluFunctor, alpha)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu,
funcs::ThresholdedReluFunctor, funcs::ThresholdedReluFunctor,
threshold) threshold)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Mish, funcs::MishFunctor, threshold) // DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Mish, funcs::MishFunctor, threshold)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, funcs::BReluFunctor, t_min, t_max) DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, funcs::BReluFunctor, t_min, t_max)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(STanh, DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(STanh,
funcs::STanhFunctor, funcs::STanhFunctor,
scale_a, scale_a,
scale_b) scale_b)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(Softplus, // DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(Softplus,
funcs::SoftplusFunctor, // funcs::SoftplusFunctor,
beta, // beta,
threshold) // threshold)
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(relu, CPU, ALL_LAYOUT, phi::ReluKernel, float, double) {} PD_REGISTER_KERNEL(relu, CPU, ALL_LAYOUT, phi::ReluKernel, float, double) {}
...@@ -111,12 +113,12 @@ PD_REGISTER_ACTIVATION_KERNEL(tanh, Tanh) ...@@ -111,12 +113,12 @@ PD_REGISTER_ACTIVATION_KERNEL(tanh, Tanh)
PD_REGISTER_ACTIVATION_KERNEL(brelu, BRelu) PD_REGISTER_ACTIVATION_KERNEL(brelu, BRelu)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyRelu) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyRelu)
PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedRelu) PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedRelu)
PD_REGISTER_ACTIVATION_KERNEL(mish, Mish) // PD_REGISTER_ACTIVATION_KERNEL(mish, Mish)
PD_REGISTER_ACTIVATION_KERNEL(stanh, STanh) PD_REGISTER_ACTIVATION_KERNEL(stanh, STanh)
PD_REGISTER_ACTIVATION_KERNEL(reciprocal, Reciprocal) PD_REGISTER_ACTIVATION_KERNEL(reciprocal, Reciprocal)
PD_REGISTER_ACTIVATION_KERNEL(sqrt, Sqrt) PD_REGISTER_ACTIVATION_KERNEL(sqrt, Sqrt)
PD_REGISTER_ACTIVATION_KERNEL(rsqrt, Rsqrt) PD_REGISTER_ACTIVATION_KERNEL(rsqrt, Rsqrt)
PD_REGISTER_ACTIVATION_KERNEL(softplus, Softplus) // PD_REGISTER_ACTIVATION_KERNEL(softplus, Softplus)
PD_REGISTER_ACTIVATION_KERNEL(softsign, Softsign) PD_REGISTER_ACTIVATION_KERNEL(softsign, Softsign)
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
......
...@@ -41,7 +41,8 @@ void ActivationImpl(const Context& dev_ctx, ...@@ -41,7 +41,8 @@ void ActivationImpl(const Context& dev_ctx,
bool use_32bit_index = out.size() < Eigen::NumTraits<int>::highest(); bool use_32bit_index = out.size() < Eigen::NumTraits<int>::highest();
bool is_gpu_place = paddle::platform::is_gpu_place(dev_ctx.GetPlace()); bool is_gpu_place = paddle::platform::is_gpu_place(dev_ctx.GetPlace());
if (use_32bit_index && is_gpu_place) { if (use_32bit_index && is_gpu_place) {
functor(*place, To32BitIndex(x), To32BitIndex(out)); // functor(*place, To32BitIndex(x), To32BitIndex(out));
functor(*place, x, out);
} else { } else {
functor(*place, x, out); functor(*place, x, out);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册