提交 68b65a76 编写于 作者: Z zhupengyang 提交者: GitHub

[NPU] support relu6 (#2582)

test=develop
上级 8aee9417
......@@ -148,7 +148,7 @@ int CvtActMode(std::string act_type) {
act_mode = 1;
} else if (act_type == "tanh") {
act_mode = 2;
} else if (act_type == "relu_clipped") {
} else if (act_type == "relu_clipped" || act_type == "relu6") {
act_mode = 3;
} else if (act_type == "elu") {
act_mode = 4;
......
......@@ -44,6 +44,9 @@ node_map_type ActConverter(const std::shared_ptr<lite::OpLite> act_op,
if (op_type == "relu_clipped") {
auto Relu_clipped_coef = op_info->GetAttr<float>("Relu_clipped_coef");
act_node->set_attr_coef(Relu_clipped_coef);
} else if (op_type == "relu6") {
float Relu_clipped_coef = 6.f;
act_node->set_attr_coef(Relu_clipped_coef);
} else if (op_type == "leaky_relu") {
auto alpha = op_info->GetAttr<float>("alpha");
act_node->set_attr_negative_slope(alpha);
......@@ -70,6 +73,7 @@ REGISTER_NPU_BRIDGE(relu, paddle::lite::kernels::npu::bridges::ActConverter);
REGISTER_NPU_BRIDGE(tanh, paddle::lite::kernels::npu::bridges::ActConverter);
REGISTER_NPU_BRIDGE(relu_clipped,
paddle::lite::kernels::npu::bridges::ActConverter);
REGISTER_NPU_BRIDGE(relu6, paddle::lite::kernels::npu::bridges::ActConverter);
// REGISTER_NPU_BRIDGE(elu, paddle::lite::kernels::npu::bridges::ActConverter);
REGISTER_NPU_BRIDGE(leaky_relu,
paddle::lite::kernels::npu::bridges::ActConverter);
......
......@@ -55,6 +55,10 @@ void act_ref(const std::shared_ptr<operators::ActivationOp> op) {
for (size_t i = 0; i < out->numel(); i++) {
out_data[i] = std::min(std::max(0.f, x_data[i]), relu_clipped_coef);
}
} else if (op_type == "relu6") {
for (size_t i = 0; i < out->numel(); i++) {
out_data[i] = std::min(std::max(0.f, x_data[i]), 6.f);
}
} else if (op_type == "leaky_relu") {
auto alpha = op_info->GetAttr<float>("alpha");
for (size_t i = 0; i < out->numel(); i++) {
......@@ -96,6 +100,8 @@ void test_act(std::vector<int64_t> x_shape, std::string op_type) {
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
if (op_type == "relu_clipped") {
opdesc.SetAttr("Relu_clipped_coef", 3.f);
} else if (op_type == "relu6") {
opdesc.SetAttr("Relu_clipped_coef", 6.f);
} else if (op_type == "leaky_relu") {
opdesc.SetAttr("alpha", 0.02f);
......@@ -125,6 +131,7 @@ TEST(NPUBridges, activation) {
"relu",
"tanh",
"relu_clipped",
"relu6",
"leaky_relu",
"softsign",
"hard_sigmoid"};
......@@ -149,6 +156,8 @@ USE_LITE_OP(tanh);
USE_NPU_BRIDGE(tanh);
USE_LITE_OP(relu_clipped);
USE_NPU_BRIDGE(relu_clipped);
USE_LITE_OP(relu6);
USE_NPU_BRIDGE(relu6);
USE_LITE_OP(leaky_relu);
USE_NPU_BRIDGE(leaky_relu);
......
......@@ -20,6 +20,7 @@ USE_NPU_BRIDGE(sigmoid);
USE_NPU_BRIDGE(relu);
USE_NPU_BRIDGE(tanh);
USE_NPU_BRIDGE(relu_clipped);
USE_NPU_BRIDGE(relu6);
USE_NPU_BRIDGE(leaky_relu);
USE_NPU_BRIDGE(softsign);
USE_NPU_BRIDGE(hard_sigmoid);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册