提交 e7ed7286 编写于 作者: Z zhupengyang 提交者: GitHub

[NPU] fix act; refine act unit tests; fix batch_norm (#2533)

test=develop
上级 f8e20e80
......@@ -142,21 +142,25 @@ ge::TensorPtr CvtTensor(lite::Tensor* in_tensor,
int CvtActMode(std::string act_type) {
int act_mode = 1;
if (act_type == "sigmod") {
if (act_type == "sigmoid") {
act_mode = 0;
} else if (act_type == "relu") {
act_mode = 1;
} else if (act_type == "tanh") {
act_mode = 2;
} else if (act_type == "relu_clipped") {
act_mode = 3;
} else if (act_type == "elu") {
act_mode = 4;
} else if (act_type == "leaky_relu") {
act_mode = 5;
} else if (act_type == "abs") {
act_mode = 6;
} else if (act_type == "softsign") {
act_mode = 8;
} else if (act_type == "softplus") {
act_mode = 9;
} else if (act_type == "hardsigmoid") {
} else if (act_type == "hard_sigmoid") {
act_mode = 10;
} else {
// TODO(hong19860320) support more activation mode
......
......@@ -41,6 +41,19 @@ node_map_type ActConverter(const std::shared_ptr<lite::OpLite> act_op,
// clipped_relu etc.
act_node->set_attr_mode(lite::npu::CvtActMode(op_type));
if (op_type == "relu_clipped") {
auto Relu_clipped_coef = op_info->GetAttr<float>("Relu_clipped_coef");
act_node->set_attr_coef(Relu_clipped_coef);
} else if (op_type == "leaky_relu") {
auto alpha = op_info->GetAttr<float>("alpha");
act_node->set_attr_negative_slope(alpha);
} else if (op_type == "hard_sigmoid") {
auto slope = op_info->GetAttr<float>("slope");
auto offset = op_info->GetAttr<float>("offset");
act_node->set_attr_negative_slope(slope);
act_node->set_attr_coef(offset);
}
node_map_type outputs_map;
outputs_map[op_info->Output("Out").front()] = act_node;
return outputs_map;
......@@ -52,14 +65,18 @@ node_map_type ActConverter(const std::shared_ptr<lite::OpLite> act_op,
} // namespace lite
} // namespace paddle
REGISTER_NPU_BRIDGE(sigmod, paddle::lite::kernels::npu::bridges::ActConverter);
REGISTER_NPU_BRIDGE(sigmoid, paddle::lite::kernels::npu::bridges::ActConverter);
REGISTER_NPU_BRIDGE(relu, paddle::lite::kernels::npu::bridges::ActConverter);
REGISTER_NPU_BRIDGE(tanh, paddle::lite::kernels::npu::bridges::ActConverter);
REGISTER_NPU_BRIDGE(elu, paddle::lite::kernels::npu::bridges::ActConverter);
REGISTER_NPU_BRIDGE(relu_clipped,
paddle::lite::kernels::npu::bridges::ActConverter);
// REGISTER_NPU_BRIDGE(elu, paddle::lite::kernels::npu::bridges::ActConverter);
REGISTER_NPU_BRIDGE(leaky_relu,
paddle::lite::kernels::npu::bridges::ActConverter);
REGISTER_NPU_BRIDGE(abs, paddle::lite::kernels::npu::bridges::ActConverter);
REGISTER_NPU_BRIDGE(softsign,
paddle::lite::kernels::npu::bridges::ActConverter);
REGISTER_NPU_BRIDGE(softplus,
paddle::lite::kernels::npu::bridges::ActConverter);
REGISTER_NPU_BRIDGE(hardsigmoid,
REGISTER_NPU_BRIDGE(hard_sigmoid,
paddle::lite::kernels::npu::bridges::ActConverter);
......@@ -17,7 +17,7 @@
#include "lite/core/op_registry.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/test_helper.h"
#include "lite/operators/relu_op.h"
#include "lite/operators/activation_ops.h"
namespace paddle {
namespace lite {
......@@ -25,69 +25,112 @@ namespace kernels {
namespace npu {
namespace bridges {
void relu_ref(const std::shared_ptr<operators::ReluOp> op) {
void act_ref(const std::shared_ptr<operators::ActivationOp> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x = scope->FindVar(op_info->Input("X").front())->GetMutable<Tensor>();
auto out =
scope->FindVar(op_info->Output("Out").front())->GetMutable<Tensor>();
auto op_type = op_info->Type();
auto x = scope->FindTensor("x");
auto out = scope->FindMutableTensor("out_ref");
out->Resize(x->dims());
auto x_data = x->data<float>();
auto out_data = out->mutable_data<float>();
DDim x_dims = x->dims();
DDim out_dims = out->dims();
CHECK_EQ(x_dims.production(), out_dims.production());
for (int i = 0; i < out_dims.production(); i++) {
CHECK_EQ(x->numel(), out->numel());
// "sigmoid","relu","tanh","relu_clipped","leaky_relu","softsign","hard_sigmoid"
if (op_type == "sigmoid") {
for (size_t i = 0; i < out->numel(); i++) {
out_data[i] = 1.f / (1.f + std::exp(-x_data[i]));
}
} else if (op_type == "relu") {
for (size_t i = 0; i < out->numel(); i++) {
out_data[i] = std::max(0.f, x_data[i]);
}
} else if (op_type == "tanh") {
for (size_t i = 0; i < out->numel(); i++) {
out_data[i] = (std::exp(x_data[i]) - std::exp(-x_data[i])) /
(std::exp(x_data[i]) + std::exp(-x_data[i]));
}
} else if (op_type == "relu_clipped") {
auto relu_clipped_coef = op_info->GetAttr<float>("Relu_clipped_coef");
for (size_t i = 0; i < out->numel(); i++) {
out_data[i] = std::min(std::max(0.f, x_data[i]), relu_clipped_coef);
}
} else if (op_type == "leaky_relu") {
auto alpha = op_info->GetAttr<float>("alpha");
for (size_t i = 0; i < out->numel(); i++) {
out_data[i] = std::max(x_data[i], x_data[i] * alpha);
}
} else if (op_type == "softsign") {
for (size_t i = 0; i < out->numel(); i++) {
out_data[i] = x_data[i] / (1 + std::abs(x_data[i]));
}
} else if (op_type == "hard_sigmoid") {
auto slope = op_info->GetAttr<float>("slope");
auto offset = op_info->GetAttr<float>("offset");
for (size_t i = 0; i < out->numel(); i++) {
out_data[i] = std::min(1.f, slope * x_data[i] + offset);
out_data[i] = std::max(0.f, out_data[i]);
}
} else {
LOG(FATAL) << "unsupported activation type: " << op_type;
}
}
void test_relu(int bs, int ic, int ih, int iw) {
void test_act(std::vector<int64_t> x_shape, std::string op_type) {
// prepare input&output variables
Scope scope;
std::string x_var_name("x");
std::string out_var_name("out");
std::string out_ref_var_name("out_ref");
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
x->Resize({bs, ic, ih, iw});
auto* x = scope.NewTensor(x_var_name);
auto* out = scope.NewTensor(out_var_name);
auto* out_ref = scope.NewTensor(out_ref_var_name);
x->Resize(x_shape);
// initialize input&output data
FillTensor<float, int>(x);
FillTensor<float>(x, -8, 8);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("relu");
opdesc.SetType(op_type);
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
if (op_type == "relu_clipped") {
opdesc.SetAttr("Relu_clipped_coef", 6.f);
} else if (op_type == "leaky_relu") {
opdesc.SetAttr("alpha", 0.02f);
} else if (op_type == "hard_sigmoid") {
opdesc.SetAttr("slope", 0.2f);
opdesc.SetAttr("offset", 0.5f);
}
// create and convert op to NPU model, then run it on NPU
auto op = CreateOp<operators::ReluOp>(opdesc, &scope);
auto op = CreateOp<operators::ActivationOp>(opdesc, &scope);
LauchOp(op, {x_var_name}, {out_var_name});
out_ref->CopyDataFrom(*out);
// execute reference implementation and save to output tensor
relu_ref(op);
act_ref(op);
// compare results
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
VLOG(5) << i;
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5);
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-2);
}
}
TEST(NPUBridges, relu) {
for (auto bs : {1, 3}) {
for (auto ic : {3, 4}) {
for (auto ih : {2, 5}) {
for (auto iw : {5, 9}) {
VLOG(3) << "bs: " << bs << " ic: " << ic << " ih: " << ih
<< " iw: " << iw;
test_relu(bs, ic, ih, iw);
}
}
TEST(NPUBridges, activation) {
std::vector<std::vector<int64_t>> shapes{{1}, {2, 3}, {1, 2, 3, 4}};
std::vector<std::string> types{"sigmoid",
"relu",
"tanh",
"relu_clipped",
"leaky_relu",
"softsign",
"hard_sigmoid"};
for (auto x_shape : shapes) {
for (auto op_type : types) {
test_act(x_shape, op_type);
}
}
}
......@@ -98,5 +141,20 @@ TEST(NPUBridges, relu) {
} // namespace lite
} // namespace paddle
USE_LITE_OP(sigmoid);
USE_NPU_BRIDGE(sigmoid);
USE_LITE_OP(relu);
USE_NPU_BRIDGE(relu);
USE_LITE_OP(tanh);
USE_NPU_BRIDGE(tanh);
USE_LITE_OP(relu_clipped);
USE_NPU_BRIDGE(relu_clipped);
USE_LITE_OP(leaky_relu);
USE_NPU_BRIDGE(leaky_relu);
USE_LITE_OP(softsign);
USE_NPU_BRIDGE(softsign);
USE_LITE_OP(hard_sigmoid);
USE_NPU_BRIDGE(hard_sigmoid);
......@@ -30,8 +30,8 @@ node_map_type BatchNormConverter(
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "[NPU] Converting " + op_type + "...";
std::shared_ptr<ge::op::BatchNorm> batch_norm_node =
std::make_shared<ge::op::BatchNorm>(unique_op_type);
std::shared_ptr<ge::op::BatchNormExt2> batch_norm_node =
std::make_shared<ge::op::BatchNormExt2>(unique_op_type);
auto x_var_name = op_info->Input("X").front();
auto scale_var_name = op_info->Input("Scale").front();
......@@ -66,7 +66,7 @@ node_map_type BatchNormConverter(
batch_norm_node->set_input_x(*inputs_map.at(x_var_name));
batch_norm_node->set_input_scale(*npu_scale);
batch_norm_node->set_input_b(*npu_bias);
batch_norm_node->set_input_offset(*npu_bias);
batch_norm_node->set_input_mean(*npu_mean);
batch_norm_node->set_input_variance(*npu_variance);
batch_norm_node->set_attr_momentum(npu_momentum);
......
......@@ -16,31 +16,40 @@
#include "lite/kernels/npu/bridges/registry.h"
USE_NPU_BRIDGE(mul);
USE_NPU_BRIDGE(fc);
USE_NPU_BRIDGE(sigmoid);
USE_NPU_BRIDGE(relu);
USE_NPU_BRIDGE(tanh);
USE_NPU_BRIDGE(relu_clipped);
USE_NPU_BRIDGE(leaky_relu);
USE_NPU_BRIDGE(softsign);
USE_NPU_BRIDGE(hard_sigmoid);
USE_NPU_BRIDGE(batch_norm);
USE_NPU_BRIDGE(concat);
USE_NPU_BRIDGE(conv2d);
USE_NPU_BRIDGE(depthwise_conv2d);
USE_NPU_BRIDGE(pool2d);
USE_NPU_BRIDGE(relu);
USE_NPU_BRIDGE(conv2d_transpose);
USE_NPU_BRIDGE(elementwise_add);
USE_NPU_BRIDGE(fusion_elementwise_add_activation);
USE_NPU_BRIDGE(elementwise_sub);
USE_NPU_BRIDGE(elementwise_mul);
USE_NPU_BRIDGE(elementwise_div);
USE_NPU_BRIDGE(scale);
USE_NPU_BRIDGE(softmax);
USE_NPU_BRIDGE(concat);
USE_NPU_BRIDGE(split);
USE_NPU_BRIDGE(transpose);
USE_NPU_BRIDGE(transpose2);
USE_NPU_BRIDGE(shuffle_channel);
USE_NPU_BRIDGE(batch_norm);
USE_NPU_BRIDGE(fc);
USE_NPU_BRIDGE(bilinear_interp);
USE_NPU_BRIDGE(conv2d_transpose);
USE_NPU_BRIDGE(nearest_interp);
USE_NPU_BRIDGE(mul);
USE_NPU_BRIDGE(pad2d);
USE_NPU_BRIDGE(pool2d);
USE_NPU_BRIDGE(reduce_mean);
USE_NPU_BRIDGE(reshape);
USE_NPU_BRIDGE(reshape2);
USE_NPU_BRIDGE(scale);
USE_NPU_BRIDGE(shuffle_channel);
USE_NPU_BRIDGE(softmax);
USE_NPU_BRIDGE(split);
USE_NPU_BRIDGE(sqrt);
USE_NPU_BRIDGE(square);
USE_NPU_BRIDGE(reduce_mean);
USE_NPU_BRIDGE(tanh);
USE_NPU_BRIDGE(nearest_interp);
USE_NPU_BRIDGE(transpose);
USE_NPU_BRIDGE(transpose2);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册