未验证 提交 0cdaafea 编写于 作者: Z zhangyuqin1998 提交者: GitHub

delete swish_raw (#54536)

* delete swish_raw

* fix

* Update activation_kernel.cc

* fix
上级 7c2c965d
...@@ -2541,7 +2541,7 @@ ...@@ -2541,7 +2541,7 @@
outputs : outputs :
out : Out out : Out
extra : extra :
attrs : [bool use_mkldnn = false] attrs : [bool use_mkldnn = false, float beta = 1.0]
- op : sync_batch_norm - op : sync_batch_norm
backward : sync_batch_norm_grad backward : sync_batch_norm_grad
......
...@@ -288,7 +288,7 @@ ...@@ -288,7 +288,7 @@
backward : sum_double_grad backward : sum_double_grad
- backward_op : swish_grad - backward_op : swish_grad
forward : swish (Tensor x, float beta = 1.0f) -> Tensor(out) forward : swish (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad) args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad) output : Tensor(x_grad)
infer_meta : infer_meta :
......
...@@ -540,13 +540,13 @@ ...@@ -540,13 +540,13 @@
backward : sum_grad backward : sum_grad
- op : swish - op : swish
args : (Tensor x, float beta = 1.0f) args : (Tensor x)
output : Tensor(out) output : Tensor(out)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
param : [x] param : [x]
kernel : kernel :
func : swish_raw func : swish
backward : swish_grad backward : swish_grad
- op : tril_indices - op : tril_indices
......
...@@ -26,19 +26,11 @@ void Relu6Kernel(const Context& dev_ctx, ...@@ -26,19 +26,11 @@ void Relu6Kernel(const Context& dev_ctx,
Relu6RawKernel<T, Context>(dev_ctx, x, 6, out); Relu6RawKernel<T, Context>(dev_ctx, x, 6, out);
} }
template <typename T, typename Context>
void SwishKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
SwishRawKernel<T, Context>(dev_ctx, x, 1.0, out);
}
} // namespace phi } // namespace phi
using complex64 = ::phi::dtype::complex<float>; using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>; using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL(relu6, CPU, ALL_LAYOUT, phi::Relu6Kernel, float, double) {} PD_REGISTER_KERNEL(relu6, CPU, ALL_LAYOUT, phi::Relu6Kernel, float, double) {}
PD_REGISTER_KERNEL(swish, CPU, ALL_LAYOUT, phi::SwishKernel, float, double) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(relu6, PD_REGISTER_KERNEL(relu6,
...@@ -49,28 +41,14 @@ PD_REGISTER_KERNEL(relu6, ...@@ -49,28 +41,14 @@ PD_REGISTER_KERNEL(relu6,
double, double,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16) {} phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(swish,
GPU,
ALL_LAYOUT,
phi::SwishKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#endif #endif
#if defined PADDLE_WITH_XPU #if defined PADDLE_WITH_XPU
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
relu6, XPU, ALL_LAYOUT, phi::Relu6Kernel, float, phi::dtype::float16) {} relu6, XPU, ALL_LAYOUT, phi::Relu6Kernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(
swish, XPU, ALL_LAYOUT, phi::SwishKernel, float, phi::dtype::float16) {}
#endif #endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
relu6, OneDNN, ONEDNN, phi::Relu6Kernel, float, phi::dtype::bfloat16) {} relu6, OneDNN, ONEDNN, phi::Relu6Kernel, float, phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(
swish, OneDNN, ONEDNN, phi::SwishKernel, float, phi::dtype::bfloat16) {}
#endif #endif
...@@ -81,7 +81,6 @@ DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Mish, threshold) ...@@ -81,7 +81,6 @@ DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Mish, threshold)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(HardShrink, threshold) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(HardShrink, threshold)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(SoftShrink, lambda) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(SoftShrink, lambda)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Elu, alpha) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Elu, alpha)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(SwishRaw, beta)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Celu, alpha) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Celu, alpha)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Logit, eps) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Logit, eps)
......
...@@ -114,7 +114,6 @@ DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Mish, MishFunctor, threshold) ...@@ -114,7 +114,6 @@ DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Mish, MishFunctor, threshold)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, HardShrinkFunctor, threshold) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, HardShrinkFunctor, threshold)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, SoftShrinkFunctor, lambda) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, SoftShrinkFunctor, lambda)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Elu, ELUFunctor, alpha) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Elu, ELUFunctor, alpha)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(SwishRaw, SwishFunctor, beta)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Celu, CELUFunctor, alpha) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Celu, CELUFunctor, alpha)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(HardTanh, HardTanhFunctor, t_min, t_max) DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(HardTanh, HardTanhFunctor, t_min, t_max)
...@@ -141,6 +140,16 @@ void HardSwishKernel(const Context& dev_ctx, ...@@ -141,6 +140,16 @@ void HardSwishKernel(const Context& dev_ctx,
dev_ctx, x, out, functor); dev_ctx, x, out, functor);
} }
template <typename T, typename Context>
void SwishKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
funcs::SwishFunctor<T> functor;
auto attrs = functor.GetAttrs();
*(attrs[0].second) = 1.0;
ActivationImpl<T, T, Context, funcs::SwishFunctor<T>>(
dev_ctx, x, out, functor);
}
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(relu, CPU, ALL_LAYOUT, phi::ReluKernel, float, double) {} PD_REGISTER_KERNEL(relu, CPU, ALL_LAYOUT, phi::ReluKernel, float, double) {}
...@@ -202,6 +211,7 @@ PD_REGISTER_ACTIVATION_KERNEL(softsign, SoftsignKernel) ...@@ -202,6 +211,7 @@ PD_REGISTER_ACTIVATION_KERNEL(softsign, SoftsignKernel)
PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(logsigmoid, LogSigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(logsigmoid, LogSigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel)
PD_REGISTER_KERNEL(log, PD_REGISTER_KERNEL(log,
CPU, CPU,
...@@ -244,7 +254,6 @@ PD_REGISTER_KERNEL(log1p, ...@@ -244,7 +254,6 @@ PD_REGISTER_KERNEL(log1p,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16) {} phi::dtype::bfloat16) {}
PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel) PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel) PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel) PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
......
...@@ -132,7 +132,6 @@ DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, ...@@ -132,7 +132,6 @@ DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink,
threshold) threshold)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, CudaSoftShrinkFunctor, lambda) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, CudaSoftShrinkFunctor, lambda)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Elu, CudaELUFunctor, alpha) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Elu, CudaELUFunctor, alpha)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(SwishRaw, CudaSwishFunctor, beta)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Mish, CudaMishFunctor, threshold) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Mish, CudaMishFunctor, threshold)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Celu, CudaCELUFunctor, alpha) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Celu, CudaCELUFunctor, alpha)
...@@ -167,6 +166,16 @@ void HardSwishKernel(const Context& dev_ctx, ...@@ -167,6 +166,16 @@ void HardSwishKernel(const Context& dev_ctx,
dev_ctx, x, out, functor); dev_ctx, x, out, functor);
} }
template <typename T, typename Context>
void SwishKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
funcs::CudaSwishFunctor<T> functor;
auto attrs = functor.GetAttrs();
*(attrs[0].second) = 1.0;
ActivationGPUImpl<T, Context, funcs::CudaSwishFunctor<T>>(
dev_ctx, x, out, functor);
}
} // namespace phi } // namespace phi
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
...@@ -262,7 +271,7 @@ PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel) ...@@ -262,7 +271,7 @@ PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(logsigmoid, LogSigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(logsigmoid, LogSigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel) PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel) PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel) PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel) PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel) PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel)
......
...@@ -154,7 +154,6 @@ DEFINE_ONEDNN_ACTIVATION_KERNEL(Round, RoundOneDNNFunctor) ...@@ -154,7 +154,6 @@ DEFINE_ONEDNN_ACTIVATION_KERNEL(Round, RoundOneDNNFunctor)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Elu, EluOneDNNFunctor, alpha) DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Elu, EluOneDNNFunctor, alpha)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, ReluOneDNNFunctor, alpha) DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, ReluOneDNNFunctor, alpha)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Mish, MishOneDNNFunctor, threshold) DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Mish, MishOneDNNFunctor, threshold)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(SwishRaw, SwishOneDNNFunctor, beta)
template <typename T, typename Context> template <typename T, typename Context>
void HardSwishKernel(const Context& dev_ctx, void HardSwishKernel(const Context& dev_ctx,
...@@ -187,6 +186,14 @@ void Relu6RawKernel(const Context& dev_ctx, ...@@ -187,6 +186,14 @@ void Relu6RawKernel(const Context& dev_ctx,
functor(dev_ctx, x, 0, threshold, out); functor(dev_ctx, x, 0, threshold, out);
} }
template <typename T, typename Context>
void SwishKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
SwishOneDNNFunctor<T> functor;
functor(dev_ctx, x, 1.0, 0, out);
}
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(round, OneDNN, ONEDNN, phi::RoundKernel, float) {} PD_REGISTER_KERNEL(round, OneDNN, ONEDNN, phi::RoundKernel, float) {}
...@@ -206,5 +213,5 @@ PD_REGISTER_ACTIVATION_KERNEL(relu, ReluKernel) ...@@ -206,5 +213,5 @@ PD_REGISTER_ACTIVATION_KERNEL(relu, ReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(relu6_raw, Relu6RawKernel) PD_REGISTER_ACTIVATION_KERNEL(relu6_raw, Relu6RawKernel)
PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel) PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel) PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel) PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel)
...@@ -403,10 +403,9 @@ struct XPUMishFunctor : public funcs::BaseActivationFunctor<T> { ...@@ -403,10 +403,9 @@ struct XPUMishFunctor : public funcs::BaseActivationFunctor<T> {
}; };
template <typename T, typename Context> template <typename T, typename Context>
void SwishRawKernel(const Context& dev_ctx, void SwishKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
float beta, DenseTensor* out) {
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
int r = xpu::swish(dev_ctx.x_context(), int r = xpu::swish(dev_ctx.x_context(),
...@@ -542,12 +541,8 @@ PD_REGISTER_KERNEL( ...@@ -542,12 +541,8 @@ PD_REGISTER_KERNEL(
silu, XPU, ALL_LAYOUT, phi::SiluKernel, float, phi::dtype::float16) {} silu, XPU, ALL_LAYOUT, phi::SiluKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
sigmoid, XPU, ALL_LAYOUT, phi::SigmoidKernel, float, phi::dtype::float16) {} sigmoid, XPU, ALL_LAYOUT, phi::SigmoidKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(swish_raw, PD_REGISTER_KERNEL(
XPU, swish, XPU, ALL_LAYOUT, phi::SwishKernel, float, phi::dtype::float16) {}
ALL_LAYOUT,
phi::SwishRawKernel,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(hard_sigmoid, PD_REGISTER_KERNEL(hard_sigmoid,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
......
...@@ -180,7 +180,7 @@ class TestConvActOneDNNFusePass(PassAutoScanTest): ...@@ -180,7 +180,7 @@ class TestConvActOneDNNFusePass(PassAutoScanTest):
'swish', 'swish',
inputs={'X': ['conv2d_out']}, inputs={'X': ['conv2d_out']},
outputs={'Out': ['swish_out']}, outputs={'Out': ['swish_out']},
beta=draw(st.floats(min_value=0.1, max_value=1.0)), beta=1.0,
) )
elif act_type == 'clip': elif act_type == 'clip':
act_op = OpConfig( act_op = OpConfig(
......
...@@ -107,7 +107,7 @@ class TestMatmulActivationMkldnnFusePass(PassAutoScanTest): ...@@ -107,7 +107,7 @@ class TestMatmulActivationMkldnnFusePass(PassAutoScanTest):
activation_type, activation_type,
inputs={"X": ["matmul_output"]}, inputs={"X": ["matmul_output"]},
outputs={"Out": ["activation_output"]}, outputs={"Out": ["activation_output"]},
beta=draw(st.floats(min_value=0.1, max_value=1.0)), beta=1.0,
) )
elif activation_type == "clip": elif activation_type == "clip":
activation_op = OpConfig( activation_op = OpConfig(
......
...@@ -95,7 +95,7 @@ class TestMatmulElementwiseAddActivationMkldnnFusePass(PassAutoScanTest): ...@@ -95,7 +95,7 @@ class TestMatmulElementwiseAddActivationMkldnnFusePass(PassAutoScanTest):
activation_type, activation_type,
inputs={"X": ["elementwise_add_output"]}, inputs={"X": ["elementwise_add_output"]},
outputs={"Out": ["activation_output"]}, outputs={"Out": ["activation_output"]},
beta=draw(st.floats(min_value=0.1, max_value=1.0)), beta=1.0,
) )
elif activation_type == "clip": elif activation_type == "clip":
activation_op = OpConfig( activation_op = OpConfig(
......
...@@ -111,7 +111,7 @@ class TestMatmulv2ActivationMkldnnFusePass(PassAutoScanTest): ...@@ -111,7 +111,7 @@ class TestMatmulv2ActivationMkldnnFusePass(PassAutoScanTest):
activation_type, activation_type,
inputs={'X': ['matmul_output']}, inputs={'X': ['matmul_output']},
outputs={'Out': ['activation_output']}, outputs={'Out': ['activation_output']},
beta=draw(st.floats(min_value=0.1, max_value=1.0)), beta=1.0,
) )
elif activation_type == 'clip': elif activation_type == 'clip':
activation_op = OpConfig( activation_op = OpConfig(
......
...@@ -113,7 +113,7 @@ class TestOneDNNConvConcatActivationFusePass(PassAutoScanTest): ...@@ -113,7 +113,7 @@ class TestOneDNNConvConcatActivationFusePass(PassAutoScanTest):
activation_type, activation_type,
inputs={'X': ['concat_output']}, inputs={'X': ['concat_output']},
outputs={'Out': ['activation_output']}, outputs={'Out': ['activation_output']},
beta=draw(st.floats(min_value=0.1, max_value=1.0)), beta=1.0,
) )
elif activation_type == 'clip': elif activation_type == 'clip':
activation_op = OpConfig( activation_op = OpConfig(
......
...@@ -83,7 +83,7 @@ class TestElementwiseAddActivationOneDNNFusePass(PassAutoScanTest): ...@@ -83,7 +83,7 @@ class TestElementwiseAddActivationOneDNNFusePass(PassAutoScanTest):
activation_type, activation_type,
inputs={'X': ['eltwise_output']}, inputs={'X': ['eltwise_output']},
outputs={'Out': ['activation_output']}, outputs={'Out': ['activation_output']},
beta=draw(st.floats(min_value=0.1, max_value=1.0)), beta=1.0,
) )
elif activation_type == 'clip': elif activation_type == 'clip':
activation_op = OpConfig( activation_op = OpConfig(
......
...@@ -103,7 +103,7 @@ class TestFCActivationOneDNNFusePass(PassAutoScanTest): ...@@ -103,7 +103,7 @@ class TestFCActivationOneDNNFusePass(PassAutoScanTest):
activation_type, activation_type,
inputs={"X": ["fc_output"]}, inputs={"X": ["fc_output"]},
outputs={"Out": ["activation_output"]}, outputs={"Out": ["activation_output"]},
beta=draw(st.floats(min_value=0.1, max_value=10.0)), beta=1.0,
) )
else: else:
activation_op = OpConfig( activation_op = OpConfig(
......
...@@ -92,7 +92,7 @@ class TestSoftplusActivationOneDNNFusePass(PassAutoScanTest): ...@@ -92,7 +92,7 @@ class TestSoftplusActivationOneDNNFusePass(PassAutoScanTest):
activation_type, activation_type,
inputs={'X': ['softplus_out']}, inputs={'X': ['softplus_out']},
outputs={'Out': ['activation_output']}, outputs={'Out': ['activation_output']},
beta=draw(st.floats(min_value=0.1, max_value=10.0)), beta=1.0,
) )
else: else:
activation_op = OpConfig( activation_op = OpConfig(
......
...@@ -41,7 +41,7 @@ class TrtConvertSwishTest(TrtLayerAutoScanTest): ...@@ -41,7 +41,7 @@ class TrtConvertSwishTest(TrtLayerAutoScanTest):
return np.ones([1, 3, 64, 64]).astype(np.float32) return np.ones([1, 3, 64, 64]).astype(np.float32)
for dims in [0, 1, 2, 3, 4]: for dims in [0, 1, 2, 3, 4]:
for beta in [1.0, 2.0, 3.0]: for beta in [1.0]:
self.dims = dims self.dims = dims
dics = [{"beta": beta}] dics = [{"beta": beta}]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册