未验证 提交 ba66bc55 编写于 作者: Q Qi Li 提交者: GitHub

[ascend] add softmax and dropout op, test=develop (#4071)

* [ascend] add softmax and dropout op, test=develop

* [Ascend] address review comments, test=develop
上级 e7d43b02
...@@ -14,6 +14,8 @@ lite_cc_library(subgraph_bridge_concat_op_huawei_ascend_npu SRCS concat_op.cc DE ...@@ -14,6 +14,8 @@ lite_cc_library(subgraph_bridge_concat_op_huawei_ascend_npu SRCS concat_op.cc DE
lite_cc_library(subgraph_bridge_pool_op_huawei_ascend_npu SRCS pool_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_pool_op_huawei_ascend_npu SRCS pool_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_elementwise_ops_huawei_ascend_npu SRCS elementwise_ops.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_elementwise_ops_huawei_ascend_npu SRCS elementwise_ops.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_batch_norm_op_huawei_ascend_npu SRCS batch_norm_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_batch_norm_op_huawei_ascend_npu SRCS batch_norm_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_softmax_op_huawei_ascend_npu SRCS softmax_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_dropout_op_huawei_ascend_npu SRCS dropout_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
set(huawei_ascend_npu_subgraph_bridges set(huawei_ascend_npu_subgraph_bridges
subgraph_bridge_registry subgraph_bridge_registry
...@@ -26,4 +28,6 @@ set(huawei_ascend_npu_subgraph_bridges ...@@ -26,4 +28,6 @@ set(huawei_ascend_npu_subgraph_bridges
subgraph_bridge_pool_op_huawei_ascend_npu subgraph_bridge_pool_op_huawei_ascend_npu
subgraph_bridge_elementwise_ops_huawei_ascend_npu subgraph_bridge_elementwise_ops_huawei_ascend_npu
subgraph_bridge_batch_norm_op_huawei_ascend_npu subgraph_bridge_batch_norm_op_huawei_ascend_npu
subgraph_bridge_softmax_op_huawei_ascend_npu
subgraph_bridge_dropout_op_huawei_ascend_npu
CACHE INTERNAL "huawei_ascend_npu_subgraph_bridges") CACHE INTERNAL "huawei_ascend_npu_subgraph_bridges")
...@@ -28,7 +28,7 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -28,7 +28,7 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto op_info = op->op_info(); auto op_info = op->op_info();
auto op_type = op_info->Type(); auto op_type = op_info->Type();
auto scope = op->scope(); auto scope = op->scope();
VLOG(3) << "[NPU] Converting " << op_type << " ... "; VLOG(3) << "[HUAWEI_ASCEND_NPU] Converting " << op_type << " ... ";
// Get input and output vars and op attributes // Get input and output vars and op attributes
auto x_names = op_info->Input("X"); auto x_names = op_info->Input("X");
...@@ -36,6 +36,14 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -36,6 +36,14 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto out_name = op_info->Output("Out").front(); auto out_name = op_info->Output("Out").front();
auto num = x_names.size(); auto num = x_names.size();
// TODO(qili93): Ascend has bug in ge::op::Concat (i.e. has axis tensor
// input), to be fixed
if (op_info->HasInput("AxisTensor")) {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] Huawei Ascend NPU DDK not support "
"AxisTensor input!";
return FAILED;
}
if (op_info->HasInput("AxisTensor")) { if (op_info->HasInput("AxisTensor")) {
// axis node // axis node
auto axis_name = op_info->Input("AxisTensor").front(); auto axis_name = op_info->Input("AxisTensor").front();
......
...@@ -96,25 +96,41 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -96,25 +96,41 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
filter_dims); filter_dims);
// Check Restrictions: HxW(input) == HxW(filter) if output feature h*w = 1*1 // Check Restrictions: HxW(input) == HxW(filter) if output feature h*w = 1*1
if (output_dims[2] == 1 && output_dims[3] == 1) { if (output_dims[2] == 1) {
int input_h = input_dims[2] + paddings[0] + paddings[1]; int input_h = input_dims[2] + paddings[0] + paddings[1];
int input_w = input_dims[3] + paddings[2] + paddings[3];
int filter_h = (filter_dims[2] - 1) * dilations[0] + 1; int filter_h = (filter_dims[2] - 1) * dilations[0] + 1;
if (input_h != filter_h) {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] Huawei Ascend NPU DDK restriction: "
"input height after padding should equal to filter "
"height after dilation if output height is 1. Input "
"height after padding is: "
<< input_h
<< ", filter height after dilation is: " << filter_h;
return FAILED;
}
}
// Check Restrictions: HxW(input) == HxW(filter) if output feature h*w = 1*1
if (output_dims[3] == 1) {
int input_w = input_dims[3] + paddings[2] + paddings[3];
int filter_w = (filter_dims[3] - 1) * dilations[1] + 1; int filter_w = (filter_dims[3] - 1) * dilations[1] + 1;
CHECK_EQ(input_h, filter_h) << "[HUAWEI_ASCEND_NPU] Huawei Ascend NPU DDK " if (input_w != filter_w) {
"restriction: if output HxW = 1x1, then " LOG(WARNING) << "[HUAWEI_ASCEND_NPU] Huawei Ascend NPU DDK restriction: "
"input height after padding should equal to " "input width after padding should equal to filter width "
"filter height after dilation"; "after dilation if output width is 1. Input width after "
CHECK_EQ(input_w, filter_w) << "[HUAWEI_ASCEND_NPU] Huawei Ascend NPU DDK " "padding is: "
"restriction: if output HxW = 1x1, then " << input_w
"input width after padding should equal to " << ", filter width after dilation is: " << filter_w;
"filter width after dilation"; return FAILED;
}
} }
// Check Restrictions: outChannel divide groups should equal to 0 // Check Restrictions: outChannel divide groups should equal to 0
CHECK_EQ(oc % groups, 0) << "[HUAWEI_ASCEND_NPU] Huawei Ascend NPU DDK " if (oc % groups != 0) {
"restriction: out channel divice groups should " LOG(WARNING) << "[HUAWEI_ASCEND_NPU] Huawei Ascend NPU DDK restriction: "
"equal to 0"; "out channel divice groups should equal to 0. out channel "
"is: "
<< oc << ", groups is: " << groups;
return FAILED;
}
// Check depthwise mode, and decide whether use DepthwiseConv2D Op // Check depthwise mode, and decide whether use DepthwiseConv2D Op
bool use_depthwise_conv = false; bool use_depthwise_conv = false;
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/huawei_ascend_npu/bridges/graph.h"
#include "lite/kernels/huawei_ascend_npu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace huawei_ascend_npu {
int DropoutConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[HUAWEI_ASCEND_NPU] Converting " + op_type + "...";
// Get input, output and op attributes
auto x_name = op_info->Input("X").front();
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
auto dropout_implementation =
op_info->GetAttr<std::string>("dropout_implementation");
auto scale = 1 - op_info->GetAttr<float>("dropout_prob");
if (dropout_implementation == "upscale_in_train") {
scale = 1.f;
}
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x, CvtShape(x_dims));
}
// Dropout node
auto dropout_node = graph->Add<ge::op::Muls>(out_name);
auto dropout_op = dropout_node->data<ge::op::Muls>();
dropout_op->set_input_x(*x_node->data());
dropout_op->set_attr_value(scale);
INPUT_UPDATE(dropout_op, x, x_node);
OUTPUT_UPDATE(dropout_op, y, dropout_node);
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace huawei_ascend_npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(
dropout,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::DropoutConverter);
...@@ -104,6 +104,14 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -104,6 +104,14 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto scope = op->scope(); auto scope = op->scope();
VLOG(3) << "[HUAWEI_ASCEND_NPU] Converting " + op_type + "..."; VLOG(3) << "[HUAWEI_ASCEND_NPU] Converting " + op_type + "...";
// TODO(qili93): Ascend has bug in RealDiv, to be fixed
if (op_type == "elementwise_div" ||
op_type == "fusion_elementwise_div_activation") {
LOG(WARNING)
<< "[HUAWEI_ASCEND_NPU] Huawei Ascend NPU DDK not support RealDiv OP!";
return FAILED;
}
// Get input and output vars and op attributes // Get input and output vars and op attributes
auto x_name = op_info->Input("X").front(); auto x_name = op_info->Input("X").front();
auto x = scope->FindTensor(x_name); auto x = scope->FindTensor(x_name);
...@@ -200,6 +208,15 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -200,6 +208,15 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
INPUT_UPDATE(elt_op, x1, x_node); INPUT_UPDATE(elt_op, x1, x_node);
INPUT_UPDATE(elt_op, x2, y_node); INPUT_UPDATE(elt_op, x2, y_node);
OUTPUT_UPDATE(elt_op, y, elt_node); OUTPUT_UPDATE(elt_op, y, elt_node);
} else if (op_type == "elementwise_max" ||
op_type == "fusion_elementwise_max_activation") {
elt_node = graph->Add<ge::op::Maximum>(out_name);
auto elt_op = elt_node->data<ge::op::Maximum>();
elt_op->set_input_x1(*x_node->data());
elt_op->set_input_x2(*y_node->data());
INPUT_UPDATE(elt_op, x1, x_node);
INPUT_UPDATE(elt_op, x2, y_node);
OUTPUT_UPDATE(elt_op, y, elt_node);
} else { } else {
LOG(WARNING) << "[NPU] Unsupported op type: " << op_type; LOG(WARNING) << "[NPU] Unsupported op type: " << op_type;
return FAILED; return FAILED;
...@@ -223,7 +240,8 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -223,7 +240,8 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
if (op_type == "fusion_elementwise_add_activation" || if (op_type == "fusion_elementwise_add_activation" ||
op_type == "fusion_elementwise_sub_activation" || op_type == "fusion_elementwise_sub_activation" ||
op_type == "fusion_elementwise_mul_activation" || op_type == "fusion_elementwise_mul_activation" ||
op_type == "fusion_elementwise_div_activation") { op_type == "fusion_elementwise_div_activation" ||
op_type == "fusion_elementwise_max_activation") {
auto act_type = op_info->GetAttr<std::string>("act_type"); auto act_type = op_info->GetAttr<std::string>("act_type");
if (act_type == "leaky_relu") { if (act_type == "leaky_relu") {
auto act_node = graph->Add<ge::op::LeakyRelu>(out_name); auto act_node = graph->Add<ge::op::LeakyRelu>(out_name);
...@@ -269,6 +287,10 @@ REGISTER_SUBGRAPH_BRIDGE( ...@@ -269,6 +287,10 @@ REGISTER_SUBGRAPH_BRIDGE(
elementwise_div, elementwise_div,
kHuaweiAscendNPU, kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::ElementwiseConverter); paddle::lite::subgraph::huawei_ascend_npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(
elementwise_max,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE( REGISTER_SUBGRAPH_BRIDGE(
fusion_elementwise_add_activation, fusion_elementwise_add_activation,
kHuaweiAscendNPU, kHuaweiAscendNPU,
...@@ -285,3 +307,7 @@ REGISTER_SUBGRAPH_BRIDGE( ...@@ -285,3 +307,7 @@ REGISTER_SUBGRAPH_BRIDGE(
fusion_elementwise_div_activation, fusion_elementwise_div_activation,
kHuaweiAscendNPU, kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::ElementwiseConverter); paddle::lite::subgraph::huawei_ascend_npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(
fusion_elementwise_max_activation,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::ElementwiseConverter);
...@@ -86,8 +86,8 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -86,8 +86,8 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} }
} }
if (out_size_node == nullptr) { if (out_size_node == nullptr) {
out_size_node = out_size_node = graph->Add<int>(out_name + "/out_size",
graph->Add(out_name + "/out_size", std::vector<int>({out_h, out_w})); std::vector<int>({out_h, out_w}));
} }
if (interp_method == "bilinear") { if (interp_method == "bilinear") {
......
...@@ -32,8 +32,12 @@ USE_SUBGRAPH_BRIDGE(elementwise_add, kHuaweiAscendNPU); ...@@ -32,8 +32,12 @@ USE_SUBGRAPH_BRIDGE(elementwise_add, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(elementwise_sub, kHuaweiAscendNPU); USE_SUBGRAPH_BRIDGE(elementwise_sub, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(elementwise_mul, kHuaweiAscendNPU); USE_SUBGRAPH_BRIDGE(elementwise_mul, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(elementwise_div, kHuaweiAscendNPU); USE_SUBGRAPH_BRIDGE(elementwise_div, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(elementwise_max, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(fusion_elementwise_add_activation, kHuaweiAscendNPU); USE_SUBGRAPH_BRIDGE(fusion_elementwise_add_activation, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(fusion_elementwise_sub_activation, kHuaweiAscendNPU); USE_SUBGRAPH_BRIDGE(fusion_elementwise_sub_activation, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(fusion_elementwise_mul_activation, kHuaweiAscendNPU); USE_SUBGRAPH_BRIDGE(fusion_elementwise_mul_activation, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(fusion_elementwise_div_activation, kHuaweiAscendNPU); USE_SUBGRAPH_BRIDGE(fusion_elementwise_div_activation, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(fusion_elementwise_max_activation, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(batch_norm, kHuaweiAscendNPU); USE_SUBGRAPH_BRIDGE(batch_norm, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(softmax, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(dropout, kHuaweiAscendNPU);
...@@ -93,12 +93,18 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -93,12 +93,18 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) {
strides, strides,
ksize); ksize);
// Ascend restriction: padT should equals padB, and padL should equals padR // Ascend restriction: padT should equals padB, and padL should equals padR
CHECK_EQ(paddings[0], paddings[1]) << "[HUAWEI_ASCEND_NPU] Padding top " if (paddings[0] != paddings[1]) {
"should equals to padding bottom in " LOG(WARNING) << "[HUAWEI_ASCEND_NPU] Padding top should equals to padding "
"Huawei Ascend NPU DDK"; "bottom in Huawei Ascend NPU DDK, padding top is: "
CHECK_EQ(paddings[2], paddings[3]) << "[HUAWEI_ASCEND_NPU] Padding left " << paddings[0] << ", padding bottom is: " << paddings[1];
"should equals to padding right in " return FAILED;
"Huawei Ascend NPU DDK"; }
if (paddings[2] != paddings[3]) {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] Padding left should equals to padding "
"right in Huawei Ascend NPU DDK, padding left is: "
<< paddings[2] << ", padding right is: " << paddings[3];
return FAILED;
}
// ceil mode // ceil mode
bool ceil_mode = bool ceil_mode =
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/huawei_ascend_npu/bridges/graph.h"
#include "lite/kernels/huawei_ascend_npu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace huawei_ascend_npu {
int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[HUAWEI_ASCEND_NPU] Converting " + op_type + "...";
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto x_rank = x_dims.size();
auto out_name = op_info->Output("Out").front();
int axis = op_info->HasAttr("axis") ? op_info->GetAttr<int>("axis") : -1;
if (axis < 0) {
axis += x_rank;
}
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x);
}
// Softmax node
auto softmax_node = graph->Add<ge::op::SoftmaxV2>(out_name);
auto softmax_op = softmax_node->data<ge::op::SoftmaxV2>();
softmax_op->set_input_x(*x_node->data());
softmax_op->set_attr_axes({axis});
INPUT_UPDATE(softmax_op, x, x_node);
OUTPUT_UPDATE(softmax_op, y, softmax_node);
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace huawei_ascend_npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(
softmax,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::SoftmaxConverter);
...@@ -41,19 +41,29 @@ bool DropoutOp::InferShapeImpl() const { ...@@ -41,19 +41,29 @@ bool DropoutOp::InferShapeImpl() const {
bool DropoutOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { bool DropoutOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
auto input = op_desc.Input("X").front(); auto input = op_desc.Input("X").front();
auto out = op_desc.Output("Out").front(); auto out = op_desc.Output("Out").front();
auto Mask = op_desc.Output("Mask").front();
param_.x = GetVar<lite::Tensor>(scope, input); param_.x = GetVar<lite::Tensor>(scope, input);
param_.output = GetMutableVar<lite::Tensor>(scope, out); param_.output = GetMutableVar<lite::Tensor>(scope, out);
param_.mask = GetMutableVar<lite::Tensor>(scope, Mask);
param_.dropout_prob = op_desc.GetAttr<float>("dropout_prob"); param_.dropout_prob = op_desc.GetAttr<float>("dropout_prob");
param_.is_test = true;
// TODO(sangoly): `is_test` has different attr type in x86 and arm, set auto is_test_type = op_desc.GetAttrType("is_test");
// `true` now. switch (is_test_type) {
// if (op_desc.HasAttr("is_test")) { case OpDescAPI::AttrType::INT:
// param_.is_test = op_desc.GetAttr<bool>("is_test"); param_.is_test = op_desc.GetAttr<int>("is_test");
// } break;
case OpDescAPI::AttrType::BOOLEAN:
param_.is_test = op_desc.GetAttr<bool>("is_test");
break;
default:
LOG(FATAL) << "Unsupported attribute type: the type of attribute "
"`is_test` in BatchNormOP should be int or bool.";
}
if (!param_.is_test) {
auto Mask = op_desc.Output("Mask").front();
param_.mask = GetMutableVar<lite::Tensor>(scope, Mask);
}
param_.fix_seed = op_desc.GetAttr<bool>("fix_seed"); param_.fix_seed = op_desc.GetAttr<bool>("fix_seed");
param_.seed = op_desc.GetAttr<int>("seed"); param_.seed = op_desc.GetAttr<int>("seed");
param_.dropout_implementation = param_.dropout_implementation =
......
...@@ -86,7 +86,7 @@ endif() ...@@ -86,7 +86,7 @@ endif()
lite_cc_test(test_kernel_squeeze_compute SRCS squeeze_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_squeeze_compute SRCS squeeze_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_slice_compute SRCS slice_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_slice_compute SRCS slice_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_expand_compute SRCS expand_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_expand_compute SRCS expand_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_expand_as_compute SRCS expand_as_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_expand_as_compute SRCS expand_as_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_matmul_compute SRCS matmul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_matmul_compute SRCS matmul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_crf_decoding_compute SRCS crf_decoding_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_crf_decoding_compute SRCS crf_decoding_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
endif() endif()
...@@ -160,10 +160,6 @@ TEST(Concat, precision) { ...@@ -160,10 +160,6 @@ TEST(Concat, precision) {
for (int axis : {1, 2}) { for (int axis : {1, 2}) {
for (bool is_use_axis_tensor : {false, true}) { for (bool is_use_axis_tensor : {false, true}) {
// is_use_axis_tensor = true has bugs in Huawei Ascend NPU DDK
if (place == TARGET(kHuaweiAscendNPU) && is_use_axis_tensor) {
continue;
}
LOG(INFO) << "axis:" << axis LOG(INFO) << "axis:" << axis
<< ", is_use_axis_tensor:" << is_use_axis_tensor; << ", is_use_axis_tensor:" << is_use_axis_tensor;
std::unique_ptr<arena::TestCase> tester( std::unique_ptr<arena::TestCase> tester(
......
...@@ -296,11 +296,6 @@ void TestConvStrides(Place place, float abs_error = 2e-5) { ...@@ -296,11 +296,6 @@ void TestConvStrides(Place place, float abs_error = 2e-5) {
for (auto out_channels : {1, 3}) { for (auto out_channels : {1, 3}) {
for (auto strides : for (auto strides :
std::vector<std::vector<int>>{{2, 2}, {3, 3}, {1, 2}, {3, 1}}) { std::vector<std::vector<int>>{{2, 2}, {3, 3}, {1, 2}, {3, 1}}) {
// Check Huawei Ascend NPU restriction if output HxW = 1x1
// input_w after padding = 4 should equal to fitler_w after dilation = 3
if (place == TARGET(kHuaweiAscendNPU) && dims[3] == 4) {
continue;
}
std::unique_ptr<arena::TestCase> tester(new ConvComputeTester( std::unique_ptr<arena::TestCase> tester(new ConvComputeTester(
place, "def", DDim(dims), out_channels, 3, strides)); place, "def", DDim(dims), out_channels, 3, strides));
arena::Arena arena(std::move(tester), place, abs_error); arena::Arena arena(std::move(tester), place, abs_error);
......
...@@ -35,6 +35,7 @@ class DropoutComputeTester : public arena::TestCase { ...@@ -35,6 +35,7 @@ class DropoutComputeTester : public arena::TestCase {
bool fix_seed_ = true; bool fix_seed_ = true;
int seed_ = 1; int seed_ = 1;
std::string dropout_implementation_ = "downgrade_in_infer"; std::string dropout_implementation_ = "downgrade_in_infer";
int is_test_ = 1;
public: public:
DropoutComputeTester(const Place& place, DropoutComputeTester(const Place& place,
...@@ -73,11 +74,14 @@ class DropoutComputeTester : public arena::TestCase { ...@@ -73,11 +74,14 @@ class DropoutComputeTester : public arena::TestCase {
op_desc->SetType(type_); op_desc->SetType(type_);
op_desc->SetInput("X", {x_}); op_desc->SetInput("X", {x_});
op_desc->SetOutput("Out", {out_}); op_desc->SetOutput("Out", {out_});
op_desc->SetOutput("Mask", {mask_}); if (!is_test_) {
op_desc->SetOutput("Mask", {mask_});
}
op_desc->SetAttr("dropout_prob", dropout_prob_); op_desc->SetAttr("dropout_prob", dropout_prob_);
op_desc->SetAttr("fix_seed", fix_seed_); op_desc->SetAttr("fix_seed", fix_seed_);
op_desc->SetAttr("seed", seed_); op_desc->SetAttr("seed", seed_);
op_desc->SetAttr("dropout_implementation", dropout_implementation_); op_desc->SetAttr("dropout_implementation", dropout_implementation_);
op_desc->SetAttr("is_test", is_test_);
} }
void PrepareData() override { void PrepareData() override {
...@@ -94,6 +98,9 @@ TEST(Dropout, precision) { ...@@ -94,6 +98,9 @@ TEST(Dropout, precision) {
#if defined(LITE_WITH_NPU) #if defined(LITE_WITH_NPU)
place = TARGET(kNPU); place = TARGET(kNPU);
abs_error = 1e-2; // Using fp16 in NPU abs_error = 1e-2; // Using fp16 in NPU
#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU)
place = TARGET(kHuaweiAscendNPU);
abs_error = 1e-2; // precision_mode default is force_fp16
#elif defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL) #elif defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL)
place = TARGET(kXPU); place = TARGET(kXPU);
#else #else
......
...@@ -206,11 +206,6 @@ void TestEltDims(Place place, float abs_error) { ...@@ -206,11 +206,6 @@ void TestEltDims(Place place, float abs_error) {
void TestEltTypes(Place place, float abs_error) { void TestEltTypes(Place place, float abs_error) {
for (auto elt_type : for (auto elt_type :
std::vector<std::string>{"add", "sub", "mul", "div", "max"}) { std::vector<std::string>{"add", "sub", "mul", "div", "max"}) {
// Huawei Ascend NPU DDK has bugs in div, and not support max yet
if (place == TARGET(kHuaweiAscendNPU) &&
(elt_type == "div" || elt_type == "max")) {
continue;
}
TestElt(place, abs_error, elt_type, {2, 3, 4, 5}, {2, 3, 4, 5}, 0); TestElt(place, abs_error, elt_type, {2, 3, 4, 5}, {2, 3, 4, 5}, 0);
TestElt(place, abs_error, elt_type, {2, 3, 4, 5}, {3}, 1); TestElt(place, abs_error, elt_type, {2, 3, 4, 5}, {3}, 1);
} }
...@@ -219,11 +214,6 @@ void TestEltTypes(Place place, float abs_error) { ...@@ -219,11 +214,6 @@ void TestEltTypes(Place place, float abs_error) {
void TestEltFuseAct(Place place, float abs_error) { void TestEltFuseAct(Place place, float abs_error) {
for (auto elt_type : for (auto elt_type :
std::vector<std::string>{"add", "sub", "mul", "div", "max"}) { std::vector<std::string>{"add", "sub", "mul", "div", "max"}) {
// Huawei Ascend NPU DDK has bugs in div, and not support max yet
if (place == TARGET(kHuaweiAscendNPU) &&
(elt_type == "div" || elt_type == "max")) {
continue;
}
TestElt(place, abs_error, elt_type, {2, 3, 4, 5}, {2, 3, 4, 5}, 0, "relu"); TestElt(place, abs_error, elt_type, {2, 3, 4, 5}, {2, 3, 4, 5}, 0, "relu");
TestElt(place, abs_error, elt_type, {2, 3, 4, 5}, {3}, 1, "relu"); TestElt(place, abs_error, elt_type, {2, 3, 4, 5}, {3}, 1, "relu");
} }
......
...@@ -322,10 +322,6 @@ void TestPoolPaddings(Place place, float abs_error = 2e-5) { ...@@ -322,10 +322,6 @@ void TestPoolPaddings(Place place, float abs_error = 2e-5) {
{1, 1}, {1, 1},
{0, 0, 1, 1}, {0, 0, 1, 1},
{2, 2}); {2, 2});
// Ascend restriction: padT should equals padB, and padL should equals padR
if (place == TARGET(kHuaweiAscendNPU)) {
continue;
}
TestPoolHelper(place, TestPoolHelper(place,
abs_error, abs_error,
{2, 3, 6, 7}, {2, 3, 6, 7},
......
...@@ -103,6 +103,9 @@ TEST(Softmax, precision) { ...@@ -103,6 +103,9 @@ TEST(Softmax, precision) {
#if defined(LITE_WITH_NPU) #if defined(LITE_WITH_NPU)
place = TARGET(kNPU); place = TARGET(kNPU);
abs_error = 4e-3; // Using fp16 in NPU abs_error = 4e-3; // Using fp16 in NPU
#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU)
place = TARGET(kHuaweiAscendNPU);
abs_error = 4e-3; // precision_mode default is force_fp16
#elif defined(LITE_WITH_XPU) #elif defined(LITE_WITH_XPU)
place = TARGET(kXPU); place = TARGET(kXPU);
#else #else
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册