未验证 提交 4c49f876 编写于 作者: Q Qi Li 提交者: GitHub

[ascend] add pool, elementwise, batch_norm op, test=develop (#4050)

* [ascend] add pool and elementwise op, test=develop

* [ascend] add batch_norm op and simpify tensor update, test=develop
上级 07104881
......@@ -76,6 +76,7 @@ bool Device::Build(std::vector<ge::Operator>& input_nodes, // NOLINT
}
}
VLOG(3) << "Getting input node size " << input_nodes.size();
VLOG(3) << "Getting output node size " << output_nodes.size();
ir_graph.SetInputs(input_nodes).SetOutputs(output_nodes);
// Build IR model
......
......@@ -96,7 +96,9 @@ bool AclModelClient::GetModelIOTensorDim(
ACL_CALL(aclmdlGetInputDims(model_desc_, i, &input_dim));
aclDataType data_type = aclmdlGetInputDataType(model_desc_, i);
aclFormat data_format = aclmdlGetInputFormat(model_desc_, i);
TensorDesc tensor_desc = TensorDesc(data_type, input_dim, data_format);
const std::string name_str(aclmdlGetInputNameByIndex(model_desc_, i));
TensorDesc tensor_desc =
TensorDesc(name_str, data_type, input_dim, data_format);
input_tensor->push_back(tensor_desc);
}
......@@ -108,7 +110,9 @@ bool AclModelClient::GetModelIOTensorDim(
ACL_CALL(aclmdlGetOutputDims(model_desc_, i, &output_dim));
aclDataType data_type = aclmdlGetOutputDataType(model_desc_, i);
aclFormat data_format = aclmdlGetOutputFormat(model_desc_, i);
TensorDesc tensor_desc = TensorDesc(data_type, output_dim, data_format);
const std::string name_str(aclmdlGetOutputNameByIndex(model_desc_, i));
TensorDesc tensor_desc =
TensorDesc(name_str, data_type, output_dim, data_format);
output_tensor->push_back(tensor_desc);
}
return true;
......@@ -118,12 +122,10 @@ bool AclModelClient::GetTensorFromDataset(
std::vector<std::shared_ptr<ge::Tensor>>* output_tensor) {
size_t device_output_num = aclmdlGetDatasetNumBuffers(output_dataset_);
size_t tensor_output_num = reinterpret_cast<size_t>(output_tensor->size());
if (device_output_num != tensor_output_num) {
LOG(ERROR)
<< "[HUAWEI_ASCEND_NPU] output number not equal, device number is "
<< device_output_num << "tensor number is " << tensor_output_num;
return false;
}
CHECK_EQ(device_output_num, tensor_output_num)
<< "[HUAWEI_ASCEND_NPU] tensor output number should equal to device "
"output number, device output number is "
<< device_output_num << ", tensor output number is " << tensor_output_num;
for (size_t i = 0; i < device_output_num; i++) {
aclDataBuffer* buffer_device = aclmdlGetDatasetBuffer(output_dataset_, i);
void* device_data = aclGetDataBufferAddr(buffer_device);
......@@ -195,7 +197,10 @@ void AclModelClient::CreateOutputDataset(
return;
}
size_t output_size = aclmdlGetNumOutputs(model_desc_);
CHECK_EQ(output_size, output_tensor->size());
CHECK_EQ(output_size, output_tensor->size())
<< "[HUAWEI_ASCEND_NPU] model output number should equal to output "
"tensor size, model output number is "
<< output_size << ", output tensor number is " << output_tensor->size();
for (size_t i = 0; i < output_size; i++) {
size_t buffer_size = aclmdlGetOutputSizeByIndex(model_desc_, i);
void* buffer_device = nullptr;
......
......@@ -25,15 +25,20 @@ namespace huawei_ascend_npu {
class TensorDesc {
public:
TensorDesc(aclDataType data_type, aclmdlIODims dims, aclFormat format) {
TensorDesc(const std::string name,
aclDataType data_type,
aclmdlIODims dims,
aclFormat format) {
if (format == ACL_FORMAT_NHWC) {
dim_order[1] = 3;
dim_order[2] = 1;
dim_order[3] = 2;
}
// create ge::Tensordesc
VLOG(3) << "[HUAWEI_ASCEND_NPU] Getting tensor name : " << name;
ge_tensor_desc_ = new ge::TensorDesc(
GetGeShape(dims), GetGeFormat(format), GetGeDataType(data_type));
ge_tensor_desc_->SetName(name);
CHECK(ge_tensor_desc_ != nullptr);
VLOG(3) << "[HUAWEI_ASCEND_NPU] Getting data shape : " << repr();
}
......
......@@ -44,7 +44,7 @@ void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass)
.BindTargets({TARGET(kAny)})
.ExcludeTargets({TARGET(kXPU)})
#ifndef LITE_WITH_MLU
#if (!defined(LITE_WITH_MLU) && !defined(LITE_WITH_HUAWEI_ASCEND_NPU))
.ExcludeTargets({TARGET(kX86)})
#endif
.ExcludeTargets({TARGET(kBM)})
......
......@@ -11,6 +11,9 @@ lite_cc_library(subgraph_bridge_act_op_huawei_ascend_npu SRCS act_op.cc DEPS ${h
lite_cc_library(subgraph_bridge_conv_op_huawei_ascend_npu SRCS conv_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_interpolate_op_huawei_ascend_npu SRCS interpolate_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_concat_op_huawei_ascend_npu SRCS concat_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_pool_op_huawei_ascend_npu SRCS pool_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_elementwise_ops_huawei_ascend_npu SRCS elementwise_ops.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_batch_norm_op_huawei_ascend_npu SRCS batch_norm_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
set(huawei_ascend_npu_subgraph_bridges
subgraph_bridge_registry
......@@ -20,4 +23,7 @@ set(huawei_ascend_npu_subgraph_bridges
subgraph_bridge_conv_op_huawei_ascend_npu
subgraph_bridge_interpolate_op_huawei_ascend_npu
subgraph_bridge_concat_op_huawei_ascend_npu
subgraph_bridge_pool_op_huawei_ascend_npu
subgraph_bridge_elementwise_ops_huawei_ascend_npu
subgraph_bridge_batch_norm_op_huawei_ascend_npu
CACHE INTERNAL "huawei_ascend_npu_subgraph_bridges")
......@@ -49,10 +49,8 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto act_node = graph->template Add<ActType>(out_name);
auto act_op = act_node->template data<ActType>();
act_op->set_input_x(*x_node->data());
TENSOR_UPDATE_INPUT(
act_op, x, ge::FORMAT_NCHW, CvtPrecisionType(x_node->precision()));
TENSOR_UPDATE_OUTPUT(
act_op, y, ge::FORMAT_NCHW, CvtPrecisionType(act_node->precision()));
INPUT_UPDATE(act_op, x, x_node);
OUTPUT_UPDATE(act_op, y, act_node);
return SUCCESS;
}
......@@ -88,10 +86,8 @@ int ActConverter<ge::op::LeakyRelu>(void* ctx, OpLite* op, KernelBase* kernel) {
// only for leaky_relu
auto alpha = op_info->GetAttr<float>("alpha");
act_op->set_attr_negative_slope(alpha);
TENSOR_UPDATE_INPUT(
act_op, x, ge::FORMAT_NCHW, CvtPrecisionType(x_node->precision()));
TENSOR_UPDATE_OUTPUT(
act_op, y, ge::FORMAT_NCHW, CvtPrecisionType(act_node->precision()));
INPUT_UPDATE(act_op, x, x_node);
OUTPUT_UPDATE(act_op, y, act_node);
return SUCCESS;
}
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/huawei_ascend_npu/bridges/graph.h"
#include "lite/kernels/huawei_ascend_npu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace huawei_ascend_npu {
int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[HUAWEI_ASCEND_NPU] Converting " + op_type + "...";
// Get input data nodes
auto x_name = op_info->Input("X").front();
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto scale_name = op_info->Input("Scale").front();
auto scale = scope->FindMutableTensor(scale_name);
auto bias_name = op_info->Input("Bias").front();
auto bias = scope->FindMutableTensor(bias_name);
auto mean_name = op_info->Input("Mean").front();
auto mean = scope->FindMutableTensor(mean_name);
auto variance_name = op_info->Input("Variance").front();
auto variance = scope->FindMutableTensor(variance_name);
// Get output var nodes
auto y_name = op_info->Output("Y").front();
// Get attributes
float epsilon = op_info->GetAttr<float>("epsilon");
// Check is_test
auto is_test_type = op_info->GetAttrType("is_test");
if (is_test_type == OpDescAPI::AttrType::INT) {
CHECK_EQ(op_info->GetAttr<int>("is_test"), 1)
<< "[HUAWEI_ASCEND_NPU] Only is_test=1 or is_test=true is supported in "
"inference mode.";
} else if (is_test_type == OpDescAPI::AttrType::BOOLEAN) {
CHECK_EQ(op_info->GetAttr<bool>("is_test"), true)
<< "[HUAWEI_ASCEND_NPU] Only is_test=1 or is_test=true is supported in "
"inference mode.";
}
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x);
}
// Scale, Bias, Mean, Variance node
auto scale_node = graph->Add(scale_name, *scale);
auto bias_node = graph->Add(bias_name, *bias);
auto mean_node = graph->Add(mean_name, *mean);
auto variance_node = graph->Add(variance_name, *variance);
// Batch Norm node - output nodes
auto batch_norm_node = graph->Add<ge::op::BatchNorm>(y_name + "/batch_norm");
auto batch_norm_op = batch_norm_node->data<ge::op::BatchNorm>();
batch_norm_op->set_input_x(*x_node->data());
batch_norm_op->set_input_scale(*scale_node->data());
batch_norm_op->set_input_offset(*bias_node->data());
batch_norm_op->set_input_mean(*mean_node->data());
batch_norm_op->set_input_variance(*variance_node->data());
batch_norm_op->set_attr_epsilon(epsilon);
batch_norm_op->set_attr_data_format("NCHW");
batch_norm_op->set_attr_is_training(false);
INPUT_UPDATE(batch_norm_op, x, x_node);
INPUT_UPDATE(batch_norm_op, scale, scale_node);
INPUT_UPDATE(batch_norm_op, offset, bias_node);
INPUT_UPDATE(batch_norm_op, mean, mean_node);
INPUT_UPDATE(batch_norm_op, variance, variance_node);
OUTPUT_UPDATE(batch_norm_op, y, batch_norm_node);
// Create Variable node for batch norm output y
auto out_y_node = graph->Add<ge::op::Identity>(y_name);
auto out_y_op = out_y_node->data<ge::op::Identity>();
out_y_op->set_input_x(*batch_norm_node->data(), "y");
INPUT_UPDATE(out_y_op, x, batch_norm_node);
OUTPUT_UPDATE(out_y_op, y, out_y_node);
return SUCCESS;
}
} // namespace huawei_ascend_npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(
batch_norm,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::BatchNormConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -51,10 +51,8 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto concat_op = concat_node->data<ge::op::Concat>();
// set axis input
concat_op->set_input_concat_dim(*axis_node->data());
TENSOR_UPDATE_INPUT(concat_op,
concat_dim,
ge::FORMAT_NCHW,
CvtPrecisionType(axis_node->precision()));
INPUT_UPDATE(concat_op, concat_dim, axis_node);
// set dynamic input
concat_op->set_attr_N(num);
concat_op->create_dynamic_input_x(num);
......@@ -69,17 +67,10 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) {
x_node = graph->Add(x_name, *x);
}
concat_op->set_dynamic_input_x(idx, *x_node->data());
TENSOR_UPDATE_DYNAMIC_INPUT(concat_op,
x,
idx,
ge::FORMAT_NCHW,
CvtPrecisionType(x_node->precision()));
DYNAMIC_INPUT_UPDATE(concat_op, x, idx, x_node);
idx++;
}
TENSOR_UPDATE_OUTPUT(concat_op,
y,
ge::FORMAT_NCHW,
CvtPrecisionType(concat_node->precision()));
OUTPUT_UPDATE(concat_op, y, concat_node);
} else {
auto concat_node = graph->Add<ge::op::ConcatD>(out_name);
auto concat_op = concat_node->data<ge::op::ConcatD>();
......@@ -97,17 +88,10 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) {
x_node = graph->Add(x_name, *x);
}
concat_op->set_dynamic_input_x(idx, *x_node->data());
TENSOR_UPDATE_DYNAMIC_INPUT(concat_op,
x,
idx,
ge::FORMAT_NCHW,
CvtPrecisionType(x_node->precision()));
DYNAMIC_INPUT_UPDATE(concat_op, x, idx, x_node);
idx++;
}
TENSOR_UPDATE_OUTPUT(concat_op,
y,
ge::FORMAT_NCHW,
CvtPrecisionType(concat_node->precision()));
OUTPUT_UPDATE(concat_op, y, concat_node);
}
return SUCCESS;
......
......@@ -182,19 +182,11 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
conv_op->set_attr_data_format("NCHW");
if (bias_node != nullptr && is_channel_bias) {
conv_op->set_input_bias(*bias_node->data());
TENSOR_UPDATE_INPUT(conv_op,
bias,
ge::FORMAT_NCHW,
CvtPrecisionType(bias_node->precision()));
INPUT_UPDATE(conv_op, bias, bias_node);
}
TENSOR_UPDATE_INPUT(
conv_op, x, ge::FORMAT_NCHW, CvtPrecisionType(input_node->precision()));
TENSOR_UPDATE_INPUT(conv_op,
filter,
ge::FORMAT_NCHW,
CvtPrecisionType(filter_node->precision()));
TENSOR_UPDATE_OUTPUT(
conv_op, y, ge::FORMAT_NCHW, CvtPrecisionType(conv_node->precision()));
INPUT_UPDATE(conv_op, x, input_node);
INPUT_UPDATE(conv_op, filter, filter_node);
OUTPUT_UPDATE(conv_op, y, conv_node);
} else {
conv_node = graph->Add<ge::op::Conv2D>(output_name);
auto conv_op = conv_node->data<ge::op::Conv2D>();
......@@ -210,19 +202,11 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
conv_op->set_attr_data_format("NCHW");
if (bias_node != nullptr && is_channel_bias) {
conv_op->set_input_bias(*bias_node->data());
TENSOR_UPDATE_INPUT(conv_op,
bias,
ge::FORMAT_NCHW,
CvtPrecisionType(bias_node->precision()));
INPUT_UPDATE(conv_op, bias, bias_node);
}
TENSOR_UPDATE_INPUT(
conv_op, x, ge::FORMAT_NCHW, CvtPrecisionType(input_node->precision()));
TENSOR_UPDATE_INPUT(conv_op,
filter,
ge::FORMAT_NCHW,
CvtPrecisionType(filter_node->precision()));
TENSOR_UPDATE_OUTPUT(
conv_op, y, ge::FORMAT_NCHW, CvtPrecisionType(conv_node->precision()));
INPUT_UPDATE(conv_op, x, input_node);
INPUT_UPDATE(conv_op, filter, filter_node);
OUTPUT_UPDATE(conv_op, y, conv_node);
}
// append Add node to support bias
if (bias_node != nullptr && !is_channel_bias) {
......@@ -230,7 +214,9 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto add_op = add_node->data<ge::op::Add>();
add_op->set_input_x1(*conv_node->data());
add_op->set_input_x2(*bias_node->data());
conv_node = add_node;
INPUT_UPDATE(add_op, x1, conv_node);
INPUT_UPDATE(add_op, x2, bias_node);
OUTPUT_UPDATE(add_op, y, add_node);
}
CHECK(conv_node);
......@@ -241,11 +227,15 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto act_node = graph->Add<ge::op::Relu>(output_name);
auto act_op = act_node->data<ge::op::Relu>();
act_op->set_input_x(*conv_node->data());
INPUT_UPDATE(act_op, x, conv_node);
OUTPUT_UPDATE(act_op, y, act_node);
} else if (act_type == "leaky_relu") {
auto act_node = graph->Add<ge::op::LeakyRelu>(output_name);
auto act_op = act_node->data<ge::op::LeakyRelu>();
act_op->set_input_x(*conv_node->data());
act_op->set_attr_negative_slope(leaky_relu_alpha);
INPUT_UPDATE(act_op, x, conv_node);
OUTPUT_UPDATE(act_op, y, act_node);
} else {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] act type not supported: "
<< act_type;
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/huawei_ascend_npu/bridges/graph.h"
#include "lite/kernels/huawei_ascend_npu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace huawei_ascend_npu {
void CvtXYShape(std::vector<int64_t>* x_shape,
std::vector<int64_t>* y_shape,
int axis) {
int x_shape_size = x_shape->size();
int y_shape_size = y_shape->size();
CHECK_GE(x_shape_size, y_shape_size);
// only support:
// 1. same shape
// 2. (n,c,h,w) * (1,c,1,1)
// 3. (n,c,h,w) * (n,c,1,1)
// 4. (n,c,h,w) * (1,c,h,1)
// 5. (n,c,h,w) * (1,c,h,w)
// 6. (n,c,h,w) * (n,c,1,w)
if (*x_shape == *y_shape) {
*x_shape = CvtShape(*x_shape);
*y_shape = CvtShape(*y_shape);
return;
}
if (y_shape_size == 1) {
for (int i = 0; i < 4 - x_shape_size; i++) {
x_shape->push_back(1);
}
int64_t n = x_shape->at(0);
int64_t c = x_shape->at(1);
int64_t h = x_shape->at(2);
int64_t w = x_shape->at(3);
if (axis == 0) {
*x_shape = std::vector<int64_t>{1, n, c * h * w, 1};
} else if (axis == 2) {
*x_shape = std::vector<int64_t>{n * c, h, w, 1};
} else if (axis == 3) {
*x_shape = std::vector<int64_t>{n * c * h, w, 1, 1};
}
*y_shape = std::vector<int64_t>{1, y_shape->at(0), 1, 1};
return;
}
if (y_shape_size == 2) {
for (int i = 0; i < 4 - x_shape_size; i++) {
x_shape->push_back(1);
}
int64_t n = x_shape->at(0);
int64_t c = x_shape->at(1);
int64_t h = x_shape->at(2);
int64_t w = x_shape->at(3);
if (axis == 0) {
y_shape->insert(y_shape->end(), 2, 1);
} else if (axis == 1) {
y_shape->insert(y_shape->begin(), 1);
y_shape->insert(y_shape->end(), 1);
} else if (axis == 2) {
*x_shape = std::vector<int64_t>{n * c, h, w, 1};
y_shape->insert(y_shape->begin(), 1);
y_shape->insert(y_shape->end(), 1);
}
return;
}
if (y_shape_size == 3) {
y_shape->insert(y_shape->begin(), 1);
int64_t n = x_shape->at(0);
int64_t c = x_shape->at(1);
int64_t h = x_shape->at(2);
int64_t w = x_shape->at(3);
if (axis == 0) {
*x_shape = std::vector<int64_t>{1, n * c * h, w, 1};
*y_shape = std::vector<int64_t>{1, n * c * h, 1, 1};
}
return;
}
}
int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[HUAWEI_ASCEND_NPU] Converting " + op_type + "...";
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x = scope->FindTensor(x_name);
auto x_dims = x->dims();
auto y_name = op_info->Input("Y").front();
auto y = scope->FindTensor(y_name);
auto y_dims = y->dims();
auto out_name = op_info->Output("Out").front();
auto out = scope->FindTensor(out_name);
auto out_dims = out->dims();
auto axis = op_info->GetAttr<int>("axis");
if (axis < 0) {
axis = x_dims.size() - y_dims.size();
}
auto x_new_shape = x_dims.Vectorize();
auto y_new_shape = y_dims.Vectorize();
CvtXYShape(&x_new_shape, &y_new_shape, axis);
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
auto shape_node = graph->Add<int64_t>(x_name + "/shape", x_new_shape);
auto reshaped_x_node = graph->Add<ge::op::Reshape>(x_name + "/reshape");
auto reshaped_x_op = reshaped_x_node->data<ge::op::Reshape>();
reshaped_x_op->set_input_x(*x_node->data());
reshaped_x_op->set_input_shape(*shape_node->data());
reshaped_x_op->set_attr_axis(0);
INPUT_UPDATE(reshaped_x_op, x, x_node);
INPUT_UPDATE(reshaped_x_op, shape, shape_node);
OUTPUT_UPDATE(reshaped_x_op, y, reshaped_x_node);
x_node = reshaped_x_node;
} else {
x_node = graph->Add(x_name, *x, x_new_shape);
}
// Y node
std::shared_ptr<Node> y_node = nullptr;
if (graph->Has(y_name)) {
y_node = graph->Get(y_name);
auto shape_node = graph->Add<int64_t>(y_name + "/shape", y_new_shape);
auto reshaped_y_node = graph->Add<ge::op::Reshape>(y_name + "/reshape");
auto reshaped_y_op = reshaped_y_node->data<ge::op::Reshape>();
reshaped_y_op->set_input_x(*y_node->data());
reshaped_y_op->set_input_shape(*shape_node->data());
reshaped_y_op->set_attr_axis(0);
INPUT_UPDATE(reshaped_y_op, x, y_node);
INPUT_UPDATE(reshaped_y_op, shape, shape_node);
OUTPUT_UPDATE(reshaped_y_op, y, reshaped_y_node);
y_node = reshaped_y_node;
} else {
y_node = graph->Add(y_name, *y, y_new_shape);
}
// Elementwise node
std::shared_ptr<Node> elt_node = nullptr;
if (op_type == "elementwise_add" ||
op_type == "fusion_elementwise_add_activation") {
elt_node = graph->Add<ge::op::Add>(out_name);
auto elt_op = elt_node->data<ge::op::Add>();
elt_op->set_input_x1(*x_node->data());
elt_op->set_input_x2(*y_node->data());
INPUT_UPDATE(elt_op, x1, x_node);
INPUT_UPDATE(elt_op, x2, y_node);
OUTPUT_UPDATE(elt_op, y, elt_node);
} else if (op_type == "elementwise_sub" ||
op_type == "fusion_elementwise_sub_activation") {
elt_node = graph->Add<ge::op::Sub>(out_name);
auto elt_op = elt_node->data<ge::op::Sub>();
elt_op->set_input_x1(*x_node->data());
elt_op->set_input_x2(*y_node->data());
INPUT_UPDATE(elt_op, x1, x_node);
INPUT_UPDATE(elt_op, x2, y_node);
OUTPUT_UPDATE(elt_op, y, elt_node);
} else if (op_type == "elementwise_mul" ||
op_type == "fusion_elementwise_mul_activation") {
elt_node = graph->Add<ge::op::Mul>(out_name);
auto elt_op = elt_node->data<ge::op::Mul>();
elt_op->set_input_x1(*x_node->data());
elt_op->set_input_x2(*y_node->data());
INPUT_UPDATE(elt_op, x1, x_node);
INPUT_UPDATE(elt_op, x2, y_node);
OUTPUT_UPDATE(elt_op, y, elt_node);
} else if (op_type == "elementwise_div" ||
op_type == "fusion_elementwise_div_activation") {
elt_node = graph->Add<ge::op::RealDiv>(out_name);
auto elt_op = elt_node->data<ge::op::RealDiv>();
elt_op->set_input_x1(*x_node->data());
elt_op->set_input_x2(*y_node->data());
INPUT_UPDATE(elt_op, x1, x_node);
INPUT_UPDATE(elt_op, x2, y_node);
OUTPUT_UPDATE(elt_op, y, elt_node);
} else {
LOG(WARNING) << "[NPU] Unsupported op type: " << op_type;
return FAILED;
}
auto out_shape = out_dims.Vectorize();
if (out_shape != x_new_shape) {
auto shape_node = graph->Add<int64_t>(out_name + "/shape", out_shape);
auto reshaped_elt_node = graph->Add<ge::op::Reshape>(out_name);
auto reshaped_elt_op = reshaped_elt_node->data<ge::op::Reshape>();
reshaped_elt_op->set_input_x(*elt_node->data());
reshaped_elt_op->set_input_shape(*shape_node->data());
reshaped_elt_op->set_attr_axis(0);
INPUT_UPDATE(reshaped_elt_op, x, elt_node);
INPUT_UPDATE(reshaped_elt_op, shape, shape_node);
OUTPUT_UPDATE(reshaped_elt_op, y, reshaped_elt_node);
elt_node = reshaped_elt_node;
}
// Act node
if (op_type == "fusion_elementwise_add_activation" ||
op_type == "fusion_elementwise_sub_activation" ||
op_type == "fusion_elementwise_mul_activation" ||
op_type == "fusion_elementwise_div_activation") {
auto act_type = op_info->GetAttr<std::string>("act_type");
if (act_type == "leaky_relu") {
auto act_node = graph->Add<ge::op::LeakyRelu>(out_name);
auto act_op = act_node->data<ge::op::LeakyRelu>();
act_op->set_input_x(*elt_node->data());
auto alpha = op_info->GetAttr<float>("alpha");
act_op->set_attr_negative_slope(alpha);
INPUT_UPDATE(act_op, x, elt_node);
OUTPUT_UPDATE(act_op, y, act_node);
} else if (act_type == "relu") {
auto act_node = graph->Add<ge::op::Relu>(out_name);
auto act_op = act_node->data<ge::op::Relu>();
act_op->set_input_x(*elt_node->data());
INPUT_UPDATE(act_op, x, elt_node);
OUTPUT_UPDATE(act_op, y, act_node);
} else {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] Unsupported act type: " << act_type;
return FAILED;
}
}
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace huawei_ascend_npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(
elementwise_add,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(
elementwise_sub,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(
elementwise_mul,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(
elementwise_div,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(
fusion_elementwise_add_activation,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(
fusion_elementwise_sub_activation,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(
fusion_elementwise_mul_activation,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(
fusion_elementwise_div_activation,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::ElementwiseConverter);
......@@ -37,19 +37,29 @@ class Node {
kData,
};
Node(std::shared_ptr<ge::Operator> data,
Node(std::string name,
std::shared_ptr<ge::Operator> data,
PrecisionType precision,
DataLayoutType layout,
Role role)
: data_(data), precision_(precision), layout_(layout), role_(role) {}
Node(PrecisionType precision, DataLayoutType layout, Role role)
: precision_(precision), layout_(layout), role_(role) {}
: name_(name),
data_(data),
precision_(precision),
layout_(layout),
role_(role) {}
Node(std::string name,
PrecisionType precision,
DataLayoutType layout,
Role role)
: name_(name), precision_(precision), layout_(layout), role_(role) {}
void set_name(std::string name) { name_ = name; }
void set_data(std::shared_ptr<ge::Operator> data) { data_ = data; }
void set_precision(PrecisionType precision) { precision_ = precision; }
void set_layout(DataLayoutType layout) { layout_ = layout; }
void set_role(Role role) { role_ = role; }
std::string name() { return name_; }
template <typename T>
std::shared_ptr<T> data() {
return std::static_pointer_cast<T>(data_);
......@@ -62,6 +72,7 @@ class Node {
bool is_data() const { return role_ == Role::kData; }
private:
std::string name_{};
std::shared_ptr<ge::Operator> data_{nullptr};
PrecisionType precision_{PRECISION(kFloat)};
DataLayoutType layout_{DATALAYOUT(kNCHW)};
......@@ -83,10 +94,10 @@ class Graph {
} else if (typeid(T) == typeid(ge::op::Data)) {
role = Node::Role::kData;
}
auto node = std::make_shared<Node>(precision, layout, role);
auto node = std::make_shared<Node>(name, precision, layout, role);
auto idx = Add(name, node);
CHECK_GE(idx, 1);
// Generate a unique name for the created HiAI IR
// Generate a unique name for the created Huawei Ascend NPU IR
node->set_data(
std::make_shared<T>(name + "__" + paddle::lite::to_string(idx)));
return node;
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -97,18 +97,9 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) {
bilinear_interp_op->set_input_x(*x_node->data());
bilinear_interp_op->set_input_size(*out_size_node->data());
bilinear_interp_op->set_attr_align_corners(align_corners);
TENSOR_UPDATE_INPUT(bilinear_interp_op,
x,
ge::FORMAT_NCHW,
CvtPrecisionType(x_node->precision()));
TENSOR_UPDATE_INPUT(bilinear_interp_op,
size,
ge::FORMAT_NCHW,
CvtPrecisionType(out_size_node->precision()));
TENSOR_UPDATE_OUTPUT(bilinear_interp_op,
y,
ge::FORMAT_NCHW,
CvtPrecisionType(bilinear_interp_node->precision()));
INPUT_UPDATE(bilinear_interp_op, x, x_node);
INPUT_UPDATE(bilinear_interp_op, size, out_size_node);
OUTPUT_UPDATE(bilinear_interp_op, y, bilinear_interp_node);
} else if (interp_method == "nearest") {
auto nearest_interp_node =
graph->Add<ge::op::ResizeNearestNeighborV2>(out_name);
......@@ -117,18 +108,9 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) {
nearest_interp_op->set_input_x(*x_node->data());
nearest_interp_op->set_input_size(*out_size_node->data());
nearest_interp_op->set_attr_align_corners(align_corners);
TENSOR_UPDATE_INPUT(nearest_interp_op,
x,
ge::FORMAT_NCHW,
CvtPrecisionType(x_node->precision()));
TENSOR_UPDATE_INPUT(nearest_interp_op,
size,
ge::FORMAT_NCHW,
CvtPrecisionType(out_size_node->precision()));
TENSOR_UPDATE_OUTPUT(nearest_interp_op,
y,
ge::FORMAT_NCHW,
CvtPrecisionType(nearest_interp_node->precision()));
INPUT_UPDATE(nearest_interp_op, x, x_node);
INPUT_UPDATE(nearest_interp_op, size, out_size_node);
OUTPUT_UPDATE(nearest_interp_op, y, nearest_interp_node);
} else {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] Unsupported interpolate method: "
<< interp_method;
......
......@@ -22,9 +22,18 @@ USE_SUBGRAPH_BRIDGE(relu6, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(leaky_relu, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(softsign, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(softplus, kHuaweiAscendNPU);
// conv
USE_SUBGRAPH_BRIDGE(conv2d, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(depthwise_conv2d, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(bilinear_interp, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(nearest_interp, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(concat, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(pool2d, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(elementwise_add, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(elementwise_sub, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(elementwise_mul, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(elementwise_div, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(fusion_elementwise_add_activation, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(fusion_elementwise_sub_activation, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(fusion_elementwise_mul_activation, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(fusion_elementwise_div_activation, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(batch_norm, kHuaweiAscendNPU);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/pool_op.h"
#include "lite/kernels/huawei_ascend_npu/bridges/graph.h"
#include "lite/kernels/huawei_ascend_npu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace huawei_ascend_npu {
int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[HUAWEI_ASCEND_NPU] Converting " + op_type + "...";
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
auto pooling_type = op_info->GetAttr<std::string>("pooling_type");
auto global_pooling = op_info->GetAttr<bool>("global_pooling");
auto ksize = op_info->GetAttr<std::vector<int>>("ksize");
auto paddings = op_info->GetAttr<std::vector<int>>("paddings");
CHECK_EQ(op_info->GetAttr<bool>("exclusive"), true)
<< "[HUAWEI_ASCEND_NPU] Only exclusive=true is supported for Huawei "
"Ascend NPU DDK.";
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x);
}
// pool mode: 0:max pooling or 1:avg pooling
int mode = 0;
if (pooling_type == "max") {
mode = 0;
} else if (pooling_type == "avg") {
mode = 1;
} else {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] Unsupported pooling type: "
<< pooling_type;
return FAILED;
}
// pad algorithm
std::string padding_algorithm("");
if (op_info->HasAttr("padding_algorithm")) {
padding_algorithm = op_info->GetAttr<std::string>("padding_algorithm");
}
// paddings and strides
if (paddings.size() == 2L) {
for (size_t i = 0; i < 2L; ++i) {
int copy_pad = *(paddings.begin() + 2 * i);
paddings.insert(paddings.begin() + 2 * i + 1, copy_pad);
}
}
CHECK_EQ(paddings.size(), 4L) << "[HUAWEI_ASCEND_NPU] Paddings size should "
"be the same or twice as the inputs size.";
bool adaptive = false;
if (op_info->HasAttr("adaptive")) {
adaptive = op_info->GetAttr<bool>("adaptive");
}
auto strides = op_info->GetAttr<std::vector<int>>("strides");
lite::operators::UpdatePadding(&paddings,
global_pooling,
adaptive,
padding_algorithm,
x->dims(),
strides,
ksize);
// Ascend restriction: padT should equals padB, and padL should equals padR
CHECK_EQ(paddings[0], paddings[1]) << "[HUAWEI_ASCEND_NPU] Padding top "
"should equals to padding bottom in "
"Huawei Ascend NPU DDK";
CHECK_EQ(paddings[2], paddings[3]) << "[HUAWEI_ASCEND_NPU] Padding left "
"should equals to padding right in "
"Huawei Ascend NPU DDK";
// ceil mode
bool ceil_mode =
op_info->HasAttr("ceil_mode") && op_info->GetAttr<bool>("ceil_mode");
// Pooling node
auto pool_node = graph->Add<ge::op::Pooling>(out_name);
auto pool_op = pool_node->data<ge::op::Pooling>();
pool_op->set_input_x(*x_node->data());
pool_op->set_attr_mode(mode);
pool_op->set_attr_global_pooling(global_pooling);
pool_op->set_attr_window(ge::Operator::OpListInt({ksize[0], ksize[1]}));
pool_op->set_attr_stride(ge::Operator::OpListInt({strides[0], strides[1]}));
pool_op->set_attr_pad(ge::Operator::OpListInt(
{paddings[0], paddings[1], paddings[2], paddings[3]}));
// "0" (ceil mode) or "1" (floor mode). Defaults to "0"
if (!ceil_mode) {
pool_op->set_attr_ceil_mode(1);
}
INPUT_UPDATE(pool_op, x, x_node);
OUTPUT_UPDATE(pool_op, y, pool_node);
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace huawei_ascend_npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(
pool2d,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::PoolConverter);
......@@ -19,9 +19,7 @@
#include <memory>
#include <string>
#include <vector>
// #include "graph/buffer.h"
#include "graph/tensor.h"
#include "graph/types.h"
#include "lite/backends/huawei_ascend_npu/utils.h"
#include "lite/core/op_lite.h"
#include "lite/utils/macros.h"
......@@ -30,16 +28,34 @@ namespace lite {
namespace subgraph {
namespace huawei_ascend_npu {
#define TENSOR_UPDATE_INPUT(op, attr, format, dtype) \
ge::TensorDesc _##op##_input_desc_##attr(ge::Shape(), format, dtype); \
#define INPUT_UPDATE(...) TENSOR_INPUT_UPDATE(__VA_ARGS__, ge::FORMAT_NCHW)
#define OUTPUT_UPDATE(...) TENSOR_OUTPUT_UPDATE(__VA_ARGS__, ge::FORMAT_NCHW)
#define DYNAMIC_INPUT_UPDATE(...) \
TENSOR_DYNAMIC_INPUT_UPDATE(__VA_ARGS__, ge::FORMAT_NCHW)
#define DYNAMIC_OUTPUT_UPDATE(...) \
TENSOR_DYNAMIC_OUTPUT_UPDATE(__VA_ARGS__, ge::FORMAT_NCHW)
#define TENSOR_INPUT_UPDATE(op, attr, node, format) \
ge::TensorDesc _##op##_input_desc_##attr( \
ge::Shape(), format, CvtPrecisionType(node->precision())); \
_##op##_input_desc_##attr.SetName(node->name()); \
op->update_input_desc_##attr(_##op##_input_desc_##attr);
#define TENSOR_UPDATE_OUTPUT(op, attr, format, dtype) \
ge::TensorDesc _##op##_output_desc_##attr(ge::Shape(), format, dtype); \
#define TENSOR_OUTPUT_UPDATE(op, attr, node, format) \
ge::TensorDesc _##op##_output_desc_##attr( \
ge::Shape(), format, CvtPrecisionType(node->precision())); \
_##op##_output_desc_##attr.SetName(node->name()); \
op->update_output_desc_##attr(_##op##_output_desc_##attr);
#define TENSOR_UPDATE_DYNAMIC_INPUT(op, attr, idx, format, dtype) \
ge::TensorDesc _##op##_input_desc_##attr##_##idx( \
ge::Shape(), format, dtype); \
#define TENSOR_DYNAMIC_INPUT_UPDATE(op, attr, idx, node, format) \
ge::TensorDesc _##op##_input_desc_##attr##_##idx( \
ge::Shape(), format, CvtPrecisionType(node->precision())); \
_##op##_input_desc_##attr##_##idx.SetName(node->name()); \
op->update_dynamic_input_desc_##attr(idx, _##op##_input_desc_##attr##_##idx);
#define TENSOR_DYNAMIC_OUTPUT_UPDATE(op, attr, idx, node, format) \
ge::TensorDesc _##op##_output_desc_##attr##_##idx( \
ge::Shape(), format, CvtPrecisionType(node->precision())); \
_##op##_output_desc_##attr##_##idx.SetName(node->name()); \
op->update_dynamic_output_desc_##attr(idx, \
_##op##_output_desc_##attr##_##idx);
// Type/tensor converters for converting Paddle type/tensor to HiAI type/tensor
bool HasInputArg(const OpInfo* op_info,
......
......@@ -220,7 +220,7 @@ bool DeviceProgram::ShareBufferWithOriginTensors(
CHECK(!model_name_.empty() && model_client_);
// Query the dimensions of the device input and output tensors if not
// initialized
VLOG(3) << "[HUAWEI_ASCEND_NPU] Sharing buffer with origin tnsors...";
VLOG(3) << "[HUAWEI_ASCEND_NPU] Sharing buffer with origin tensors...";
if (device_idims_.empty() || device_odims_.empty()) {
if (!(model_client_->GetModelIOTensorDim(&device_idims_, &device_odims_))) {
LOG(WARNING)
......
......@@ -117,10 +117,12 @@ class BatchNormComputeTest : public arena::TestCase {
op_desc->SetInput("Mean", {mean_});
op_desc->SetInput("Variance", {variance_});
op_desc->SetOutput("Y", {output_});
op_desc->SetOutput("MeanOut", {mean_out_});
op_desc->SetOutput("VarianceOut", {variance_out_});
op_desc->SetOutput("SavedMean", {saved_mean_});
op_desc->SetOutput("SavedVariance", {saved_variance_});
if (!is_test_) {
op_desc->SetOutput("MeanOut", {mean_out_});
op_desc->SetOutput("VarianceOut", {variance_out_});
op_desc->SetOutput("SavedMean", {saved_mean_});
op_desc->SetOutput("SavedVariance", {saved_variance_});
}
op_desc->SetAttr("epsilon", epsilon_);
op_desc->SetAttr("momentum", momentum_);
op_desc->SetAttr("use_global_stats", use_global_stats_);
......@@ -159,6 +161,9 @@ TEST(BatchNorm, precision) {
Place place;
#if defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL)
place = TARGET(kXPU);
#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU)
place = TARGET(kHuaweiAscendNPU);
abs_error = 1e-2; // precision_mode default is force_fp16
#elif defined(LITE_WITH_NPU)
place = TARGET(kNPU);
#else
......
......@@ -206,6 +206,11 @@ void TestEltDims(Place place, float abs_error) {
void TestEltTypes(Place place, float abs_error) {
for (auto elt_type :
std::vector<std::string>{"add", "sub", "mul", "div", "max"}) {
// Huawei Ascend NPU DDK has bugs in div, and not support max yet
if (place == TARGET(kHuaweiAscendNPU) &&
(elt_type == "div" || elt_type == "max")) {
continue;
}
TestElt(place, abs_error, elt_type, {2, 3, 4, 5}, {2, 3, 4, 5}, 0);
TestElt(place, abs_error, elt_type, {2, 3, 4, 5}, {3}, 1);
}
......@@ -214,6 +219,11 @@ void TestEltTypes(Place place, float abs_error) {
void TestEltFuseAct(Place place, float abs_error) {
for (auto elt_type :
std::vector<std::string>{"add", "sub", "mul", "div", "max"}) {
// Huawei Ascend NPU DDK has bugs in div, and not support max yet
if (place == TARGET(kHuaweiAscendNPU) &&
(elt_type == "div" || elt_type == "max")) {
continue;
}
TestElt(place, abs_error, elt_type, {2, 3, 4, 5}, {2, 3, 4, 5}, 0, "relu");
TestElt(place, abs_error, elt_type, {2, 3, 4, 5}, {3}, 1, "relu");
}
......@@ -226,6 +236,9 @@ TEST(Elementwise, precision) {
#if defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 1e-2; // use fp16 in npu
#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU)
place = TARGET(kHuaweiAscendNPU);
abs_error = 1e-2; // precision_mode default is force_fp16
#elif defined(LITE_WITH_ARM)
place = TARGET(kARM);
#elif defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL)
......
......@@ -322,6 +322,10 @@ void TestPoolPaddings(Place place, float abs_error = 2e-5) {
{1, 1},
{0, 0, 1, 1},
{2, 2});
// Ascend restriction: padT should equals padB, and padL should equals padR
if (place == TARGET(kHuaweiAscendNPU)) {
continue;
}
TestPoolHelper(place,
abs_error,
{2, 3, 6, 7},
......@@ -381,6 +385,9 @@ TEST(Pool, precision) {
#if defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 1e-2; // Using fp16 in NPU
#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU)
place = TARGET(kHuaweiAscendNPU);
abs_error = 1e-2; // precision_mode default is force_fp16
#elif defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL)
place = TARGET(kXPU);
#else
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册