未验证 提交 a29c84a2 编写于 作者: H hong19860320 提交者: GitHub

[LITE][NPU][XPU] Refine the registration and implementation of op bridges (#2700)

* Fix the compiling error which occurs when specify the ddk_root path and build for huawei NPU.

* Refine the registration of op bridges and make it similar to the registration of op and kernel.

* Refine the interfaces of the graph and node for op bridges, and support creating constant and data node automatically according to the attribute 'persistable' of the target tensor.

* Add the unit test of the scale and softmax op bridge for NPU.
上级 bc5bd154
......@@ -30,7 +30,7 @@ if(NOT NPU_DDK_INC)
message(FATAL_ERROR "Can not find HiAiModelManagerService.h in ${NPU_DDK_ROOT}/include")
endif()
include_directories("${NPU_DDK_ROOT}")
include_directories("${NPU_DDK_ROOT}/include")
set(NPU_SUB_LIB_PATH "lib64")
if(ARM_TARGET_ARCH_ABI STREQUAL "armv8")
......
......@@ -18,8 +18,8 @@
#include <string>
#include <unordered_map>
#include <vector>
#include "ai_ddk_lib/include/HiAiModelManagerService.h"
#include "ai_ddk_lib/include/hiai_ir_build.h"
#include "HiAiModelManagerService.h" // NOLINT
#include "hiai_ir_build.h" // NOLINT
namespace paddle {
namespace lite {
......
......@@ -27,7 +27,7 @@ namespace mir {
void NPUSubgraphPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
std::unordered_set<std::string> supported_lists;
#define USE_SUBGRAPH_BRIDGE(dev_type, op_type) supported_lists.insert(#op_type);
#define USE_SUBGRAPH_BRIDGE(op_type, target) supported_lists.insert(#op_type);
#include "lite/kernels/npu/bridges/paddle_use_bridges.h"
#undef USE_SUBGRAPH_BRIDGE
auto teller = [&](Node* node) {
......@@ -41,7 +41,7 @@ void NPUSubgraphPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
void XPUSubgraphPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
std::unordered_set<std::string> supported_lists;
#define USE_SUBGRAPH_BRIDGE(dev_type, op_type) supported_lists.insert(#op_type);
#define USE_SUBGRAPH_BRIDGE(op_type, target) supported_lists.insert(#op_type);
#include "lite/kernels/xpu/bridges/paddle_use_bridges.h"
#undef USE_SUBGRAPH_BRIDGE
auto teller = [&](Node* node) {
......
......@@ -43,33 +43,34 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
// Act node
auto act_node = graph->AddNode<ge::op::Activation>(out_name);
act_node->set_input_x(*x_node);
auto act_node = graph->Add<ge::op::Activation>(out_name);
auto act_op = act_node->data<ge::op::Activation>();
act_op->set_input_x(*x_node->data());
// TODO(hong19860320) set the coef value for act Ops, such as leaky_relu,
// clipped_relu etc.
act_node->set_attr_mode(CvtActMode(op_type));
act_op->set_attr_mode(CvtActMode(op_type));
if (op_type == "relu_clipped") {
auto Relu_clipped_coef = op_info->GetAttr<float>("Relu_clipped_coef");
act_node->set_attr_coef(Relu_clipped_coef);
act_op->set_attr_coef(Relu_clipped_coef);
} else if (op_type == "relu6") {
float Relu_clipped_coef = 6.f;
act_node->set_attr_coef(Relu_clipped_coef);
act_op->set_attr_coef(Relu_clipped_coef);
} else if (op_type == "leaky_relu") {
auto alpha = op_info->GetAttr<float>("alpha");
act_node->set_attr_negative_slope(alpha);
act_op->set_attr_negative_slope(alpha);
} else if (op_type == "hard_sigmoid") {
auto slope = op_info->GetAttr<float>("slope");
auto offset = op_info->GetAttr<float>("offset");
act_node->set_attr_negative_slope(slope);
act_node->set_attr_coef(offset);
act_op->set_attr_negative_slope(slope);
act_op->set_attr_coef(offset);
}
return SUCCESS;
}
......@@ -79,25 +80,27 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU,
sigmoid,
REGISTER_SUBGRAPH_BRIDGE(sigmoid,
kNPU,
paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU, relu, paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU, tanh, paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
relu_clipped,
REGISTER_SUBGRAPH_BRIDGE(relu, kNPU, paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(tanh, kNPU, paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(relu_clipped,
kNPU,
paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU, relu6, paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
leaky_relu,
REGISTER_SUBGRAPH_BRIDGE(relu6,
kNPU,
paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU, abs, paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
softsign,
REGISTER_SUBGRAPH_BRIDGE(leaky_relu,
kNPU,
paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
softplus,
REGISTER_SUBGRAPH_BRIDGE(abs, kNPU, paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(softsign,
kNPU,
paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
hard_sigmoid,
REGISTER_SUBGRAPH_BRIDGE(softplus,
kNPU,
paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(hard_sigmoid,
kNPU,
paddle::lite::subgraph::npu::ActConverter);
......@@ -44,20 +44,21 @@ int ArgmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
int axis = op_info->GetAttr<int64_t>("axis");
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
// Axis node
auto axis_const_node = graph->AddNode(out_name + "/axis", axis);
auto axis_node = graph->Add(out_name + "/axis", axis);
// Argmax node
auto argmax_node = graph->AddNode<ge::op::ArgMax>(out_name);
argmax_node->set_input_x1(*x_node);
argmax_node->set_input_x2(*axis_const_node);
auto argmax_node = graph->Add<ge::op::ArgMax>(out_name);
auto argmax_op = argmax_node->data<ge::op::ArgMax>();
argmax_op->set_input_x1(*x_node->data());
argmax_op->set_input_x2(*axis_node->data());
return SUCCESS;
}
......@@ -66,6 +67,6 @@ int ArgmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU,
arg_max,
REGISTER_SUBGRAPH_BRIDGE(arg_max,
kNPU,
paddle::lite::subgraph::npu::ArgmaxConverter);
......@@ -67,30 +67,31 @@ int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
bool use_global_stats = op_info->GetAttr<bool>("use_global_stats");
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
// Scale, Bias, Mean, Variance node
auto scale_const_node = graph->AddNode(scale_name, *scale);
auto bias_const_node = graph->AddNode(bias_name, *bias);
auto mean_const_node = graph->AddNode(mean_name, *mean);
auto variance_const_node = graph->AddNode(variance_name, *variance);
auto scale_node = graph->Add(scale_name, *scale);
auto bias_node = graph->Add(bias_name, *bias);
auto mean_node = graph->Add(mean_name, *mean);
auto variance_node = graph->Add(variance_name, *variance);
// Batch Norm node
auto batch_norm_node = graph->AddNode<ge::op::BatchNormExt2>(y_name);
batch_norm_node->set_input_x(*x_node);
batch_norm_node->set_input_scale(*scale_const_node);
batch_norm_node->set_input_offset(*bias_const_node);
batch_norm_node->set_input_mean(*mean_const_node);
batch_norm_node->set_input_variance(*variance_const_node);
batch_norm_node->set_attr_momentum(momentum);
batch_norm_node->set_attr_epsilon(epsilon);
batch_norm_node->set_attr_mode(mode);
batch_norm_node->set_attr_use_global_stats(use_global_stats);
auto batch_norm_node = graph->Add<ge::op::BatchNormExt2>(y_name);
auto batch_norm_op = batch_norm_node->data<ge::op::BatchNormExt2>();
batch_norm_op->set_input_x(*x_node->data());
batch_norm_op->set_input_scale(*scale_node->data());
batch_norm_op->set_input_offset(*bias_node->data());
batch_norm_op->set_input_mean(*mean_node->data());
batch_norm_op->set_input_variance(*variance_node->data());
batch_norm_op->set_attr_momentum(momentum);
batch_norm_op->set_attr_epsilon(epsilon);
batch_norm_op->set_attr_mode(mode);
batch_norm_op->set_attr_use_global_stats(use_global_stats);
return SUCCESS;
}
......@@ -99,6 +100,6 @@ int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU,
batch_norm,
REGISTER_SUBGRAPH_BRIDGE(batch_norm,
kNPU,
paddle::lite::subgraph::npu::BatchNormConverter);
......@@ -44,21 +44,22 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// Traverse all of input nodes which are added into the new created concat
// node
auto concat_node = graph->AddNode<ge::op::Concat>(out_name);
concat_node->set_attr_axis(axis);
concat_node->set_attr_N(num);
concat_node->create_dynamic_input_x(num);
auto concat_node = graph->Add<ge::op::Concat>(out_name);
auto concat_op = concat_node->data<ge::op::Concat>();
concat_op->set_attr_axis(axis);
concat_op->set_attr_N(num);
concat_op->create_dynamic_input_x(num);
int idx = 1;
for (auto& x_name : x_names) {
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
concat_node->set_dynamic_input_x(idx, *x_node);
concat_op->set_dynamic_input_x(idx, *x_node->data());
idx++;
}
return SUCCESS;
......@@ -69,6 +70,6 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU,
concat,
REGISTER_SUBGRAPH_BRIDGE(concat,
kNPU,
paddle::lite::subgraph::npu::ConcatConverter);
......@@ -67,11 +67,11 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK_EQ(dilations.size(), 2L);
// Input node
std::shared_ptr<ge::Operator> input_node = nullptr;
if (graph->HasNode(input_name)) {
input_node = graph->GetNode(input_name);
std::shared_ptr<Node> input_node = nullptr;
if (graph->Has(input_name)) {
input_node = graph->Get(input_name);
} else {
input_node = graph->AddNode(input_name, input_dims);
input_node = graph->Add(input_name, *input);
}
if (paddings.size() == 2L) {
......@@ -109,104 +109,102 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
}
// Filter node
auto filter_const_node = graph->AddNode(filter_name, *filter);
auto filter_node = graph->Add(filter_name, *filter);
// Add bias node if exists bias
// Supports the bias nodes with the following dimensions
// 0: {oc}
// 1: {1, oc, oh, ow}
// 2: {n, oc, oh, ow}
std::shared_ptr<ge::Operator> bias_node = nullptr;
std::shared_ptr<Node> bias_node = nullptr;
bool is_channel_bias = false;
if (HasInputArg(op_info, scope, "Bias")) {
auto bias_name = op_info->Input("Bias").front();
auto bias_type = kernel->GetInputDeclType("Bias");
CHECK(bias_type->precision() == PRECISION(kFloat));
CHECK(bias_type->layout() == DATALAYOUT(kNCHW));
auto bias = scope->FindMutableTensor(bias_name);
auto bias_dims = bias->dims();
auto bias_data_size = bias_dims.production();
auto output_data_size = output_dims.production();
std::vector<int64_t> bias_shape;
if (bias_data_size == oc) {
// 0: {oc}
bias_shape = {1, oc, 1, 1};
is_channel_bias = true;
} else if (bias_data_size == output_data_size / bs) {
// 1: {1, oc, oh, ow}
bias_shape = {1, output_dims[1], output_dims[2], output_dims[3]};
} else if (bias_data_size == output_data_size) {
// 2: {n, oc, oh, ow}
bias_shape = output_dims.Vectorize();
if (graph->Has(bias_name)) {
bias_node = graph->Get(bias_name);
} else {
LOG(WARNING) << "[NPU] Bias dimension " << bias_dims
<< " isn't supported in conv2d Op when output dimension is "
<< output_dims;
return FAILED;
}
if (graph->HasNode(bias_name)) {
// Bias node from input node
bias_node = graph->GetNode(bias_name);
} else {
// Bias node with const data
bias_node = graph->AddNode(bias_name, *bias, bias_shape);
auto bias_type = kernel->GetInputDeclType("Bias");
CHECK(bias_type->precision() == PRECISION(kFloat));
CHECK(bias_type->layout() == DATALAYOUT(kNCHW));
auto bias = scope->FindMutableTensor(bias_name);
auto bias_dims = bias->dims();
auto bias_data_size = bias_dims.production();
auto output_data_size = output_dims.production();
std::vector<int64_t> bias_shape;
if (bias_data_size == oc) {
// 0: {oc}
bias_shape = {1, oc, 1, 1};
is_channel_bias = true;
} else if (bias_data_size == output_data_size / bs) {
// 1: {1, oc, oh, ow}
bias_shape = {1, output_dims[1], output_dims[2], output_dims[3]};
} else if (bias_data_size == output_data_size) {
// 2: {n, oc, oh, ow}
bias_shape = output_dims.Vectorize();
} else {
LOG(WARNING)
<< "[NPU] Bias dimension " << bias_dims
<< " isn't supported in conv2d Op when output dimension is "
<< output_dims;
return FAILED;
}
bias_node = graph->Add(bias_name, *bias, bias_shape);
}
}
// Conv node
std::shared_ptr<ge::Operator> conv_node = nullptr;
std::shared_ptr<Node> conv_node = nullptr;
if (use_depthwise_conv && is_depthwise_mode) {
auto depthwise_conv_node =
graph->AddNode<ge::op::ConvolutionDepthwise>(output_name);
depthwise_conv_node->set_input_x(*input_node);
depthwise_conv_node->set_input_filter(*filter_const_node);
depthwise_conv_node->set_attr_mode(1);
depthwise_conv_node->set_attr_algo(0);
depthwise_conv_node->set_attr_format(0); // NCHW
depthwise_conv_node->set_attr_pad_mode(5); // VALID
depthwise_conv_node->set_attr_group(groups);
depthwise_conv_node->set_attr_pad(ge::AttrValue::LIST_INT(
conv_node = graph->Add<ge::op::ConvolutionDepthwise>(output_name);
auto conv_op = conv_node->data<ge::op::ConvolutionDepthwise>();
conv_op->set_input_x(*input_node->data());
conv_op->set_input_filter(*filter_node->data());
conv_op->set_attr_mode(1);
conv_op->set_attr_algo(0);
conv_op->set_attr_format(0); // NCHW
conv_op->set_attr_pad_mode(5); // VALID
conv_op->set_attr_group(groups);
conv_op->set_attr_pad(ge::AttrValue::LIST_INT(
{paddings[0], paddings[1], paddings[2], paddings[3]}));
depthwise_conv_node->set_attr_dilation(
conv_op->set_attr_dilation(
ge::AttrValue::LIST_INT({dilations[0], dilations[1]}));
depthwise_conv_node->set_attr_stride(
ge::AttrValue::LIST_INT({strides[0], strides[1]}));
depthwise_conv_node->set_attr_kernel(
conv_op->set_attr_stride(ge::AttrValue::LIST_INT({strides[0], strides[1]}));
conv_op->set_attr_kernel(
ge::AttrValue::LIST_INT({filter_dims[2], filter_dims[3]}));
conv_node = depthwise_conv_node;
// ConvolutionDepthwise Op doesn't support bias, so append Add node to
// support bias
if (bias_node != nullptr) {
auto add_node = graph->AddNode<ge::op::Add>(output_name);
add_node->set_input_x1(*depthwise_conv_node);
add_node->set_input_x2(*bias_node);
auto add_node = graph->Add<ge::op::Add>(output_name);
auto add_op = add_node->data<ge::op::Add>();
add_op->set_input_x1(*conv_node->data());
add_op->set_input_x2(*bias_node->data());
conv_node = add_node;
}
} else {
auto common_conv_node = graph->AddNode<ge::op::Convolution>(output_name);
common_conv_node->set_input_x(*input_node);
common_conv_node->set_input_w(*filter_const_node);
common_conv_node->set_attr_mode(1);
common_conv_node->set_attr_pad_mode(0); // NOTSET
common_conv_node->set_attr_group(groups);
common_conv_node->set_attr_pad(ge::AttrValue::LIST_INT(
conv_node = graph->Add<ge::op::Convolution>(output_name);
auto conv_op = conv_node->data<ge::op::Convolution>();
conv_op->set_input_x(*input_node->data());
conv_op->set_input_w(*filter_node->data());
conv_op->set_attr_mode(1);
conv_op->set_attr_pad_mode(0); // NOTSET
conv_op->set_attr_group(groups);
conv_op->set_attr_pad(ge::AttrValue::LIST_INT(
{paddings[0], paddings[0], paddings[2], paddings[2]}));
common_conv_node->set_attr_dilation(
conv_op->set_attr_dilation(
ge::AttrValue::LIST_INT({dilations[0], dilations[1]}));
common_conv_node->set_attr_stride(
ge::AttrValue::LIST_INT({strides[0], strides[1]}));
common_conv_node->set_attr_kernel(
conv_op->set_attr_stride(ge::AttrValue::LIST_INT({strides[0], strides[1]}));
conv_op->set_attr_kernel(
ge::AttrValue::LIST_INT({filter_dims[2], filter_dims[3]}));
conv_node = common_conv_node;
// Convolution Op only support bias with dimension {1, oc, 1, 1},
// so append Add node if dimension is {1, oc, oh, ow} or (n, oc, oh, ow)
if (bias_node != nullptr) {
if (is_channel_bias) {
common_conv_node->set_input_b(*bias_node);
conv_op->set_input_b(*bias_node->data());
} else {
auto add_node = graph->AddNode<ge::op::Add>(output_name);
add_node->set_input_x1(*common_conv_node);
add_node->set_input_x2(*bias_node);
auto add_node = graph->Add<ge::op::Add>(output_name);
auto add_op = add_node->data<ge::op::Add>();
add_op->set_input_x1(*conv_node->data());
add_op->set_input_x2(*bias_node->data());
conv_node = add_node;
}
}
......@@ -215,9 +213,10 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
if (fuse_relu) {
// Append relu node if fuse_relu is true
auto relu_node = graph->AddNode<ge::op::Activation>(output_name);
relu_node->set_input_x(*conv_node);
relu_node->set_attr_mode(CvtActMode("relu"));
auto relu_node = graph->Add<ge::op::Activation>(output_name);
auto relu_op = relu_node->data<ge::op::Activation>();
relu_op->set_input_x(*conv_node->data());
relu_op->set_attr_mode(CvtActMode("relu"));
}
return REBUILD_WHEN_SHAPE_CHANGED;
}
......@@ -227,9 +226,9 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU,
conv2d,
REGISTER_SUBGRAPH_BRIDGE(conv2d,
kNPU,
paddle::lite::subgraph::npu::ConvConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
depthwise_conv2d,
REGISTER_SUBGRAPH_BRIDGE(depthwise_conv2d,
kNPU,
paddle::lite::subgraph::npu::ConvConverter);
......@@ -58,11 +58,11 @@ int ConvTransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK_EQ(dilations.size(), 2L);
// Input node
std::shared_ptr<ge::Operator> input_node = nullptr;
if (graph->HasNode(input_name)) {
input_node = graph->GetNode(input_name);
std::shared_ptr<Node> input_node = nullptr;
if (graph->Has(input_name)) {
input_node = graph->Get(input_name);
} else {
input_node = graph->AddNode(input_name, input_dims);
input_node = graph->Add(input_name, *input);
}
// Create input sizes node to describe the dimensions of input tensor
......@@ -83,55 +83,59 @@ int ConvTransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
(input_dims[i + 2] - 1) * strides[i] + kernel_ext - 2 * paddings[i];
input_sizes.push_back(output_size);
}
auto input_sizes_const_node =
graph->AddNode(output_name + "/input_sizes", input_sizes);
auto input_sizes_node = graph->Add(output_name + "/input_sizes", input_sizes);
// Filter node
auto filter_const_node = graph->AddNode(filter_name, *filter);
auto filter_node = graph->Add(filter_name, *filter);
// Deconv node
auto conv_transpose_node = graph->AddNode<ge::op::Deconvolution>(output_name);
conv_transpose_node->set_input_input_sizes(*input_sizes_const_node);
conv_transpose_node->set_input_filter(*filter_const_node);
conv_transpose_node->set_input_x(*input_node);
auto conv_transpose_node = graph->Add<ge::op::Deconvolution>(output_name);
auto conv_transpose_op = conv_transpose_node->data<ge::op::Deconvolution>();
conv_transpose_op->set_input_input_sizes(*input_sizes_node->data());
conv_transpose_op->set_input_filter(*filter_node->data());
conv_transpose_op->set_input_x(*input_node->data());
// Set attributes
conv_transpose_node->set_attr_format(0); // NCHW
conv_transpose_node->set_attr_pad_mode(0); // NOTSET
conv_transpose_node->set_attr_group(groups);
conv_transpose_node->set_attr_pad(ge::AttrValue::LIST_INT(
conv_transpose_op->set_attr_format(0); // NCHW
conv_transpose_op->set_attr_pad_mode(0); // NOTSET
conv_transpose_op->set_attr_group(groups);
conv_transpose_op->set_attr_pad(ge::AttrValue::LIST_INT(
{paddings[0], paddings[1], paddings[2], paddings[3]}));
conv_transpose_node->set_attr_dilation(
conv_transpose_op->set_attr_dilation(
ge::AttrValue::LIST_INT({dilations[0], dilations[1]}));
conv_transpose_node->set_attr_stride(
conv_transpose_op->set_attr_stride(
ge::AttrValue::LIST_INT({strides[0], strides[1]}));
conv_transpose_node->set_attr_kernel(
conv_transpose_op->set_attr_kernel(
ge::AttrValue::LIST_INT({filter_dims[2], filter_dims[3]}));
// Append add node to add bias if exists bias
std::shared_ptr<ge::Operator> output_node = conv_transpose_node;
if (HasInputArg(op_info, scope, "Bias")) {
// Create bias node
std::shared_ptr<Node> bias_node = nullptr;
auto bias_name = op_info->Input("Bias").front();
auto bias_type = kernel->GetInputDeclType("Bias");
CHECK(bias_type->precision() == PRECISION(kFloat));
CHECK(bias_type->layout() == DATALAYOUT(kNCHW));
auto bias = scope->FindMutableTensor(bias_name);
auto channel_size = bias->dims().production();
CHECK_EQ(channel_size, filter_dims[1] * groups);
auto bias_const_node =
graph->AddNode(bias_name, *bias, {1, channel_size, 1, 1});
if (graph->Has(bias_name)) {
bias_node = graph->Get(bias_name);
} else {
auto bias_type = kernel->GetInputDeclType("Bias");
CHECK(bias_type->precision() == PRECISION(kFloat));
CHECK(bias_type->layout() == DATALAYOUT(kNCHW));
auto bias = scope->FindMutableTensor(bias_name);
auto channel_size = bias->dims().production();
CHECK_EQ(channel_size, filter_dims[1] * groups);
bias_node = graph->Add(bias_name, *bias, {1, channel_size, 1, 1});
}
// Append add node to add bias node
auto add_node = graph->AddNode<ge::op::Add>(output_name);
add_node->set_input_x1(*conv_transpose_node);
add_node->set_input_x2(*bias_const_node);
output_node = add_node;
auto add_node = graph->Add<ge::op::Add>(output_name);
auto add_op = add_node->data<ge::op::Add>();
add_op->set_input_x1(*conv_transpose_node->data());
add_op->set_input_x2(*bias_node->data());
conv_transpose_node = add_node;
}
if (fuse_relu) {
// Append relu node if fuse_relu is true
auto relu_node = graph->AddNode<ge::op::Activation>(output_name);
relu_node->set_input_x(*output_node);
relu_node->set_attr_mode(CvtActMode("relu"));
auto relu_node = graph->Add<ge::op::Activation>(output_name);
auto relu_op = relu_node->data<ge::op::Activation>();
relu_op->set_input_x(*conv_transpose_node->data());
relu_op->set_attr_mode(CvtActMode("relu"));
}
return REBUILD_WHEN_SHAPE_CHANGED;
}
......@@ -141,6 +145,6 @@ int ConvTransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU,
conv2d_transpose,
REGISTER_SUBGRAPH_BRIDGE(conv2d_transpose,
kNPU,
paddle::lite::subgraph::npu::ConvTransposeConverter);
......@@ -74,45 +74,45 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto axis = op_info->GetAttr<int>("axis");
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
// Y node
std::shared_ptr<ge::Operator> y_node = nullptr;
if (graph->HasNode(y_name)) {
y_node = graph->GetNode(y_name);
std::shared_ptr<Node> y_node = nullptr;
if (graph->Has(y_name)) {
y_node = graph->Get(y_name);
} else {
auto y_new_shape = CvtYShape(x_dims, y_dims, axis);
y_node = graph->AddNode(y_name, y_new_shape);
y_node = graph->Add(y_name, *y, y_new_shape);
}
// Elementwise node
std::shared_ptr<ge::Operator> elementwise_node = nullptr;
std::shared_ptr<Node> elt_node = nullptr;
if (op_type == "elementwise_add" ||
op_type == "fusion_elementwise_add_activation") {
auto elt_node = graph->AddNode<ge::op::Add>(out_name);
elt_node->set_input_x1(*x_node);
elt_node->set_input_x2(*y_node);
elementwise_node = elt_node;
elt_node = graph->Add<ge::op::Add>(out_name);
auto elt_op = elt_node->data<ge::op::Add>();
elt_op->set_input_x1(*x_node->data());
elt_op->set_input_x2(*y_node->data());
} else if (op_type == "elementwise_sub") {
auto elt_node = graph->AddNode<ge::op::Sub>(out_name);
elt_node->set_input_x1(*x_node);
elt_node->set_input_x2(*y_node);
elementwise_node = elt_node;
elt_node = graph->Add<ge::op::Sub>(out_name);
auto elt_op = elt_node->data<ge::op::Sub>();
elt_op->set_input_x1(*x_node->data());
elt_op->set_input_x2(*y_node->data());
} else if (op_type == "elementwise_mul") {
auto elt_node = graph->AddNode<ge::op::Mul>(out_name);
elt_node->set_input_x(*x_node);
elt_node->set_input_y(*y_node);
elementwise_node = elt_node;
elt_node = graph->Add<ge::op::Mul>(out_name);
auto elt_op = elt_node->data<ge::op::Mul>();
elt_op->set_input_x(*x_node->data());
elt_op->set_input_y(*y_node->data());
} else if (op_type == "elementwise_div") {
auto elt_node = graph->AddNode<ge::op::RealDiv>(out_name);
elt_node->set_input_x1(*x_node);
elt_node->set_input_x2(*y_node);
elementwise_node = elt_node;
elt_node = graph->Add<ge::op::RealDiv>(out_name);
auto elt_op = elt_node->data<ge::op::RealDiv>();
elt_op->set_input_x1(*x_node->data());
elt_op->set_input_x2(*y_node->data());
} else {
LOG(WARNING) << "[NPU] Unsupported op type: " << op_type;
return FAILED;
......@@ -121,11 +121,12 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// Act node
if (op_type == "fusion_elementwise_add_activation") {
auto act_type = op_info->GetAttr<std::string>("act_type");
auto act_node = graph->AddNode<ge::op::Activation>(out_name);
act_node->set_input_x(*elementwise_node);
auto act_node = graph->Add<ge::op::Activation>(out_name);
auto act_op = act_node->data<ge::op::Activation>();
act_op->set_input_x(*elt_node->data());
// TODO(hong19860320) set the coef value for act Ops, such as leaky_relu,
// clipped_relu etc.
act_node->set_attr_mode(CvtActMode(act_type));
act_op->set_attr_mode(CvtActMode(act_type));
}
return REBUILD_WHEN_SHAPE_CHANGED;
}
......@@ -135,18 +136,18 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU,
elementwise_add,
REGISTER_SUBGRAPH_BRIDGE(elementwise_add,
kNPU,
paddle::lite::subgraph::npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
fusion_elementwise_add_activation,
REGISTER_SUBGRAPH_BRIDGE(fusion_elementwise_add_activation,
kNPU,
paddle::lite::subgraph::npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
elementwise_sub,
REGISTER_SUBGRAPH_BRIDGE(elementwise_sub,
kNPU,
paddle::lite::subgraph::npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
elementwise_mul,
REGISTER_SUBGRAPH_BRIDGE(elementwise_mul,
kNPU,
paddle::lite::subgraph::npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
elementwise_div,
REGISTER_SUBGRAPH_BRIDGE(elementwise_div,
kNPU,
paddle::lite::subgraph::npu::ElementwiseConverter);
......@@ -57,22 +57,24 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) {
<< " m: " << m << " k: " << k << " n: " << n;
// Create input node and reshape it to (m, k, 1, 1)
std::shared_ptr<ge::Operator> input_node = nullptr;
if (graph->HasNode(input_name)) {
input_node = graph->GetNode(input_name);
std::shared_ptr<Node> input_node = nullptr;
if (graph->Has(input_name)) {
input_node = graph->Get(input_name);
} else {
input_node = graph->AddNode(input_name, input_dims);
input_node = graph->Add(input_name, *input);
}
auto reshaped_input_node =
graph->AddNode<ge::op::Reshape>(input_name + "/reshape");
reshaped_input_node->set_input_tensor(*input_node);
reshaped_input_node->set_attr_shape({m, k, 1, 1});
reshaped_input_node->set_attr_axis(0);
graph->Add<ge::op::Reshape>(input_name + "/reshape");
auto reshaped_input_op = reshaped_input_node->data<ge::op::Reshape>();
reshaped_input_op->set_input_tensor(*input_node->data());
reshaped_input_op->set_attr_shape({m, k, 1, 1});
reshaped_input_op->set_attr_axis(0);
// Create w const node, set its shape to (n, k, 1, 1) and fill with
// the transposed w tensor
Tensor transpose_w;
transpose_w.Resize({n, k, 1, 1});
transpose_w.set_persistable(true);
auto transpose_w_data = transpose_w.mutable_data<float>();
auto w_data = w->mutable_data<float>();
for (int i = 0; i < k; i++) {
......@@ -80,29 +82,36 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) {
transpose_w_data[j * k + i] = w_data[i * n + j];
}
}
auto trans_w_const_node = graph->AddNode(w_name, transpose_w);
auto trans_w_node = graph->Add(w_name, transpose_w);
// FC node
auto fc_node = graph->AddNode<ge::op::FullConnection>(out_name + "/fc");
fc_node->set_input_x(*reshaped_input_node);
fc_node->set_input_w(*trans_w_const_node);
auto fc_node = graph->Add<ge::op::FullConnection>(out_name + "/fc");
auto fc_op = fc_node->data<ge::op::FullConnection>();
fc_op->set_input_x(*reshaped_input_node->data());
fc_op->set_input_w(*trans_w_node->data());
// Add bias node if bias tensor exists
if (HasInputArg(op_info, scope, "Bias")) {
std::shared_ptr<Node> bias_node = nullptr;
auto bias_name = op_info->Input("Bias").front();
auto bias_type = kernel->GetInputDeclType("Bias");
CHECK(bias_type->precision() == PRECISION(kFloat));
CHECK(bias_type->layout() == DATALAYOUT(kNCHW));
auto bias = scope->FindMutableTensor(bias_name);
auto bias_dims = bias->dims();
CHECK_EQ(bias_dims.production(), n);
auto bias_const_node = graph->AddNode(bias_name, *bias, {1, n, 1, 1});
fc_node->set_input_b(*bias_const_node);
if (graph->Has(bias_name)) {
bias_node = graph->Get(bias_name);
} else {
auto bias_type = kernel->GetInputDeclType("Bias");
CHECK(bias_type->precision() == PRECISION(kFloat));
CHECK(bias_type->layout() == DATALAYOUT(kNCHW));
auto bias = scope->FindMutableTensor(bias_name);
auto bias_dims = bias->dims();
CHECK_EQ(bias_dims.production(), n);
bias_node = graph->Add(bias_name, *bias, {1, n, 1, 1});
}
fc_op->set_input_b(*bias_node->data());
}
// Reshape output of FC node from (m, n, 1, 1) to (m, n)
auto reshaped_fc_node = graph->AddNode<ge::op::Reshape>(out_name);
reshaped_fc_node->set_input_tensor(*fc_node);
reshaped_fc_node->set_attr_shape({m, n});
reshaped_fc_node->set_attr_axis(0);
auto reshaped_fc_node = graph->Add<ge::op::Reshape>(out_name);
auto reshaped_fc_op = reshaped_fc_node->data<ge::op::Reshape>();
reshaped_fc_op->set_input_tensor(*fc_node->data());
reshaped_fc_op->set_attr_shape({m, n});
reshaped_fc_op->set_attr_axis(0);
return REBUILD_WHEN_SHAPE_CHANGED;
}
......@@ -111,4 +120,4 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU, fc, paddle::lite::subgraph::npu::FCConverter);
REGISTER_SUBGRAPH_BRIDGE(fc, kNPU, paddle::lite::subgraph::npu::FCConverter);
......@@ -21,26 +21,52 @@ namespace lite {
namespace subgraph {
namespace npu {
// Const node
std::shared_ptr<ge::op::Const> Graph::AddNode(const std::string& name,
const Tensor& tensor,
std::vector<int64_t> shape,
PrecisionType precision,
DataLayoutType layout) {
auto node = AddNode<ge::op::Const>(name, precision, layout);
node->set_attr_value(CvtTensor(tensor, shape, precision, layout));
int Graph::Add(const std::string& name, std::shared_ptr<Node> node) {
auto it = nodes_.find(name);
if (it != nodes_.end()) {
// Only variable node can be shared with the same name
if (!node->is_var() || !it->second.back()->is_var()) {
LOG(FATAL) << "[NPU] Const or data node " << name << " is redefined.";
return -1;
}
} else {
auto ret = nodes_.insert(
std::make_pair(name, std::vector<std::shared_ptr<Node>>()));
CHECK(ret.second);
it = ret.first;
}
it->second.push_back(node);
return it->second.size();
}
// Const or data node
std::shared_ptr<Node> Graph::Add(const std::string& name,
const Tensor& tensor,
std::vector<int64_t> shape,
PrecisionType precision,
DataLayoutType layout) {
std::shared_ptr<Node> node = nullptr;
if (tensor.persistable()) {
// Const node
node = Add<ge::op::Const>(name, precision, layout);
node->data<ge::op::Const>()->set_attr_value(
CvtTensor(tensor, shape, precision, layout));
} else {
// Data node
node = Add(name, shape, precision, layout);
}
return node;
}
// Data node
std::shared_ptr<ge::op::Data> Graph::AddNode(const std::string& name,
std::vector<int64_t> shape,
PrecisionType precision,
DataLayoutType layout) {
auto node = AddNode<ge::op::Data>(name);
std::shared_ptr<Node> Graph::Add(const std::string& name,
std::vector<int64_t> shape,
PrecisionType precision,
DataLayoutType layout) {
auto node = Add<ge::op::Data>(name, precision, layout);
ge::TensorDesc desc(
ge::Shape(shape), CvtDataLayoutType(layout), CvtPrecisionType(precision));
node->update_input_desc_x(desc);
node->data<ge::op::Data>()->update_input_desc_x(desc);
return node;
}
......
......@@ -19,7 +19,7 @@
#include <unordered_map>
#include <utility>
#include <vector>
#include "ai_ddk_lib/include/graph/op/all_ops.h"
#include "graph/op/all_ops.h"
#include "lite/core/op_lite.h"
#include "lite/core/tensor.h"
......@@ -28,94 +28,97 @@ namespace lite {
namespace subgraph {
namespace npu {
// Type of graph nodes
class Type {
// Graph and node is defined to collect all of converted HiAI IR nodes
class Node {
public:
Type(PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW),
bool persistable = false)
: precision_(precision), layout_(layout), persistable_(persistable) {}
enum class Role {
kVar = 0,
kConst,
kData,
};
Node(std::shared_ptr<ge::Operator> data,
PrecisionType precision,
DataLayoutType layout,
Role role)
: data_(data), precision_(precision), layout_(layout), role_(role) {}
Node(PrecisionType precision, DataLayoutType layout, Role role)
: precision_(precision), layout_(layout), role_(role) {}
void set_data(std::shared_ptr<ge::Operator> data) { data_ = data; }
void set_precision(PrecisionType precision) { precision_ = precision; }
void set_layout(DataLayoutType layout) { layout_ = layout; }
bool set_persistable(bool persistable) { persistable_ = persistable; }
void set_role(Role role) { role_ = role; }
template <typename T>
std::shared_ptr<T> data() {
return std::static_pointer_cast<T>(data_);
}
std::shared_ptr<ge::Operator> data() { return data_; }
PrecisionType precision() const { return precision_; }
DataLayoutType layout() const { return layout_; }
bool persistable() const { return persistable_; }
bool is_var() const { return role_ == Role::kVar; }
bool is_const() const { return role_ == Role::kConst; }
bool is_data() const { return role_ == Role::kData; }
private:
std::shared_ptr<ge::Operator> data_{nullptr};
PrecisionType precision_{PRECISION(kFloat)};
DataLayoutType layout_{DATALAYOUT(kNCHW)};
bool persistable_{false};
Role role_{Role::kVar};
};
// Graph to collect all of converted HiAI IR nodes
class Graph {
public:
int Add(const std::string& name, std::shared_ptr<Node> node);
// Variable, const or data node
template <typename T>
std::shared_ptr<T> AddNode(const std::string& name,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) {
auto unique_name = [&](const std::string& key) {
int idx = 1;
auto it = counts_.find(key);
if (it == counts_.end()) {
counts_.insert(std::make_pair(key, idx));
} else {
idx = ++(it->second);
}
return key + "_" + std::to_string(idx);
};
bool persistable = typeid(T) == typeid(ge::op::Const);
auto it = nodes_.find(name);
if (it != nodes_.end()) {
// Only variable can rebind the name
CHECK(!it->second.second.persistable() && !persistable)
<< "[NPU] Node " << name << " redefined.";
// Generate a new unique name as the key to bind the origin node:
// new_name->node
nodes_.insert(std::make_pair(unique_name(name + "_var"), it->second));
nodes_.erase(it);
std::shared_ptr<Node> Add(const std::string& name,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) {
Node::Role role = Node::Role::kVar;
if (typeid(T) == typeid(ge::op::Const)) {
role = Node::Role::kConst;
} else if (typeid(T) == typeid(ge::op::Data)) {
role = Node::Role::kData;
}
// Create a new node and bind with the name: name->new_node
auto node = std::make_shared<T>(unique_name(name + "_op"));
nodes_.insert(std::make_pair(
name, std::make_pair(node, Type(precision, layout, persistable))));
auto node = std::make_shared<Node>(precision, layout, role);
auto idx = Add(name, node);
CHECK_GE(idx, 1);
// Generate a unique name for the created HiAI IR
node->set_data(std::make_shared<T>(name + "__" + std::to_string(idx)));
return node;
}
// Const node
std::shared_ptr<ge::op::Const> AddNode(
const std::string& name,
const Tensor& tensor,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return AddNode(name, tensor, tensor.dims().Vectorize(), precision, layout);
// Const or data node
std::shared_ptr<Node> Add(const std::string& name,
const Tensor& tensor,
std::vector<int64_t> shape,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW));
std::shared_ptr<Node> Add(const std::string& name,
const Tensor& tensor,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return Add(name, tensor, tensor.dims().Vectorize(), precision, layout);
}
std::shared_ptr<ge::op::Const> AddNode(
const std::string& name,
const Tensor& tensor,
std::vector<int64_t> shape,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW));
std::shared_ptr<ge::op::Const> AddNode(
const std::string& name,
const Tensor& tensor,
DDim dims,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return AddNode(name, tensor, dims.Vectorize(), precision, layout);
std::shared_ptr<Node> Add(const std::string& name,
const Tensor& tensor,
DDim dims,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return Add(name, tensor, dims.Vectorize(), precision, layout);
}
// Const node
template <typename T>
std::shared_ptr<ge::op::Const> AddNode(
const std::string& name,
const std::vector<T>& data,
std::vector<int64_t> shape = {},
DataLayoutType layout = DATALAYOUT(kNCHW)) {
std::shared_ptr<Node> Add(const std::string& name,
const std::vector<T>& data,
std::vector<int64_t> shape = {},
DataLayoutType layout = DATALAYOUT(kNCHW)) {
const std::type_info& info = typeid(T);
PrecisionType precision = PRECISION(kFloat);
if (info == typeid(float)) {
......@@ -138,78 +141,66 @@ class Graph {
}
Tensor tensor;
tensor.Resize(shape);
tensor.set_persistable(true);
std::memcpy(reinterpret_cast<uint8_t*>(tensor.mutable_data<T>()),
reinterpret_cast<const uint8_t*>(data.data()),
data.size() * sizeof(T));
return AddNode(name, tensor, precision, layout);
return Add(name, tensor, precision, layout);
}
template <typename T>
std::shared_ptr<ge::op::Const> AddNode(
const std::string& name,
const std::vector<T>& data,
DDim dims,
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return AddNode(name, data, dims.Vectorize(), layout);
std::shared_ptr<Node> Add(const std::string& name,
const std::vector<T>& data,
DDim dims,
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return Add(name, data, dims.Vectorize(), layout);
}
template <typename T>
std::shared_ptr<ge::op::Const> AddNode(
const std::string& name,
T value,
std::vector<int64_t> shape = {1},
DataLayoutType layout = DATALAYOUT(kNCHW)) {
std::shared_ptr<Node> Add(const std::string& name,
T value,
std::vector<int64_t> shape = {1},
DataLayoutType layout = DATALAYOUT(kNCHW)) {
int64_t size = 1;
for (auto i : shape) {
size *= i;
}
std::vector<T> data(size, value);
return AddNode(name, data, shape, layout);
return Add(name, data, shape, layout);
}
template <typename T>
std::shared_ptr<ge::op::Const> AddNode(
const std::string& name,
T value,
DDim dims,
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return AddNode(name, value, dims.Vectorize(), layout);
std::shared_ptr<Node> Add(const std::string& name,
T value,
DDim dims,
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return Add(name, value, dims.Vectorize(), layout);
}
// Data node
std::shared_ptr<ge::op::Data> AddNode(
const std::string& name,
std::vector<int64_t> shape,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW));
std::shared_ptr<ge::op::Data> AddNode(
const std::string& name,
DDim dims,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return AddNode(name, dims.Vectorize(), precision, layout);
}
std::shared_ptr<ge::Operator> GetNode(std::string name) {
CHECK(HasNode(name)) << "[NPU] Node " << name << " not found.";
return nodes_.at(name).first;
std::shared_ptr<Node> Add(const std::string& name,
std::vector<int64_t> shape,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW));
std::shared_ptr<Node> Add(const std::string& name,
DDim dims,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return Add(name, dims.Vectorize(), precision, layout);
}
const Type& GetType(const std::string& name) {
CHECK(HasNode(name)) << "[NPU] Node " << name << " not found.";
return nodes_.at(name).second;
std::shared_ptr<Node> Get(std::string name) {
CHECK(Has(name)) << "[NPU] Node " << name << " not found.";
return nodes_.at(name).back();
}
bool HasNode(const std::string& name) {
bool Has(const std::string& name) {
return nodes_.find(name) != nodes_.end();
}
private:
std::unordered_map<std::string,
std::pair<std::shared_ptr<ge::Operator>, Type>>
nodes_;
std::unordered_map<std::string, int> counts_;
std::unordered_map<std::string, std::vector<std::shared_ptr<Node>>> nodes_;
};
} // namespace npu
......
......@@ -55,11 +55,11 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) {
"supported in HiAI DDK";
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
// Priority: OutSize > scale > out_h/out_w
......@@ -71,17 +71,18 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) {
}
// Update out_h and out_w and create out_size node if has OutSize
std::shared_ptr<ge::Operator> out_size_node = nullptr;
std::shared_ptr<Node> out_size_node = nullptr;
if (HasInputArg(op_info, scope, "OutSize")) {
auto out_size_name = op_info->Input("OutSize").front();
auto out_size_type = kernel->GetInputDeclType("OutSize");
CHECK(out_size_type->precision() == PRECISION(kInt32));
CHECK(out_size_type->layout() == DATALAYOUT(kNCHW));
if (graph->HasNode(out_size_name)) {
out_size_node = graph->GetNode(out_size_name);
if (graph->Has(out_size_name)) {
out_size_node = graph->Get(out_size_name);
} else {
auto out_size = scope->FindMutableTensor(out_size_name);
CHECK_EQ(out_size->numel(), 2);
CHECK(out_size->persistable());
auto out_size_data = out_size->mutable_data<int>();
// Update out_h and out_w if has OutSize
out_h = out_size_data[0];
......@@ -97,22 +98,25 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) {
<< " is too large, should not exceed " << largest_multiple
<< " in HiAI DDK";
}
out_size_node = graph->AddNode(out_name + "/out_size",
std::vector<int>({out_h, out_w}));
out_size_node =
graph->Add(out_name + "/out_size", std::vector<int>({out_h, out_w}));
}
if (interp_method == "bilinear") {
auto bilinear_interp_node =
graph->AddNode<ge::op::ResizeBilinear>(out_name);
bilinear_interp_node->set_input_x(*x_node);
bilinear_interp_node->set_input_size(*out_size_node);
bilinear_interp_node->set_attr_align_corners(align_corners);
auto bilinear_interp_node = graph->Add<ge::op::ResizeBilinear>(out_name);
auto bilinear_interp_op =
bilinear_interp_node->data<ge::op::ResizeBilinear>();
bilinear_interp_op->set_input_x(*x_node->data());
bilinear_interp_op->set_input_size(*out_size_node->data());
bilinear_interp_op->set_attr_align_corners(align_corners);
} else if (interp_method == "nearest") {
auto nearest_interp_node =
graph->AddNode<ge::op::ResizeNearestNeighbor>(out_name);
nearest_interp_node->set_input_image(*x_node);
nearest_interp_node->set_input_size(*out_size_node);
nearest_interp_node->set_attr_align_corners(align_corners);
graph->Add<ge::op::ResizeNearestNeighbor>(out_name);
auto nearest_interp_op =
nearest_interp_node->data<ge::op::ResizeNearestNeighbor>();
nearest_interp_op->set_input_image(*x_node->data());
nearest_interp_op->set_input_size(*out_size_node->data());
nearest_interp_op->set_attr_align_corners(align_corners);
} else {
LOG(WARNING) << "[NPU] Unsupported interpolate method: " << interp_method;
return FAILED;
......@@ -125,9 +129,9 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU,
bilinear_interp,
REGISTER_SUBGRAPH_BRIDGE(bilinear_interp,
kNPU,
paddle::lite::subgraph::npu::InterpolateConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
nearest_interp,
REGISTER_SUBGRAPH_BRIDGE(nearest_interp,
kNPU,
paddle::lite::subgraph::npu::InterpolateConverter);
......@@ -56,45 +56,46 @@ int MulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
<< "[NPU] columns of X must be equal with rows of Y";
int n = y_dims.Slice(y_num_col_dims, y_dims.size()).production();
VLOG(3) << "m:" << m << ",n:" << n << ",k:" << k;
VLOG(3) << "x_name:" << x_name << ", is data: " << graph->HasNode(x_name);
VLOG(3) << "y_name:" << y_name << ", is data: " << graph->HasNode(y_name);
CHECK(graph->HasNode(x_name))
VLOG(3) << "x_name:" << x_name << ", is data: " << graph->Has(x_name);
VLOG(3) << "y_name:" << y_name << ", is data: " << graph->Has(y_name);
CHECK(graph->Has(x_name))
<< "[NPU] MatMul in HiAI DDK only support X is data, Y is const yet.";
// X node which supports persistable and non-persistable tensor, and
// reshape to (m, k)
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
auto reshaped_x_node = graph->AddNode<ge::op::Reshape>(x_name + "/reshape");
reshaped_x_node->set_input_tensor(*x_node);
reshaped_x_node->set_attr_shape({m, k});
reshaped_x_node->set_attr_axis(0);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
auto reshaped_x_node = graph->Add<ge::op::Reshape>(x_name + "/reshape");
auto reshaped_x_op = reshaped_x_node->data<ge::op::Reshape>();
reshaped_x_op->set_input_tensor(*x_node->data());
reshaped_x_op->set_attr_shape({m, k});
reshaped_x_op->set_attr_axis(0);
x_node = reshaped_x_node;
} else {
auto x_const_node = graph->AddNode(x_name, *x, {m, k});
x_node = x_const_node;
x_node = graph->Add(x_name, *x, {m, k});
}
// Y node which only supports persistable tensor, and reshape to
// (k,n)
std::shared_ptr<ge::Operator> y_node = nullptr;
if (graph->HasNode(y_name)) {
y_node = graph->GetNode(y_name);
auto reshaped_y_node = graph->AddNode<ge::op::Reshape>(y_name + "/reshape");
reshaped_y_node->set_input_tensor(*y_node);
reshaped_y_node->set_attr_shape({k, n});
reshaped_y_node->set_attr_axis(0);
std::shared_ptr<Node> y_node = nullptr;
if (graph->Has(y_name)) {
y_node = graph->Get(y_name);
auto reshaped_y_node = graph->Add<ge::op::Reshape>(y_name + "/reshape");
auto reshaped_y_op = reshaped_y_node->data<ge::op::Reshape>();
reshaped_y_op->set_input_tensor(*y_node->data());
reshaped_y_op->set_attr_shape({k, n});
reshaped_y_op->set_attr_axis(0);
y_node = reshaped_y_node;
} else {
auto y_const_node = graph->AddNode(y_name, *y, {k, n});
y_node = y_const_node;
y_node = graph->Add(y_name, *y, {k, n});
}
// Matmul node
auto mul_node = graph->AddNode<ge::op::MatMul>(out_name);
mul_node->set_input_x1(*x_node);
mul_node->set_input_x2(*y_node);
auto mul_node = graph->Add<ge::op::MatMul>(out_name);
auto mul_op = mul_node->data<ge::op::MatMul>();
mul_op->set_input_x1(*x_node->data());
mul_op->set_input_x2(*y_node->data());
return REBUILD_WHEN_SHAPE_CHANGED;
}
......@@ -103,4 +104,4 @@ int MulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU, mul, paddle::lite::subgraph::npu::MulConverter);
REGISTER_SUBGRAPH_BRIDGE(mul, kNPU, paddle::lite::subgraph::npu::MulConverter);
......@@ -45,35 +45,34 @@ int Pad2dConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK_EQ(padding.size(), 4);
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
// Padding node
int xds = x_dims.size();
padding.insert(padding.begin(), xds * 2 - 4, 0);
auto padding_const_node =
graph->AddNode(out_name + "/padding", padding, {xds, 2});
auto padding_node = graph->Add(out_name + "/padding", padding, {xds, 2});
// Pad node
auto pad2d_node = graph->AddNode<ge::op::Pad>(out_name);
pad2d_node->set_input_x(*x_node);
pad2d_node->set_input_padding(*padding_const_node);
auto pad2d_node = graph->Add<ge::op::Pad>(out_name);
auto pad2d_op = pad2d_node->data<ge::op::Pad>();
pad2d_op->set_input_x(*x_node->data());
pad2d_op->set_input_padding(*padding_node->data());
auto mode = op_info->GetAttr<std::string>("mode");
if (mode == "constant") {
// Pad value node
auto pad_value = op_info->GetAttr<float>("pad_value");
auto pad_value_const_node =
graph->AddNode(out_name + "/pad_value", pad_value);
pad2d_node->set_input_constant_values(*pad_value_const_node);
pad2d_node->set_attr_T(0); // type of pad_value: 0:float 3:int32
pad2d_node->set_attr_mode(0);
auto pad_value_node = graph->Add(out_name + "/pad_value", pad_value);
pad2d_op->set_input_constant_values(*pad_value_node->data());
pad2d_op->set_attr_T(0); // type of pad_value: 0:float 3:int32
pad2d_op->set_attr_mode(0);
} else if (mode == "reflect") {
LOG(WARNING) << "[NPU] pad mode " << mode << " isn't supported in HiAI DDK";
pad2d_node->set_attr_mode(1);
pad2d_op->set_attr_mode(1);
return FAILED;
} else {
LOG(WARNING) << "[NPU] pad mode " << mode << " isn't supported in HiAI DDK";
......@@ -87,6 +86,6 @@ int Pad2dConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU,
pad2d,
REGISTER_SUBGRAPH_BRIDGE(pad2d,
kNPU,
paddle::lite::subgraph::npu::Pad2dConverter);
......@@ -14,40 +14,40 @@
#pragma once
USE_SUBGRAPH_BRIDGE(NPU, sigmoid);
USE_SUBGRAPH_BRIDGE(NPU, relu);
USE_SUBGRAPH_BRIDGE(NPU, tanh);
USE_SUBGRAPH_BRIDGE(NPU, relu_clipped);
USE_SUBGRAPH_BRIDGE(NPU, leaky_relu);
USE_SUBGRAPH_BRIDGE(NPU, softsign);
USE_SUBGRAPH_BRIDGE(NPU, hard_sigmoid);
USE_SUBGRAPH_BRIDGE(sigmoid, kNPU);
USE_SUBGRAPH_BRIDGE(relu, kNPU);
USE_SUBGRAPH_BRIDGE(tanh, kNPU);
USE_SUBGRAPH_BRIDGE(relu_clipped, kNPU);
USE_SUBGRAPH_BRIDGE(leaky_relu, kNPU);
USE_SUBGRAPH_BRIDGE(softsign, kNPU);
USE_SUBGRAPH_BRIDGE(hard_sigmoid, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, batch_norm);
USE_SUBGRAPH_BRIDGE(NPU, concat);
USE_SUBGRAPH_BRIDGE(NPU, conv2d);
USE_SUBGRAPH_BRIDGE(NPU, depthwise_conv2d);
USE_SUBGRAPH_BRIDGE(NPU, conv2d_transpose);
USE_SUBGRAPH_BRIDGE(batch_norm, kNPU);
USE_SUBGRAPH_BRIDGE(concat, kNPU);
USE_SUBGRAPH_BRIDGE(conv2d, kNPU);
USE_SUBGRAPH_BRIDGE(depthwise_conv2d, kNPU);
USE_SUBGRAPH_BRIDGE(conv2d_transpose, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, elementwise_add);
USE_SUBGRAPH_BRIDGE(NPU, fusion_elementwise_add_activation);
USE_SUBGRAPH_BRIDGE(NPU, elementwise_sub);
USE_SUBGRAPH_BRIDGE(NPU, elementwise_mul);
USE_SUBGRAPH_BRIDGE(NPU, elementwise_div);
USE_SUBGRAPH_BRIDGE(elementwise_add, kNPU);
USE_SUBGRAPH_BRIDGE(fusion_elementwise_add_activation, kNPU);
USE_SUBGRAPH_BRIDGE(elementwise_sub, kNPU);
USE_SUBGRAPH_BRIDGE(elementwise_mul, kNPU);
USE_SUBGRAPH_BRIDGE(elementwise_div, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, fc);
USE_SUBGRAPH_BRIDGE(NPU, bilinear_interp);
USE_SUBGRAPH_BRIDGE(NPU, nearest_interp);
USE_SUBGRAPH_BRIDGE(NPU, mul);
USE_SUBGRAPH_BRIDGE(NPU, pad2d);
USE_SUBGRAPH_BRIDGE(NPU, pool2d);
USE_SUBGRAPH_BRIDGE(NPU, reduce_mean);
USE_SUBGRAPH_BRIDGE(NPU, reshape);
USE_SUBGRAPH_BRIDGE(NPU, reshape2);
USE_SUBGRAPH_BRIDGE(NPU, scale);
USE_SUBGRAPH_BRIDGE(NPU, shuffle_channel);
USE_SUBGRAPH_BRIDGE(NPU, softmax);
USE_SUBGRAPH_BRIDGE(NPU, split);
USE_SUBGRAPH_BRIDGE(NPU, sqrt);
USE_SUBGRAPH_BRIDGE(NPU, square);
USE_SUBGRAPH_BRIDGE(NPU, transpose);
USE_SUBGRAPH_BRIDGE(NPU, transpose2);
USE_SUBGRAPH_BRIDGE(fc, kNPU);
USE_SUBGRAPH_BRIDGE(bilinear_interp, kNPU);
USE_SUBGRAPH_BRIDGE(nearest_interp, kNPU);
USE_SUBGRAPH_BRIDGE(mul, kNPU);
USE_SUBGRAPH_BRIDGE(pad2d, kNPU);
USE_SUBGRAPH_BRIDGE(pool2d, kNPU);
USE_SUBGRAPH_BRIDGE(reduce_mean, kNPU);
USE_SUBGRAPH_BRIDGE(reshape, kNPU);
USE_SUBGRAPH_BRIDGE(reshape2, kNPU);
USE_SUBGRAPH_BRIDGE(scale, kNPU);
USE_SUBGRAPH_BRIDGE(shuffle_channel, kNPU);
USE_SUBGRAPH_BRIDGE(softmax, kNPU);
USE_SUBGRAPH_BRIDGE(split, kNPU);
USE_SUBGRAPH_BRIDGE(sqrt, kNPU);
USE_SUBGRAPH_BRIDGE(square, kNPU);
USE_SUBGRAPH_BRIDGE(transpose, kNPU);
USE_SUBGRAPH_BRIDGE(transpose2, kNPU);
......@@ -48,11 +48,11 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto paddings = op_info->GetAttr<std::vector<int>>("paddings");
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
// pool mode
......@@ -109,19 +109,19 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) {
}
// Pooling node
auto pool_node = graph->AddNode<ge::op::Pooling>(out_name);
pool_node->set_input_x(*x_node);
pool_node->set_attr_mode(mode);
pool_node->set_attr_pad_mode(pad_mode);
pool_node->set_attr_global_pooling(global_pooling);
pool_node->set_attr_window(
ge::AttrValue::LIST_INT(ksize.begin(), ksize.end()));
pool_node->set_attr_pad(ge::AttrValue::LIST_INT{
auto pool_node = graph->Add<ge::op::Pooling>(out_name);
auto pool_op = pool_node->data<ge::op::Pooling>();
pool_op->set_input_x(*x_node->data());
pool_op->set_attr_mode(mode);
pool_op->set_attr_pad_mode(pad_mode);
pool_op->set_attr_global_pooling(global_pooling);
pool_op->set_attr_window(ge::AttrValue::LIST_INT(ksize.begin(), ksize.end()));
pool_op->set_attr_pad(ge::AttrValue::LIST_INT{
paddings[0], paddings[1], paddings[2], paddings[3]});
pool_node->set_attr_stride(
pool_op->set_attr_stride(
ge::AttrValue::LIST_INT(strides.begin(), strides.end()));
pool_node->set_attr_ceil_mode(ceil_mode);
// pool_node->set_attr_data_mode(data_mode);
pool_op->set_attr_ceil_mode(ceil_mode);
// pool_op->set_attr_data_mode(data_mode);
return REBUILD_WHEN_SHAPE_CHANGED;
}
......@@ -130,6 +130,6 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU,
pool2d,
REGISTER_SUBGRAPH_BRIDGE(pool2d,
kNPU,
paddle::lite::subgraph::npu::PoolConverter);
......@@ -52,29 +52,30 @@ int ReduceMeanConverter(void* ctx, OpLite* op, KernelBase* kernel) {
std::sort(dim.begin(), dim.end());
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
// Using ReduceSum + Scale to implement ReduceMean
// Dim node
auto dim_const_node = graph->AddNode(out_name + "/dim", dim);
auto dim_node = graph->Add(out_name + "/dim", dim);
// Reduce Sum node
auto reduce_sum_node =
graph->AddNode<ge::op::ReduceSum>(out_name + "/reducesum");
reduce_sum_node->set_input_x(*x_node);
reduce_sum_node->set_input_w(*dim_const_node);
reduce_sum_node->set_attr_keep_dims(keep_dim);
auto reduce_sum_node = graph->Add<ge::op::ReduceSum>(out_name + "/reducesum");
auto reduce_sum_op = reduce_sum_node->data<ge::op::ReduceSum>();
reduce_sum_op->set_input_x(*x_node->data());
reduce_sum_op->set_input_w(*dim_node->data());
reduce_sum_op->set_attr_keep_dims(keep_dim);
// Scale node
auto scale_node = graph->AddNode<ge::op::Scale>(out_name);
scale_node->set_input_x(*reduce_sum_node);
scale_node->set_attr_axis(1);
auto scale_node = graph->Add<ge::op::Scale>(out_name);
auto scale_op = scale_node->data<ge::op::Scale>();
scale_op->set_input_x(*reduce_sum_node->data());
scale_op->set_attr_axis(1);
// Add filter node(fill with scale)
float scale = 1;
......@@ -95,9 +96,8 @@ int ReduceMeanConverter(void* ctx, OpLite* op, KernelBase* kernel) {
remove(scale_bias_shape.begin(), scale_bias_shape.end(), kDelFlag),
scale_bias_shape.end());
}
auto filter_const_node =
graph->AddNode(out_name + "/filter", scale, scale_bias_shape);
scale_node->set_input_filter(*filter_const_node);
auto filter_node = graph->Add(out_name + "/filter", scale, scale_bias_shape);
scale_op->set_input_filter(*filter_node->data());
return REBUILD_WHEN_SHAPE_CHANGED;
}
......@@ -106,6 +106,6 @@ int ReduceMeanConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU,
reduce_mean,
REGISTER_SUBGRAPH_BRIDGE(reduce_mean,
kNPU,
paddle::lite::subgraph::npu::ReduceMeanConverter);
......@@ -24,27 +24,27 @@ Registry& Registry::Instance() {
return x;
}
void Registry::Insert(const std::string& dev_type,
const std::string& op_type,
void Registry::Insert(const std::string& op_type,
const std::string& target,
const cvt_func_type& cvt_func_name) {
auto it = map_.find(dev_type);
auto it = map_.find(target);
if (it == map_.end()) {
map_.insert(std::make_pair(
dev_type, std::unordered_map<std::string, cvt_func_type>()));
target, std::unordered_map<std::string, cvt_func_type>()));
}
map_.at(dev_type).insert(std::make_pair(op_type, cvt_func_name));
map_.at(target).insert(std::make_pair(op_type, cvt_func_name));
}
const cvt_func_type& Registry::Select(const std::string& dev_type,
const std::string& op_type) const {
return map_.at(dev_type).at(op_type);
const cvt_func_type& Registry::Select(const std::string& op_type,
const std::string& target) const {
return map_.at(target).at(op_type);
}
bool Registry::Exists(const std::string& dev_type,
const std::string& op_type) const {
bool found = map_.find(dev_type) != map_.end();
bool Registry::Exists(const std::string& op_type,
const std::string& target) const {
bool found = map_.find(target) != map_.end();
if (found) {
found = map_.at(dev_type).find(op_type) != map_.at(dev_type).end();
found = map_.at(target).find(op_type) != map_.at(target).end();
}
return found;
}
......
......@@ -42,12 +42,12 @@ class Registry {
public:
static Registry& Instance();
void Insert(const std::string& dev_type,
const std::string& op_type,
void Insert(const std::string& op_type,
const std::string& target,
const cvt_func_type& cvt_func_name);
const cvt_func_type& Select(const std::string& dev_type,
const std::string& op_type) const;
bool Exists(const std::string& dev_type, const std::string& op_type) const;
const cvt_func_type& Select(const std::string& op_type,
const std::string& target) const;
bool Exists(const std::string& op_type, const std::string& target) const;
Registry() = default;
private:
......@@ -73,18 +73,18 @@ class Registry {
__test_global_namespace_##uniq_name##__>::value, \
msg)
#define REGISTER_SUBGRAPH_BRIDGE(dev_type, op_type, cvt_func_name) \
#define REGISTER_SUBGRAPH_BRIDGE(op_type__, target__, cvt_func_name) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_subgraph_bridge_##dev_type##_##op_type##__, \
__reg_subgraph_bridge_##op_type__##_##target__##__, \
"REGISTER_SUBGRAPH_BRIDGE must be called in global namespace only " \
"once!"); \
int __reg_subgraph_bridge_##dev_type##_##op_type##_Insert() { \
int __reg_subgraph_bridge_##op_type__##_##target__##_Insert() { \
paddle::lite::subgraph::Registry::Instance().Insert( \
#dev_type, #op_type, cvt_func_name); \
#op_type__, #target__, cvt_func_name); \
return 0; \
}
#define USE_SUBGRAPH_BRIDGE(dev_type, op_type) \
extern int __reg_subgraph_bridge_##dev_type##_##op_type##_Insert(); \
static int __reg_subgraph_bridge_##dev_type##_##op_type##_Insert_return \
UNUSED = __reg_subgraph_bridge_##dev_type##_##op_type##_Insert();
#define USE_SUBGRAPH_BRIDGE(op_type__, target__) \
extern int __reg_subgraph_bridge_##op_type__##_##target__##_Insert(); \
static int __reg_subgraph_bridge_##op_type__##_##target__##_Insert_return \
UNUSED = __reg_subgraph_bridge_##op_type__##_##target__##_Insert();
......@@ -44,16 +44,17 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
// Reshape node
auto reshape_node = graph->AddNode<ge::op::Reshape>(out_name);
reshape_node->set_input_tensor(*x_node);
auto reshape_node = graph->Add<ge::op::Reshape>(out_name);
auto reshape_op = reshape_node->data<ge::op::Reshape>();
reshape_op->set_input_tensor(*x_node->data());
// Read shape from "ShapeTensor"(input), or "Shape"(input), or "shape"(attr)
if (HasInputArg(op_info, scope, "ShapeTensor")) {
......@@ -64,9 +65,9 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// auto actual_shape_type = kernel->GetInputDeclType("Shape");
// CHECK(actual_shape_type->precision() == PRECISION(kInt32));
// CHECK(actual_shape_type->layout() == DATALAYOUT(kNCHW));
std::shared_ptr<ge::Operator> actual_shape_node = nullptr;
if (graph->HasNode(actual_shape_name)) {
actual_shape_node = graph->GetNode(actual_shape_name);
std::shared_ptr<Node> actual_shape_node = nullptr;
if (graph->Has(actual_shape_name)) {
actual_shape_node = graph->Get(actual_shape_name);
} else {
auto actual_shape = scope->FindMutableTensor(actual_shape_name);
auto actual_shape_dims = actual_shape->dims();
......@@ -81,12 +82,11 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
"but Shape has "
<< out_shape.size();
}
auto actual_shape_const_node =
graph->AddNode(actual_shape_name,
std::vector<int>(out_shape.begin(), out_shape.end()));
actual_shape_node = actual_shape_const_node;
actual_shape_node =
graph->Add(actual_shape_name,
std::vector<int>(out_shape.begin(), out_shape.end()));
}
reshape_node->set_input_w(*actual_shape_node);
reshape_op->set_input_w(*actual_shape_node->data());
} else {
auto shape = op_info->GetAttr<std::vector<int>>("shape");
auto out_dims = lite::operators::ValidateShape(shape, x_dims);
......@@ -96,7 +96,7 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
"but shape has "
<< out_shape.size();
}
reshape_node->set_attr_shape(
reshape_op->set_attr_shape(
ge::AttrValue::LIST_INT(out_shape.begin(), out_shape.end()));
}
......@@ -117,9 +117,10 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// auto xshape_type = kernel->GetOutputDeclType("XShape");
// CHECK(xshape_type->precision() == PRECISION(kFloat));
// CHECK(xshape_type->layout() == DATALAYOUT(kNCHW));
auto xshape_node = graph->AddNode<ge::op::Reshape>(xshape_name);
xshape_node->set_input_tensor(*x_node);
xshape_node->set_attr_shape(
auto xshape_node = graph->Add<ge::op::Reshape>(xshape_name);
auto xshape_op = xshape_node->data<ge::op::Reshape>();
xshape_op->set_input_tensor(*x_node->data());
xshape_op->set_attr_shape(
ge::AttrValue::LIST_INT(xshape_dims.begin(), xshape_dims.end()));
}
return REBUILD_WHEN_SHAPE_CHANGED;
......@@ -130,9 +131,9 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU,
reshape,
REGISTER_SUBGRAPH_BRIDGE(reshape,
kNPU,
paddle::lite::subgraph::npu::ReshapeConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
reshape2,
REGISTER_SUBGRAPH_BRIDGE(reshape2,
kNPU,
paddle::lite::subgraph::npu::ReshapeConverter);
......@@ -37,12 +37,15 @@ int ScaleConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
CHECK_GE(x_dims.size(), 2);
auto x_rank = x_dims.size();
CHECK_GE(x_rank, 2);
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
std::vector<int64_t> scale_bias_shape = {x_dims[1]};
// HiAI only support [n, c, 1, 1] for the shape of scale and bias
std::vector<int64_t> scale_bias_shape = {
1, x_rank < 3 ? 1 : x_dims[x_rank - 3], 1, 1};
float scale = op_info->GetAttr<float>("scale");
float bias = op_info->GetAttr<float>("bias");
bool bias_after_scale = op_info->GetAttr<bool>("bias_after_scale");
......@@ -51,29 +54,28 @@ int ScaleConverter(void* ctx, OpLite* op, KernelBase* kernel) {
}
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x, CvtShape(x_dims));
}
// Scale node
auto scale_node = graph->AddNode<ge::op::Scale>(out_name);
scale_node->set_input_x(*x_node);
scale_node->set_attr_axis(1);
auto scale_node = graph->Add<ge::op::Scale>(out_name);
auto scale_op = scale_node->data<ge::op::Scale>();
scale_op->set_input_x(*x_node->data());
scale_op->set_attr_axis(1);
// Add filter node(fill with scale)
auto filter_const_node =
graph->AddNode(out_name + "/filter", scale, scale_bias_shape);
scale_node->set_input_filter(*filter_const_node);
auto filter_node = graph->Add(out_name + "/filter", scale, scale_bias_shape);
scale_op->set_input_filter(*filter_node->data());
// Add bias node(fill with bias)
if (fabs(bias) > 1e-6f) {
auto bias_const_node =
graph->AddNode(out_name + "/bias", bias, scale_bias_shape);
scale_node->set_input_bias(*bias_const_node);
scale_node->set_attr_has_bias_value(true);
auto bias_node = graph->Add(out_name + "/bias", bias, scale_bias_shape);
scale_op->set_input_bias(*bias_node->data());
scale_op->set_attr_has_bias_value(true);
}
return REBUILD_WHEN_SHAPE_CHANGED;
}
......@@ -83,6 +85,6 @@ int ScaleConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU,
scale,
REGISTER_SUBGRAPH_BRIDGE(scale,
kNPU,
paddle::lite::subgraph::npu::ScaleConverter);
......@@ -44,17 +44,19 @@ int ShuffleChannelConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto group = op_info->GetAttr<int>("group");
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
// Shuffle Channel node
auto shuffle_channel_node = graph->AddNode<ge::op::ShuffleChannel>(out_name);
shuffle_channel_node->set_input_x(*x_node);
shuffle_channel_node->set_attr_group(group);
auto shuffle_channel_node = graph->Add<ge::op::ShuffleChannel>(out_name);
auto shuffle_channel_op =
shuffle_channel_node->data<ge::op::ShuffleChannel>();
shuffle_channel_op->set_input_x(*x_node->data());
shuffle_channel_op->set_attr_group(group);
return SUCCESS;
}
......@@ -63,6 +65,6 @@ int ShuffleChannelConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU,
shuffle_channel,
REGISTER_SUBGRAPH_BRIDGE(shuffle_channel,
kNPU,
paddle::lite::subgraph::npu::ShuffleChannelConverter);
......@@ -37,29 +37,34 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto x_rank = x_dims.size();
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
auto axis = op_info->GetAttr<int>("axis");
if (x_dims.size() > 3) {
CHECK(!(axis == 2 && x_dims[3] > 1))
<< "[NPU] Unsupported softmax params: axis = " << axis
<< " :x_w = " << x_dims[3];
if (axis < 0) {
axis += x_rank;
}
if (axis == 2 && x_rank > 3 && x_dims[3] != 1) {
LOG(WARNING) << "[NPU] Unsupported softmax params: axis = " << axis
<< " :x_w = " << x_dims[3];
return FAILED;
}
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
// Softmax node
auto softmax_node = graph->AddNode<ge::op::Softmax>(out_name);
softmax_node->set_input_x(*x_node);
softmax_node->set_attr_axis(axis);
auto softmax_node = graph->Add<ge::op::Softmax>(out_name);
auto softmax_op = softmax_node->data<ge::op::Softmax>();
softmax_op->set_input_x(*x_node->data());
softmax_op->set_attr_axis(axis);
return REBUILD_WHEN_SHAPE_CHANGED;
}
......@@ -68,6 +73,6 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU,
softmax,
REGISTER_SUBGRAPH_BRIDGE(softmax,
kNPU,
paddle::lite::subgraph::npu::SoftmaxConverter);
......@@ -47,33 +47,34 @@ int SplitConverter(void* ctx, OpLite* op, KernelBase* kernel) {
int64_t sections_num = static_cast<int64_t>(sections.size());
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
// Split node
auto split_node = graph->AddNode<ge::op::Split>(op_type + "/" + x_name);
split_node->set_input_x(*x_node);
split_node->set_attr_axis(static_cast<int64_t>(axis));
auto split_node = graph->Add<ge::op::Split>(op_type + "/" + x_name);
auto split_op = split_node->data<ge::op::Split>();
split_op->set_input_x(*x_node->data());
split_op->set_attr_axis(static_cast<int64_t>(axis));
if (num > 0) {
split_node->set_attr_output_num(static_cast<int64_t>(num));
split_op->set_attr_output_num(static_cast<int64_t>(num));
} else {
split_node->set_attr_output_num(sections_num);
split_op->set_attr_output_num(sections_num);
auto size_split = ge::AttrValue::LIST_INT(sections.begin(), sections.end());
split_node->set_attr_size_split(size_split);
split_op->set_attr_size_split(size_split);
}
split_node->create_dynamic_output_y(out_names.size());
split_op->create_dynamic_output_y(out_names.size());
int idx = 1;
for (auto& out_name : out_names) {
auto zero_const_node =
graph->AddNode(out_name + "/zero" + std::to_string(idx), 0);
auto add_node = graph->AddNode<ge::op::Add>(out_name);
add_node->set_input_x1(*split_node, "y" + std::to_string(idx));
add_node->set_input_x2(*zero_const_node);
auto zero_node = graph->Add(out_name + "/zero" + std::to_string(idx), 0);
auto add_node = graph->Add<ge::op::Add>(out_name);
auto add_op = add_node->data<ge::op::Add>();
add_op->set_input_x1(*split_node->data(), "y" + std::to_string(idx));
add_op->set_input_x2(*zero_node->data());
idx++;
}
return REBUILD_WHEN_SHAPE_CHANGED;
......@@ -84,6 +85,6 @@ int SplitConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU,
split,
REGISTER_SUBGRAPH_BRIDGE(split,
kNPU,
paddle::lite::subgraph::npu::SplitConverter);
......@@ -43,16 +43,17 @@ int SqrtConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
// Sqrt node
auto sqrt_node = graph->AddNode<ge::op::Sqrt>(out_name);
sqrt_node->set_input_x(*x_node);
auto sqrt_node = graph->Add<ge::op::Sqrt>(out_name);
auto sqrt_op = sqrt_node->data<ge::op::Sqrt>();
sqrt_op->set_input_x(*x_node->data());
return SUCCESS;
}
......@@ -61,4 +62,6 @@ int SqrtConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU, sqrt, paddle::lite::subgraph::npu::SqrtConverter);
REGISTER_SUBGRAPH_BRIDGE(sqrt,
kNPU,
paddle::lite::subgraph::npu::SqrtConverter);
......@@ -43,16 +43,17 @@ int SquareConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
// Square node
auto square_node = graph->AddNode<ge::op::Square>(out_name);
square_node->set_input_x(*x_node);
auto square_node = graph->Add<ge::op::Square>(out_name);
auto square_op = square_node->data<ge::op::Square>();
square_op->set_input_x(*x_node->data());
return SUCCESS;
}
......@@ -61,6 +62,6 @@ int SquareConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU,
square,
REGISTER_SUBGRAPH_BRIDGE(square,
kNPU,
paddle::lite::subgraph::npu::SquareConverter);
......@@ -41,19 +41,20 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto axis = op_info->GetAttr<std::vector<int>>("axis");
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
// Transpose node
auto transpose_node = graph->AddNode<ge::op::Permute>(out_name);
transpose_node->set_input_x(*x_node);
auto w_const_node = graph->AddNode(out_name + "/w", 1.0f);
transpose_node->set_input_w(*w_const_node);
transpose_node->set_attr_order(
auto transpose_node = graph->Add<ge::op::Permute>(out_name);
auto transpose_op = transpose_node->data<ge::op::Permute>();
transpose_op->set_input_x(*x_node->data());
auto w_node = graph->Add(out_name + "/w", 1.0f);
transpose_op->set_input_w(*w_node->data());
transpose_op->set_attr_order(
ge::AttrValue::LIST_INT(axis.begin(), axis.end()));
return SUCCESS;
}
......@@ -63,9 +64,9 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU,
transpose,
REGISTER_SUBGRAPH_BRIDGE(transpose,
kNPU,
paddle::lite::subgraph::npu::TransposeConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
transpose2,
REGISTER_SUBGRAPH_BRIDGE(transpose2,
kNPU,
paddle::lite::subgraph::npu::TransposeConverter);
......@@ -45,17 +45,18 @@ int UnsqueezeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
<< "[NPU] unsqueeze not support axes from tensor now";
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
// Unsqueeze node
auto unsqueeze_node = graph->AddNode<ge::op::Reshape>(out_name);
unsqueeze_node->set_input_tensor(*x_node);
unsqueeze_node->set_attr_shape(
auto unsqueeze_node = graph->Add<ge::op::Reshape>(out_name);
auto unsqueeze_op = unsqueeze_node->data<ge::op::Reshape>();
unsqueeze_op->set_input_tensor(*x_node->data());
unsqueeze_op->set_attr_shape(
ge::AttrValue::LIST_INT(out_shape.begin(), out_shape.end()));
return REBUILD_WHEN_SHAPE_CHANGED;
}
......@@ -65,9 +66,9 @@ int UnsqueezeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU,
unsqueeze,
REGISTER_SUBGRAPH_BRIDGE(unsqueeze,
kNPU,
paddle::lite::subgraph::npu::UnsqueezeConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
unsqueeze2,
REGISTER_SUBGRAPH_BRIDGE(unsqueeze2,
kNPU,
paddle::lite::subgraph::npu::UnsqueezeConverter);
......@@ -85,6 +85,22 @@ ge::Format CvtDataLayoutType(DataLayoutType itype) {
return otype;
}
std::vector<int64_t> CvtShape(const std::vector<int64_t>& in_shape) {
std::vector<int64_t> out_shape;
// Padding the shape to 4-dimensions(NCHW)
for (int i = 0; i < 4 - in_shape.size(); i++) {
out_shape.push_back(1);
}
for (int i = 0; i < in_shape.size(); i++) {
out_shape.push_back(in_shape[i]);
}
return out_shape;
}
std::vector<int64_t> CvtShape(const DDim& in_dims) {
return CvtShape(in_dims.Vectorize());
}
ge::TensorPtr CvtTensor(const Tensor& in_tensor,
std::vector<int64_t> out_shape,
PrecisionType in_precision,
......
......@@ -19,12 +19,12 @@
#include <string>
#include <unordered_map>
#include <vector>
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
#include "ai_ddk_lib/include/graph/op/all_ops.h"
#include "ai_ddk_lib/include/graph/operator.h"
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "graph/buffer.h"
#include "graph/graph.h"
#include "graph/model.h"
#include "graph/op/all_ops.h"
#include "graph/operator.h"
#include "graph/operator_reg.h"
#include "lite/core/op_lite.h"
#include "lite/utils/macros.h"
......@@ -70,59 +70,16 @@ ge::DataType CvtPrecisionType(PrecisionType itype);
ge::Format CvtDataLayoutType(DataLayoutType itype);
// Padding the shape to 4-dimensions(NCHW) for HiAI
std::vector<int64_t> CvtShape(const std::vector<int64_t>& in_shape);
std::vector<int64_t> CvtShape(const DDim& in_dims);
ge::TensorPtr CvtTensor(const Tensor& in_tensor,
std::vector<int64_t> out_shape = {},
PrecisionType in_precision = PRECISION(kFloat),
DataLayoutType in_layout = DATALAYOUT(kNCHW));
template <typename T>
ge::TensorPtr CreateTensorAndFillData(const std::vector<T>& data,
std::vector<int64_t> shape = {},
ge::Format format = ge::FORMAT_NCHW) {
const std::type_info& info = typeid(T);
ge::DataType type = ge::DT_FLOAT;
if (info == typeid(float)) {
type = ge::DT_FLOAT;
} else if (info == typeid(int8_t)) {
type = ge::DT_INT8;
} else if (info == typeid(int16_t)) {
type = ge::DT_INT16;
} else if (info == typeid(int32_t)) {
type = ge::DT_INT32;
} else if (info == typeid(int64_t)) {
type = ge::DT_INT64;
} else {
LOG(FATAL) << "[NPU] Unknow value type " << info.name();
}
if (shape.empty()) {
shape = {static_cast<int64_t>(data.size())};
} else {
int size = 1;
for (auto i : shape) {
size *= i;
}
CHECK_EQ(data.size(), size);
}
ge::TensorDesc desc(ge::Shape(shape), format, type);
ge::TensorPtr tensor = std::make_shared<ge::Tensor>();
tensor->SetTensorDesc(desc);
tensor->SetData(reinterpret_cast<uint8_t*>(data.data()),
data.size() * sizeof(T));
return tensor;
}
template <typename T>
ge::TensorPtr CreateTensorAndFillData(T value,
std::vector<int64_t> shape = {1},
ge::Format format = ge::FORMAT_NCHW) {
int64_t size = 1;
for (auto i : shape) {
size *= i;
}
std::vector<T> data(size, value);
return CreateTensorAndFillData(data, shape, format);
}
int CvtActMode(std::string act_type);
} // namespace npu
......
......@@ -16,7 +16,7 @@
#include <sys/time.h>
#include <time.h>
#include <utility>
#include "ai_ddk_lib/include/hiai_ir_build.h"
#include "hiai_ir_build.h" // NOLINT
#include "lite/backends/npu/device.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/npu/bridges/graph.h"
......@@ -39,13 +39,13 @@ int SubgraphEngine::BuildDeviceProgram() {
op->CheckShape();
op->InferShape();
std::string op_type = op->op_info()->Type();
if (!bridges.Exists("NPU", op_type)) {
if (!bridges.Exists(op_type, "kNPU")) {
return subgraph::FAILED;
}
auto kernel = inst.kernel();
status |= bridges.Select("NPU", op_type)(reinterpret_cast<void*>(&graph),
const_cast<OpLite*>(op),
const_cast<KernelBase*>(kernel));
status |= bridges.Select(op_type, "kNPU")(reinterpret_cast<void*>(&graph),
const_cast<OpLite*>(op),
const_cast<KernelBase*>(kernel));
if (subgraph::CHECK_FAILED(status)) {
return subgraph::FAILED;
}
......@@ -57,26 +57,26 @@ int SubgraphEngine::BuildDeviceProgram() {
std::vector<ge::Operator> device_inodes;
std::vector<ge::Operator> device_onodes;
for (auto& input_name : input_names_) {
if (graph.HasNode(input_name)) {
if (!graph.GetType(input_name).persistable()) {
device_inodes.push_back(*graph.GetNode(input_name));
if (graph.Has(input_name)) {
if (graph.Get(input_name)->is_data()) {
device_inodes.push_back(*graph.Get(input_name)->data());
device_inames_.push_back(input_name);
} else {
LOG(WARNING) << "[NPU] Input node " << input_name
<< " is skipped because it is a persistable node.";
<< " is ignored because it is not a data node.";
}
} else {
LOG(WARNING) << "[NPU] Input node " << input_name
<< " is skipped because it does not exist.";
<< " is ignored because it does not exist.";
}
}
for (auto& output_name : output_names_) {
if (graph.HasNode(output_name)) {
device_onodes.push_back(*graph.GetNode(output_name));
if (graph.Has(output_name)) {
device_onodes.push_back(*graph.Get(output_name)->data());
device_onames_.push_back(output_name);
} else {
LOG(WARNING) << "[NPU] Output node " << output_name
<< " is skipped because it does not exist.";
<< " is ignored because it does not exist.";
}
}
CHECK(!device_inames_.empty())
......@@ -108,14 +108,14 @@ int SubgraphEngine::BuildDeviceProgram() {
origin_otensors_.resize(device_onames_.size());
device_otensors_.resize(device_onames_.size());
for (int i = 0; i < device_inames_.size(); i++) {
auto type = graph.GetType(device_inames_[i]);
auto precision = type.precision();
auto layout = type.layout();
auto node = graph.Get(device_inames_[i]);
auto precision = node->precision();
auto layout = node->layout();
origin_itensors_[i] = scope_->FindMutableTensor(device_inames_[i]);
CHECK(origin_itensors_[i]);
origin_idims_[i] = origin_itensors_[i]->dims();
VLOG(3) << "[NPU] Inputs[" << i
<< "] precision: " << PrecisionToStr(precision)
VLOG(3) << "[NPU] Inputs[" << i << "] name: " << device_inames_[i]
<< " precision: " << PrecisionToStr(precision)
<< " layout: " << DataLayoutToStr(layout) << " dims: {"
<< device_idims[i].GetNumber() << ","
<< device_idims[i].GetChannel() << ","
......@@ -129,14 +129,14 @@ int SubgraphEngine::BuildDeviceProgram() {
device_itensors_[i]->Init(&(device_idims[i]));
}
for (int i = 0; i < device_onames_.size(); i++) {
auto type = graph.GetType(device_onames_[i]);
auto precision = type.precision();
auto layout = type.layout();
auto node = graph.Get(device_onames_[i]);
auto precision = node->precision();
auto layout = node->layout();
origin_otensors_[i] = scope_->FindMutableTensor(device_onames_[i]);
CHECK(origin_otensors_[i]);
origin_odims_[i] = origin_otensors_[i]->dims();
VLOG(3) << "[NPU] Outputs[" << i
<< "] precision: " << PrecisionToStr(precision)
VLOG(3) << "[NPU] Outputs[" << i << "] name: " << device_onames_[i]
<< " precision: " << PrecisionToStr(precision)
<< " layout: " << DataLayoutToStr(layout) << " dims: {"
<< device_odims[i].GetNumber() << ","
<< device_odims[i].GetChannel() << ","
......
......@@ -17,7 +17,7 @@
#include <memory>
#include <string>
#include <vector>
#include "ai_ddk_lib/include/HiAiModelManagerService.h"
#include "HiAiModelManagerService.h"
#include "lite/core/kernel.h"
#include "lite/kernels/npu/bridges/engine.h"
#include "lite/kernels/npu/bridges/registry.h"
......
......@@ -43,20 +43,21 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
// X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
// Act node
if (op_type == "relu") {
graph->AddNode(out_name, graph->builder_.CreateRelu(*x_node));
graph->Add(out_name, graph->builder_.CreateRelu(*x_node->data()));
} else if (op_type == "tanh") {
graph->AddNode(out_name, graph->builder_.CreateUnaryOp("tanh", *x_node));
graph->Add(out_name,
graph->builder_.CreateUnaryOp("tanh", *x_node->data()));
} else if (op_type == "gelu") {
graph->AddNode(out_name, graph->builder_.CreateGelu(*x_node));
graph->Add(out_name, graph->builder_.CreateGelu(*x_node->data()));
} else {
// TODO(hong19860320) supports more activation ops
LOG(WARNING) << "[XPU] Unsupported activation type " << op_type;
......@@ -70,6 +71,6 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU, relu, paddle::lite::subgraph::xpu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(XPU, tanh, paddle::lite::subgraph::xpu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(XPU, gelu, paddle::lite::subgraph::xpu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(relu, kXPU, paddle::lite::subgraph::xpu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(tanh, kXPU, paddle::lite::subgraph::xpu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(gelu, kXPU, paddle::lite::subgraph::xpu::ActConverter);
......@@ -64,28 +64,28 @@ int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto epsilon = op_info->GetAttr<float>("epsilon");
// X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
// Scale, Bias, Mean, Variance node
auto scale_const_node = graph->AddNode(scale_name, *scale);
auto bias_const_node = graph->AddNode(bias_name, *bias);
auto mean_const_node = graph->AddNode(mean_name, *mean);
auto variance_const_node = graph->AddNode(variance_name, *variance);
auto scale_node = graph->Add(scale_name, *scale);
auto bias_node = graph->Add(bias_name, *bias);
auto mean_node = graph->Add(mean_name, *mean);
auto variance_node = graph->Add(variance_name, *variance);
// Batch Norm node and extract the first field as the output node
auto batch_norm_node = graph->builder_.CreateBatchNorm(*x_node,
*scale_const_node,
*bias_const_node,
*mean_const_node,
*variance_const_node,
auto batch_norm_data = graph->builder_.CreateBatchNorm(*x_node->data(),
*scale_node->data(),
*bias_node->data(),
*mean_node->data(),
*variance_node->data(),
1,
epsilon);
graph->AddNode(y_name, graph->builder_.GetField(batch_norm_node, 0));
graph->Add(y_name, graph->builder_.GetField(batch_norm_data, 0));
return SUCCESS;
}
......@@ -94,6 +94,6 @@ int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU,
batch_norm,
REGISTER_SUBGRAPH_BRIDGE(batch_norm,
kXPU,
paddle::lite::subgraph::xpu::BatchNormConverter);
......@@ -61,11 +61,11 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK_EQ(dilations.size(), 2L);
// Input node
std::shared_ptr<xtcl::xExpr> input_node = nullptr;
if (graph->HasNode(input_name)) {
input_node = graph->GetNode(input_name);
std::shared_ptr<Node> input_node = nullptr;
if (graph->Has(input_name)) {
input_node = graph->Get(input_name);
} else {
input_node = graph->AddNode(input_name, input_dims);
input_node = graph->Add(input_name, *input);
}
if (paddings.size() == 2L) {
......@@ -99,7 +99,7 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
DDim output_dims(output_shape);
// Filter node
auto filter_const_node = graph->AddNode(filter_name, *filter);
auto filter_node = graph->Add(filter_name, *filter);
// Conv node
auto conv_attrs = xtcl::make_node<xtcl::network::Conv2DAttrs>();
......@@ -114,9 +114,9 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
conv_attrs->out_layout = "";
// conv_attrs->out_dtype = "";
auto conv_node =
graph->AddNode(output_name,
graph->builder_.CreateConv2D(
*input_node, *filter_const_node, conv_attrs));
graph->Add(output_name,
graph->builder_.CreateConv2D(
*input_node->data(), *filter_node->data(), conv_attrs));
// Add bias node if exists bias
// supports the bias nodes with the following dimensions
......@@ -149,30 +149,27 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
<< " isn't supported in conv2d Op when output dimension is "
<< output_dims;
}
std::shared_ptr<xtcl::xExpr> bias_node = nullptr;
if (graph->HasNode(bias_name)) {
// Bias node from input node
bias_node = graph->GetNode(bias_name);
std::shared_ptr<Node> bias_node = nullptr;
if (graph->Has(bias_name)) {
bias_node = graph->Get(bias_name);
} else {
// Bias node with const data
bias_node = graph->AddNode(bias_name, *bias, bias_shape);
bias_node = graph->Add(bias_name, *bias, bias_shape);
}
std::shared_ptr<xtcl::xExpr> add_node = nullptr;
if (is_channel_bias) {
add_node = graph->AddNode(
output_name,
graph->builder_.CreateBiasAdd(*conv_node, 1, *bias_node));
conv_node = graph->Add(output_name,
graph->builder_.CreateBiasAdd(
*conv_node->data(), 1, *bias_node->data()));
} else {
add_node = graph->AddNode(
output_name,
graph->builder_.CreateBinaryOp("add", *conv_node, *bias_node));
conv_node =
graph->Add(output_name,
graph->builder_.CreateBinaryOp(
"add", *conv_node->data(), *bias_node->data()));
}
conv_node = add_node;
}
if (fuse_relu) {
// Append relu node if fuse_relu is true
graph->AddNode(output_name, graph->builder_.CreateRelu(*conv_node));
graph->Add(output_name, graph->builder_.CreateRelu(*conv_node->data()));
}
return REBUILD_WHEN_SHAPE_CHANGED;
}
......@@ -182,9 +179,9 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU,
conv2d,
REGISTER_SUBGRAPH_BRIDGE(conv2d,
kXPU,
paddle::lite::subgraph::xpu::ConvConverter);
REGISTER_SUBGRAPH_BRIDGE(XPU,
depthwise_conv2d,
REGISTER_SUBGRAPH_BRIDGE(depthwise_conv2d,
kXPU,
paddle::lite::subgraph::xpu::ConvConverter);
......@@ -46,21 +46,21 @@ int DropoutConverter(void* ctx, OpLite* op, KernelBase* kernel) {
op_info->GetAttr<std::string>("dropout_implementation");
// X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
// Dropout node
if (dropout_implementation == "downgrade_in_infer") {
graph->AddNode(
out_name,
graph->builder_.CreateScale(*x_node, 1.f - dropout_prob, 0.0f, false));
graph->Add(out_name,
graph->builder_.CreateScale(
*x_node->data(), 1.f - dropout_prob, 0.0f, false));
} else if (dropout_implementation == "upscale_in_train") {
graph->AddNode(out_name,
graph->builder_.CreateScale(*x_node, 1.0f, 0.0f, false));
graph->Add(out_name,
graph->builder_.CreateScale(*x_node->data(), 1.0f, 0.0f, false));
} else {
LOG(WARNING) << "[XPU] Unsupported dropout_implementation == "
<< dropout_implementation << " for dropout";
......@@ -74,6 +74,6 @@ int DropoutConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU,
dropout,
REGISTER_SUBGRAPH_BRIDGE(dropout,
kXPU,
paddle::lite::subgraph::xpu::DropoutConverter);
......@@ -50,29 +50,31 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto axis = op_info->GetAttr<int>("axis");
// X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
// Y node
std::shared_ptr<xtcl::xExpr> y_node = nullptr;
if (graph->HasNode(y_name)) {
y_node = graph->GetNode(y_name);
std::shared_ptr<Node> y_node = nullptr;
if (graph->Has(y_name)) {
y_node = graph->Get(y_name);
} else {
y_node = graph->AddNode(y_name, y_dims);
y_node = graph->Add(y_name, *y);
}
// Elementwise node
std::shared_ptr<xtcl::xExpr> elementwise_node = nullptr;
std::shared_ptr<Node> elt_node = nullptr;
if (y_dims.size() == 1) {
elementwise_node = graph->AddNode(
out_name, graph->builder_.CreateBiasAdd(*x_node, axis, *y_node));
elt_node = graph->Add(
out_name,
graph->builder_.CreateBiasAdd(*x_node->data(), axis, *y_node->data()));
} else if (x_dims.size() == y_dims.size()) {
elementwise_node = graph->AddNode(
out_name, graph->builder_.CreateBinaryOp("add", *x_node, *y_node));
elt_node = graph->Add(out_name,
graph->builder_.CreateBinaryOp(
"add", *x_node->data(), *y_node->data()));
} else {
LOG(WARNING)
<< "[XPU] elementwise_add only support y of one dimension, or x "
......@@ -88,6 +90,6 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU,
elementwise_add,
REGISTER_SUBGRAPH_BRIDGE(elementwise_add,
kXPU,
paddle::lite::subgraph::xpu::ElementwiseConverter);
......@@ -54,38 +54,39 @@ int GatherConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto out_dims = out->dims();
// X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
// Index node
std::shared_ptr<xtcl::xExpr> index_node = nullptr;
if (graph->HasNode(index_name)) {
index_node = graph->GetNode(index_name);
std::shared_ptr<Node> index_node = nullptr;
if (graph->Has(index_name)) {
index_node = graph->Get(index_name);
} else {
index_node = graph->AddNode(
index_name, index_dims, index_type->precision(), index_type->layout());
index_node = graph->Add(
index_name, *index, index_type->precision(), index_type->layout());
}
// Flatten index node
if (index_dims.size() != 1) {
index_node =
graph->AddNode(index_name + "/reshape",
graph->builder_.CreateReshape(*index_node, {-1}),
index_type->precision(),
index_type->layout());
graph->Add(index_name + "/reshape",
graph->builder_.CreateReshape(*index_node->data(), {-1}),
index_type->precision(),
index_type->layout());
}
// Reshape the gather node with the inferred shape as the output node
auto gather_node = graph->AddNode(
out_name,
graph->builder_.CreateGather(*x_node, *index_node, /* axis= */ 0));
auto gather_node =
graph->Add(out_name,
graph->builder_.CreateGather(
*x_node->data(), *index_node->data(), /* axis= */ 0));
if (out_dims.size() != 2) {
graph->AddNode(out_name,
graph->builder_.CreateReshape(
*gather_node, CvtShape<xtcl::Integer>(out_dims)));
graph->Add(out_name,
graph->builder_.CreateReshape(
*gather_node->data(), CvtShape<xtcl::Integer>(out_dims)));
}
return SUCCESS;
}
......@@ -95,6 +96,6 @@ int GatherConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU,
gather,
REGISTER_SUBGRAPH_BRIDGE(gather,
kXPU,
paddle::lite::subgraph::xpu::GatherConverter);
......@@ -21,71 +21,71 @@ namespace lite {
namespace subgraph {
namespace xpu {
std::shared_ptr<xtcl::xExpr> Graph::AddNode(const std::string& name,
const xtcl::xExpr& layer,
PrecisionType precision,
DataLayoutType layout) {
auto unique_name = [&](const std::string& key) {
int idx = 1;
auto it = counts_.find(key);
if (it == counts_.end()) {
counts_.insert(std::make_pair(key, idx));
} else {
idx = ++(it->second);
}
return key + "_" + std::to_string(idx);
};
int Graph::Add(const std::string& name, std::shared_ptr<Node> node) {
auto it = nodes_.find(name);
if (it != nodes_.end()) {
// Only variable can rebind the name
CHECK(!it->second.second.persistable()) << "[XPU] Node " << name
<< " redefined.";
// Generate a new unique name as the key to bind the origin node if the
// origin node isn't a const node: new_name->node
nodes_.insert(std::make_pair(unique_name(name + "_var"), it->second));
nodes_.erase(it);
// Only variable node can be shared with the same name
if (!node->is_var() || !it->second.back()->is_var()) {
LOG(FATAL) << "[XPU] Const or data node " << name << " is redefined.";
return -1;
}
} else {
auto ret = nodes_.insert(
std::make_pair(name, std::vector<std::shared_ptr<Node>>()));
CHECK(ret.second);
it = ret.first;
}
// Create a new node and bind with the name: name->new_node
auto node = std::make_shared<xtcl::xExpr>(layer);
nodes_.insert(std::make_pair(
name, std::make_pair(node, Type(precision, layout, false))));
builder_.SetLayer(unique_name(name + "_op"));
return node;
it->second.push_back(node);
return it->second.size();
}
// Const node
std::shared_ptr<xtcl::xExpr> Graph::AddNode(const std::string& name,
const Tensor& tensor,
PrecisionType precision,
DataLayoutType layout) {
return AddNode(name, tensor, tensor.dims().Vectorize(), precision, layout);
// Variable node
std::shared_ptr<Node> Graph::Add(const std::string& name,
const xtcl::xExpr& layer,
PrecisionType precision,
DataLayoutType layout) {
auto node = std::make_shared<Node>(precision, layout, Node::Role::kVar);
auto idx = Add(name, node);
CHECK_GE(idx, 1);
node->set_data(std::make_shared<xtcl::xExpr>(layer));
// Generate a unique name for the current XTCL layer
builder_.SetLayer(name + "__" + std::to_string(idx));
return node;
}
std::shared_ptr<xtcl::xExpr> Graph::AddNode(const std::string& name,
const Tensor& tensor,
std::vector<int64_t> shape,
PrecisionType precision,
DataLayoutType layout) {
CHECK(!HasNode(name)) << "[NPU] Node " << name << " redefined.";
auto node = std::make_shared<xtcl::xExpr>(builder_.CreateTensor(
name, CvtShape<xtcl::xIndexExpr>(shape), CvtPrecisionType(precision)));
nodes_.insert(std::make_pair(
name, std::make_pair(node, Type(precision, layout, true))));
params_.emplace(
std::make_pair(name, *CvtTensor(tensor, shape, precision, layout)));
// Const or data node
std::shared_ptr<Node> Graph::Add(const std::string& name,
const Tensor& tensor,
std::vector<int64_t> shape,
PrecisionType precision,
DataLayoutType layout) {
std::shared_ptr<Node> node = nullptr;
if (tensor.persistable()) {
// Const node
node = std::make_shared<Node>(precision, layout, Node::Role::kConst);
auto idx = Add(name, node);
CHECK_EQ(idx, 1);
node->set_data(std::make_shared<xtcl::xExpr>(builder_.CreateTensor(
name, CvtShape<xtcl::xIndexExpr>(shape), CvtPrecisionType(precision))));
params_.emplace(
std::make_pair(name, *CvtTensor(tensor, shape, precision, layout)));
} else {
// Data node
node = Add(name, shape, precision, layout);
}
return node;
}
// Data node
std::shared_ptr<xtcl::xExpr> Graph::AddNode(const std::string& name,
std::vector<int64_t> shape,
PrecisionType precision,
DataLayoutType layout) {
CHECK(!HasNode(name)) << "[NPU] Node " << name << " redefined.";
auto node = std::make_shared<xtcl::xExpr>(builder_.CreateTensor(
name, CvtShape<xtcl::xIndexExpr>(shape), CvtPrecisionType(precision)));
nodes_.insert(std::make_pair(
name, std::make_pair(node, Type(precision, layout, false))));
std::shared_ptr<Node> Graph::Add(const std::string& name,
std::vector<int64_t> shape,
PrecisionType precision,
DataLayoutType layout) {
auto node = std::make_shared<Node>(precision, layout, Node::Role::kData);
auto idx = Add(name, node);
CHECK_EQ(idx, 1);
node->set_data(std::make_shared<xtcl::xExpr>(builder_.CreateTensor(
name, CvtShape<xtcl::xIndexExpr>(shape), CvtPrecisionType(precision))));
return node;
}
......
......@@ -28,67 +28,81 @@ namespace lite {
namespace subgraph {
namespace xpu {
// Type of graph nodes
class Type {
// Graph and node is defined to collect all of converted XTCL IR nodes
class Node {
public:
Type(PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW),
bool persistable = false)
: precision_(precision), layout_(layout), persistable_(persistable) {}
enum class Role {
kVar = 0,
kConst,
kData,
};
Node(std::shared_ptr<xtcl::xExpr> data,
PrecisionType precision,
DataLayoutType layout,
Role role)
: data_(data), precision_(precision), layout_(layout), role_(role) {}
Node(PrecisionType precision, DataLayoutType layout, Role role)
: precision_(precision), layout_(layout), role_(role) {}
void set_data(std::shared_ptr<xtcl::xExpr> data) { data_ = data; }
void set_precision(PrecisionType precision) { precision_ = precision; }
void set_layout(DataLayoutType layout) { layout_ = layout; }
void set_persistable(bool persistable) { persistable_ = persistable; }
void set_role(Role role) { role_ = role; }
std::shared_ptr<xtcl::xExpr> data() { return data_; }
PrecisionType precision() const { return precision_; }
DataLayoutType layout() const { return layout_; }
bool persistable() const { return persistable_; }
Role role() const { return role_; }
bool is_var() const { return role_ == Role::kVar; }
bool is_const() const { return role_ == Role::kConst; }
bool is_data() const { return role_ == Role::kData; }
private:
std::shared_ptr<xtcl::xExpr> data_{nullptr};
PrecisionType precision_{PRECISION(kFloat)};
DataLayoutType layout_{DATALAYOUT(kNCHW)};
bool persistable_{false};
Role role_{Role::kVar};
};
// Graph to collect all of converted XPU IR nodes
class Graph {
public:
// Layer node
std::shared_ptr<xtcl::xExpr> AddNode(
const std::string& name,
const xtcl::xExpr& layer,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW));
int Add(const std::string& name, std::shared_ptr<Node> node);
// Variable node
std::shared_ptr<Node> Add(const std::string& name,
const xtcl::xExpr& layer,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW));
// Const or data node
std::shared_ptr<Node> Add(const std::string& name,
const Tensor& tensor,
std::vector<int64_t> shape,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW));
std::shared_ptr<Node> Add(const std::string& name,
const Tensor& tensor,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return Add(name, tensor, tensor.dims().Vectorize(), precision, layout);
}
// Const node
std::shared_ptr<xtcl::xExpr> AddNode(
const std::string& name,
const Tensor& tensor,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW));
std::shared_ptr<xtcl::xExpr> AddNode(
const std::string& name,
const Tensor& tensor,
std::vector<int64_t> shape,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW));
std::shared_ptr<xtcl::xExpr> AddNode(
const std::string& name,
const Tensor& tensor,
DDim dims,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return AddNode(name, tensor, dims.Vectorize(), precision, layout);
std::shared_ptr<Node> Add(const std::string& name,
const Tensor& tensor,
DDim dims,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return Add(name, tensor, dims.Vectorize(), precision, layout);
}
// Const node
template <typename T>
std::shared_ptr<xtcl::xExpr> AddNode(
const std::string& name,
const std::vector<T>& data,
std::vector<int64_t> shape = {},
DataLayoutType layout = DATALAYOUT(kNCHW)) {
std::shared_ptr<Node> Add(const std::string& name,
const std::vector<T>& data,
std::vector<int64_t> shape = {},
DataLayoutType layout = DATALAYOUT(kNCHW)) {
const std::type_info& info = typeid(T);
PrecisionType precision = PRECISION(kFloat);
if (info == typeid(float)) {
......@@ -111,70 +125,61 @@ class Graph {
}
Tensor tensor;
tensor.Resize(shape);
tensor.set_persistable(true);
std::memcpy(reinterpret_cast<uint8_t*>(tensor.mutable_data<T>()),
reinterpret_cast<const uint8_t*>(data.data()),
data.size() * sizeof(T));
return AddNode(name, tensor, precision, layout);
return Add(name, tensor, precision, layout);
}
template <typename T>
std::shared_ptr<xtcl::xExpr> AddNode(
const std::string& name,
const std::vector<T>& data,
DDim dims,
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return AddNode(name, data, dims.Vectorize(), layout);
std::shared_ptr<Node> Add(const std::string& name,
const std::vector<T>& data,
DDim dims,
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return Add(name, data, dims.Vectorize(), layout);
}
template <typename T>
std::shared_ptr<xtcl::xExpr> AddNode(
const std::string& name,
T value,
std::vector<int64_t> shape = {1},
DataLayoutType layout = DATALAYOUT(kNCHW)) {
std::shared_ptr<Node> Add(const std::string& name,
T value,
std::vector<int64_t> shape = {1},
DataLayoutType layout = DATALAYOUT(kNCHW)) {
int64_t size = 1;
for (auto i : shape) {
size *= i;
}
std::vector<T> data(size, value);
return AddNode(name, data, shape, layout);
return Add(name, data, shape, layout);
}
template <typename T>
std::shared_ptr<xtcl::xExpr> AddNode(
const std::string& name,
T value,
DDim dims,
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return AddNode(name, value, dims.Vectorize(), layout);
std::shared_ptr<Node> Add(const std::string& name,
T value,
DDim dims,
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return Add(name, value, dims.Vectorize(), layout);
}
// Data node
std::shared_ptr<xtcl::xExpr> AddNode(
const std::string& name,
std::vector<int64_t> shape,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW));
std::shared_ptr<xtcl::xExpr> AddNode(
const std::string& name,
DDim dims,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return AddNode(name, dims.Vectorize(), precision, layout);
}
std::shared_ptr<xtcl::xExpr> GetNode(const std::string& name) {
CHECK(HasNode(name)) << "[XPU] Node " << name << " not found.";
return nodes_.at(name).first;
std::shared_ptr<Node> Add(const std::string& name,
std::vector<int64_t> shape,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW));
std::shared_ptr<Node> Add(const std::string& name,
DDim dims,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return Add(name, dims.Vectorize(), precision, layout);
}
const Type& GetType(const std::string& name) {
CHECK(HasNode(name)) << "[XPU] Node " << name << " not found.";
return nodes_.at(name).second;
std::shared_ptr<Node> Get(const std::string& name) {
CHECK(Has(name)) << "[XPU] Node " << name << " not found.";
return nodes_.at(name).back();
}
bool HasNode(const std::string& name) {
bool Has(const std::string& name) {
return nodes_.find(name) != nodes_.end();
}
......@@ -184,9 +189,7 @@ class Graph {
xtcl::network::xTensorCompiler::ParamNDArrayMap params_;
private:
std::unordered_map<std::string, std::pair<std::shared_ptr<xtcl::xExpr>, Type>>
nodes_;
std::unordered_map<std::string, int> counts_;
std::unordered_map<std::string, std::vector<std::shared_ptr<Node>>> nodes_;
};
} // namespace xpu
......
......@@ -51,23 +51,23 @@ int LayerNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto x_inner_size = x_dims.Slice(axis, x_rank).production();
// X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
if (reshape) {
auto reshaped_x_dims = x_dims.Slice(0, axis).Vectorize();
reshaped_x_dims.push_back(x_inner_size);
x_node =
graph->AddNode(x_name + "/reshape",
graph->builder_.CreateReshape(
*x_node, CvtShape<xtcl::Integer>(reshaped_x_dims)));
x_node = graph->Add(
x_name + "/reshape",
graph->builder_.CreateReshape(
*x_node->data(), CvtShape<xtcl::Integer>(reshaped_x_dims)));
}
// Scale node
std::shared_ptr<xtcl::xExpr> scale_const_node = nullptr;
std::shared_ptr<Node> scale_node = nullptr;
if (HasInputArg(op_info, scope, "Scale")) {
auto scale_name = op_info->Input("Scale").front();
auto scale_type = kernel->GetInputDeclType("Scale");
......@@ -77,14 +77,13 @@ int LayerNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto scale_dims = scale->dims();
CHECK_EQ(scale_dims.size(), 1);
CHECK_EQ(scale_dims.production(), x_inner_size);
scale_const_node = graph->AddNode(scale_name, *scale);
scale_node = graph->Add(scale_name, *scale);
} else {
scale_const_node =
graph->AddNode(y_name + "/scale_one", 1.0f, {x_inner_size});
scale_node = graph->Add(y_name + "/scale_one", 1.0f, {x_inner_size});
}
// Bias node
std::shared_ptr<xtcl::xExpr> bias_const_node = nullptr;
std::shared_ptr<Node> bias_node = nullptr;
if (HasInputArg(op_info, scope, "Bias")) {
auto bias_name = op_info->Input("Bias").front();
auto bias_type = kernel->GetInputDeclType("Bias");
......@@ -94,26 +93,25 @@ int LayerNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto bias_dims = bias->dims();
CHECK_EQ(bias_dims.size(), 1);
CHECK_EQ(bias_dims.production(), x_inner_size);
bias_const_node = graph->AddNode(bias_name, *bias);
bias_node = graph->Add(bias_name, *bias);
} else {
bias_const_node =
graph->AddNode(y_name + "/bias_zero", 0.0f, {x_inner_size});
bias_node = graph->Add(y_name + "/bias_zero", 0.0f, {x_inner_size});
}
// Layer Norm node
auto layer_norm_node =
graph->AddNode(y_name,
graph->builder_.CreateLayerNorm(*x_node,
*scale_const_node,
*bias_const_node,
axis,
epsilon,
true,
true));
graph->Add(y_name,
graph->builder_.CreateLayerNorm(*x_node->data(),
*scale_node->data(),
*bias_node->data(),
axis,
epsilon,
true,
true));
if (reshape) {
graph->AddNode(y_name,
graph->builder_.CreateReshape(
*layer_norm_node, CvtShape<xtcl::Integer>(y_dims)));
graph->Add(y_name,
graph->builder_.CreateReshape(*layer_norm_node->data(),
CvtShape<xtcl::Integer>(y_dims)));
}
return REBUILD_WHEN_SHAPE_CHANGED;
}
......@@ -123,6 +121,6 @@ int LayerNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU,
layer_norm,
REGISTER_SUBGRAPH_BRIDGE(layer_norm,
kXPU,
paddle::lite::subgraph::xpu::LayerNormConverter);
......@@ -57,30 +57,34 @@ int LookupTableConverter(void* ctx, OpLite* op, KernelBase* kernel) {
}
// Ids node
std::shared_ptr<xtcl::xExpr> ids_node = nullptr;
if (graph->HasNode(ids_name)) {
ids_node = graph->GetNode(ids_name);
std::shared_ptr<Node> ids_node = nullptr;
if (graph->Has(ids_name)) {
ids_node = graph->Get(ids_name);
} else {
ids_node = graph->AddNode(
ids_node = graph->Add(
ids_name, ids_dims, ids_type->precision(), ids_type->layout());
}
// Flatten Ids node
if (ids_dims.size() != 1) {
ids_node = graph->AddNode(ids_name + "/reshape",
graph->builder_.CreateReshape(*ids_node, {-1}),
ids_type->precision(),
ids_type->layout());
ids_node =
graph->Add(ids_name + "/reshape",
graph->builder_.CreateReshape(*ids_node->data(), {-1}),
ids_type->precision(),
ids_type->layout());
}
auto w_const_node = graph->AddNode(w_name, *w);
// W node
auto w_node = graph->Add(w_name, *w);
// Reshape the gather node with the inferred shape as the output node
auto gather_node = graph->AddNode(
out_name,
graph->builder_.CreateGather(*w_const_node, *ids_node, /* axis= */ 0));
auto gather_node =
graph->Add(out_name,
graph->builder_.CreateGather(
*w_node->data(), *ids_node->data(), /* axis= */ 0));
if (out_dims.size() != 2) {
graph->AddNode(out_name,
graph->builder_.CreateReshape(
*gather_node, CvtShape<xtcl::Integer>(out_dims)));
graph->Add(out_name,
graph->builder_.CreateReshape(
*gather_node->data(), CvtShape<xtcl::Integer>(out_dims)));
}
return SUCCESS;
}
......@@ -90,6 +94,6 @@ int LookupTableConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU,
lookup_table,
REGISTER_SUBGRAPH_BRIDGE(lookup_table,
kXPU,
paddle::lite::subgraph::xpu::LookupTableConverter);
......@@ -57,19 +57,19 @@ int MatmulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto alpha = op_info->GetAttr<float>("alpha");
// X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
// Y node
std::shared_ptr<xtcl::xExpr> y_node = nullptr;
if (graph->HasNode(y_name)) {
y_node = graph->GetNode(y_name);
std::shared_ptr<Node> y_node = nullptr;
if (graph->Has(y_name)) {
y_node = graph->Get(y_name);
} else {
y_node = graph->AddNode(y_name, y_dims);
y_node = graph->Add(y_name, *y);
}
// Matmul node
......@@ -80,52 +80,55 @@ int MatmulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
if (x_dims.size() != 3) {
auto m = static_cast<int>(x_dims[x_dims.size() - 2]);
auto k = static_cast<int>(x_dims[x_dims.size() - 1]);
x_node =
graph->AddNode(x_name + "/reshape",
graph->builder_.CreateReshape(*x_node, {-1, m, k}));
x_node = graph->Add(
x_name + "/reshape",
graph->builder_.CreateReshape(*x_node->data(), {-1, m, k}));
if (transpose_x) {
x_node =
graph->AddNode(x_name + "/reshape/transpose",
graph->builder_.CreateTranspose(*x_node, {0, 2, 1}));
x_node = graph->Add(
x_name + "/reshape/transpose",
graph->builder_.CreateTranspose(*x_node->data(), {0, 2, 1}));
}
}
// Reshape and transposed Y node
if (y_dims.size() != 3) {
auto k = static_cast<int>(y_dims[y_dims.size() - 2]);
auto n = static_cast<int>(y_dims[y_dims.size() - 1]);
y_node =
graph->AddNode(y_name + "/reshape",
graph->builder_.CreateReshape(*y_node, {-1, k, n}));
y_node = graph->Add(
y_name + "/reshape",
graph->builder_.CreateReshape(*y_node->data(), {-1, k, n}));
if (!transpose_y) {
y_node =
graph->AddNode(y_name + "/reshape/transpose",
graph->builder_.CreateTranspose(*y_node, {0, 2, 1}));
y_node = graph->Add(
y_name + "/reshape/transpose",
graph->builder_.CreateTranspose(*y_node->data(), {0, 2, 1}));
}
}
// Matmul node
auto matmul_node = graph->AddNode(
out_name, graph->builder_.CreateBatchMatmul(*x_node, *y_node));
auto matmul_node = graph->Add(
out_name,
graph->builder_.CreateBatchMatmul(*x_node->data(), *y_node->data()));
if (fabs(alpha - 1) > 1e-6f) {
matmul_node = graph->AddNode(
out_name, graph->builder_.CreateScale(*matmul_node, alpha));
matmul_node = graph->Add(
out_name, graph->builder_.CreateScale(*matmul_node->data(), alpha));
}
if (out_dims.size() != 3) {
graph->AddNode(out_name,
graph->builder_.CreateReshape(
*matmul_node, CvtShape<xtcl::Integer>(out_dims)));
graph->Add(out_name,
graph->builder_.CreateReshape(
*matmul_node->data(), CvtShape<xtcl::Integer>(out_dims)));
}
} else if (x_dims.size() == 2 && y_dims.size() == 2) {
// x: [M, K], y: [K, N], out: [M, N]
if (transpose_x) {
x_node = graph->AddNode(x_name + "/transpose",
graph->builder_.CreateTranspose(*x_node, {1, 0}));
x_node =
graph->Add(x_name + "/transpose",
graph->builder_.CreateTranspose(*x_node->data(), {1, 0}));
}
auto matmul_node = graph->AddNode(
out_name,
graph->builder_.CreateMatmul2D(*x_node, *y_node, transpose_y));
auto matmul_node =
graph->Add(out_name,
graph->builder_.CreateMatmul2D(
*x_node->data(), *y_node->data(), transpose_y));
if (fabs(alpha - 1) > 1e-6f) {
matmul_node = graph->AddNode(
out_name, graph->builder_.CreateScale(*matmul_node, alpha));
matmul_node = graph->Add(
out_name, graph->builder_.CreateScale(*matmul_node->data(), alpha));
}
} else if (x_dims.size() == 1 && y_dims.size() == 1) {
// x: [K], y: [K], out: [1]
......@@ -141,6 +144,6 @@ int MatmulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU,
matmul,
REGISTER_SUBGRAPH_BRIDGE(matmul,
kXPU,
paddle::lite::subgraph::xpu::MatmulConverter);
......@@ -56,49 +56,50 @@ int MulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK_EQ(x_matrix_dims[1], y_matrix_dims[0]);
// X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
// Flatten X node
if (x_dims.size() != 2) {
x_node =
graph->AddNode(x_name + "/reshape",
graph->builder_.CreateReshape(
*x_node, {-1, static_cast<int>(x_matrix_dims[1])}));
x_node = graph->Add(
x_name + "/reshape",
graph->builder_.CreateReshape(
*x_node->data(), {-1, static_cast<int>(x_matrix_dims[1])}));
}
// Y node
std::shared_ptr<xtcl::xExpr> y_node = nullptr;
if (graph->HasNode(y_name)) {
y_node = graph->GetNode(y_name);
std::shared_ptr<Node> y_node = nullptr;
if (graph->Has(y_name)) {
y_node = graph->Get(y_name);
} else {
y_node = graph->AddNode(y_name, y_dims);
y_node = graph->Add(y_name, *y);
}
// Flatten Y node
if (y_dims.size() != 2) {
y_node =
graph->AddNode(y_name + "/reshape",
graph->builder_.CreateReshape(
*y_node, {static_cast<int>(y_matrix_dims[0]), -1}));
y_node = graph->Add(
y_name + "/reshape",
graph->builder_.CreateReshape(
*y_node->data(), {static_cast<int>(y_matrix_dims[0]), -1}));
}
// Reshape the matmul node with the inferred shape as the output node
auto matmul_node = graph->AddNode(
out_name, graph->builder_.CreateMatmul2D(*x_node, *y_node, false));
auto matmul_node = graph->Add(
out_name,
graph->builder_.CreateMatmul2D(*x_node->data(), *y_node->data(), false));
if (out_dims.size() != 2) {
graph->AddNode(out_name,
graph->builder_.CreateReshape(
*matmul_node, CvtShape<xtcl::Integer>(out_dims)));
graph->Add(out_name,
graph->builder_.CreateReshape(
*matmul_node->data(), CvtShape<xtcl::Integer>(out_dims)));
}
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace xpu
} // namespace xpu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU, mul, paddle::lite::subgraph::xpu::MulConverter);
REGISTER_SUBGRAPH_BRIDGE(mul, kXPU, paddle::lite::subgraph::xpu::MulConverter);
......@@ -14,25 +14,25 @@
#pragma once
USE_SUBGRAPH_BRIDGE(XPU, relu);
USE_SUBGRAPH_BRIDGE(XPU, tanh);
USE_SUBGRAPH_BRIDGE(XPU, conv2d);
USE_SUBGRAPH_BRIDGE(XPU, depthwise_conv2d);
USE_SUBGRAPH_BRIDGE(XPU, elementwise_add);
USE_SUBGRAPH_BRIDGE(XPU, pool2d);
USE_SUBGRAPH_BRIDGE(XPU, softmax);
USE_SUBGRAPH_BRIDGE(XPU, mul);
USE_SUBGRAPH_BRIDGE(XPU, batch_norm);
USE_SUBGRAPH_BRIDGE(XPU, stack);
USE_SUBGRAPH_BRIDGE(XPU, gather);
USE_SUBGRAPH_BRIDGE(XPU, scale);
USE_SUBGRAPH_BRIDGE(XPU, lookup_table);
USE_SUBGRAPH_BRIDGE(XPU, slice);
USE_SUBGRAPH_BRIDGE(XPU, transpose);
USE_SUBGRAPH_BRIDGE(XPU, transpose2);
USE_SUBGRAPH_BRIDGE(XPU, reshape);
USE_SUBGRAPH_BRIDGE(XPU, reshape2);
USE_SUBGRAPH_BRIDGE(XPU, layer_norm);
USE_SUBGRAPH_BRIDGE(XPU, gelu);
USE_SUBGRAPH_BRIDGE(XPU, dropout);
USE_SUBGRAPH_BRIDGE(XPU, matmul);
USE_SUBGRAPH_BRIDGE(relu, kXPU);
USE_SUBGRAPH_BRIDGE(tanh, kXPU);
USE_SUBGRAPH_BRIDGE(conv2d, kXPU);
USE_SUBGRAPH_BRIDGE(depthwise_conv2d, kXPU);
USE_SUBGRAPH_BRIDGE(elementwise_add, kXPU);
USE_SUBGRAPH_BRIDGE(pool2d, kXPU);
USE_SUBGRAPH_BRIDGE(softmax, kXPU);
USE_SUBGRAPH_BRIDGE(mul, kXPU);
USE_SUBGRAPH_BRIDGE(batch_norm, kXPU);
USE_SUBGRAPH_BRIDGE(stack, kXPU);
USE_SUBGRAPH_BRIDGE(gather, kXPU);
USE_SUBGRAPH_BRIDGE(scale, kXPU);
USE_SUBGRAPH_BRIDGE(lookup_table, kXPU);
USE_SUBGRAPH_BRIDGE(slice, kXPU);
USE_SUBGRAPH_BRIDGE(transpose, kXPU);
USE_SUBGRAPH_BRIDGE(transpose2, kXPU);
USE_SUBGRAPH_BRIDGE(reshape, kXPU);
USE_SUBGRAPH_BRIDGE(reshape2, kXPU);
USE_SUBGRAPH_BRIDGE(layer_norm, kXPU);
USE_SUBGRAPH_BRIDGE(gelu, kXPU);
USE_SUBGRAPH_BRIDGE(dropout, kXPU);
USE_SUBGRAPH_BRIDGE(matmul, kXPU);
......@@ -50,21 +50,22 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto exclusive = op_info->GetAttr<bool>("exclusive");
// X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
// Pool node
if (pooling_type == "max") {
if (global_pooling) {
graph->AddNode(out_name, graph->builder_.CreateGlobalMaxPool2D(*x_node));
graph->Add(out_name,
graph->builder_.CreateGlobalMaxPool2D(*x_node->data()));
} else {
graph->AddNode(
graph->Add(
out_name,
graph->builder_.CreateMaxPool2D(*x_node,
graph->builder_.CreateMaxPool2D(*x_node->data(),
CvtShape<xtcl::xIndexExpr>(ksize),
CvtShape<xtcl::xIndexExpr>(strides),
CvtShape<xtcl::xIndexExpr>(paddings),
......@@ -73,12 +74,13 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) {
}
} else if (pooling_type == "avg") {
if (global_pooling) {
graph->AddNode(out_name, graph->builder_.CreateGlobalAvgPool2D(*x_node));
graph->Add(out_name,
graph->builder_.CreateGlobalAvgPool2D(*x_node->data()));
} else {
// !exclusive ---> count_include_pad
graph->AddNode(
graph->Add(
out_name,
graph->builder_.CreateAvgPool2D(*x_node,
graph->builder_.CreateAvgPool2D(*x_node->data(),
CvtShape<xtcl::xIndexExpr>(ksize),
CvtShape<xtcl::xIndexExpr>(strides),
CvtShape<xtcl::xIndexExpr>(paddings),
......@@ -98,6 +100,6 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU,
pool2d,
REGISTER_SUBGRAPH_BRIDGE(pool2d,
kXPU,
paddle::lite::subgraph::xpu::PoolConverter);
......@@ -44,11 +44,11 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
// X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
std::vector<int> shape;
......@@ -59,6 +59,7 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// CHECK(shape_tensor_type->layout() == DATALAYOUT(kNCHW));
for (auto shape_tensor_name : shape_tensor_names) {
auto shape_tensor = scope->FindMutableTensor(shape_tensor_name);
CHECK(shape_tensor->persistable());
auto shape_tensor_data = shape_tensor->mutable_data<int>();
shape.emplace_back(shape_tensor_data[0]);
}
......@@ -73,6 +74,7 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// CHECK(actual_shape_type->precision() == PRECISION(kInt32));
// CHECK(actual_shape_type->layout() == DATALAYOUT(kNCHW));
auto actual_shape = scope->FindMutableTensor(actual_shape_name);
CHECK(actual_shape->persistable());
auto actual_shape_dims = actual_shape->dims();
auto actual_shape_data = actual_shape->mutable_data<int>();
auto shape = std::vector<int>(
......@@ -86,9 +88,9 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto out_dims = operators::ValidateShape(shape, x_dims);
// Reshape node
graph->AddNode(out_name,
graph->builder_.CreateReshape(
*x_node, CvtShape<xtcl::Integer>(out_dims)));
graph->Add(out_name,
graph->builder_.CreateReshape(*x_node->data(),
CvtShape<xtcl::Integer>(out_dims)));
return REBUILD_WHEN_SHAPE_CHANGED;
}
......@@ -97,9 +99,9 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU,
reshape2,
REGISTER_SUBGRAPH_BRIDGE(reshape2,
kXPU,
paddle::lite::subgraph::xpu::ReshapeConverter);
REGISTER_SUBGRAPH_BRIDGE(XPU,
reshape,
REGISTER_SUBGRAPH_BRIDGE(reshape,
kXPU,
paddle::lite::subgraph::xpu::ReshapeConverter);
......@@ -46,17 +46,17 @@ int ScaleConverter(void* ctx, OpLite* op, KernelBase* kernel) {
float bias = op_info->GetAttr<float>("bias");
// X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
// Scale node
graph->AddNode(
out_name,
graph->builder_.CreateScale(*x_node, scale, bias, bias_after_scale));
graph->Add(out_name,
graph->builder_.CreateScale(
*x_node->data(), scale, bias, bias_after_scale));
return SUCCESS;
}
......@@ -65,6 +65,6 @@ int ScaleConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU,
scale,
REGISTER_SUBGRAPH_BRIDGE(scale,
kXPU,
paddle::lite::subgraph::xpu::ScaleConverter);
......@@ -46,11 +46,11 @@ int SliceConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto ends = op_info->GetAttr<std::vector<int>>("ends");
// Input node
std::shared_ptr<xtcl::xExpr> input_node = nullptr;
if (graph->HasNode(input_name)) {
input_node = graph->GetNode(input_name);
std::shared_ptr<Node> input_node = nullptr;
if (graph->Has(input_name)) {
input_node = graph->Get(input_name);
} else {
input_node = graph->AddNode(input_name, input_dims);
input_node = graph->Add(input_name, *input);
}
// Calculate the begin and end of the slice in all of
......@@ -74,9 +74,9 @@ int SliceConverter(void* ctx, OpLite* op, KernelBase* kernel) {
strides.push_back(1);
}
}
graph->AddNode(
out_name,
graph->builder_.CreateStridedSlice(*input_node, begin, end, strides));
graph->Add(out_name,
graph->builder_.CreateStridedSlice(
*input_node->data(), begin, end, strides));
return REBUILD_WHEN_SHAPE_CHANGED;
}
......@@ -85,6 +85,6 @@ int SliceConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU,
slice,
REGISTER_SUBGRAPH_BRIDGE(slice,
kXPU,
paddle::lite::subgraph::xpu::SliceConverter);
......@@ -44,15 +44,15 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto axis = op_info->GetAttr<int>("axis");
// X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
// Softmax node
graph->AddNode(out_name, graph->builder_.CreateSoftmax(*x_node, axis));
graph->Add(out_name, graph->builder_.CreateSoftmax(*x_node->data(), axis));
return SUCCESS;
}
......@@ -61,6 +61,6 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU,
softmax,
REGISTER_SUBGRAPH_BRIDGE(softmax,
kXPU,
paddle::lite::subgraph::xpu::SoftmaxConverter);
......@@ -46,19 +46,19 @@ int StackConverter(void* ctx, OpLite* op, KernelBase* kernel) {
for (auto& x_name : x_names) {
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
std::shared_ptr<xtcl::xExpr> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
x_nodes.push_back(*x_node);
x_nodes.push_back(*x_node->data());
}
// Stack node
graph->AddNode(y_name,
graph->builder_.CreateStack(
xtcl::network::TupleNode::make(x_nodes), axis));
graph->Add(y_name,
graph->builder_.CreateStack(
xtcl::network::TupleNode::make(x_nodes), axis));
return SUCCESS;
}
......@@ -67,6 +67,6 @@ int StackConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU,
stack,
REGISTER_SUBGRAPH_BRIDGE(stack,
kXPU,
paddle::lite::subgraph::xpu::StackConverter);
......@@ -44,19 +44,19 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto axis = op_info->GetAttr<std::vector<int>>("axis");
// X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
x_node = graph->Add(x_name, *x);
}
// Transpose node
graph->AddNode(out_name,
graph->builder_.CreateTranspose(
*x_node,
CvtShape<xtcl::Integer>(
std::vector<int64_t>(axis.begin(), axis.end()))));
graph->Add(out_name,
graph->builder_.CreateTranspose(
*x_node->data(),
CvtShape<xtcl::Integer>(
std::vector<int64_t>(axis.begin(), axis.end()))));
return SUCCESS;
}
......@@ -66,9 +66,9 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU,
transpose,
REGISTER_SUBGRAPH_BRIDGE(transpose,
kXPU,
paddle::lite::subgraph::xpu::TransposeConverter);
REGISTER_SUBGRAPH_BRIDGE(XPU,
transpose2,
REGISTER_SUBGRAPH_BRIDGE(transpose2,
kXPU,
paddle::lite::subgraph::xpu::TransposeConverter);
......@@ -39,13 +39,13 @@ int SubgraphEngine::BuildDeviceProgram() {
op->CheckShape();
op->InferShape();
std::string op_type = op->op_info()->Type();
if (!bridges.Exists("XPU", op_type)) {
if (!bridges.Exists(op_type, "kXPU")) {
return subgraph::FAILED;
}
auto kernel = inst.kernel();
status |= bridges.Select("XPU", op_type)(reinterpret_cast<void*>(&graph),
const_cast<OpLite*>(op),
const_cast<KernelBase*>(kernel));
status |= bridges.Select(op_type, "kXPU")(reinterpret_cast<void*>(&graph),
const_cast<OpLite*>(op),
const_cast<KernelBase*>(kernel));
if (subgraph::CHECK_FAILED(status)) {
return subgraph::FAILED;
}
......@@ -57,26 +57,26 @@ int SubgraphEngine::BuildDeviceProgram() {
std::vector<xtcl::xExpr*> device_inodes;
std::vector<xtcl::xExpr*> device_onodes;
for (auto& input_name : input_names_) {
if (graph.HasNode(input_name)) {
if (!graph.GetType(input_name).persistable()) {
device_inodes.push_back(graph.GetNode(input_name).get());
if (graph.Has(input_name)) {
if (graph.Get(input_name)->is_data()) {
device_inodes.push_back(graph.Get(input_name)->data().get());
device_inames_.push_back(input_name);
} else {
LOG(WARNING) << "[XPU] Input node " << input_name
<< " is skipped because it is a persistable node.";
<< " is ignored because it is not a data node.";
}
} else {
LOG(WARNING) << "[XPU] Input node " << input_name
<< " is skipped because it does not exist.";
<< " is ignored because it does not exist.";
}
}
for (auto& output_name : output_names_) {
if (graph.HasNode(output_name)) {
device_onodes.push_back(graph.GetNode(output_name).get());
if (graph.Has(output_name)) {
device_onodes.push_back(graph.Get(output_name)->data().get());
device_onames_.push_back(output_name);
} else {
LOG(WARNING) << "[XPU] Output node " << output_name
<< " is skipped because it does not exist.";
<< " is ignored because it does not exist.";
}
}
CHECK(!device_inames_.empty())
......@@ -98,14 +98,14 @@ int SubgraphEngine::BuildDeviceProgram() {
origin_otensors_.resize(device_onames_.size());
device_otensors_.resize(device_onames_.size());
for (int i = 0; i < device_inames_.size(); i++) {
auto type = graph.GetType(device_inames_[i]);
auto precision = type.precision();
auto layout = type.layout();
auto node = graph.Get(device_inames_[i]);
auto precision = node->precision();
auto layout = node->layout();
origin_itensors_[i] = scope_->FindMutableTensor(device_inames_[i]);
CHECK(origin_itensors_[i]);
origin_idims_[i] = origin_itensors_[i]->dims();
VLOG(3) << "[XPU] Inputs[" << i
<< "] precision: " << PrecisionToStr(precision)
VLOG(3) << "[XPU] Inputs[" << i << "] name: " << device_inames_[i]
<< " precision: " << PrecisionToStr(precision)
<< " layout: " << DataLayoutToStr(layout)
<< " dims: " << origin_idims_[i];
// Prepare the device input tensors which share data with the origin input
......@@ -122,14 +122,14 @@ int SubgraphEngine::BuildDeviceProgram() {
device_itensors_[i].byte_offset = 0;
}
for (int i = 0; i < device_onames_.size(); i++) {
auto type = graph.GetType(device_onames_[i]);
auto precision = type.precision();
auto layout = type.layout();
auto node = graph.Get(device_onames_[i]);
auto precision = node->precision();
auto layout = node->layout();
origin_otensors_[i] = scope_->FindMutableTensor(device_onames_[i]);
CHECK(origin_otensors_[i]);
origin_odims_[i] = origin_otensors_[i]->dims();
VLOG(3) << "[XPU] Outputs[" << i
<< "] precision: " << PrecisionToStr(precision)
VLOG(3) << "[XPU] Outputs[" << i << "] name: " << device_onames_[i]
<< " precision: " << PrecisionToStr(precision)
<< " layout: " << DataLayoutToStr(layout)
<< " dims: " << origin_odims_[i];
// Prepare the device output tensors which share data with the origin output
......
......@@ -29,7 +29,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH
lite_cc_test(test_kernel_reshape_compute SRCS reshape_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_layer_norm_compute SRCS layer_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_dropout_compute SRCS dropout_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_softmax_compute SRCS softmax_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_softmax_compute SRCS softmax_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_mul_compute SRCS mul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
if(LITE_BUILD_EXTRA)
......
......@@ -16,6 +16,7 @@
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle {
namespace lite {
......@@ -23,31 +24,33 @@ namespace lite {
class ScaleComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string input_ = "x";
std::string output_ = "out";
std::string x_ = "x";
std::string out_ = "out";
DDim x_dims_{{100, 20}};
float scale_ = 0.;
float bias_ = 0.;
DDim dims_{{100, 20}};
bool bias_after_scale_;
public:
ScaleComputeTester(const Place& place,
const std::string& alias,
const DDim& x_dims,
float scale,
float bias,
bool bias_after_scale)
: TestCase(place, alias),
x_dims_(x_dims),
scale_(scale),
bias_(bias),
bias_after_scale_(bias_after_scale) {}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
auto* out = scope->NewTensor(out_);
CHECK(out);
out->Resize(dims_);
out->Resize(x_dims_);
auto* out_data = out->mutable_data<float>();
auto* x = scope->FindTensor(input_);
auto* x = scope->FindTensor(x_);
const auto* x_data = x->data<float>();
float bias = bias_;
......@@ -56,35 +59,34 @@ class ScaleComputeTester : public arena::TestCase {
bias *= scale_;
}
for (int i = 0; i < dims_.production(); i++) {
for (int i = 0; i < x_dims_.production(); i++) {
out_data[i] = x_data[i] * scale_ + bias;
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("scale");
op_desc->SetInput("X", {input_});
op_desc->SetOutput("Out", {output_});
op_desc->SetInput("X", {x_});
op_desc->SetOutput("Out", {out_});
op_desc->SetAttr("scale", scale_);
op_desc->SetAttr("bias", bias_);
op_desc->SetAttr("bias_after_scale", bias_after_scale_);
}
void PrepareData() override {
std::vector<float> data(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
data[i] = i * 1.1;
}
SetCommonTensor(input_, dims_, data.data());
std::vector<float> x(x_dims_.production());
fill_data_rand(x.data(), -1.f, 1.f, x_dims_.production());
SetCommonTensor(x_, x_dims_, x.data());
}
};
TEST(Scale, precision) {
Place place;
float abs_error = 2e-5;
#if defined(LITE_WITH_ARM)
#if defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 4e-3; // Using fp16 in NPU
#elif defined(LITE_WITH_ARM)
place = TARGET(kARM);
#elif defined(LITE_WITH_XPU)
place = TARGET(kXPU);
......@@ -95,13 +97,16 @@ TEST(Scale, precision) {
return;
#endif
for (float scale : {0.123, 2., -1.2}) {
for (float bias : {1., 0., -1.2331}) {
for (bool bias_before : {true, false}) {
std::unique_ptr<arena::TestCase> tester(
new ScaleComputeTester(place, "def", scale, bias, bias_before));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
for (auto x_dims :
std::vector<std::vector<int64_t>>{{5, 2, 3, 4}, {8, 3, 5}, {12, 3}}) {
for (float scale : {0.123, 2., -1.2}) {
for (float bias : {1., 0., -1.2331}) {
for (bool bias_after_scale : {true, false}) {
std::unique_ptr<arena::TestCase> tester(new ScaleComputeTester(
place, "def", DDim(x_dims), scale, bias, bias_after_scale));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
}
}
......@@ -117,8 +122,8 @@ TEST(Scale, performance) {
return;
#endif
std::unique_ptr<arena::TestCase> tester(
new ScaleComputeTester(place, "def", 1.2, 1.1, true));
std::unique_ptr<arena::TestCase> tester(new ScaleComputeTester(
place, "def", DDim(std::vector<int64_t>{5, 2, 3, 4}), 1.2, 1.1, true));
// To modify the arm context, one can retrive the context as follows.
// #ifdef LITE_WITH_ARM
......
......@@ -25,33 +25,33 @@ class SoftmaxComputeTest : public arena::TestCase {
protected:
// common attributes for this op.
std::string op_type_ = "softmax";
std::string input_ = "x";
std::string output_ = "out";
DDim dims_{{1, 2, 3, 4}};
DDim x_dims_{{1, 2, 3, 4}};
std::string x_ = "x";
std::string out_ = "out";
int axis_ = 1;
public:
SoftmaxComputeTest(const Place& place,
const std::string& alias,
DDim dims,
DDim x_dims,
int axis)
: TestCase(place, alias), dims_(dims), axis_(axis) {}
: TestCase(place, alias), x_dims_(x_dims), axis_(axis) {}
void RunBaseline(Scope* scope) override {
auto x = scope->FindTensor(input_);
auto out = scope->NewTensor(output_);
auto x = scope->FindTensor(x_);
auto out = scope->NewTensor(out_);
CHECK(out);
out->Resize(dims_);
out->Resize(x_dims_);
auto x_data = x->data<float>();
auto out_data = out->mutable_data<float>();
auto x_rank = dims_.size();
auto x_rank = x_dims_.size();
if (axis_ < 0) {
axis_ += x_rank;
}
int axis_size = dims_[axis_];
int outer_num = dims_.Slice(0, axis_).production();
int inner_num = dims_.Slice(axis_ + 1, x_rank).production();
int axis_size = x_dims_[axis_];
int outer_num = x_dims_.Slice(0, axis_).production();
int inner_num = x_dims_.Slice(axis_ + 1, x_rank).production();
int compute_size = outer_num * inner_num;
for (int i = 0; i < compute_size; i++) {
int idx_inner = i % inner_num;
......@@ -84,15 +84,15 @@ class SoftmaxComputeTest : public arena::TestCase {
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType(op_type_);
op_desc->SetInput("X", {input_});
op_desc->SetOutput("Out", {output_});
op_desc->SetInput("X", {x_});
op_desc->SetOutput("Out", {out_});
op_desc->SetAttr("axis", axis_);
}
void PrepareData() override {
std::vector<float> din(dims_.production());
fill_data_rand(din.data(), -1.f, 1.f, dims_.production());
SetCommonTensor(input_, dims_, din.data());
std::vector<float> x(x_dims_.production());
fill_data_rand(x.data(), -1.f, 1.f, x_dims_.production());
SetCommonTensor(x_, x_dims_, x.data());
}
};
......@@ -100,18 +100,21 @@ TEST(Softmax, precision) {
LOG(INFO) << "test softmax op";
float abs_error = 2e-5;
Place place;
#if defined(LITE_WITH_XPU)
#if defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 4e-3; // Using fp16 in NPU
#elif defined(LITE_WITH_XPU)
place = TARGET(kXPU);
#else
return;
#endif
std::vector<std::vector<int64_t>> dims{{1, 2, 3, 4}, {2, 3, 4}, {3, 4}};
for (auto dim_in : dims) {
for (auto x_dims :
std::vector<std::vector<int64_t>>{{1, 2, 3, 4}, {2, 3, 4}, {3, 4}}) {
for (auto axis : {-1, 0, 1, 2, 3}) {
if (axis >= dim_in.size()) continue;
if (axis >= x_dims.size()) continue;
std::unique_ptr<arena::TestCase> tester(
new SoftmaxComputeTest(place, "def", DDim(dim_in), axis));
new SoftmaxComputeTest(place, "def", DDim(x_dims), axis));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册