未验证 提交 a2f981a4 编写于 作者: Z zhupengyang 提交者: GitHub

[NPU] support relu6 (#2582)

test=develop
上级 ea837ec7
...@@ -148,7 +148,7 @@ int CvtActMode(std::string act_type) { ...@@ -148,7 +148,7 @@ int CvtActMode(std::string act_type) {
act_mode = 1; act_mode = 1;
} else if (act_type == "tanh") { } else if (act_type == "tanh") {
act_mode = 2; act_mode = 2;
} else if (act_type == "relu_clipped") { } else if (act_type == "relu_clipped" || act_type == "relu6") {
act_mode = 3; act_mode = 3;
} else if (act_type == "elu") { } else if (act_type == "elu") {
act_mode = 4; act_mode = 4;
......
...@@ -44,6 +44,9 @@ node_map_type ActConverter(const std::shared_ptr<lite::OpLite> act_op, ...@@ -44,6 +44,9 @@ node_map_type ActConverter(const std::shared_ptr<lite::OpLite> act_op,
if (op_type == "relu_clipped") { if (op_type == "relu_clipped") {
auto Relu_clipped_coef = op_info->GetAttr<float>("Relu_clipped_coef"); auto Relu_clipped_coef = op_info->GetAttr<float>("Relu_clipped_coef");
act_node->set_attr_coef(Relu_clipped_coef); act_node->set_attr_coef(Relu_clipped_coef);
} else if (op_type == "relu6") {
float Relu_clipped_coef = 6.f;
act_node->set_attr_coef(Relu_clipped_coef);
} else if (op_type == "leaky_relu") { } else if (op_type == "leaky_relu") {
auto alpha = op_info->GetAttr<float>("alpha"); auto alpha = op_info->GetAttr<float>("alpha");
act_node->set_attr_negative_slope(alpha); act_node->set_attr_negative_slope(alpha);
...@@ -70,6 +73,7 @@ REGISTER_NPU_BRIDGE(relu, paddle::lite::kernels::npu::bridges::ActConverter); ...@@ -70,6 +73,7 @@ REGISTER_NPU_BRIDGE(relu, paddle::lite::kernels::npu::bridges::ActConverter);
REGISTER_NPU_BRIDGE(tanh, paddle::lite::kernels::npu::bridges::ActConverter); REGISTER_NPU_BRIDGE(tanh, paddle::lite::kernels::npu::bridges::ActConverter);
REGISTER_NPU_BRIDGE(relu_clipped, REGISTER_NPU_BRIDGE(relu_clipped,
paddle::lite::kernels::npu::bridges::ActConverter); paddle::lite::kernels::npu::bridges::ActConverter);
REGISTER_NPU_BRIDGE(relu6, paddle::lite::kernels::npu::bridges::ActConverter);
// REGISTER_NPU_BRIDGE(elu, paddle::lite::kernels::npu::bridges::ActConverter); // REGISTER_NPU_BRIDGE(elu, paddle::lite::kernels::npu::bridges::ActConverter);
REGISTER_NPU_BRIDGE(leaky_relu, REGISTER_NPU_BRIDGE(leaky_relu,
paddle::lite::kernels::npu::bridges::ActConverter); paddle::lite::kernels::npu::bridges::ActConverter);
......
...@@ -55,6 +55,10 @@ void act_ref(const std::shared_ptr<operators::ActivationOp> op) { ...@@ -55,6 +55,10 @@ void act_ref(const std::shared_ptr<operators::ActivationOp> op) {
for (size_t i = 0; i < out->numel(); i++) { for (size_t i = 0; i < out->numel(); i++) {
out_data[i] = std::min(std::max(0.f, x_data[i]), relu_clipped_coef); out_data[i] = std::min(std::max(0.f, x_data[i]), relu_clipped_coef);
} }
} else if (op_type == "relu6") {
for (size_t i = 0; i < out->numel(); i++) {
out_data[i] = std::min(std::max(0.f, x_data[i]), 6.f);
}
} else if (op_type == "leaky_relu") { } else if (op_type == "leaky_relu") {
auto alpha = op_info->GetAttr<float>("alpha"); auto alpha = op_info->GetAttr<float>("alpha");
for (size_t i = 0; i < out->numel(); i++) { for (size_t i = 0; i < out->numel(); i++) {
...@@ -96,6 +100,8 @@ void test_act(std::vector<int64_t> x_shape, std::string op_type) { ...@@ -96,6 +100,8 @@ void test_act(std::vector<int64_t> x_shape, std::string op_type) {
opdesc.SetInput("X", {x_var_name}); opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name}); opdesc.SetOutput("Out", {out_var_name});
if (op_type == "relu_clipped") { if (op_type == "relu_clipped") {
opdesc.SetAttr("Relu_clipped_coef", 3.f);
} else if (op_type == "relu6") {
opdesc.SetAttr("Relu_clipped_coef", 6.f); opdesc.SetAttr("Relu_clipped_coef", 6.f);
} else if (op_type == "leaky_relu") { } else if (op_type == "leaky_relu") {
opdesc.SetAttr("alpha", 0.02f); opdesc.SetAttr("alpha", 0.02f);
...@@ -125,6 +131,7 @@ TEST(NPUBridges, activation) { ...@@ -125,6 +131,7 @@ TEST(NPUBridges, activation) {
"relu", "relu",
"tanh", "tanh",
"relu_clipped", "relu_clipped",
"relu6",
"leaky_relu", "leaky_relu",
"softsign", "softsign",
"hard_sigmoid"}; "hard_sigmoid"};
...@@ -149,6 +156,8 @@ USE_LITE_OP(tanh); ...@@ -149,6 +156,8 @@ USE_LITE_OP(tanh);
USE_NPU_BRIDGE(tanh); USE_NPU_BRIDGE(tanh);
USE_LITE_OP(relu_clipped); USE_LITE_OP(relu_clipped);
USE_NPU_BRIDGE(relu_clipped); USE_NPU_BRIDGE(relu_clipped);
USE_LITE_OP(relu6);
USE_NPU_BRIDGE(relu6);
USE_LITE_OP(leaky_relu); USE_LITE_OP(leaky_relu);
USE_NPU_BRIDGE(leaky_relu); USE_NPU_BRIDGE(leaky_relu);
......
...@@ -20,6 +20,7 @@ USE_NPU_BRIDGE(sigmoid); ...@@ -20,6 +20,7 @@ USE_NPU_BRIDGE(sigmoid);
USE_NPU_BRIDGE(relu); USE_NPU_BRIDGE(relu);
USE_NPU_BRIDGE(tanh); USE_NPU_BRIDGE(tanh);
USE_NPU_BRIDGE(relu_clipped); USE_NPU_BRIDGE(relu_clipped);
USE_NPU_BRIDGE(relu6);
USE_NPU_BRIDGE(leaky_relu); USE_NPU_BRIDGE(leaky_relu);
USE_NPU_BRIDGE(softsign); USE_NPU_BRIDGE(softsign);
USE_NPU_BRIDGE(hard_sigmoid); USE_NPU_BRIDGE(hard_sigmoid);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册