未验证 提交 b4bfa42e 编写于 作者: H HappyAngel 提交者: GitHub

[arm]fix ttfnet run error (#4086)

* fix compute error

* fix format, test=develop
上级 9846b4d2
...@@ -300,13 +300,15 @@ void fill_bias_act<float>(float* tensor, ...@@ -300,13 +300,15 @@ void fill_bias_act<float>(float* tensor,
switch (act_param->active_type) { switch (act_param->active_type) {
case lite_api::ActivationType::kRelu: case lite_api::ActivationType::kRelu:
for (int i = 0; i < remain; i++) { for (int i = 0; i < remain; i++) {
*dst = *src >= 0.f ? *src : 0.f; float tmp = (*src + bias_data);
*dst = tmp >= 0.f ? tmp : 0.f;
src++; src++;
dst++; dst++;
} }
case lite_api::ActivationType::kRelu6: case lite_api::ActivationType::kRelu6:
for (int i = 0; i < remain; i++) { for (int i = 0; i < remain; i++) {
float tmp = *src >= 0.f ? *src : 0.f; float tmp = (*src + bias_data);
tmp = tmp >= 0.f ? tmp : 0.f;
*dst = tmp <= act_param->Relu_clipped_coef *dst = tmp <= act_param->Relu_clipped_coef
? tmp ? tmp
: act_param->Relu_clipped_coef; : act_param->Relu_clipped_coef;
...@@ -315,10 +317,11 @@ void fill_bias_act<float>(float* tensor, ...@@ -315,10 +317,11 @@ void fill_bias_act<float>(float* tensor,
} }
case lite_api::ActivationType::kLeakyRelu: case lite_api::ActivationType::kLeakyRelu:
for (int i = 0; i < remain; i++) { for (int i = 0; i < remain; i++) {
if (*src >= 0.f) { float tmp = (*src + bias_data);
*dst = *src; if (tmp >= 0.f) {
*dst = tmp;
} else { } else {
*dst = *src * act_param->Leaky_relu_alpha; *dst = tmp * act_param->Leaky_relu_alpha;
} }
src++; src++;
dst++; dst++;
...@@ -336,17 +339,24 @@ void fill_bias_act<float>(float* tensor, ...@@ -336,17 +339,24 @@ void fill_bias_act<float>(float* tensor,
float32x4_t vbias = vdupq_n_f32(bias_data); float32x4_t vbias = vdupq_n_f32(bias_data);
float* src = data + j * channel_size; float* src = data + j * channel_size;
float* dst = data + j * channel_size; float* dst = data + j * channel_size;
if (cnt > 0) {
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(FILL_BIAS FILL_STORE asm volatile(FILL_BIAS FILL_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt) :
: [vbias] "w"(vbias) [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: "memory", "cc", "v0", "v1", "v2", "v3"); : [vbias] "w"(vbias)
: "memory", "cc", "v0", "v1", "v2", "v3");
#else #else
asm volatile(FILL_BIAS FILL_STORE asm volatile(FILL_BIAS FILL_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt) :
: [vbias] "w"(vbias) [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: "memory", "cc", "q3", "q4", "q5", "q6"); : [vbias] "w"(vbias)
: "memory", "cc", "q3", "q4", "q5", "q6");
#endif #endif
}
for (int i = 0; i < remain; i++) {
*dst = *src + bias_data;
}
} }
} }
} }
......
...@@ -104,9 +104,7 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { ...@@ -104,9 +104,7 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto conv_weight_t = auto conv_weight_t =
scope->FindVar(conv_weight_name)->GetMutable<lite::Tensor>(); scope->FindVar(conv_weight_name)->GetMutable<lite::Tensor>();
auto groups = conv_op_desc->GetAttr<int>("groups"); auto groups = conv_op_desc->GetAttr<int>("groups");
bool depthwise = false;
if (conv_type_ == "conv2d_transpose") { if (conv_type_ == "conv2d_transpose") {
depthwise = (conv_weight_t->dims()[0] == conv_weight_t->dims()[1] * groups);
CHECK_EQ(static_cast<size_t>(bn_scale_t->data_size()), CHECK_EQ(static_cast<size_t>(bn_scale_t->data_size()),
static_cast<size_t>(conv_weight_t->dims()[1] * groups)) static_cast<size_t>(conv_weight_t->dims()[1] * groups))
<< "The BN bias's size should be equal to the size of the first " << "The BN bias's size should be equal to the size of the first "
...@@ -120,7 +118,6 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { ...@@ -120,7 +118,6 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
size_t weight_num = conv_weight_t->data_size(); size_t weight_num = conv_weight_t->data_size();
bool enable_int8 = conv_op_desc->HasAttr("enable_int8") ? true : false; bool enable_int8 = conv_op_desc->HasAttr("enable_int8") ? true : false;
bool is_weight_quantization = conv_op_desc->HasAttr("quantize_weight_bits"); bool is_weight_quantization = conv_op_desc->HasAttr("quantize_weight_bits");
// comupte BN alpha and beta // comupte BN alpha and beta
Tensor alpha_tensor, beta_tensor; Tensor alpha_tensor, beta_tensor;
alpha_tensor.CopyDataFrom(*bn_bias_t); alpha_tensor.CopyDataFrom(*bn_bias_t);
...@@ -162,12 +159,13 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { ...@@ -162,12 +159,13 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto conv_weight_d = conv_weight_t->mutable_data<int8_t>(); auto conv_weight_d = conv_weight_t->mutable_data<int8_t>();
// compute new conv_weight for int8 // compute new conv_weight for int8
auto weight_scale = conv_op_desc->GetInputScale(weight_name); auto weight_scale = conv_op_desc->GetInputScale(weight_name);
if (conv_type_ == "conv2d_transpose" && !depthwise) { if (conv_type_ == "conv2d_transpose") {
int c_size = conv_weight_t->dims()[1] * conv_weight_t->dims()[2] * int cout = conv_weight_t->dims()[1] * groups;
conv_weight_t->dims()[3]; int cin_group = conv_weight_t->dims()[0] / groups;
int c_size = cout * conv_weight_t->dims()[2] * conv_weight_t->dims()[3];
int hw = conv_weight_t->dims()[2] * conv_weight_t->dims()[3]; int hw = conv_weight_t->dims()[2] * conv_weight_t->dims()[3];
for (int k = 0; k < conv_weight_t->dims()[0]; ++k) { for (int k = 0; k < cin_group; ++k) {
for (int i = 0; i < h; ++i) { for (int i = 0; i < cout; ++i) {
weight_scale[i] *= fabsf(alpha_data[i]); weight_scale[i] *= fabsf(alpha_data[i]);
if (alpha_data[i] < 0.f) { if (alpha_data[i] < 0.f) {
auto ptr_row = conv_weight_d + k * c_size + i * hw; auto ptr_row = conv_weight_d + k * c_size + i * hw;
...@@ -203,12 +201,13 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { ...@@ -203,12 +201,13 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
} else { } else {
// compute new conv_weight // compute new conv_weight
auto conv_weight_d = conv_weight_t->mutable_data<float>(); auto conv_weight_d = conv_weight_t->mutable_data<float>();
if (conv_type_ == "conv2d_transpose" && !depthwise) { if (conv_type_ == "conv2d_transpose") {
int c_size = conv_weight_t->dims()[1] * conv_weight_t->dims()[2] * int cout = conv_weight_t->dims()[1] * groups;
conv_weight_t->dims()[3]; int cin_group = conv_weight_t->dims()[0] / groups;
int c_size = cout * conv_weight_t->dims()[2] * conv_weight_t->dims()[3];
int hw = conv_weight_t->dims()[2] * conv_weight_t->dims()[3]; int hw = conv_weight_t->dims()[2] * conv_weight_t->dims()[3];
for (int k = 0; k < conv_weight_t->dims()[0]; ++k) { for (int k = 0; k < cin_group; ++k) {
for (int i = 0; i < h; ++i) { for (int i = 0; i < cout; ++i) {
auto ptr_row = conv_weight_d + k * c_size + i * hw; auto ptr_row = conv_weight_d + k * c_size + i * hw;
for (int j = 0; j < hw; ++j) { for (int j = 0; j < hw; ++j) {
ptr_row[j] *= alpha_data[i]; ptr_row[j] *= alpha_data[i];
......
...@@ -73,7 +73,6 @@ void Conv2DTransposeCompute::Run() { ...@@ -73,7 +73,6 @@ void Conv2DTransposeCompute::Run() {
int kw = w_dims[3]; // oihw int kw = w_dims[3]; // oihw
int kh = w_dims[2]; int kh = w_dims[2];
int group = param.groups; int group = param.groups;
bool fuse_relu = param.fuse_relu;
bool flag_bias = (param.bias != nullptr); bool flag_bias = (param.bias != nullptr);
auto paddings = *param.paddings; auto paddings = *param.paddings;
...@@ -104,6 +103,7 @@ void Conv2DTransposeCompute::Run() { ...@@ -104,6 +103,7 @@ void Conv2DTransposeCompute::Run() {
auto dout = param.output->mutable_data<float>(); auto dout = param.output->mutable_data<float>();
auto weights = param.filter->data<float>(); auto weights = param.filter->data<float>();
auto act_param = param.activation_param; auto act_param = param.activation_param;
bool has_act = act_param.has_active;
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
const float* din_batch = din + i * chin * hin * win; const float* din_batch = din + i * chin * hin * win;
float* dout_batch = dout + i * chout * hout * wout; float* dout_batch = dout + i * chout * hout * wout;
...@@ -152,13 +152,14 @@ void Conv2DTransposeCompute::Run() { ...@@ -152,13 +152,14 @@ void Conv2DTransposeCompute::Run() {
dout_batch); dout_batch);
} }
if (flag_bias) { if (flag_bias) {
lite::arm::math::fill_bias_relu<float>( act_param.has_active = has_act;
lite::arm::math::fill_bias_act<float>(
dout_batch, dout_batch,
static_cast<const float*>(param.bias->data<float>()), static_cast<const float*>(param.bias->data<float>()),
chout, chout,
wout * hout, wout * hout,
flag_bias, flag_bias,
fuse_relu); &act_param);
} }
} }
} }
......
...@@ -168,6 +168,5 @@ using whereindex = paddle::lite::kernels::host::WhereIndexCompute; ...@@ -168,6 +168,5 @@ using whereindex = paddle::lite::kernels::host::WhereIndexCompute;
REGISTER_LITE_KERNEL(where_index, kHost, kAny, kAny, whereindex, def) REGISTER_LITE_KERNEL(where_index, kHost, kAny, kAny, whereindex, def)
.BindInput("Condition", .BindInput("Condition",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))}) {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))})
.BindOutput("Out", .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))})
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64))})
.Finalize(); .Finalize();
...@@ -30,7 +30,14 @@ bool CompareOp::InferShapeImpl() const { ...@@ -30,7 +30,14 @@ bool CompareOp::InferShapeImpl() const {
CHECK_OR_FALSE(param_.Out); CHECK_OR_FALSE(param_.Out);
// TODO(Superjomn) Enable data sharing. // TODO(Superjomn) Enable data sharing.
auto input_dims = param_.X->dims(); auto input_dims = param_.X->dims();
param_.Out->Resize(input_dims); std::vector<int64_t> new_dims;
if (input_dims.size() == 2 && input_dims[1] == 1) {
new_dims.push_back(input_dims[0]);
param_.Out->Resize(new_dims);
} else {
param_.Out->Resize(input_dims);
}
// param_.Out->Resize(input_dims);
return true; return true;
} }
......
...@@ -141,9 +141,25 @@ bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc& op_desc, ...@@ -141,9 +141,25 @@ bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc& op_desc,
} }
} }
} }
if (op_desc.HasAttr("fuse_relu")) { if (op_desc.HasAttr("with_act") && op_desc.GetAttr<bool>("with_act")) {
param_.fuse_relu = op_desc.GetAttr<bool>("fuse_relu"); param_.activation_param.has_active = true;
param_.activation_param.active_type = lite_api::ActivationType::kRelu; auto act_type = op_desc.GetAttr<std::string>("act_type");
if (act_type == "relu") {
param_.activation_param.active_type = lite_api::ActivationType::kRelu;
param_.fuse_relu = true;
} else if (act_type == "relu6") {
param_.activation_param.active_type = lite_api::ActivationType::kRelu6;
param_.activation_param.Relu_clipped_coef =
op_desc.GetAttr<float>("fuse_brelu_threshold"); // 6.f
} else if (act_type == "leaky_relu") {
param_.activation_param.active_type =
lite_api::ActivationType::kLeakyRelu;
param_.activation_param.Leaky_relu_alpha =
op_desc.GetAttr<float>("leaky_relu_alpha");
} else {
CHECK(false)
<< "The fused conv only supports fuse with relu and leaky relu";
}
} }
if (op_desc.HasAttr("output_size")) { if (op_desc.HasAttr("output_size")) {
param_.output_size = op_desc.GetAttr<std::vector<int>>("output_size"); param_.output_size = op_desc.GetAttr<std::vector<int>>("output_size");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册