未验证 提交 a29c84a2 编写于 作者: H hong19860320 提交者: GitHub

[LITE][NPU][XPU] Refine the registration and implementation of op bridges (#2700)

* Fix the compiling error which occurs when specify the ddk_root path and build for huawei NPU.

* Refine the registration of op bridges and make it similar to the registration of op and kernel.

* Refine the interfaces of the graph and node for op bridges, and support creating constant and data node automatically according to the attribute 'persistable' of the target tensor.

* Add the unit test of the scale and softmax op bridge for NPU.
上级 bc5bd154
...@@ -30,7 +30,7 @@ if(NOT NPU_DDK_INC) ...@@ -30,7 +30,7 @@ if(NOT NPU_DDK_INC)
message(FATAL_ERROR "Can not find HiAiModelManagerService.h in ${NPU_DDK_ROOT}/include") message(FATAL_ERROR "Can not find HiAiModelManagerService.h in ${NPU_DDK_ROOT}/include")
endif() endif()
include_directories("${NPU_DDK_ROOT}") include_directories("${NPU_DDK_ROOT}/include")
set(NPU_SUB_LIB_PATH "lib64") set(NPU_SUB_LIB_PATH "lib64")
if(ARM_TARGET_ARCH_ABI STREQUAL "armv8") if(ARM_TARGET_ARCH_ABI STREQUAL "armv8")
......
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "ai_ddk_lib/include/HiAiModelManagerService.h" #include "HiAiModelManagerService.h" // NOLINT
#include "ai_ddk_lib/include/hiai_ir_build.h" #include "hiai_ir_build.h" // NOLINT
namespace paddle { namespace paddle {
namespace lite { namespace lite {
......
...@@ -27,7 +27,7 @@ namespace mir { ...@@ -27,7 +27,7 @@ namespace mir {
void NPUSubgraphPass::Apply(const std::unique_ptr<SSAGraph>& graph) { void NPUSubgraphPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
std::unordered_set<std::string> supported_lists; std::unordered_set<std::string> supported_lists;
#define USE_SUBGRAPH_BRIDGE(dev_type, op_type) supported_lists.insert(#op_type); #define USE_SUBGRAPH_BRIDGE(op_type, target) supported_lists.insert(#op_type);
#include "lite/kernels/npu/bridges/paddle_use_bridges.h" #include "lite/kernels/npu/bridges/paddle_use_bridges.h"
#undef USE_SUBGRAPH_BRIDGE #undef USE_SUBGRAPH_BRIDGE
auto teller = [&](Node* node) { auto teller = [&](Node* node) {
...@@ -41,7 +41,7 @@ void NPUSubgraphPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -41,7 +41,7 @@ void NPUSubgraphPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
void XPUSubgraphPass::Apply(const std::unique_ptr<SSAGraph>& graph) { void XPUSubgraphPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
std::unordered_set<std::string> supported_lists; std::unordered_set<std::string> supported_lists;
#define USE_SUBGRAPH_BRIDGE(dev_type, op_type) supported_lists.insert(#op_type); #define USE_SUBGRAPH_BRIDGE(op_type, target) supported_lists.insert(#op_type);
#include "lite/kernels/xpu/bridges/paddle_use_bridges.h" #include "lite/kernels/xpu/bridges/paddle_use_bridges.h"
#undef USE_SUBGRAPH_BRIDGE #undef USE_SUBGRAPH_BRIDGE
auto teller = [&](Node* node) { auto teller = [&](Node* node) {
......
...@@ -43,33 +43,34 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -43,33 +43,34 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(out_type->layout() == DATALAYOUT(kNCHW)); CHECK(out_type->layout() == DATALAYOUT(kNCHW));
// X node // X node
std::shared_ptr<ge::Operator> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
// Act node // Act node
auto act_node = graph->AddNode<ge::op::Activation>(out_name); auto act_node = graph->Add<ge::op::Activation>(out_name);
act_node->set_input_x(*x_node); auto act_op = act_node->data<ge::op::Activation>();
act_op->set_input_x(*x_node->data());
// TODO(hong19860320) set the coef value for act Ops, such as leaky_relu, // TODO(hong19860320) set the coef value for act Ops, such as leaky_relu,
// clipped_relu etc. // clipped_relu etc.
act_node->set_attr_mode(CvtActMode(op_type)); act_op->set_attr_mode(CvtActMode(op_type));
if (op_type == "relu_clipped") { if (op_type == "relu_clipped") {
auto Relu_clipped_coef = op_info->GetAttr<float>("Relu_clipped_coef"); auto Relu_clipped_coef = op_info->GetAttr<float>("Relu_clipped_coef");
act_node->set_attr_coef(Relu_clipped_coef); act_op->set_attr_coef(Relu_clipped_coef);
} else if (op_type == "relu6") { } else if (op_type == "relu6") {
float Relu_clipped_coef = 6.f; float Relu_clipped_coef = 6.f;
act_node->set_attr_coef(Relu_clipped_coef); act_op->set_attr_coef(Relu_clipped_coef);
} else if (op_type == "leaky_relu") { } else if (op_type == "leaky_relu") {
auto alpha = op_info->GetAttr<float>("alpha"); auto alpha = op_info->GetAttr<float>("alpha");
act_node->set_attr_negative_slope(alpha); act_op->set_attr_negative_slope(alpha);
} else if (op_type == "hard_sigmoid") { } else if (op_type == "hard_sigmoid") {
auto slope = op_info->GetAttr<float>("slope"); auto slope = op_info->GetAttr<float>("slope");
auto offset = op_info->GetAttr<float>("offset"); auto offset = op_info->GetAttr<float>("offset");
act_node->set_attr_negative_slope(slope); act_op->set_attr_negative_slope(slope);
act_node->set_attr_coef(offset); act_op->set_attr_coef(offset);
} }
return SUCCESS; return SUCCESS;
} }
...@@ -79,25 +80,27 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -79,25 +80,27 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(sigmoid,
sigmoid, kNPU,
paddle::lite::subgraph::npu::ActConverter); paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU, relu, paddle::lite::subgraph::npu::ActConverter); REGISTER_SUBGRAPH_BRIDGE(relu, kNPU, paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU, tanh, paddle::lite::subgraph::npu::ActConverter); REGISTER_SUBGRAPH_BRIDGE(tanh, kNPU, paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(relu_clipped,
relu_clipped, kNPU,
paddle::lite::subgraph::npu::ActConverter); paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU, relu6, paddle::lite::subgraph::npu::ActConverter); REGISTER_SUBGRAPH_BRIDGE(relu6,
REGISTER_SUBGRAPH_BRIDGE(NPU, kNPU,
leaky_relu,
paddle::lite::subgraph::npu::ActConverter); paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU, abs, paddle::lite::subgraph::npu::ActConverter); REGISTER_SUBGRAPH_BRIDGE(leaky_relu,
REGISTER_SUBGRAPH_BRIDGE(NPU, kNPU,
softsign,
paddle::lite::subgraph::npu::ActConverter); paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(abs, kNPU, paddle::lite::subgraph::npu::ActConverter);
softplus, REGISTER_SUBGRAPH_BRIDGE(softsign,
kNPU,
paddle::lite::subgraph::npu::ActConverter); paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(softplus,
hard_sigmoid, kNPU,
paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(hard_sigmoid,
kNPU,
paddle::lite::subgraph::npu::ActConverter); paddle::lite::subgraph::npu::ActConverter);
...@@ -44,20 +44,21 @@ int ArgmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -44,20 +44,21 @@ int ArgmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
int axis = op_info->GetAttr<int64_t>("axis"); int axis = op_info->GetAttr<int64_t>("axis");
// X node // X node
std::shared_ptr<ge::Operator> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
// Axis node // Axis node
auto axis_const_node = graph->AddNode(out_name + "/axis", axis); auto axis_node = graph->Add(out_name + "/axis", axis);
// Argmax node // Argmax node
auto argmax_node = graph->AddNode<ge::op::ArgMax>(out_name); auto argmax_node = graph->Add<ge::op::ArgMax>(out_name);
argmax_node->set_input_x1(*x_node); auto argmax_op = argmax_node->data<ge::op::ArgMax>();
argmax_node->set_input_x2(*axis_const_node); argmax_op->set_input_x1(*x_node->data());
argmax_op->set_input_x2(*axis_node->data());
return SUCCESS; return SUCCESS;
} }
...@@ -66,6 +67,6 @@ int ArgmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -66,6 +67,6 @@ int ArgmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(arg_max,
arg_max, kNPU,
paddle::lite::subgraph::npu::ArgmaxConverter); paddle::lite::subgraph::npu::ArgmaxConverter);
...@@ -67,30 +67,31 @@ int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -67,30 +67,31 @@ int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
bool use_global_stats = op_info->GetAttr<bool>("use_global_stats"); bool use_global_stats = op_info->GetAttr<bool>("use_global_stats");
// X node // X node
std::shared_ptr<ge::Operator> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
// Scale, Bias, Mean, Variance node // Scale, Bias, Mean, Variance node
auto scale_const_node = graph->AddNode(scale_name, *scale); auto scale_node = graph->Add(scale_name, *scale);
auto bias_const_node = graph->AddNode(bias_name, *bias); auto bias_node = graph->Add(bias_name, *bias);
auto mean_const_node = graph->AddNode(mean_name, *mean); auto mean_node = graph->Add(mean_name, *mean);
auto variance_const_node = graph->AddNode(variance_name, *variance); auto variance_node = graph->Add(variance_name, *variance);
// Batch Norm node // Batch Norm node
auto batch_norm_node = graph->AddNode<ge::op::BatchNormExt2>(y_name); auto batch_norm_node = graph->Add<ge::op::BatchNormExt2>(y_name);
batch_norm_node->set_input_x(*x_node); auto batch_norm_op = batch_norm_node->data<ge::op::BatchNormExt2>();
batch_norm_node->set_input_scale(*scale_const_node); batch_norm_op->set_input_x(*x_node->data());
batch_norm_node->set_input_offset(*bias_const_node); batch_norm_op->set_input_scale(*scale_node->data());
batch_norm_node->set_input_mean(*mean_const_node); batch_norm_op->set_input_offset(*bias_node->data());
batch_norm_node->set_input_variance(*variance_const_node); batch_norm_op->set_input_mean(*mean_node->data());
batch_norm_node->set_attr_momentum(momentum); batch_norm_op->set_input_variance(*variance_node->data());
batch_norm_node->set_attr_epsilon(epsilon); batch_norm_op->set_attr_momentum(momentum);
batch_norm_node->set_attr_mode(mode); batch_norm_op->set_attr_epsilon(epsilon);
batch_norm_node->set_attr_use_global_stats(use_global_stats); batch_norm_op->set_attr_mode(mode);
batch_norm_op->set_attr_use_global_stats(use_global_stats);
return SUCCESS; return SUCCESS;
} }
...@@ -99,6 +100,6 @@ int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -99,6 +100,6 @@ int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(batch_norm,
batch_norm, kNPU,
paddle::lite::subgraph::npu::BatchNormConverter); paddle::lite::subgraph::npu::BatchNormConverter);
...@@ -44,21 +44,22 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -44,21 +44,22 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// Traverse all of input nodes which are added into the new created concat // Traverse all of input nodes which are added into the new created concat
// node // node
auto concat_node = graph->AddNode<ge::op::Concat>(out_name); auto concat_node = graph->Add<ge::op::Concat>(out_name);
concat_node->set_attr_axis(axis); auto concat_op = concat_node->data<ge::op::Concat>();
concat_node->set_attr_N(num); concat_op->set_attr_axis(axis);
concat_node->create_dynamic_input_x(num); concat_op->set_attr_N(num);
concat_op->create_dynamic_input_x(num);
int idx = 1; int idx = 1;
for (auto& x_name : x_names) { for (auto& x_name : x_names) {
auto x = scope->FindMutableTensor(x_name); auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims(); auto x_dims = x->dims();
std::shared_ptr<ge::Operator> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
concat_node->set_dynamic_input_x(idx, *x_node); concat_op->set_dynamic_input_x(idx, *x_node->data());
idx++; idx++;
} }
return SUCCESS; return SUCCESS;
...@@ -69,6 +70,6 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -69,6 +70,6 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(concat,
concat, kNPU,
paddle::lite::subgraph::npu::ConcatConverter); paddle::lite::subgraph::npu::ConcatConverter);
...@@ -67,11 +67,11 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -67,11 +67,11 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK_EQ(dilations.size(), 2L); CHECK_EQ(dilations.size(), 2L);
// Input node // Input node
std::shared_ptr<ge::Operator> input_node = nullptr; std::shared_ptr<Node> input_node = nullptr;
if (graph->HasNode(input_name)) { if (graph->Has(input_name)) {
input_node = graph->GetNode(input_name); input_node = graph->Get(input_name);
} else { } else {
input_node = graph->AddNode(input_name, input_dims); input_node = graph->Add(input_name, *input);
} }
if (paddings.size() == 2L) { if (paddings.size() == 2L) {
...@@ -109,104 +109,102 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -109,104 +109,102 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} }
// Filter node // Filter node
auto filter_const_node = graph->AddNode(filter_name, *filter); auto filter_node = graph->Add(filter_name, *filter);
// Add bias node if exists bias // Add bias node if exists bias
// Supports the bias nodes with the following dimensions // Supports the bias nodes with the following dimensions
// 0: {oc} // 0: {oc}
// 1: {1, oc, oh, ow} // 1: {1, oc, oh, ow}
// 2: {n, oc, oh, ow} // 2: {n, oc, oh, ow}
std::shared_ptr<ge::Operator> bias_node = nullptr; std::shared_ptr<Node> bias_node = nullptr;
bool is_channel_bias = false; bool is_channel_bias = false;
if (HasInputArg(op_info, scope, "Bias")) { if (HasInputArg(op_info, scope, "Bias")) {
auto bias_name = op_info->Input("Bias").front(); auto bias_name = op_info->Input("Bias").front();
auto bias_type = kernel->GetInputDeclType("Bias"); if (graph->Has(bias_name)) {
CHECK(bias_type->precision() == PRECISION(kFloat)); bias_node = graph->Get(bias_name);
CHECK(bias_type->layout() == DATALAYOUT(kNCHW));
auto bias = scope->FindMutableTensor(bias_name);
auto bias_dims = bias->dims();
auto bias_data_size = bias_dims.production();
auto output_data_size = output_dims.production();
std::vector<int64_t> bias_shape;
if (bias_data_size == oc) {
// 0: {oc}
bias_shape = {1, oc, 1, 1};
is_channel_bias = true;
} else if (bias_data_size == output_data_size / bs) {
// 1: {1, oc, oh, ow}
bias_shape = {1, output_dims[1], output_dims[2], output_dims[3]};
} else if (bias_data_size == output_data_size) {
// 2: {n, oc, oh, ow}
bias_shape = output_dims.Vectorize();
} else { } else {
LOG(WARNING) << "[NPU] Bias dimension " << bias_dims auto bias_type = kernel->GetInputDeclType("Bias");
<< " isn't supported in conv2d Op when output dimension is " CHECK(bias_type->precision() == PRECISION(kFloat));
<< output_dims; CHECK(bias_type->layout() == DATALAYOUT(kNCHW));
return FAILED; auto bias = scope->FindMutableTensor(bias_name);
} auto bias_dims = bias->dims();
if (graph->HasNode(bias_name)) { auto bias_data_size = bias_dims.production();
// Bias node from input node auto output_data_size = output_dims.production();
bias_node = graph->GetNode(bias_name); std::vector<int64_t> bias_shape;
} else { if (bias_data_size == oc) {
// Bias node with const data // 0: {oc}
bias_node = graph->AddNode(bias_name, *bias, bias_shape); bias_shape = {1, oc, 1, 1};
is_channel_bias = true;
} else if (bias_data_size == output_data_size / bs) {
// 1: {1, oc, oh, ow}
bias_shape = {1, output_dims[1], output_dims[2], output_dims[3]};
} else if (bias_data_size == output_data_size) {
// 2: {n, oc, oh, ow}
bias_shape = output_dims.Vectorize();
} else {
LOG(WARNING)
<< "[NPU] Bias dimension " << bias_dims
<< " isn't supported in conv2d Op when output dimension is "
<< output_dims;
return FAILED;
}
bias_node = graph->Add(bias_name, *bias, bias_shape);
} }
} }
// Conv node // Conv node
std::shared_ptr<ge::Operator> conv_node = nullptr; std::shared_ptr<Node> conv_node = nullptr;
if (use_depthwise_conv && is_depthwise_mode) { if (use_depthwise_conv && is_depthwise_mode) {
auto depthwise_conv_node = conv_node = graph->Add<ge::op::ConvolutionDepthwise>(output_name);
graph->AddNode<ge::op::ConvolutionDepthwise>(output_name); auto conv_op = conv_node->data<ge::op::ConvolutionDepthwise>();
depthwise_conv_node->set_input_x(*input_node); conv_op->set_input_x(*input_node->data());
depthwise_conv_node->set_input_filter(*filter_const_node); conv_op->set_input_filter(*filter_node->data());
depthwise_conv_node->set_attr_mode(1); conv_op->set_attr_mode(1);
depthwise_conv_node->set_attr_algo(0); conv_op->set_attr_algo(0);
depthwise_conv_node->set_attr_format(0); // NCHW conv_op->set_attr_format(0); // NCHW
depthwise_conv_node->set_attr_pad_mode(5); // VALID conv_op->set_attr_pad_mode(5); // VALID
depthwise_conv_node->set_attr_group(groups); conv_op->set_attr_group(groups);
depthwise_conv_node->set_attr_pad(ge::AttrValue::LIST_INT( conv_op->set_attr_pad(ge::AttrValue::LIST_INT(
{paddings[0], paddings[1], paddings[2], paddings[3]})); {paddings[0], paddings[1], paddings[2], paddings[3]}));
depthwise_conv_node->set_attr_dilation( conv_op->set_attr_dilation(
ge::AttrValue::LIST_INT({dilations[0], dilations[1]})); ge::AttrValue::LIST_INT({dilations[0], dilations[1]}));
depthwise_conv_node->set_attr_stride( conv_op->set_attr_stride(ge::AttrValue::LIST_INT({strides[0], strides[1]}));
ge::AttrValue::LIST_INT({strides[0], strides[1]})); conv_op->set_attr_kernel(
depthwise_conv_node->set_attr_kernel(
ge::AttrValue::LIST_INT({filter_dims[2], filter_dims[3]})); ge::AttrValue::LIST_INT({filter_dims[2], filter_dims[3]}));
conv_node = depthwise_conv_node;
// ConvolutionDepthwise Op doesn't support bias, so append Add node to // ConvolutionDepthwise Op doesn't support bias, so append Add node to
// support bias // support bias
if (bias_node != nullptr) { if (bias_node != nullptr) {
auto add_node = graph->AddNode<ge::op::Add>(output_name); auto add_node = graph->Add<ge::op::Add>(output_name);
add_node->set_input_x1(*depthwise_conv_node); auto add_op = add_node->data<ge::op::Add>();
add_node->set_input_x2(*bias_node); add_op->set_input_x1(*conv_node->data());
add_op->set_input_x2(*bias_node->data());
conv_node = add_node; conv_node = add_node;
} }
} else { } else {
auto common_conv_node = graph->AddNode<ge::op::Convolution>(output_name); conv_node = graph->Add<ge::op::Convolution>(output_name);
common_conv_node->set_input_x(*input_node); auto conv_op = conv_node->data<ge::op::Convolution>();
common_conv_node->set_input_w(*filter_const_node); conv_op->set_input_x(*input_node->data());
common_conv_node->set_attr_mode(1); conv_op->set_input_w(*filter_node->data());
common_conv_node->set_attr_pad_mode(0); // NOTSET conv_op->set_attr_mode(1);
common_conv_node->set_attr_group(groups); conv_op->set_attr_pad_mode(0); // NOTSET
common_conv_node->set_attr_pad(ge::AttrValue::LIST_INT( conv_op->set_attr_group(groups);
conv_op->set_attr_pad(ge::AttrValue::LIST_INT(
{paddings[0], paddings[0], paddings[2], paddings[2]})); {paddings[0], paddings[0], paddings[2], paddings[2]}));
common_conv_node->set_attr_dilation( conv_op->set_attr_dilation(
ge::AttrValue::LIST_INT({dilations[0], dilations[1]})); ge::AttrValue::LIST_INT({dilations[0], dilations[1]}));
common_conv_node->set_attr_stride( conv_op->set_attr_stride(ge::AttrValue::LIST_INT({strides[0], strides[1]}));
ge::AttrValue::LIST_INT({strides[0], strides[1]})); conv_op->set_attr_kernel(
common_conv_node->set_attr_kernel(
ge::AttrValue::LIST_INT({filter_dims[2], filter_dims[3]})); ge::AttrValue::LIST_INT({filter_dims[2], filter_dims[3]}));
conv_node = common_conv_node;
// Convolution Op only support bias with dimension {1, oc, 1, 1}, // Convolution Op only support bias with dimension {1, oc, 1, 1},
// so append Add node if dimension is {1, oc, oh, ow} or (n, oc, oh, ow) // so append Add node if dimension is {1, oc, oh, ow} or (n, oc, oh, ow)
if (bias_node != nullptr) { if (bias_node != nullptr) {
if (is_channel_bias) { if (is_channel_bias) {
common_conv_node->set_input_b(*bias_node); conv_op->set_input_b(*bias_node->data());
} else { } else {
auto add_node = graph->AddNode<ge::op::Add>(output_name); auto add_node = graph->Add<ge::op::Add>(output_name);
add_node->set_input_x1(*common_conv_node); auto add_op = add_node->data<ge::op::Add>();
add_node->set_input_x2(*bias_node); add_op->set_input_x1(*conv_node->data());
add_op->set_input_x2(*bias_node->data());
conv_node = add_node; conv_node = add_node;
} }
} }
...@@ -215,9 +213,10 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -215,9 +213,10 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
if (fuse_relu) { if (fuse_relu) {
// Append relu node if fuse_relu is true // Append relu node if fuse_relu is true
auto relu_node = graph->AddNode<ge::op::Activation>(output_name); auto relu_node = graph->Add<ge::op::Activation>(output_name);
relu_node->set_input_x(*conv_node); auto relu_op = relu_node->data<ge::op::Activation>();
relu_node->set_attr_mode(CvtActMode("relu")); relu_op->set_input_x(*conv_node->data());
relu_op->set_attr_mode(CvtActMode("relu"));
} }
return REBUILD_WHEN_SHAPE_CHANGED; return REBUILD_WHEN_SHAPE_CHANGED;
} }
...@@ -227,9 +226,9 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -227,9 +226,9 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(conv2d,
conv2d, kNPU,
paddle::lite::subgraph::npu::ConvConverter); paddle::lite::subgraph::npu::ConvConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(depthwise_conv2d,
depthwise_conv2d, kNPU,
paddle::lite::subgraph::npu::ConvConverter); paddle::lite::subgraph::npu::ConvConverter);
...@@ -58,11 +58,11 @@ int ConvTransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -58,11 +58,11 @@ int ConvTransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK_EQ(dilations.size(), 2L); CHECK_EQ(dilations.size(), 2L);
// Input node // Input node
std::shared_ptr<ge::Operator> input_node = nullptr; std::shared_ptr<Node> input_node = nullptr;
if (graph->HasNode(input_name)) { if (graph->Has(input_name)) {
input_node = graph->GetNode(input_name); input_node = graph->Get(input_name);
} else { } else {
input_node = graph->AddNode(input_name, input_dims); input_node = graph->Add(input_name, *input);
} }
// Create input sizes node to describe the dimensions of input tensor // Create input sizes node to describe the dimensions of input tensor
...@@ -83,55 +83,59 @@ int ConvTransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -83,55 +83,59 @@ int ConvTransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
(input_dims[i + 2] - 1) * strides[i] + kernel_ext - 2 * paddings[i]; (input_dims[i + 2] - 1) * strides[i] + kernel_ext - 2 * paddings[i];
input_sizes.push_back(output_size); input_sizes.push_back(output_size);
} }
auto input_sizes_const_node = auto input_sizes_node = graph->Add(output_name + "/input_sizes", input_sizes);
graph->AddNode(output_name + "/input_sizes", input_sizes);
// Filter node // Filter node
auto filter_const_node = graph->AddNode(filter_name, *filter); auto filter_node = graph->Add(filter_name, *filter);
// Deconv node // Deconv node
auto conv_transpose_node = graph->AddNode<ge::op::Deconvolution>(output_name); auto conv_transpose_node = graph->Add<ge::op::Deconvolution>(output_name);
conv_transpose_node->set_input_input_sizes(*input_sizes_const_node); auto conv_transpose_op = conv_transpose_node->data<ge::op::Deconvolution>();
conv_transpose_node->set_input_filter(*filter_const_node); conv_transpose_op->set_input_input_sizes(*input_sizes_node->data());
conv_transpose_node->set_input_x(*input_node); conv_transpose_op->set_input_filter(*filter_node->data());
conv_transpose_op->set_input_x(*input_node->data());
// Set attributes // Set attributes
conv_transpose_node->set_attr_format(0); // NCHW conv_transpose_op->set_attr_format(0); // NCHW
conv_transpose_node->set_attr_pad_mode(0); // NOTSET conv_transpose_op->set_attr_pad_mode(0); // NOTSET
conv_transpose_node->set_attr_group(groups); conv_transpose_op->set_attr_group(groups);
conv_transpose_node->set_attr_pad(ge::AttrValue::LIST_INT( conv_transpose_op->set_attr_pad(ge::AttrValue::LIST_INT(
{paddings[0], paddings[1], paddings[2], paddings[3]})); {paddings[0], paddings[1], paddings[2], paddings[3]}));
conv_transpose_node->set_attr_dilation( conv_transpose_op->set_attr_dilation(
ge::AttrValue::LIST_INT({dilations[0], dilations[1]})); ge::AttrValue::LIST_INT({dilations[0], dilations[1]}));
conv_transpose_node->set_attr_stride( conv_transpose_op->set_attr_stride(
ge::AttrValue::LIST_INT({strides[0], strides[1]})); ge::AttrValue::LIST_INT({strides[0], strides[1]}));
conv_transpose_node->set_attr_kernel( conv_transpose_op->set_attr_kernel(
ge::AttrValue::LIST_INT({filter_dims[2], filter_dims[3]})); ge::AttrValue::LIST_INT({filter_dims[2], filter_dims[3]}));
// Append add node to add bias if exists bias // Append add node to add bias if exists bias
std::shared_ptr<ge::Operator> output_node = conv_transpose_node;
if (HasInputArg(op_info, scope, "Bias")) { if (HasInputArg(op_info, scope, "Bias")) {
// Create bias node std::shared_ptr<Node> bias_node = nullptr;
auto bias_name = op_info->Input("Bias").front(); auto bias_name = op_info->Input("Bias").front();
auto bias_type = kernel->GetInputDeclType("Bias"); if (graph->Has(bias_name)) {
CHECK(bias_type->precision() == PRECISION(kFloat)); bias_node = graph->Get(bias_name);
CHECK(bias_type->layout() == DATALAYOUT(kNCHW)); } else {
auto bias = scope->FindMutableTensor(bias_name); auto bias_type = kernel->GetInputDeclType("Bias");
auto channel_size = bias->dims().production(); CHECK(bias_type->precision() == PRECISION(kFloat));
CHECK_EQ(channel_size, filter_dims[1] * groups); CHECK(bias_type->layout() == DATALAYOUT(kNCHW));
auto bias_const_node = auto bias = scope->FindMutableTensor(bias_name);
graph->AddNode(bias_name, *bias, {1, channel_size, 1, 1}); auto channel_size = bias->dims().production();
CHECK_EQ(channel_size, filter_dims[1] * groups);
bias_node = graph->Add(bias_name, *bias, {1, channel_size, 1, 1});
}
// Append add node to add bias node // Append add node to add bias node
auto add_node = graph->AddNode<ge::op::Add>(output_name); auto add_node = graph->Add<ge::op::Add>(output_name);
add_node->set_input_x1(*conv_transpose_node); auto add_op = add_node->data<ge::op::Add>();
add_node->set_input_x2(*bias_const_node); add_op->set_input_x1(*conv_transpose_node->data());
output_node = add_node; add_op->set_input_x2(*bias_node->data());
conv_transpose_node = add_node;
} }
if (fuse_relu) { if (fuse_relu) {
// Append relu node if fuse_relu is true // Append relu node if fuse_relu is true
auto relu_node = graph->AddNode<ge::op::Activation>(output_name); auto relu_node = graph->Add<ge::op::Activation>(output_name);
relu_node->set_input_x(*output_node); auto relu_op = relu_node->data<ge::op::Activation>();
relu_node->set_attr_mode(CvtActMode("relu")); relu_op->set_input_x(*conv_transpose_node->data());
relu_op->set_attr_mode(CvtActMode("relu"));
} }
return REBUILD_WHEN_SHAPE_CHANGED; return REBUILD_WHEN_SHAPE_CHANGED;
} }
...@@ -141,6 +145,6 @@ int ConvTransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -141,6 +145,6 @@ int ConvTransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(conv2d_transpose,
conv2d_transpose, kNPU,
paddle::lite::subgraph::npu::ConvTransposeConverter); paddle::lite::subgraph::npu::ConvTransposeConverter);
...@@ -74,45 +74,45 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -74,45 +74,45 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto axis = op_info->GetAttr<int>("axis"); auto axis = op_info->GetAttr<int>("axis");
// X node // X node
std::shared_ptr<ge::Operator> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
// Y node // Y node
std::shared_ptr<ge::Operator> y_node = nullptr; std::shared_ptr<Node> y_node = nullptr;
if (graph->HasNode(y_name)) { if (graph->Has(y_name)) {
y_node = graph->GetNode(y_name); y_node = graph->Get(y_name);
} else { } else {
auto y_new_shape = CvtYShape(x_dims, y_dims, axis); auto y_new_shape = CvtYShape(x_dims, y_dims, axis);
y_node = graph->AddNode(y_name, y_new_shape); y_node = graph->Add(y_name, *y, y_new_shape);
} }
// Elementwise node // Elementwise node
std::shared_ptr<ge::Operator> elementwise_node = nullptr; std::shared_ptr<Node> elt_node = nullptr;
if (op_type == "elementwise_add" || if (op_type == "elementwise_add" ||
op_type == "fusion_elementwise_add_activation") { op_type == "fusion_elementwise_add_activation") {
auto elt_node = graph->AddNode<ge::op::Add>(out_name); elt_node = graph->Add<ge::op::Add>(out_name);
elt_node->set_input_x1(*x_node); auto elt_op = elt_node->data<ge::op::Add>();
elt_node->set_input_x2(*y_node); elt_op->set_input_x1(*x_node->data());
elementwise_node = elt_node; elt_op->set_input_x2(*y_node->data());
} else if (op_type == "elementwise_sub") { } else if (op_type == "elementwise_sub") {
auto elt_node = graph->AddNode<ge::op::Sub>(out_name); elt_node = graph->Add<ge::op::Sub>(out_name);
elt_node->set_input_x1(*x_node); auto elt_op = elt_node->data<ge::op::Sub>();
elt_node->set_input_x2(*y_node); elt_op->set_input_x1(*x_node->data());
elementwise_node = elt_node; elt_op->set_input_x2(*y_node->data());
} else if (op_type == "elementwise_mul") { } else if (op_type == "elementwise_mul") {
auto elt_node = graph->AddNode<ge::op::Mul>(out_name); elt_node = graph->Add<ge::op::Mul>(out_name);
elt_node->set_input_x(*x_node); auto elt_op = elt_node->data<ge::op::Mul>();
elt_node->set_input_y(*y_node); elt_op->set_input_x(*x_node->data());
elementwise_node = elt_node; elt_op->set_input_y(*y_node->data());
} else if (op_type == "elementwise_div") { } else if (op_type == "elementwise_div") {
auto elt_node = graph->AddNode<ge::op::RealDiv>(out_name); elt_node = graph->Add<ge::op::RealDiv>(out_name);
elt_node->set_input_x1(*x_node); auto elt_op = elt_node->data<ge::op::RealDiv>();
elt_node->set_input_x2(*y_node); elt_op->set_input_x1(*x_node->data());
elementwise_node = elt_node; elt_op->set_input_x2(*y_node->data());
} else { } else {
LOG(WARNING) << "[NPU] Unsupported op type: " << op_type; LOG(WARNING) << "[NPU] Unsupported op type: " << op_type;
return FAILED; return FAILED;
...@@ -121,11 +121,12 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -121,11 +121,12 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// Act node // Act node
if (op_type == "fusion_elementwise_add_activation") { if (op_type == "fusion_elementwise_add_activation") {
auto act_type = op_info->GetAttr<std::string>("act_type"); auto act_type = op_info->GetAttr<std::string>("act_type");
auto act_node = graph->AddNode<ge::op::Activation>(out_name); auto act_node = graph->Add<ge::op::Activation>(out_name);
act_node->set_input_x(*elementwise_node); auto act_op = act_node->data<ge::op::Activation>();
act_op->set_input_x(*elt_node->data());
// TODO(hong19860320) set the coef value for act Ops, such as leaky_relu, // TODO(hong19860320) set the coef value for act Ops, such as leaky_relu,
// clipped_relu etc. // clipped_relu etc.
act_node->set_attr_mode(CvtActMode(act_type)); act_op->set_attr_mode(CvtActMode(act_type));
} }
return REBUILD_WHEN_SHAPE_CHANGED; return REBUILD_WHEN_SHAPE_CHANGED;
} }
...@@ -135,18 +136,18 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -135,18 +136,18 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(elementwise_add,
elementwise_add, kNPU,
paddle::lite::subgraph::npu::ElementwiseConverter); paddle::lite::subgraph::npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(fusion_elementwise_add_activation,
fusion_elementwise_add_activation, kNPU,
paddle::lite::subgraph::npu::ElementwiseConverter); paddle::lite::subgraph::npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(elementwise_sub,
elementwise_sub, kNPU,
paddle::lite::subgraph::npu::ElementwiseConverter); paddle::lite::subgraph::npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(elementwise_mul,
elementwise_mul, kNPU,
paddle::lite::subgraph::npu::ElementwiseConverter); paddle::lite::subgraph::npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(elementwise_div,
elementwise_div, kNPU,
paddle::lite::subgraph::npu::ElementwiseConverter); paddle::lite::subgraph::npu::ElementwiseConverter);
...@@ -57,22 +57,24 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -57,22 +57,24 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) {
<< " m: " << m << " k: " << k << " n: " << n; << " m: " << m << " k: " << k << " n: " << n;
// Create input node and reshape it to (m, k, 1, 1) // Create input node and reshape it to (m, k, 1, 1)
std::shared_ptr<ge::Operator> input_node = nullptr; std::shared_ptr<Node> input_node = nullptr;
if (graph->HasNode(input_name)) { if (graph->Has(input_name)) {
input_node = graph->GetNode(input_name); input_node = graph->Get(input_name);
} else { } else {
input_node = graph->AddNode(input_name, input_dims); input_node = graph->Add(input_name, *input);
} }
auto reshaped_input_node = auto reshaped_input_node =
graph->AddNode<ge::op::Reshape>(input_name + "/reshape"); graph->Add<ge::op::Reshape>(input_name + "/reshape");
reshaped_input_node->set_input_tensor(*input_node); auto reshaped_input_op = reshaped_input_node->data<ge::op::Reshape>();
reshaped_input_node->set_attr_shape({m, k, 1, 1}); reshaped_input_op->set_input_tensor(*input_node->data());
reshaped_input_node->set_attr_axis(0); reshaped_input_op->set_attr_shape({m, k, 1, 1});
reshaped_input_op->set_attr_axis(0);
// Create w const node, set its shape to (n, k, 1, 1) and fill with // Create w const node, set its shape to (n, k, 1, 1) and fill with
// the transposed w tensor // the transposed w tensor
Tensor transpose_w; Tensor transpose_w;
transpose_w.Resize({n, k, 1, 1}); transpose_w.Resize({n, k, 1, 1});
transpose_w.set_persistable(true);
auto transpose_w_data = transpose_w.mutable_data<float>(); auto transpose_w_data = transpose_w.mutable_data<float>();
auto w_data = w->mutable_data<float>(); auto w_data = w->mutable_data<float>();
for (int i = 0; i < k; i++) { for (int i = 0; i < k; i++) {
...@@ -80,29 +82,36 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -80,29 +82,36 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) {
transpose_w_data[j * k + i] = w_data[i * n + j]; transpose_w_data[j * k + i] = w_data[i * n + j];
} }
} }
auto trans_w_const_node = graph->AddNode(w_name, transpose_w); auto trans_w_node = graph->Add(w_name, transpose_w);
// FC node // FC node
auto fc_node = graph->AddNode<ge::op::FullConnection>(out_name + "/fc"); auto fc_node = graph->Add<ge::op::FullConnection>(out_name + "/fc");
fc_node->set_input_x(*reshaped_input_node); auto fc_op = fc_node->data<ge::op::FullConnection>();
fc_node->set_input_w(*trans_w_const_node); fc_op->set_input_x(*reshaped_input_node->data());
fc_op->set_input_w(*trans_w_node->data());
// Add bias node if bias tensor exists // Add bias node if bias tensor exists
if (HasInputArg(op_info, scope, "Bias")) { if (HasInputArg(op_info, scope, "Bias")) {
std::shared_ptr<Node> bias_node = nullptr;
auto bias_name = op_info->Input("Bias").front(); auto bias_name = op_info->Input("Bias").front();
auto bias_type = kernel->GetInputDeclType("Bias"); if (graph->Has(bias_name)) {
CHECK(bias_type->precision() == PRECISION(kFloat)); bias_node = graph->Get(bias_name);
CHECK(bias_type->layout() == DATALAYOUT(kNCHW)); } else {
auto bias = scope->FindMutableTensor(bias_name); auto bias_type = kernel->GetInputDeclType("Bias");
auto bias_dims = bias->dims(); CHECK(bias_type->precision() == PRECISION(kFloat));
CHECK_EQ(bias_dims.production(), n); CHECK(bias_type->layout() == DATALAYOUT(kNCHW));
auto bias_const_node = graph->AddNode(bias_name, *bias, {1, n, 1, 1}); auto bias = scope->FindMutableTensor(bias_name);
fc_node->set_input_b(*bias_const_node); auto bias_dims = bias->dims();
CHECK_EQ(bias_dims.production(), n);
bias_node = graph->Add(bias_name, *bias, {1, n, 1, 1});
}
fc_op->set_input_b(*bias_node->data());
} }
// Reshape output of FC node from (m, n, 1, 1) to (m, n) // Reshape output of FC node from (m, n, 1, 1) to (m, n)
auto reshaped_fc_node = graph->AddNode<ge::op::Reshape>(out_name); auto reshaped_fc_node = graph->Add<ge::op::Reshape>(out_name);
reshaped_fc_node->set_input_tensor(*fc_node); auto reshaped_fc_op = reshaped_fc_node->data<ge::op::Reshape>();
reshaped_fc_node->set_attr_shape({m, n}); reshaped_fc_op->set_input_tensor(*fc_node->data());
reshaped_fc_node->set_attr_axis(0); reshaped_fc_op->set_attr_shape({m, n});
reshaped_fc_op->set_attr_axis(0);
return REBUILD_WHEN_SHAPE_CHANGED; return REBUILD_WHEN_SHAPE_CHANGED;
} }
...@@ -111,4 +120,4 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -111,4 +120,4 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU, fc, paddle::lite::subgraph::npu::FCConverter); REGISTER_SUBGRAPH_BRIDGE(fc, kNPU, paddle::lite::subgraph::npu::FCConverter);
...@@ -21,26 +21,52 @@ namespace lite { ...@@ -21,26 +21,52 @@ namespace lite {
namespace subgraph { namespace subgraph {
namespace npu { namespace npu {
// Const node int Graph::Add(const std::string& name, std::shared_ptr<Node> node) {
std::shared_ptr<ge::op::Const> Graph::AddNode(const std::string& name, auto it = nodes_.find(name);
const Tensor& tensor, if (it != nodes_.end()) {
std::vector<int64_t> shape, // Only variable node can be shared with the same name
PrecisionType precision, if (!node->is_var() || !it->second.back()->is_var()) {
DataLayoutType layout) { LOG(FATAL) << "[NPU] Const or data node " << name << " is redefined.";
auto node = AddNode<ge::op::Const>(name, precision, layout); return -1;
node->set_attr_value(CvtTensor(tensor, shape, precision, layout)); }
} else {
auto ret = nodes_.insert(
std::make_pair(name, std::vector<std::shared_ptr<Node>>()));
CHECK(ret.second);
it = ret.first;
}
it->second.push_back(node);
return it->second.size();
}
// Const or data node
std::shared_ptr<Node> Graph::Add(const std::string& name,
const Tensor& tensor,
std::vector<int64_t> shape,
PrecisionType precision,
DataLayoutType layout) {
std::shared_ptr<Node> node = nullptr;
if (tensor.persistable()) {
// Const node
node = Add<ge::op::Const>(name, precision, layout);
node->data<ge::op::Const>()->set_attr_value(
CvtTensor(tensor, shape, precision, layout));
} else {
// Data node
node = Add(name, shape, precision, layout);
}
return node; return node;
} }
// Data node // Data node
std::shared_ptr<ge::op::Data> Graph::AddNode(const std::string& name, std::shared_ptr<Node> Graph::Add(const std::string& name,
std::vector<int64_t> shape, std::vector<int64_t> shape,
PrecisionType precision, PrecisionType precision,
DataLayoutType layout) { DataLayoutType layout) {
auto node = AddNode<ge::op::Data>(name); auto node = Add<ge::op::Data>(name, precision, layout);
ge::TensorDesc desc( ge::TensorDesc desc(
ge::Shape(shape), CvtDataLayoutType(layout), CvtPrecisionType(precision)); ge::Shape(shape), CvtDataLayoutType(layout), CvtPrecisionType(precision));
node->update_input_desc_x(desc); node->data<ge::op::Data>()->update_input_desc_x(desc);
return node; return node;
} }
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "ai_ddk_lib/include/graph/op/all_ops.h" #include "graph/op/all_ops.h"
#include "lite/core/op_lite.h" #include "lite/core/op_lite.h"
#include "lite/core/tensor.h" #include "lite/core/tensor.h"
...@@ -28,94 +28,97 @@ namespace lite { ...@@ -28,94 +28,97 @@ namespace lite {
namespace subgraph { namespace subgraph {
namespace npu { namespace npu {
// Type of graph nodes // Graph and node is defined to collect all of converted HiAI IR nodes
class Type { class Node {
public: public:
Type(PrecisionType precision = PRECISION(kFloat), enum class Role {
DataLayoutType layout = DATALAYOUT(kNCHW), kVar = 0,
bool persistable = false) kConst,
: precision_(precision), layout_(layout), persistable_(persistable) {} kData,
};
Node(std::shared_ptr<ge::Operator> data,
PrecisionType precision,
DataLayoutType layout,
Role role)
: data_(data), precision_(precision), layout_(layout), role_(role) {}
Node(PrecisionType precision, DataLayoutType layout, Role role)
: precision_(precision), layout_(layout), role_(role) {}
void set_data(std::shared_ptr<ge::Operator> data) { data_ = data; }
void set_precision(PrecisionType precision) { precision_ = precision; } void set_precision(PrecisionType precision) { precision_ = precision; }
void set_layout(DataLayoutType layout) { layout_ = layout; } void set_layout(DataLayoutType layout) { layout_ = layout; }
bool set_persistable(bool persistable) { persistable_ = persistable; } void set_role(Role role) { role_ = role; }
template <typename T>
std::shared_ptr<T> data() {
return std::static_pointer_cast<T>(data_);
}
std::shared_ptr<ge::Operator> data() { return data_; }
PrecisionType precision() const { return precision_; } PrecisionType precision() const { return precision_; }
DataLayoutType layout() const { return layout_; } DataLayoutType layout() const { return layout_; }
bool persistable() const { return persistable_; } bool is_var() const { return role_ == Role::kVar; }
bool is_const() const { return role_ == Role::kConst; }
bool is_data() const { return role_ == Role::kData; }
private: private:
std::shared_ptr<ge::Operator> data_{nullptr};
PrecisionType precision_{PRECISION(kFloat)}; PrecisionType precision_{PRECISION(kFloat)};
DataLayoutType layout_{DATALAYOUT(kNCHW)}; DataLayoutType layout_{DATALAYOUT(kNCHW)};
bool persistable_{false}; Role role_{Role::kVar};
}; };
// Graph to collect all of converted HiAI IR nodes
class Graph { class Graph {
public: public:
int Add(const std::string& name, std::shared_ptr<Node> node);
// Variable, const or data node
template <typename T> template <typename T>
std::shared_ptr<T> AddNode(const std::string& name, std::shared_ptr<Node> Add(const std::string& name,
PrecisionType precision = PRECISION(kFloat), PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) { DataLayoutType layout = DATALAYOUT(kNCHW)) {
auto unique_name = [&](const std::string& key) { Node::Role role = Node::Role::kVar;
int idx = 1; if (typeid(T) == typeid(ge::op::Const)) {
auto it = counts_.find(key); role = Node::Role::kConst;
if (it == counts_.end()) { } else if (typeid(T) == typeid(ge::op::Data)) {
counts_.insert(std::make_pair(key, idx)); role = Node::Role::kData;
} else {
idx = ++(it->second);
}
return key + "_" + std::to_string(idx);
};
bool persistable = typeid(T) == typeid(ge::op::Const);
auto it = nodes_.find(name);
if (it != nodes_.end()) {
// Only variable can rebind the name
CHECK(!it->second.second.persistable() && !persistable)
<< "[NPU] Node " << name << " redefined.";
// Generate a new unique name as the key to bind the origin node:
// new_name->node
nodes_.insert(std::make_pair(unique_name(name + "_var"), it->second));
nodes_.erase(it);
} }
// Create a new node and bind with the name: name->new_node auto node = std::make_shared<Node>(precision, layout, role);
auto node = std::make_shared<T>(unique_name(name + "_op")); auto idx = Add(name, node);
nodes_.insert(std::make_pair( CHECK_GE(idx, 1);
name, std::make_pair(node, Type(precision, layout, persistable)))); // Generate a unique name for the created HiAI IR
node->set_data(std::make_shared<T>(name + "__" + std::to_string(idx)));
return node; return node;
} }
// Const node // Const or data node
std::shared_ptr<ge::op::Const> AddNode( std::shared_ptr<Node> Add(const std::string& name,
const std::string& name, const Tensor& tensor,
const Tensor& tensor, std::vector<int64_t> shape,
PrecisionType precision = PRECISION(kFloat), PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) { DataLayoutType layout = DATALAYOUT(kNCHW));
return AddNode(name, tensor, tensor.dims().Vectorize(), precision, layout);
std::shared_ptr<Node> Add(const std::string& name,
const Tensor& tensor,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return Add(name, tensor, tensor.dims().Vectorize(), precision, layout);
} }
std::shared_ptr<ge::op::Const> AddNode( std::shared_ptr<Node> Add(const std::string& name,
const std::string& name, const Tensor& tensor,
const Tensor& tensor, DDim dims,
std::vector<int64_t> shape, PrecisionType precision = PRECISION(kFloat),
PrecisionType precision = PRECISION(kFloat), DataLayoutType layout = DATALAYOUT(kNCHW)) {
DataLayoutType layout = DATALAYOUT(kNCHW)); return Add(name, tensor, dims.Vectorize(), precision, layout);
std::shared_ptr<ge::op::Const> AddNode(
const std::string& name,
const Tensor& tensor,
DDim dims,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return AddNode(name, tensor, dims.Vectorize(), precision, layout);
} }
// Const node
template <typename T> template <typename T>
std::shared_ptr<ge::op::Const> AddNode( std::shared_ptr<Node> Add(const std::string& name,
const std::string& name, const std::vector<T>& data,
const std::vector<T>& data, std::vector<int64_t> shape = {},
std::vector<int64_t> shape = {}, DataLayoutType layout = DATALAYOUT(kNCHW)) {
DataLayoutType layout = DATALAYOUT(kNCHW)) {
const std::type_info& info = typeid(T); const std::type_info& info = typeid(T);
PrecisionType precision = PRECISION(kFloat); PrecisionType precision = PRECISION(kFloat);
if (info == typeid(float)) { if (info == typeid(float)) {
...@@ -138,78 +141,66 @@ class Graph { ...@@ -138,78 +141,66 @@ class Graph {
} }
Tensor tensor; Tensor tensor;
tensor.Resize(shape); tensor.Resize(shape);
tensor.set_persistable(true);
std::memcpy(reinterpret_cast<uint8_t*>(tensor.mutable_data<T>()), std::memcpy(reinterpret_cast<uint8_t*>(tensor.mutable_data<T>()),
reinterpret_cast<const uint8_t*>(data.data()), reinterpret_cast<const uint8_t*>(data.data()),
data.size() * sizeof(T)); data.size() * sizeof(T));
return AddNode(name, tensor, precision, layout); return Add(name, tensor, precision, layout);
} }
template <typename T> template <typename T>
std::shared_ptr<ge::op::Const> AddNode( std::shared_ptr<Node> Add(const std::string& name,
const std::string& name, const std::vector<T>& data,
const std::vector<T>& data, DDim dims,
DDim dims, DataLayoutType layout = DATALAYOUT(kNCHW)) {
DataLayoutType layout = DATALAYOUT(kNCHW)) { return Add(name, data, dims.Vectorize(), layout);
return AddNode(name, data, dims.Vectorize(), layout);
} }
template <typename T> template <typename T>
std::shared_ptr<ge::op::Const> AddNode( std::shared_ptr<Node> Add(const std::string& name,
const std::string& name, T value,
T value, std::vector<int64_t> shape = {1},
std::vector<int64_t> shape = {1}, DataLayoutType layout = DATALAYOUT(kNCHW)) {
DataLayoutType layout = DATALAYOUT(kNCHW)) {
int64_t size = 1; int64_t size = 1;
for (auto i : shape) { for (auto i : shape) {
size *= i; size *= i;
} }
std::vector<T> data(size, value); std::vector<T> data(size, value);
return AddNode(name, data, shape, layout); return Add(name, data, shape, layout);
} }
template <typename T> template <typename T>
std::shared_ptr<ge::op::Const> AddNode( std::shared_ptr<Node> Add(const std::string& name,
const std::string& name, T value,
T value, DDim dims,
DDim dims, DataLayoutType layout = DATALAYOUT(kNCHW)) {
DataLayoutType layout = DATALAYOUT(kNCHW)) { return Add(name, value, dims.Vectorize(), layout);
return AddNode(name, value, dims.Vectorize(), layout);
} }
// Data node // Data node
std::shared_ptr<ge::op::Data> AddNode( std::shared_ptr<Node> Add(const std::string& name,
const std::string& name, std::vector<int64_t> shape,
std::vector<int64_t> shape, PrecisionType precision = PRECISION(kFloat),
PrecisionType precision = PRECISION(kFloat), DataLayoutType layout = DATALAYOUT(kNCHW));
DataLayoutType layout = DATALAYOUT(kNCHW));
std::shared_ptr<Node> Add(const std::string& name,
std::shared_ptr<ge::op::Data> AddNode( DDim dims,
const std::string& name, PrecisionType precision = PRECISION(kFloat),
DDim dims, DataLayoutType layout = DATALAYOUT(kNCHW)) {
PrecisionType precision = PRECISION(kFloat), return Add(name, dims.Vectorize(), precision, layout);
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return AddNode(name, dims.Vectorize(), precision, layout);
}
std::shared_ptr<ge::Operator> GetNode(std::string name) {
CHECK(HasNode(name)) << "[NPU] Node " << name << " not found.";
return nodes_.at(name).first;
} }
const Type& GetType(const std::string& name) { std::shared_ptr<Node> Get(std::string name) {
CHECK(HasNode(name)) << "[NPU] Node " << name << " not found."; CHECK(Has(name)) << "[NPU] Node " << name << " not found.";
return nodes_.at(name).second; return nodes_.at(name).back();
} }
bool HasNode(const std::string& name) { bool Has(const std::string& name) {
return nodes_.find(name) != nodes_.end(); return nodes_.find(name) != nodes_.end();
} }
private: private:
std::unordered_map<std::string, std::unordered_map<std::string, std::vector<std::shared_ptr<Node>>> nodes_;
std::pair<std::shared_ptr<ge::Operator>, Type>>
nodes_;
std::unordered_map<std::string, int> counts_;
}; };
} // namespace npu } // namespace npu
......
...@@ -55,11 +55,11 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -55,11 +55,11 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) {
"supported in HiAI DDK"; "supported in HiAI DDK";
// X node // X node
std::shared_ptr<ge::Operator> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
// Priority: OutSize > scale > out_h/out_w // Priority: OutSize > scale > out_h/out_w
...@@ -71,17 +71,18 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -71,17 +71,18 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} }
// Update out_h and out_w and create out_size node if has OutSize // Update out_h and out_w and create out_size node if has OutSize
std::shared_ptr<ge::Operator> out_size_node = nullptr; std::shared_ptr<Node> out_size_node = nullptr;
if (HasInputArg(op_info, scope, "OutSize")) { if (HasInputArg(op_info, scope, "OutSize")) {
auto out_size_name = op_info->Input("OutSize").front(); auto out_size_name = op_info->Input("OutSize").front();
auto out_size_type = kernel->GetInputDeclType("OutSize"); auto out_size_type = kernel->GetInputDeclType("OutSize");
CHECK(out_size_type->precision() == PRECISION(kInt32)); CHECK(out_size_type->precision() == PRECISION(kInt32));
CHECK(out_size_type->layout() == DATALAYOUT(kNCHW)); CHECK(out_size_type->layout() == DATALAYOUT(kNCHW));
if (graph->HasNode(out_size_name)) { if (graph->Has(out_size_name)) {
out_size_node = graph->GetNode(out_size_name); out_size_node = graph->Get(out_size_name);
} else { } else {
auto out_size = scope->FindMutableTensor(out_size_name); auto out_size = scope->FindMutableTensor(out_size_name);
CHECK_EQ(out_size->numel(), 2); CHECK_EQ(out_size->numel(), 2);
CHECK(out_size->persistable());
auto out_size_data = out_size->mutable_data<int>(); auto out_size_data = out_size->mutable_data<int>();
// Update out_h and out_w if has OutSize // Update out_h and out_w if has OutSize
out_h = out_size_data[0]; out_h = out_size_data[0];
...@@ -97,22 +98,25 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -97,22 +98,25 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) {
<< " is too large, should not exceed " << largest_multiple << " is too large, should not exceed " << largest_multiple
<< " in HiAI DDK"; << " in HiAI DDK";
} }
out_size_node = graph->AddNode(out_name + "/out_size", out_size_node =
std::vector<int>({out_h, out_w})); graph->Add(out_name + "/out_size", std::vector<int>({out_h, out_w}));
} }
if (interp_method == "bilinear") { if (interp_method == "bilinear") {
auto bilinear_interp_node = auto bilinear_interp_node = graph->Add<ge::op::ResizeBilinear>(out_name);
graph->AddNode<ge::op::ResizeBilinear>(out_name); auto bilinear_interp_op =
bilinear_interp_node->set_input_x(*x_node); bilinear_interp_node->data<ge::op::ResizeBilinear>();
bilinear_interp_node->set_input_size(*out_size_node); bilinear_interp_op->set_input_x(*x_node->data());
bilinear_interp_node->set_attr_align_corners(align_corners); bilinear_interp_op->set_input_size(*out_size_node->data());
bilinear_interp_op->set_attr_align_corners(align_corners);
} else if (interp_method == "nearest") { } else if (interp_method == "nearest") {
auto nearest_interp_node = auto nearest_interp_node =
graph->AddNode<ge::op::ResizeNearestNeighbor>(out_name); graph->Add<ge::op::ResizeNearestNeighbor>(out_name);
nearest_interp_node->set_input_image(*x_node); auto nearest_interp_op =
nearest_interp_node->set_input_size(*out_size_node); nearest_interp_node->data<ge::op::ResizeNearestNeighbor>();
nearest_interp_node->set_attr_align_corners(align_corners); nearest_interp_op->set_input_image(*x_node->data());
nearest_interp_op->set_input_size(*out_size_node->data());
nearest_interp_op->set_attr_align_corners(align_corners);
} else { } else {
LOG(WARNING) << "[NPU] Unsupported interpolate method: " << interp_method; LOG(WARNING) << "[NPU] Unsupported interpolate method: " << interp_method;
return FAILED; return FAILED;
...@@ -125,9 +129,9 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -125,9 +129,9 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(bilinear_interp,
bilinear_interp, kNPU,
paddle::lite::subgraph::npu::InterpolateConverter); paddle::lite::subgraph::npu::InterpolateConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(nearest_interp,
nearest_interp, kNPU,
paddle::lite::subgraph::npu::InterpolateConverter); paddle::lite::subgraph::npu::InterpolateConverter);
...@@ -56,45 +56,46 @@ int MulConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -56,45 +56,46 @@ int MulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
<< "[NPU] columns of X must be equal with rows of Y"; << "[NPU] columns of X must be equal with rows of Y";
int n = y_dims.Slice(y_num_col_dims, y_dims.size()).production(); int n = y_dims.Slice(y_num_col_dims, y_dims.size()).production();
VLOG(3) << "m:" << m << ",n:" << n << ",k:" << k; VLOG(3) << "m:" << m << ",n:" << n << ",k:" << k;
VLOG(3) << "x_name:" << x_name << ", is data: " << graph->HasNode(x_name); VLOG(3) << "x_name:" << x_name << ", is data: " << graph->Has(x_name);
VLOG(3) << "y_name:" << y_name << ", is data: " << graph->HasNode(y_name); VLOG(3) << "y_name:" << y_name << ", is data: " << graph->Has(y_name);
CHECK(graph->HasNode(x_name)) CHECK(graph->Has(x_name))
<< "[NPU] MatMul in HiAI DDK only support X is data, Y is const yet."; << "[NPU] MatMul in HiAI DDK only support X is data, Y is const yet.";
// X node which supports persistable and non-persistable tensor, and // X node which supports persistable and non-persistable tensor, and
// reshape to (m, k) // reshape to (m, k)
std::shared_ptr<ge::Operator> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
auto reshaped_x_node = graph->AddNode<ge::op::Reshape>(x_name + "/reshape"); auto reshaped_x_node = graph->Add<ge::op::Reshape>(x_name + "/reshape");
reshaped_x_node->set_input_tensor(*x_node); auto reshaped_x_op = reshaped_x_node->data<ge::op::Reshape>();
reshaped_x_node->set_attr_shape({m, k}); reshaped_x_op->set_input_tensor(*x_node->data());
reshaped_x_node->set_attr_axis(0); reshaped_x_op->set_attr_shape({m, k});
reshaped_x_op->set_attr_axis(0);
x_node = reshaped_x_node; x_node = reshaped_x_node;
} else { } else {
auto x_const_node = graph->AddNode(x_name, *x, {m, k}); x_node = graph->Add(x_name, *x, {m, k});
x_node = x_const_node;
} }
// Y node which only supports persistable tensor, and reshape to // Y node which only supports persistable tensor, and reshape to
// (k,n) // (k,n)
std::shared_ptr<ge::Operator> y_node = nullptr; std::shared_ptr<Node> y_node = nullptr;
if (graph->HasNode(y_name)) { if (graph->Has(y_name)) {
y_node = graph->GetNode(y_name); y_node = graph->Get(y_name);
auto reshaped_y_node = graph->AddNode<ge::op::Reshape>(y_name + "/reshape"); auto reshaped_y_node = graph->Add<ge::op::Reshape>(y_name + "/reshape");
reshaped_y_node->set_input_tensor(*y_node); auto reshaped_y_op = reshaped_y_node->data<ge::op::Reshape>();
reshaped_y_node->set_attr_shape({k, n}); reshaped_y_op->set_input_tensor(*y_node->data());
reshaped_y_node->set_attr_axis(0); reshaped_y_op->set_attr_shape({k, n});
reshaped_y_op->set_attr_axis(0);
y_node = reshaped_y_node; y_node = reshaped_y_node;
} else { } else {
auto y_const_node = graph->AddNode(y_name, *y, {k, n}); y_node = graph->Add(y_name, *y, {k, n});
y_node = y_const_node;
} }
// Matmul node // Matmul node
auto mul_node = graph->AddNode<ge::op::MatMul>(out_name); auto mul_node = graph->Add<ge::op::MatMul>(out_name);
mul_node->set_input_x1(*x_node); auto mul_op = mul_node->data<ge::op::MatMul>();
mul_node->set_input_x2(*y_node); mul_op->set_input_x1(*x_node->data());
mul_op->set_input_x2(*y_node->data());
return REBUILD_WHEN_SHAPE_CHANGED; return REBUILD_WHEN_SHAPE_CHANGED;
} }
...@@ -103,4 +104,4 @@ int MulConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -103,4 +104,4 @@ int MulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU, mul, paddle::lite::subgraph::npu::MulConverter); REGISTER_SUBGRAPH_BRIDGE(mul, kNPU, paddle::lite::subgraph::npu::MulConverter);
...@@ -45,35 +45,34 @@ int Pad2dConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -45,35 +45,34 @@ int Pad2dConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK_EQ(padding.size(), 4); CHECK_EQ(padding.size(), 4);
// X node // X node
std::shared_ptr<ge::Operator> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
// Padding node // Padding node
int xds = x_dims.size(); int xds = x_dims.size();
padding.insert(padding.begin(), xds * 2 - 4, 0); padding.insert(padding.begin(), xds * 2 - 4, 0);
auto padding_const_node = auto padding_node = graph->Add(out_name + "/padding", padding, {xds, 2});
graph->AddNode(out_name + "/padding", padding, {xds, 2});
// Pad node // Pad node
auto pad2d_node = graph->AddNode<ge::op::Pad>(out_name); auto pad2d_node = graph->Add<ge::op::Pad>(out_name);
pad2d_node->set_input_x(*x_node); auto pad2d_op = pad2d_node->data<ge::op::Pad>();
pad2d_node->set_input_padding(*padding_const_node); pad2d_op->set_input_x(*x_node->data());
pad2d_op->set_input_padding(*padding_node->data());
auto mode = op_info->GetAttr<std::string>("mode"); auto mode = op_info->GetAttr<std::string>("mode");
if (mode == "constant") { if (mode == "constant") {
// Pad value node // Pad value node
auto pad_value = op_info->GetAttr<float>("pad_value"); auto pad_value = op_info->GetAttr<float>("pad_value");
auto pad_value_const_node = auto pad_value_node = graph->Add(out_name + "/pad_value", pad_value);
graph->AddNode(out_name + "/pad_value", pad_value); pad2d_op->set_input_constant_values(*pad_value_node->data());
pad2d_node->set_input_constant_values(*pad_value_const_node); pad2d_op->set_attr_T(0); // type of pad_value: 0:float 3:int32
pad2d_node->set_attr_T(0); // type of pad_value: 0:float 3:int32 pad2d_op->set_attr_mode(0);
pad2d_node->set_attr_mode(0);
} else if (mode == "reflect") { } else if (mode == "reflect") {
LOG(WARNING) << "[NPU] pad mode " << mode << " isn't supported in HiAI DDK"; LOG(WARNING) << "[NPU] pad mode " << mode << " isn't supported in HiAI DDK";
pad2d_node->set_attr_mode(1); pad2d_op->set_attr_mode(1);
return FAILED; return FAILED;
} else { } else {
LOG(WARNING) << "[NPU] pad mode " << mode << " isn't supported in HiAI DDK"; LOG(WARNING) << "[NPU] pad mode " << mode << " isn't supported in HiAI DDK";
...@@ -87,6 +86,6 @@ int Pad2dConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -87,6 +86,6 @@ int Pad2dConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(pad2d,
pad2d, kNPU,
paddle::lite::subgraph::npu::Pad2dConverter); paddle::lite::subgraph::npu::Pad2dConverter);
...@@ -14,40 +14,40 @@ ...@@ -14,40 +14,40 @@
#pragma once #pragma once
USE_SUBGRAPH_BRIDGE(NPU, sigmoid); USE_SUBGRAPH_BRIDGE(sigmoid, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, relu); USE_SUBGRAPH_BRIDGE(relu, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, tanh); USE_SUBGRAPH_BRIDGE(tanh, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, relu_clipped); USE_SUBGRAPH_BRIDGE(relu_clipped, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, leaky_relu); USE_SUBGRAPH_BRIDGE(leaky_relu, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, softsign); USE_SUBGRAPH_BRIDGE(softsign, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, hard_sigmoid); USE_SUBGRAPH_BRIDGE(hard_sigmoid, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, batch_norm); USE_SUBGRAPH_BRIDGE(batch_norm, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, concat); USE_SUBGRAPH_BRIDGE(concat, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, conv2d); USE_SUBGRAPH_BRIDGE(conv2d, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, depthwise_conv2d); USE_SUBGRAPH_BRIDGE(depthwise_conv2d, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, conv2d_transpose); USE_SUBGRAPH_BRIDGE(conv2d_transpose, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, elementwise_add); USE_SUBGRAPH_BRIDGE(elementwise_add, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, fusion_elementwise_add_activation); USE_SUBGRAPH_BRIDGE(fusion_elementwise_add_activation, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, elementwise_sub); USE_SUBGRAPH_BRIDGE(elementwise_sub, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, elementwise_mul); USE_SUBGRAPH_BRIDGE(elementwise_mul, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, elementwise_div); USE_SUBGRAPH_BRIDGE(elementwise_div, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, fc); USE_SUBGRAPH_BRIDGE(fc, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, bilinear_interp); USE_SUBGRAPH_BRIDGE(bilinear_interp, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, nearest_interp); USE_SUBGRAPH_BRIDGE(nearest_interp, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, mul); USE_SUBGRAPH_BRIDGE(mul, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, pad2d); USE_SUBGRAPH_BRIDGE(pad2d, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, pool2d); USE_SUBGRAPH_BRIDGE(pool2d, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, reduce_mean); USE_SUBGRAPH_BRIDGE(reduce_mean, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, reshape); USE_SUBGRAPH_BRIDGE(reshape, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, reshape2); USE_SUBGRAPH_BRIDGE(reshape2, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, scale); USE_SUBGRAPH_BRIDGE(scale, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, shuffle_channel); USE_SUBGRAPH_BRIDGE(shuffle_channel, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, softmax); USE_SUBGRAPH_BRIDGE(softmax, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, split); USE_SUBGRAPH_BRIDGE(split, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, sqrt); USE_SUBGRAPH_BRIDGE(sqrt, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, square); USE_SUBGRAPH_BRIDGE(square, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, transpose); USE_SUBGRAPH_BRIDGE(transpose, kNPU);
USE_SUBGRAPH_BRIDGE(NPU, transpose2); USE_SUBGRAPH_BRIDGE(transpose2, kNPU);
...@@ -48,11 +48,11 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -48,11 +48,11 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto paddings = op_info->GetAttr<std::vector<int>>("paddings"); auto paddings = op_info->GetAttr<std::vector<int>>("paddings");
// X node // X node
std::shared_ptr<ge::Operator> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
// pool mode // pool mode
...@@ -109,19 +109,19 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -109,19 +109,19 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} }
// Pooling node // Pooling node
auto pool_node = graph->AddNode<ge::op::Pooling>(out_name); auto pool_node = graph->Add<ge::op::Pooling>(out_name);
pool_node->set_input_x(*x_node); auto pool_op = pool_node->data<ge::op::Pooling>();
pool_node->set_attr_mode(mode); pool_op->set_input_x(*x_node->data());
pool_node->set_attr_pad_mode(pad_mode); pool_op->set_attr_mode(mode);
pool_node->set_attr_global_pooling(global_pooling); pool_op->set_attr_pad_mode(pad_mode);
pool_node->set_attr_window( pool_op->set_attr_global_pooling(global_pooling);
ge::AttrValue::LIST_INT(ksize.begin(), ksize.end())); pool_op->set_attr_window(ge::AttrValue::LIST_INT(ksize.begin(), ksize.end()));
pool_node->set_attr_pad(ge::AttrValue::LIST_INT{ pool_op->set_attr_pad(ge::AttrValue::LIST_INT{
paddings[0], paddings[1], paddings[2], paddings[3]}); paddings[0], paddings[1], paddings[2], paddings[3]});
pool_node->set_attr_stride( pool_op->set_attr_stride(
ge::AttrValue::LIST_INT(strides.begin(), strides.end())); ge::AttrValue::LIST_INT(strides.begin(), strides.end()));
pool_node->set_attr_ceil_mode(ceil_mode); pool_op->set_attr_ceil_mode(ceil_mode);
// pool_node->set_attr_data_mode(data_mode); // pool_op->set_attr_data_mode(data_mode);
return REBUILD_WHEN_SHAPE_CHANGED; return REBUILD_WHEN_SHAPE_CHANGED;
} }
...@@ -130,6 +130,6 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -130,6 +130,6 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(pool2d,
pool2d, kNPU,
paddle::lite::subgraph::npu::PoolConverter); paddle::lite::subgraph::npu::PoolConverter);
...@@ -52,29 +52,30 @@ int ReduceMeanConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -52,29 +52,30 @@ int ReduceMeanConverter(void* ctx, OpLite* op, KernelBase* kernel) {
std::sort(dim.begin(), dim.end()); std::sort(dim.begin(), dim.end());
// X node // X node
std::shared_ptr<ge::Operator> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
// Using ReduceSum + Scale to implement ReduceMean // Using ReduceSum + Scale to implement ReduceMean
// Dim node // Dim node
auto dim_const_node = graph->AddNode(out_name + "/dim", dim); auto dim_node = graph->Add(out_name + "/dim", dim);
// Reduce Sum node // Reduce Sum node
auto reduce_sum_node = auto reduce_sum_node = graph->Add<ge::op::ReduceSum>(out_name + "/reducesum");
graph->AddNode<ge::op::ReduceSum>(out_name + "/reducesum"); auto reduce_sum_op = reduce_sum_node->data<ge::op::ReduceSum>();
reduce_sum_node->set_input_x(*x_node); reduce_sum_op->set_input_x(*x_node->data());
reduce_sum_node->set_input_w(*dim_const_node); reduce_sum_op->set_input_w(*dim_node->data());
reduce_sum_node->set_attr_keep_dims(keep_dim); reduce_sum_op->set_attr_keep_dims(keep_dim);
// Scale node // Scale node
auto scale_node = graph->AddNode<ge::op::Scale>(out_name); auto scale_node = graph->Add<ge::op::Scale>(out_name);
scale_node->set_input_x(*reduce_sum_node); auto scale_op = scale_node->data<ge::op::Scale>();
scale_node->set_attr_axis(1); scale_op->set_input_x(*reduce_sum_node->data());
scale_op->set_attr_axis(1);
// Add filter node(fill with scale) // Add filter node(fill with scale)
float scale = 1; float scale = 1;
...@@ -95,9 +96,8 @@ int ReduceMeanConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -95,9 +96,8 @@ int ReduceMeanConverter(void* ctx, OpLite* op, KernelBase* kernel) {
remove(scale_bias_shape.begin(), scale_bias_shape.end(), kDelFlag), remove(scale_bias_shape.begin(), scale_bias_shape.end(), kDelFlag),
scale_bias_shape.end()); scale_bias_shape.end());
} }
auto filter_const_node = auto filter_node = graph->Add(out_name + "/filter", scale, scale_bias_shape);
graph->AddNode(out_name + "/filter", scale, scale_bias_shape); scale_op->set_input_filter(*filter_node->data());
scale_node->set_input_filter(*filter_const_node);
return REBUILD_WHEN_SHAPE_CHANGED; return REBUILD_WHEN_SHAPE_CHANGED;
} }
...@@ -106,6 +106,6 @@ int ReduceMeanConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -106,6 +106,6 @@ int ReduceMeanConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(reduce_mean,
reduce_mean, kNPU,
paddle::lite::subgraph::npu::ReduceMeanConverter); paddle::lite::subgraph::npu::ReduceMeanConverter);
...@@ -24,27 +24,27 @@ Registry& Registry::Instance() { ...@@ -24,27 +24,27 @@ Registry& Registry::Instance() {
return x; return x;
} }
void Registry::Insert(const std::string& dev_type, void Registry::Insert(const std::string& op_type,
const std::string& op_type, const std::string& target,
const cvt_func_type& cvt_func_name) { const cvt_func_type& cvt_func_name) {
auto it = map_.find(dev_type); auto it = map_.find(target);
if (it == map_.end()) { if (it == map_.end()) {
map_.insert(std::make_pair( map_.insert(std::make_pair(
dev_type, std::unordered_map<std::string, cvt_func_type>())); target, std::unordered_map<std::string, cvt_func_type>()));
} }
map_.at(dev_type).insert(std::make_pair(op_type, cvt_func_name)); map_.at(target).insert(std::make_pair(op_type, cvt_func_name));
} }
const cvt_func_type& Registry::Select(const std::string& dev_type, const cvt_func_type& Registry::Select(const std::string& op_type,
const std::string& op_type) const { const std::string& target) const {
return map_.at(dev_type).at(op_type); return map_.at(target).at(op_type);
} }
bool Registry::Exists(const std::string& dev_type, bool Registry::Exists(const std::string& op_type,
const std::string& op_type) const { const std::string& target) const {
bool found = map_.find(dev_type) != map_.end(); bool found = map_.find(target) != map_.end();
if (found) { if (found) {
found = map_.at(dev_type).find(op_type) != map_.at(dev_type).end(); found = map_.at(target).find(op_type) != map_.at(target).end();
} }
return found; return found;
} }
......
...@@ -42,12 +42,12 @@ class Registry { ...@@ -42,12 +42,12 @@ class Registry {
public: public:
static Registry& Instance(); static Registry& Instance();
void Insert(const std::string& dev_type, void Insert(const std::string& op_type,
const std::string& op_type, const std::string& target,
const cvt_func_type& cvt_func_name); const cvt_func_type& cvt_func_name);
const cvt_func_type& Select(const std::string& dev_type, const cvt_func_type& Select(const std::string& op_type,
const std::string& op_type) const; const std::string& target) const;
bool Exists(const std::string& dev_type, const std::string& op_type) const; bool Exists(const std::string& op_type, const std::string& target) const;
Registry() = default; Registry() = default;
private: private:
...@@ -73,18 +73,18 @@ class Registry { ...@@ -73,18 +73,18 @@ class Registry {
__test_global_namespace_##uniq_name##__>::value, \ __test_global_namespace_##uniq_name##__>::value, \
msg) msg)
#define REGISTER_SUBGRAPH_BRIDGE(dev_type, op_type, cvt_func_name) \ #define REGISTER_SUBGRAPH_BRIDGE(op_type__, target__, cvt_func_name) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_subgraph_bridge_##dev_type##_##op_type##__, \ __reg_subgraph_bridge_##op_type__##_##target__##__, \
"REGISTER_SUBGRAPH_BRIDGE must be called in global namespace only " \ "REGISTER_SUBGRAPH_BRIDGE must be called in global namespace only " \
"once!"); \ "once!"); \
int __reg_subgraph_bridge_##dev_type##_##op_type##_Insert() { \ int __reg_subgraph_bridge_##op_type__##_##target__##_Insert() { \
paddle::lite::subgraph::Registry::Instance().Insert( \ paddle::lite::subgraph::Registry::Instance().Insert( \
#dev_type, #op_type, cvt_func_name); \ #op_type__, #target__, cvt_func_name); \
return 0; \ return 0; \
} }
#define USE_SUBGRAPH_BRIDGE(dev_type, op_type) \ #define USE_SUBGRAPH_BRIDGE(op_type__, target__) \
extern int __reg_subgraph_bridge_##dev_type##_##op_type##_Insert(); \ extern int __reg_subgraph_bridge_##op_type__##_##target__##_Insert(); \
static int __reg_subgraph_bridge_##dev_type##_##op_type##_Insert_return \ static int __reg_subgraph_bridge_##op_type__##_##target__##_Insert_return \
UNUSED = __reg_subgraph_bridge_##dev_type##_##op_type##_Insert(); UNUSED = __reg_subgraph_bridge_##op_type__##_##target__##_Insert();
...@@ -44,16 +44,17 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -44,16 +44,17 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(out_type->layout() == DATALAYOUT(kNCHW)); CHECK(out_type->layout() == DATALAYOUT(kNCHW));
// X node // X node
std::shared_ptr<ge::Operator> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
// Reshape node // Reshape node
auto reshape_node = graph->AddNode<ge::op::Reshape>(out_name); auto reshape_node = graph->Add<ge::op::Reshape>(out_name);
reshape_node->set_input_tensor(*x_node); auto reshape_op = reshape_node->data<ge::op::Reshape>();
reshape_op->set_input_tensor(*x_node->data());
// Read shape from "ShapeTensor"(input), or "Shape"(input), or "shape"(attr) // Read shape from "ShapeTensor"(input), or "Shape"(input), or "shape"(attr)
if (HasInputArg(op_info, scope, "ShapeTensor")) { if (HasInputArg(op_info, scope, "ShapeTensor")) {
...@@ -64,9 +65,9 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -64,9 +65,9 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// auto actual_shape_type = kernel->GetInputDeclType("Shape"); // auto actual_shape_type = kernel->GetInputDeclType("Shape");
// CHECK(actual_shape_type->precision() == PRECISION(kInt32)); // CHECK(actual_shape_type->precision() == PRECISION(kInt32));
// CHECK(actual_shape_type->layout() == DATALAYOUT(kNCHW)); // CHECK(actual_shape_type->layout() == DATALAYOUT(kNCHW));
std::shared_ptr<ge::Operator> actual_shape_node = nullptr; std::shared_ptr<Node> actual_shape_node = nullptr;
if (graph->HasNode(actual_shape_name)) { if (graph->Has(actual_shape_name)) {
actual_shape_node = graph->GetNode(actual_shape_name); actual_shape_node = graph->Get(actual_shape_name);
} else { } else {
auto actual_shape = scope->FindMutableTensor(actual_shape_name); auto actual_shape = scope->FindMutableTensor(actual_shape_name);
auto actual_shape_dims = actual_shape->dims(); auto actual_shape_dims = actual_shape->dims();
...@@ -81,12 +82,11 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -81,12 +82,11 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
"but Shape has " "but Shape has "
<< out_shape.size(); << out_shape.size();
} }
auto actual_shape_const_node = actual_shape_node =
graph->AddNode(actual_shape_name, graph->Add(actual_shape_name,
std::vector<int>(out_shape.begin(), out_shape.end())); std::vector<int>(out_shape.begin(), out_shape.end()));
actual_shape_node = actual_shape_const_node;
} }
reshape_node->set_input_w(*actual_shape_node); reshape_op->set_input_w(*actual_shape_node->data());
} else { } else {
auto shape = op_info->GetAttr<std::vector<int>>("shape"); auto shape = op_info->GetAttr<std::vector<int>>("shape");
auto out_dims = lite::operators::ValidateShape(shape, x_dims); auto out_dims = lite::operators::ValidateShape(shape, x_dims);
...@@ -96,7 +96,7 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -96,7 +96,7 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
"but shape has " "but shape has "
<< out_shape.size(); << out_shape.size();
} }
reshape_node->set_attr_shape( reshape_op->set_attr_shape(
ge::AttrValue::LIST_INT(out_shape.begin(), out_shape.end())); ge::AttrValue::LIST_INT(out_shape.begin(), out_shape.end()));
} }
...@@ -117,9 +117,10 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -117,9 +117,10 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// auto xshape_type = kernel->GetOutputDeclType("XShape"); // auto xshape_type = kernel->GetOutputDeclType("XShape");
// CHECK(xshape_type->precision() == PRECISION(kFloat)); // CHECK(xshape_type->precision() == PRECISION(kFloat));
// CHECK(xshape_type->layout() == DATALAYOUT(kNCHW)); // CHECK(xshape_type->layout() == DATALAYOUT(kNCHW));
auto xshape_node = graph->AddNode<ge::op::Reshape>(xshape_name); auto xshape_node = graph->Add<ge::op::Reshape>(xshape_name);
xshape_node->set_input_tensor(*x_node); auto xshape_op = xshape_node->data<ge::op::Reshape>();
xshape_node->set_attr_shape( xshape_op->set_input_tensor(*x_node->data());
xshape_op->set_attr_shape(
ge::AttrValue::LIST_INT(xshape_dims.begin(), xshape_dims.end())); ge::AttrValue::LIST_INT(xshape_dims.begin(), xshape_dims.end()));
} }
return REBUILD_WHEN_SHAPE_CHANGED; return REBUILD_WHEN_SHAPE_CHANGED;
...@@ -130,9 +131,9 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -130,9 +131,9 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(reshape,
reshape, kNPU,
paddle::lite::subgraph::npu::ReshapeConverter); paddle::lite::subgraph::npu::ReshapeConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(reshape2,
reshape2, kNPU,
paddle::lite::subgraph::npu::ReshapeConverter); paddle::lite::subgraph::npu::ReshapeConverter);
...@@ -37,12 +37,15 @@ int ScaleConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -37,12 +37,15 @@ int ScaleConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(x_type->layout() == DATALAYOUT(kNCHW)); CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name); auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims(); auto x_dims = x->dims();
CHECK_GE(x_dims.size(), 2); auto x_rank = x_dims.size();
CHECK_GE(x_rank, 2);
auto out_name = op_info->Output("Out").front(); auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out"); auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat)); CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW)); CHECK(out_type->layout() == DATALAYOUT(kNCHW));
std::vector<int64_t> scale_bias_shape = {x_dims[1]}; // HiAI only support [n, c, 1, 1] for the shape of scale and bias
std::vector<int64_t> scale_bias_shape = {
1, x_rank < 3 ? 1 : x_dims[x_rank - 3], 1, 1};
float scale = op_info->GetAttr<float>("scale"); float scale = op_info->GetAttr<float>("scale");
float bias = op_info->GetAttr<float>("bias"); float bias = op_info->GetAttr<float>("bias");
bool bias_after_scale = op_info->GetAttr<bool>("bias_after_scale"); bool bias_after_scale = op_info->GetAttr<bool>("bias_after_scale");
...@@ -51,29 +54,28 @@ int ScaleConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -51,29 +54,28 @@ int ScaleConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} }
// X node // X node
std::shared_ptr<ge::Operator> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x, CvtShape(x_dims));
} }
// Scale node // Scale node
auto scale_node = graph->AddNode<ge::op::Scale>(out_name); auto scale_node = graph->Add<ge::op::Scale>(out_name);
scale_node->set_input_x(*x_node); auto scale_op = scale_node->data<ge::op::Scale>();
scale_node->set_attr_axis(1); scale_op->set_input_x(*x_node->data());
scale_op->set_attr_axis(1);
// Add filter node(fill with scale) // Add filter node(fill with scale)
auto filter_const_node = auto filter_node = graph->Add(out_name + "/filter", scale, scale_bias_shape);
graph->AddNode(out_name + "/filter", scale, scale_bias_shape); scale_op->set_input_filter(*filter_node->data());
scale_node->set_input_filter(*filter_const_node);
// Add bias node(fill with bias) // Add bias node(fill with bias)
if (fabs(bias) > 1e-6f) { if (fabs(bias) > 1e-6f) {
auto bias_const_node = auto bias_node = graph->Add(out_name + "/bias", bias, scale_bias_shape);
graph->AddNode(out_name + "/bias", bias, scale_bias_shape); scale_op->set_input_bias(*bias_node->data());
scale_node->set_input_bias(*bias_const_node); scale_op->set_attr_has_bias_value(true);
scale_node->set_attr_has_bias_value(true);
} }
return REBUILD_WHEN_SHAPE_CHANGED; return REBUILD_WHEN_SHAPE_CHANGED;
} }
...@@ -83,6 +85,6 @@ int ScaleConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -83,6 +85,6 @@ int ScaleConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(scale,
scale, kNPU,
paddle::lite::subgraph::npu::ScaleConverter); paddle::lite::subgraph::npu::ScaleConverter);
...@@ -44,17 +44,19 @@ int ShuffleChannelConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -44,17 +44,19 @@ int ShuffleChannelConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto group = op_info->GetAttr<int>("group"); auto group = op_info->GetAttr<int>("group");
// X node // X node
std::shared_ptr<ge::Operator> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
// Shuffle Channel node // Shuffle Channel node
auto shuffle_channel_node = graph->AddNode<ge::op::ShuffleChannel>(out_name); auto shuffle_channel_node = graph->Add<ge::op::ShuffleChannel>(out_name);
shuffle_channel_node->set_input_x(*x_node); auto shuffle_channel_op =
shuffle_channel_node->set_attr_group(group); shuffle_channel_node->data<ge::op::ShuffleChannel>();
shuffle_channel_op->set_input_x(*x_node->data());
shuffle_channel_op->set_attr_group(group);
return SUCCESS; return SUCCESS;
} }
...@@ -63,6 +65,6 @@ int ShuffleChannelConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -63,6 +65,6 @@ int ShuffleChannelConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(shuffle_channel,
shuffle_channel, kNPU,
paddle::lite::subgraph::npu::ShuffleChannelConverter); paddle::lite::subgraph::npu::ShuffleChannelConverter);
...@@ -37,29 +37,34 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -37,29 +37,34 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(x_type->layout() == DATALAYOUT(kNCHW)); CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name); auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims(); auto x_dims = x->dims();
auto x_rank = x_dims.size();
auto out_name = op_info->Output("Out").front(); auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out"); auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat)); CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW)); CHECK(out_type->layout() == DATALAYOUT(kNCHW));
auto axis = op_info->GetAttr<int>("axis"); auto axis = op_info->GetAttr<int>("axis");
if (x_dims.size() > 3) { if (axis < 0) {
CHECK(!(axis == 2 && x_dims[3] > 1)) axis += x_rank;
<< "[NPU] Unsupported softmax params: axis = " << axis }
<< " :x_w = " << x_dims[3]; if (axis == 2 && x_rank > 3 && x_dims[3] != 1) {
LOG(WARNING) << "[NPU] Unsupported softmax params: axis = " << axis
<< " :x_w = " << x_dims[3];
return FAILED;
} }
// X node // X node
std::shared_ptr<ge::Operator> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
// Softmax node // Softmax node
auto softmax_node = graph->AddNode<ge::op::Softmax>(out_name); auto softmax_node = graph->Add<ge::op::Softmax>(out_name);
softmax_node->set_input_x(*x_node); auto softmax_op = softmax_node->data<ge::op::Softmax>();
softmax_node->set_attr_axis(axis); softmax_op->set_input_x(*x_node->data());
softmax_op->set_attr_axis(axis);
return REBUILD_WHEN_SHAPE_CHANGED; return REBUILD_WHEN_SHAPE_CHANGED;
} }
...@@ -68,6 +73,6 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -68,6 +73,6 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(softmax,
softmax, kNPU,
paddle::lite::subgraph::npu::SoftmaxConverter); paddle::lite::subgraph::npu::SoftmaxConverter);
...@@ -47,33 +47,34 @@ int SplitConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -47,33 +47,34 @@ int SplitConverter(void* ctx, OpLite* op, KernelBase* kernel) {
int64_t sections_num = static_cast<int64_t>(sections.size()); int64_t sections_num = static_cast<int64_t>(sections.size());
// X node // X node
std::shared_ptr<ge::Operator> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
// Split node // Split node
auto split_node = graph->AddNode<ge::op::Split>(op_type + "/" + x_name); auto split_node = graph->Add<ge::op::Split>(op_type + "/" + x_name);
split_node->set_input_x(*x_node); auto split_op = split_node->data<ge::op::Split>();
split_node->set_attr_axis(static_cast<int64_t>(axis)); split_op->set_input_x(*x_node->data());
split_op->set_attr_axis(static_cast<int64_t>(axis));
if (num > 0) { if (num > 0) {
split_node->set_attr_output_num(static_cast<int64_t>(num)); split_op->set_attr_output_num(static_cast<int64_t>(num));
} else { } else {
split_node->set_attr_output_num(sections_num); split_op->set_attr_output_num(sections_num);
auto size_split = ge::AttrValue::LIST_INT(sections.begin(), sections.end()); auto size_split = ge::AttrValue::LIST_INT(sections.begin(), sections.end());
split_node->set_attr_size_split(size_split); split_op->set_attr_size_split(size_split);
} }
split_node->create_dynamic_output_y(out_names.size()); split_op->create_dynamic_output_y(out_names.size());
int idx = 1; int idx = 1;
for (auto& out_name : out_names) { for (auto& out_name : out_names) {
auto zero_const_node = auto zero_node = graph->Add(out_name + "/zero" + std::to_string(idx), 0);
graph->AddNode(out_name + "/zero" + std::to_string(idx), 0); auto add_node = graph->Add<ge::op::Add>(out_name);
auto add_node = graph->AddNode<ge::op::Add>(out_name); auto add_op = add_node->data<ge::op::Add>();
add_node->set_input_x1(*split_node, "y" + std::to_string(idx)); add_op->set_input_x1(*split_node->data(), "y" + std::to_string(idx));
add_node->set_input_x2(*zero_const_node); add_op->set_input_x2(*zero_node->data());
idx++; idx++;
} }
return REBUILD_WHEN_SHAPE_CHANGED; return REBUILD_WHEN_SHAPE_CHANGED;
...@@ -84,6 +85,6 @@ int SplitConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -84,6 +85,6 @@ int SplitConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(split,
split, kNPU,
paddle::lite::subgraph::npu::SplitConverter); paddle::lite::subgraph::npu::SplitConverter);
...@@ -43,16 +43,17 @@ int SqrtConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -43,16 +43,17 @@ int SqrtConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(out_type->layout() == DATALAYOUT(kNCHW)); CHECK(out_type->layout() == DATALAYOUT(kNCHW));
// X node // X node
std::shared_ptr<ge::Operator> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
// Sqrt node // Sqrt node
auto sqrt_node = graph->AddNode<ge::op::Sqrt>(out_name); auto sqrt_node = graph->Add<ge::op::Sqrt>(out_name);
sqrt_node->set_input_x(*x_node); auto sqrt_op = sqrt_node->data<ge::op::Sqrt>();
sqrt_op->set_input_x(*x_node->data());
return SUCCESS; return SUCCESS;
} }
...@@ -61,4 +62,6 @@ int SqrtConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -61,4 +62,6 @@ int SqrtConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU, sqrt, paddle::lite::subgraph::npu::SqrtConverter); REGISTER_SUBGRAPH_BRIDGE(sqrt,
kNPU,
paddle::lite::subgraph::npu::SqrtConverter);
...@@ -43,16 +43,17 @@ int SquareConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -43,16 +43,17 @@ int SquareConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(out_type->layout() == DATALAYOUT(kNCHW)); CHECK(out_type->layout() == DATALAYOUT(kNCHW));
// X node // X node
std::shared_ptr<ge::Operator> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
// Square node // Square node
auto square_node = graph->AddNode<ge::op::Square>(out_name); auto square_node = graph->Add<ge::op::Square>(out_name);
square_node->set_input_x(*x_node); auto square_op = square_node->data<ge::op::Square>();
square_op->set_input_x(*x_node->data());
return SUCCESS; return SUCCESS;
} }
...@@ -61,6 +62,6 @@ int SquareConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -61,6 +62,6 @@ int SquareConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(square,
square, kNPU,
paddle::lite::subgraph::npu::SquareConverter); paddle::lite::subgraph::npu::SquareConverter);
...@@ -41,19 +41,20 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -41,19 +41,20 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto axis = op_info->GetAttr<std::vector<int>>("axis"); auto axis = op_info->GetAttr<std::vector<int>>("axis");
// X node // X node
std::shared_ptr<ge::Operator> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
// Transpose node // Transpose node
auto transpose_node = graph->AddNode<ge::op::Permute>(out_name); auto transpose_node = graph->Add<ge::op::Permute>(out_name);
transpose_node->set_input_x(*x_node); auto transpose_op = transpose_node->data<ge::op::Permute>();
auto w_const_node = graph->AddNode(out_name + "/w", 1.0f); transpose_op->set_input_x(*x_node->data());
transpose_node->set_input_w(*w_const_node); auto w_node = graph->Add(out_name + "/w", 1.0f);
transpose_node->set_attr_order( transpose_op->set_input_w(*w_node->data());
transpose_op->set_attr_order(
ge::AttrValue::LIST_INT(axis.begin(), axis.end())); ge::AttrValue::LIST_INT(axis.begin(), axis.end()));
return SUCCESS; return SUCCESS;
} }
...@@ -63,9 +64,9 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -63,9 +64,9 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(transpose,
transpose, kNPU,
paddle::lite::subgraph::npu::TransposeConverter); paddle::lite::subgraph::npu::TransposeConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(transpose2,
transpose2, kNPU,
paddle::lite::subgraph::npu::TransposeConverter); paddle::lite::subgraph::npu::TransposeConverter);
...@@ -45,17 +45,18 @@ int UnsqueezeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -45,17 +45,18 @@ int UnsqueezeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
<< "[NPU] unsqueeze not support axes from tensor now"; << "[NPU] unsqueeze not support axes from tensor now";
// X node // X node
std::shared_ptr<ge::Operator> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
// Unsqueeze node // Unsqueeze node
auto unsqueeze_node = graph->AddNode<ge::op::Reshape>(out_name); auto unsqueeze_node = graph->Add<ge::op::Reshape>(out_name);
unsqueeze_node->set_input_tensor(*x_node); auto unsqueeze_op = unsqueeze_node->data<ge::op::Reshape>();
unsqueeze_node->set_attr_shape( unsqueeze_op->set_input_tensor(*x_node->data());
unsqueeze_op->set_attr_shape(
ge::AttrValue::LIST_INT(out_shape.begin(), out_shape.end())); ge::AttrValue::LIST_INT(out_shape.begin(), out_shape.end()));
return REBUILD_WHEN_SHAPE_CHANGED; return REBUILD_WHEN_SHAPE_CHANGED;
} }
...@@ -65,9 +66,9 @@ int UnsqueezeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -65,9 +66,9 @@ int UnsqueezeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(unsqueeze,
unsqueeze, kNPU,
paddle::lite::subgraph::npu::UnsqueezeConverter); paddle::lite::subgraph::npu::UnsqueezeConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU, REGISTER_SUBGRAPH_BRIDGE(unsqueeze2,
unsqueeze2, kNPU,
paddle::lite::subgraph::npu::UnsqueezeConverter); paddle::lite::subgraph::npu::UnsqueezeConverter);
...@@ -85,6 +85,22 @@ ge::Format CvtDataLayoutType(DataLayoutType itype) { ...@@ -85,6 +85,22 @@ ge::Format CvtDataLayoutType(DataLayoutType itype) {
return otype; return otype;
} }
std::vector<int64_t> CvtShape(const std::vector<int64_t>& in_shape) {
std::vector<int64_t> out_shape;
// Padding the shape to 4-dimensions(NCHW)
for (int i = 0; i < 4 - in_shape.size(); i++) {
out_shape.push_back(1);
}
for (int i = 0; i < in_shape.size(); i++) {
out_shape.push_back(in_shape[i]);
}
return out_shape;
}
std::vector<int64_t> CvtShape(const DDim& in_dims) {
return CvtShape(in_dims.Vectorize());
}
ge::TensorPtr CvtTensor(const Tensor& in_tensor, ge::TensorPtr CvtTensor(const Tensor& in_tensor,
std::vector<int64_t> out_shape, std::vector<int64_t> out_shape,
PrecisionType in_precision, PrecisionType in_precision,
......
...@@ -19,12 +19,12 @@ ...@@ -19,12 +19,12 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "ai_ddk_lib/include/graph/buffer.h" #include "graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h" #include "graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h" #include "graph/model.h"
#include "ai_ddk_lib/include/graph/op/all_ops.h" #include "graph/op/all_ops.h"
#include "ai_ddk_lib/include/graph/operator.h" #include "graph/operator.h"
#include "ai_ddk_lib/include/graph/operator_reg.h" #include "graph/operator_reg.h"
#include "lite/core/op_lite.h" #include "lite/core/op_lite.h"
#include "lite/utils/macros.h" #include "lite/utils/macros.h"
...@@ -70,59 +70,16 @@ ge::DataType CvtPrecisionType(PrecisionType itype); ...@@ -70,59 +70,16 @@ ge::DataType CvtPrecisionType(PrecisionType itype);
ge::Format CvtDataLayoutType(DataLayoutType itype); ge::Format CvtDataLayoutType(DataLayoutType itype);
// Padding the shape to 4-dimensions(NCHW) for HiAI
std::vector<int64_t> CvtShape(const std::vector<int64_t>& in_shape);
std::vector<int64_t> CvtShape(const DDim& in_dims);
ge::TensorPtr CvtTensor(const Tensor& in_tensor, ge::TensorPtr CvtTensor(const Tensor& in_tensor,
std::vector<int64_t> out_shape = {}, std::vector<int64_t> out_shape = {},
PrecisionType in_precision = PRECISION(kFloat), PrecisionType in_precision = PRECISION(kFloat),
DataLayoutType in_layout = DATALAYOUT(kNCHW)); DataLayoutType in_layout = DATALAYOUT(kNCHW));
template <typename T>
ge::TensorPtr CreateTensorAndFillData(const std::vector<T>& data,
std::vector<int64_t> shape = {},
ge::Format format = ge::FORMAT_NCHW) {
const std::type_info& info = typeid(T);
ge::DataType type = ge::DT_FLOAT;
if (info == typeid(float)) {
type = ge::DT_FLOAT;
} else if (info == typeid(int8_t)) {
type = ge::DT_INT8;
} else if (info == typeid(int16_t)) {
type = ge::DT_INT16;
} else if (info == typeid(int32_t)) {
type = ge::DT_INT32;
} else if (info == typeid(int64_t)) {
type = ge::DT_INT64;
} else {
LOG(FATAL) << "[NPU] Unknow value type " << info.name();
}
if (shape.empty()) {
shape = {static_cast<int64_t>(data.size())};
} else {
int size = 1;
for (auto i : shape) {
size *= i;
}
CHECK_EQ(data.size(), size);
}
ge::TensorDesc desc(ge::Shape(shape), format, type);
ge::TensorPtr tensor = std::make_shared<ge::Tensor>();
tensor->SetTensorDesc(desc);
tensor->SetData(reinterpret_cast<uint8_t*>(data.data()),
data.size() * sizeof(T));
return tensor;
}
template <typename T>
ge::TensorPtr CreateTensorAndFillData(T value,
std::vector<int64_t> shape = {1},
ge::Format format = ge::FORMAT_NCHW) {
int64_t size = 1;
for (auto i : shape) {
size *= i;
}
std::vector<T> data(size, value);
return CreateTensorAndFillData(data, shape, format);
}
int CvtActMode(std::string act_type); int CvtActMode(std::string act_type);
} // namespace npu } // namespace npu
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <sys/time.h> #include <sys/time.h>
#include <time.h> #include <time.h>
#include <utility> #include <utility>
#include "ai_ddk_lib/include/hiai_ir_build.h" #include "hiai_ir_build.h" // NOLINT
#include "lite/backends/npu/device.h" #include "lite/backends/npu/device.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/npu/bridges/graph.h" #include "lite/kernels/npu/bridges/graph.h"
...@@ -39,13 +39,13 @@ int SubgraphEngine::BuildDeviceProgram() { ...@@ -39,13 +39,13 @@ int SubgraphEngine::BuildDeviceProgram() {
op->CheckShape(); op->CheckShape();
op->InferShape(); op->InferShape();
std::string op_type = op->op_info()->Type(); std::string op_type = op->op_info()->Type();
if (!bridges.Exists("NPU", op_type)) { if (!bridges.Exists(op_type, "kNPU")) {
return subgraph::FAILED; return subgraph::FAILED;
} }
auto kernel = inst.kernel(); auto kernel = inst.kernel();
status |= bridges.Select("NPU", op_type)(reinterpret_cast<void*>(&graph), status |= bridges.Select(op_type, "kNPU")(reinterpret_cast<void*>(&graph),
const_cast<OpLite*>(op), const_cast<OpLite*>(op),
const_cast<KernelBase*>(kernel)); const_cast<KernelBase*>(kernel));
if (subgraph::CHECK_FAILED(status)) { if (subgraph::CHECK_FAILED(status)) {
return subgraph::FAILED; return subgraph::FAILED;
} }
...@@ -57,26 +57,26 @@ int SubgraphEngine::BuildDeviceProgram() { ...@@ -57,26 +57,26 @@ int SubgraphEngine::BuildDeviceProgram() {
std::vector<ge::Operator> device_inodes; std::vector<ge::Operator> device_inodes;
std::vector<ge::Operator> device_onodes; std::vector<ge::Operator> device_onodes;
for (auto& input_name : input_names_) { for (auto& input_name : input_names_) {
if (graph.HasNode(input_name)) { if (graph.Has(input_name)) {
if (!graph.GetType(input_name).persistable()) { if (graph.Get(input_name)->is_data()) {
device_inodes.push_back(*graph.GetNode(input_name)); device_inodes.push_back(*graph.Get(input_name)->data());
device_inames_.push_back(input_name); device_inames_.push_back(input_name);
} else { } else {
LOG(WARNING) << "[NPU] Input node " << input_name LOG(WARNING) << "[NPU] Input node " << input_name
<< " is skipped because it is a persistable node."; << " is ignored because it is not a data node.";
} }
} else { } else {
LOG(WARNING) << "[NPU] Input node " << input_name LOG(WARNING) << "[NPU] Input node " << input_name
<< " is skipped because it does not exist."; << " is ignored because it does not exist.";
} }
} }
for (auto& output_name : output_names_) { for (auto& output_name : output_names_) {
if (graph.HasNode(output_name)) { if (graph.Has(output_name)) {
device_onodes.push_back(*graph.GetNode(output_name)); device_onodes.push_back(*graph.Get(output_name)->data());
device_onames_.push_back(output_name); device_onames_.push_back(output_name);
} else { } else {
LOG(WARNING) << "[NPU] Output node " << output_name LOG(WARNING) << "[NPU] Output node " << output_name
<< " is skipped because it does not exist."; << " is ignored because it does not exist.";
} }
} }
CHECK(!device_inames_.empty()) CHECK(!device_inames_.empty())
...@@ -108,14 +108,14 @@ int SubgraphEngine::BuildDeviceProgram() { ...@@ -108,14 +108,14 @@ int SubgraphEngine::BuildDeviceProgram() {
origin_otensors_.resize(device_onames_.size()); origin_otensors_.resize(device_onames_.size());
device_otensors_.resize(device_onames_.size()); device_otensors_.resize(device_onames_.size());
for (int i = 0; i < device_inames_.size(); i++) { for (int i = 0; i < device_inames_.size(); i++) {
auto type = graph.GetType(device_inames_[i]); auto node = graph.Get(device_inames_[i]);
auto precision = type.precision(); auto precision = node->precision();
auto layout = type.layout(); auto layout = node->layout();
origin_itensors_[i] = scope_->FindMutableTensor(device_inames_[i]); origin_itensors_[i] = scope_->FindMutableTensor(device_inames_[i]);
CHECK(origin_itensors_[i]); CHECK(origin_itensors_[i]);
origin_idims_[i] = origin_itensors_[i]->dims(); origin_idims_[i] = origin_itensors_[i]->dims();
VLOG(3) << "[NPU] Inputs[" << i VLOG(3) << "[NPU] Inputs[" << i << "] name: " << device_inames_[i]
<< "] precision: " << PrecisionToStr(precision) << " precision: " << PrecisionToStr(precision)
<< " layout: " << DataLayoutToStr(layout) << " dims: {" << " layout: " << DataLayoutToStr(layout) << " dims: {"
<< device_idims[i].GetNumber() << "," << device_idims[i].GetNumber() << ","
<< device_idims[i].GetChannel() << "," << device_idims[i].GetChannel() << ","
...@@ -129,14 +129,14 @@ int SubgraphEngine::BuildDeviceProgram() { ...@@ -129,14 +129,14 @@ int SubgraphEngine::BuildDeviceProgram() {
device_itensors_[i]->Init(&(device_idims[i])); device_itensors_[i]->Init(&(device_idims[i]));
} }
for (int i = 0; i < device_onames_.size(); i++) { for (int i = 0; i < device_onames_.size(); i++) {
auto type = graph.GetType(device_onames_[i]); auto node = graph.Get(device_onames_[i]);
auto precision = type.precision(); auto precision = node->precision();
auto layout = type.layout(); auto layout = node->layout();
origin_otensors_[i] = scope_->FindMutableTensor(device_onames_[i]); origin_otensors_[i] = scope_->FindMutableTensor(device_onames_[i]);
CHECK(origin_otensors_[i]); CHECK(origin_otensors_[i]);
origin_odims_[i] = origin_otensors_[i]->dims(); origin_odims_[i] = origin_otensors_[i]->dims();
VLOG(3) << "[NPU] Outputs[" << i VLOG(3) << "[NPU] Outputs[" << i << "] name: " << device_onames_[i]
<< "] precision: " << PrecisionToStr(precision) << " precision: " << PrecisionToStr(precision)
<< " layout: " << DataLayoutToStr(layout) << " dims: {" << " layout: " << DataLayoutToStr(layout) << " dims: {"
<< device_odims[i].GetNumber() << "," << device_odims[i].GetNumber() << ","
<< device_odims[i].GetChannel() << "," << device_odims[i].GetChannel() << ","
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "ai_ddk_lib/include/HiAiModelManagerService.h" #include "HiAiModelManagerService.h"
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
#include "lite/kernels/npu/bridges/engine.h" #include "lite/kernels/npu/bridges/engine.h"
#include "lite/kernels/npu/bridges/registry.h" #include "lite/kernels/npu/bridges/registry.h"
......
...@@ -43,20 +43,21 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -43,20 +43,21 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(out_type->layout() == DATALAYOUT(kNCHW)); CHECK(out_type->layout() == DATALAYOUT(kNCHW));
// X node // X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
// Act node // Act node
if (op_type == "relu") { if (op_type == "relu") {
graph->AddNode(out_name, graph->builder_.CreateRelu(*x_node)); graph->Add(out_name, graph->builder_.CreateRelu(*x_node->data()));
} else if (op_type == "tanh") { } else if (op_type == "tanh") {
graph->AddNode(out_name, graph->builder_.CreateUnaryOp("tanh", *x_node)); graph->Add(out_name,
graph->builder_.CreateUnaryOp("tanh", *x_node->data()));
} else if (op_type == "gelu") { } else if (op_type == "gelu") {
graph->AddNode(out_name, graph->builder_.CreateGelu(*x_node)); graph->Add(out_name, graph->builder_.CreateGelu(*x_node->data()));
} else { } else {
// TODO(hong19860320) supports more activation ops // TODO(hong19860320) supports more activation ops
LOG(WARNING) << "[XPU] Unsupported activation type " << op_type; LOG(WARNING) << "[XPU] Unsupported activation type " << op_type;
...@@ -70,6 +71,6 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -70,6 +71,6 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU, relu, paddle::lite::subgraph::xpu::ActConverter); REGISTER_SUBGRAPH_BRIDGE(relu, kXPU, paddle::lite::subgraph::xpu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(XPU, tanh, paddle::lite::subgraph::xpu::ActConverter); REGISTER_SUBGRAPH_BRIDGE(tanh, kXPU, paddle::lite::subgraph::xpu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(XPU, gelu, paddle::lite::subgraph::xpu::ActConverter); REGISTER_SUBGRAPH_BRIDGE(gelu, kXPU, paddle::lite::subgraph::xpu::ActConverter);
...@@ -64,28 +64,28 @@ int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -64,28 +64,28 @@ int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto epsilon = op_info->GetAttr<float>("epsilon"); auto epsilon = op_info->GetAttr<float>("epsilon");
// X node // X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
// Scale, Bias, Mean, Variance node // Scale, Bias, Mean, Variance node
auto scale_const_node = graph->AddNode(scale_name, *scale); auto scale_node = graph->Add(scale_name, *scale);
auto bias_const_node = graph->AddNode(bias_name, *bias); auto bias_node = graph->Add(bias_name, *bias);
auto mean_const_node = graph->AddNode(mean_name, *mean); auto mean_node = graph->Add(mean_name, *mean);
auto variance_const_node = graph->AddNode(variance_name, *variance); auto variance_node = graph->Add(variance_name, *variance);
// Batch Norm node and extract the first field as the output node // Batch Norm node and extract the first field as the output node
auto batch_norm_node = graph->builder_.CreateBatchNorm(*x_node, auto batch_norm_data = graph->builder_.CreateBatchNorm(*x_node->data(),
*scale_const_node, *scale_node->data(),
*bias_const_node, *bias_node->data(),
*mean_const_node, *mean_node->data(),
*variance_const_node, *variance_node->data(),
1, 1,
epsilon); epsilon);
graph->AddNode(y_name, graph->builder_.GetField(batch_norm_node, 0)); graph->Add(y_name, graph->builder_.GetField(batch_norm_data, 0));
return SUCCESS; return SUCCESS;
} }
...@@ -94,6 +94,6 @@ int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -94,6 +94,6 @@ int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU, REGISTER_SUBGRAPH_BRIDGE(batch_norm,
batch_norm, kXPU,
paddle::lite::subgraph::xpu::BatchNormConverter); paddle::lite::subgraph::xpu::BatchNormConverter);
...@@ -61,11 +61,11 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -61,11 +61,11 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK_EQ(dilations.size(), 2L); CHECK_EQ(dilations.size(), 2L);
// Input node // Input node
std::shared_ptr<xtcl::xExpr> input_node = nullptr; std::shared_ptr<Node> input_node = nullptr;
if (graph->HasNode(input_name)) { if (graph->Has(input_name)) {
input_node = graph->GetNode(input_name); input_node = graph->Get(input_name);
} else { } else {
input_node = graph->AddNode(input_name, input_dims); input_node = graph->Add(input_name, *input);
} }
if (paddings.size() == 2L) { if (paddings.size() == 2L) {
...@@ -99,7 +99,7 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -99,7 +99,7 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
DDim output_dims(output_shape); DDim output_dims(output_shape);
// Filter node // Filter node
auto filter_const_node = graph->AddNode(filter_name, *filter); auto filter_node = graph->Add(filter_name, *filter);
// Conv node // Conv node
auto conv_attrs = xtcl::make_node<xtcl::network::Conv2DAttrs>(); auto conv_attrs = xtcl::make_node<xtcl::network::Conv2DAttrs>();
...@@ -114,9 +114,9 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -114,9 +114,9 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
conv_attrs->out_layout = ""; conv_attrs->out_layout = "";
// conv_attrs->out_dtype = ""; // conv_attrs->out_dtype = "";
auto conv_node = auto conv_node =
graph->AddNode(output_name, graph->Add(output_name,
graph->builder_.CreateConv2D( graph->builder_.CreateConv2D(
*input_node, *filter_const_node, conv_attrs)); *input_node->data(), *filter_node->data(), conv_attrs));
// Add bias node if exists bias // Add bias node if exists bias
// supports the bias nodes with the following dimensions // supports the bias nodes with the following dimensions
...@@ -149,30 +149,27 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -149,30 +149,27 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
<< " isn't supported in conv2d Op when output dimension is " << " isn't supported in conv2d Op when output dimension is "
<< output_dims; << output_dims;
} }
std::shared_ptr<xtcl::xExpr> bias_node = nullptr; std::shared_ptr<Node> bias_node = nullptr;
if (graph->HasNode(bias_name)) { if (graph->Has(bias_name)) {
// Bias node from input node bias_node = graph->Get(bias_name);
bias_node = graph->GetNode(bias_name);
} else { } else {
// Bias node with const data bias_node = graph->Add(bias_name, *bias, bias_shape);
bias_node = graph->AddNode(bias_name, *bias, bias_shape);
} }
std::shared_ptr<xtcl::xExpr> add_node = nullptr;
if (is_channel_bias) { if (is_channel_bias) {
add_node = graph->AddNode( conv_node = graph->Add(output_name,
output_name, graph->builder_.CreateBiasAdd(
graph->builder_.CreateBiasAdd(*conv_node, 1, *bias_node)); *conv_node->data(), 1, *bias_node->data()));
} else { } else {
add_node = graph->AddNode( conv_node =
output_name, graph->Add(output_name,
graph->builder_.CreateBinaryOp("add", *conv_node, *bias_node)); graph->builder_.CreateBinaryOp(
"add", *conv_node->data(), *bias_node->data()));
} }
conv_node = add_node;
} }
if (fuse_relu) { if (fuse_relu) {
// Append relu node if fuse_relu is true // Append relu node if fuse_relu is true
graph->AddNode(output_name, graph->builder_.CreateRelu(*conv_node)); graph->Add(output_name, graph->builder_.CreateRelu(*conv_node->data()));
} }
return REBUILD_WHEN_SHAPE_CHANGED; return REBUILD_WHEN_SHAPE_CHANGED;
} }
...@@ -182,9 +179,9 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -182,9 +179,9 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU, REGISTER_SUBGRAPH_BRIDGE(conv2d,
conv2d, kXPU,
paddle::lite::subgraph::xpu::ConvConverter); paddle::lite::subgraph::xpu::ConvConverter);
REGISTER_SUBGRAPH_BRIDGE(XPU, REGISTER_SUBGRAPH_BRIDGE(depthwise_conv2d,
depthwise_conv2d, kXPU,
paddle::lite::subgraph::xpu::ConvConverter); paddle::lite::subgraph::xpu::ConvConverter);
...@@ -46,21 +46,21 @@ int DropoutConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -46,21 +46,21 @@ int DropoutConverter(void* ctx, OpLite* op, KernelBase* kernel) {
op_info->GetAttr<std::string>("dropout_implementation"); op_info->GetAttr<std::string>("dropout_implementation");
// X node // X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
// Dropout node // Dropout node
if (dropout_implementation == "downgrade_in_infer") { if (dropout_implementation == "downgrade_in_infer") {
graph->AddNode( graph->Add(out_name,
out_name, graph->builder_.CreateScale(
graph->builder_.CreateScale(*x_node, 1.f - dropout_prob, 0.0f, false)); *x_node->data(), 1.f - dropout_prob, 0.0f, false));
} else if (dropout_implementation == "upscale_in_train") { } else if (dropout_implementation == "upscale_in_train") {
graph->AddNode(out_name, graph->Add(out_name,
graph->builder_.CreateScale(*x_node, 1.0f, 0.0f, false)); graph->builder_.CreateScale(*x_node->data(), 1.0f, 0.0f, false));
} else { } else {
LOG(WARNING) << "[XPU] Unsupported dropout_implementation == " LOG(WARNING) << "[XPU] Unsupported dropout_implementation == "
<< dropout_implementation << " for dropout"; << dropout_implementation << " for dropout";
...@@ -74,6 +74,6 @@ int DropoutConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -74,6 +74,6 @@ int DropoutConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU, REGISTER_SUBGRAPH_BRIDGE(dropout,
dropout, kXPU,
paddle::lite::subgraph::xpu::DropoutConverter); paddle::lite::subgraph::xpu::DropoutConverter);
...@@ -50,29 +50,31 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -50,29 +50,31 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto axis = op_info->GetAttr<int>("axis"); auto axis = op_info->GetAttr<int>("axis");
// X node // X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
// Y node // Y node
std::shared_ptr<xtcl::xExpr> y_node = nullptr; std::shared_ptr<Node> y_node = nullptr;
if (graph->HasNode(y_name)) { if (graph->Has(y_name)) {
y_node = graph->GetNode(y_name); y_node = graph->Get(y_name);
} else { } else {
y_node = graph->AddNode(y_name, y_dims); y_node = graph->Add(y_name, *y);
} }
// Elementwise node // Elementwise node
std::shared_ptr<xtcl::xExpr> elementwise_node = nullptr; std::shared_ptr<Node> elt_node = nullptr;
if (y_dims.size() == 1) { if (y_dims.size() == 1) {
elementwise_node = graph->AddNode( elt_node = graph->Add(
out_name, graph->builder_.CreateBiasAdd(*x_node, axis, *y_node)); out_name,
graph->builder_.CreateBiasAdd(*x_node->data(), axis, *y_node->data()));
} else if (x_dims.size() == y_dims.size()) { } else if (x_dims.size() == y_dims.size()) {
elementwise_node = graph->AddNode( elt_node = graph->Add(out_name,
out_name, graph->builder_.CreateBinaryOp("add", *x_node, *y_node)); graph->builder_.CreateBinaryOp(
"add", *x_node->data(), *y_node->data()));
} else { } else {
LOG(WARNING) LOG(WARNING)
<< "[XPU] elementwise_add only support y of one dimension, or x " << "[XPU] elementwise_add only support y of one dimension, or x "
...@@ -88,6 +90,6 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -88,6 +90,6 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU, REGISTER_SUBGRAPH_BRIDGE(elementwise_add,
elementwise_add, kXPU,
paddle::lite::subgraph::xpu::ElementwiseConverter); paddle::lite::subgraph::xpu::ElementwiseConverter);
...@@ -54,38 +54,39 @@ int GatherConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -54,38 +54,39 @@ int GatherConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto out_dims = out->dims(); auto out_dims = out->dims();
// X node // X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
// Index node // Index node
std::shared_ptr<xtcl::xExpr> index_node = nullptr; std::shared_ptr<Node> index_node = nullptr;
if (graph->HasNode(index_name)) { if (graph->Has(index_name)) {
index_node = graph->GetNode(index_name); index_node = graph->Get(index_name);
} else { } else {
index_node = graph->AddNode( index_node = graph->Add(
index_name, index_dims, index_type->precision(), index_type->layout()); index_name, *index, index_type->precision(), index_type->layout());
} }
// Flatten index node // Flatten index node
if (index_dims.size() != 1) { if (index_dims.size() != 1) {
index_node = index_node =
graph->AddNode(index_name + "/reshape", graph->Add(index_name + "/reshape",
graph->builder_.CreateReshape(*index_node, {-1}), graph->builder_.CreateReshape(*index_node->data(), {-1}),
index_type->precision(), index_type->precision(),
index_type->layout()); index_type->layout());
} }
// Reshape the gather node with the inferred shape as the output node // Reshape the gather node with the inferred shape as the output node
auto gather_node = graph->AddNode( auto gather_node =
out_name, graph->Add(out_name,
graph->builder_.CreateGather(*x_node, *index_node, /* axis= */ 0)); graph->builder_.CreateGather(
*x_node->data(), *index_node->data(), /* axis= */ 0));
if (out_dims.size() != 2) { if (out_dims.size() != 2) {
graph->AddNode(out_name, graph->Add(out_name,
graph->builder_.CreateReshape( graph->builder_.CreateReshape(
*gather_node, CvtShape<xtcl::Integer>(out_dims))); *gather_node->data(), CvtShape<xtcl::Integer>(out_dims)));
} }
return SUCCESS; return SUCCESS;
} }
...@@ -95,6 +96,6 @@ int GatherConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -95,6 +96,6 @@ int GatherConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU, REGISTER_SUBGRAPH_BRIDGE(gather,
gather, kXPU,
paddle::lite::subgraph::xpu::GatherConverter); paddle::lite::subgraph::xpu::GatherConverter);
...@@ -21,71 +21,71 @@ namespace lite { ...@@ -21,71 +21,71 @@ namespace lite {
namespace subgraph { namespace subgraph {
namespace xpu { namespace xpu {
std::shared_ptr<xtcl::xExpr> Graph::AddNode(const std::string& name, int Graph::Add(const std::string& name, std::shared_ptr<Node> node) {
const xtcl::xExpr& layer,
PrecisionType precision,
DataLayoutType layout) {
auto unique_name = [&](const std::string& key) {
int idx = 1;
auto it = counts_.find(key);
if (it == counts_.end()) {
counts_.insert(std::make_pair(key, idx));
} else {
idx = ++(it->second);
}
return key + "_" + std::to_string(idx);
};
auto it = nodes_.find(name); auto it = nodes_.find(name);
if (it != nodes_.end()) { if (it != nodes_.end()) {
// Only variable can rebind the name // Only variable node can be shared with the same name
CHECK(!it->second.second.persistable()) << "[XPU] Node " << name if (!node->is_var() || !it->second.back()->is_var()) {
<< " redefined."; LOG(FATAL) << "[XPU] Const or data node " << name << " is redefined.";
// Generate a new unique name as the key to bind the origin node if the return -1;
// origin node isn't a const node: new_name->node }
nodes_.insert(std::make_pair(unique_name(name + "_var"), it->second)); } else {
nodes_.erase(it); auto ret = nodes_.insert(
std::make_pair(name, std::vector<std::shared_ptr<Node>>()));
CHECK(ret.second);
it = ret.first;
} }
// Create a new node and bind with the name: name->new_node it->second.push_back(node);
auto node = std::make_shared<xtcl::xExpr>(layer); return it->second.size();
nodes_.insert(std::make_pair(
name, std::make_pair(node, Type(precision, layout, false))));
builder_.SetLayer(unique_name(name + "_op"));
return node;
} }
// Const node // Variable node
std::shared_ptr<xtcl::xExpr> Graph::AddNode(const std::string& name, std::shared_ptr<Node> Graph::Add(const std::string& name,
const Tensor& tensor, const xtcl::xExpr& layer,
PrecisionType precision, PrecisionType precision,
DataLayoutType layout) { DataLayoutType layout) {
return AddNode(name, tensor, tensor.dims().Vectorize(), precision, layout); auto node = std::make_shared<Node>(precision, layout, Node::Role::kVar);
auto idx = Add(name, node);
CHECK_GE(idx, 1);
node->set_data(std::make_shared<xtcl::xExpr>(layer));
// Generate a unique name for the current XTCL layer
builder_.SetLayer(name + "__" + std::to_string(idx));
return node;
} }
std::shared_ptr<xtcl::xExpr> Graph::AddNode(const std::string& name, // Const or data node
const Tensor& tensor, std::shared_ptr<Node> Graph::Add(const std::string& name,
std::vector<int64_t> shape, const Tensor& tensor,
PrecisionType precision, std::vector<int64_t> shape,
DataLayoutType layout) { PrecisionType precision,
CHECK(!HasNode(name)) << "[NPU] Node " << name << " redefined."; DataLayoutType layout) {
auto node = std::make_shared<xtcl::xExpr>(builder_.CreateTensor( std::shared_ptr<Node> node = nullptr;
name, CvtShape<xtcl::xIndexExpr>(shape), CvtPrecisionType(precision))); if (tensor.persistable()) {
nodes_.insert(std::make_pair( // Const node
name, std::make_pair(node, Type(precision, layout, true)))); node = std::make_shared<Node>(precision, layout, Node::Role::kConst);
params_.emplace( auto idx = Add(name, node);
std::make_pair(name, *CvtTensor(tensor, shape, precision, layout))); CHECK_EQ(idx, 1);
node->set_data(std::make_shared<xtcl::xExpr>(builder_.CreateTensor(
name, CvtShape<xtcl::xIndexExpr>(shape), CvtPrecisionType(precision))));
params_.emplace(
std::make_pair(name, *CvtTensor(tensor, shape, precision, layout)));
} else {
// Data node
node = Add(name, shape, precision, layout);
}
return node; return node;
} }
// Data node // Data node
std::shared_ptr<xtcl::xExpr> Graph::AddNode(const std::string& name, std::shared_ptr<Node> Graph::Add(const std::string& name,
std::vector<int64_t> shape, std::vector<int64_t> shape,
PrecisionType precision, PrecisionType precision,
DataLayoutType layout) { DataLayoutType layout) {
CHECK(!HasNode(name)) << "[NPU] Node " << name << " redefined."; auto node = std::make_shared<Node>(precision, layout, Node::Role::kData);
auto node = std::make_shared<xtcl::xExpr>(builder_.CreateTensor( auto idx = Add(name, node);
name, CvtShape<xtcl::xIndexExpr>(shape), CvtPrecisionType(precision))); CHECK_EQ(idx, 1);
nodes_.insert(std::make_pair( node->set_data(std::make_shared<xtcl::xExpr>(builder_.CreateTensor(
name, std::make_pair(node, Type(precision, layout, false)))); name, CvtShape<xtcl::xIndexExpr>(shape), CvtPrecisionType(precision))));
return node; return node;
} }
......
...@@ -28,67 +28,81 @@ namespace lite { ...@@ -28,67 +28,81 @@ namespace lite {
namespace subgraph { namespace subgraph {
namespace xpu { namespace xpu {
// Type of graph nodes // Graph and node is defined to collect all of converted XTCL IR nodes
class Type { class Node {
public: public:
Type(PrecisionType precision = PRECISION(kFloat), enum class Role {
DataLayoutType layout = DATALAYOUT(kNCHW), kVar = 0,
bool persistable = false) kConst,
: precision_(precision), layout_(layout), persistable_(persistable) {} kData,
};
Node(std::shared_ptr<xtcl::xExpr> data,
PrecisionType precision,
DataLayoutType layout,
Role role)
: data_(data), precision_(precision), layout_(layout), role_(role) {}
Node(PrecisionType precision, DataLayoutType layout, Role role)
: precision_(precision), layout_(layout), role_(role) {}
void set_data(std::shared_ptr<xtcl::xExpr> data) { data_ = data; }
void set_precision(PrecisionType precision) { precision_ = precision; } void set_precision(PrecisionType precision) { precision_ = precision; }
void set_layout(DataLayoutType layout) { layout_ = layout; } void set_layout(DataLayoutType layout) { layout_ = layout; }
void set_persistable(bool persistable) { persistable_ = persistable; } void set_role(Role role) { role_ = role; }
std::shared_ptr<xtcl::xExpr> data() { return data_; }
PrecisionType precision() const { return precision_; } PrecisionType precision() const { return precision_; }
DataLayoutType layout() const { return layout_; } DataLayoutType layout() const { return layout_; }
bool persistable() const { return persistable_; } Role role() const { return role_; }
bool is_var() const { return role_ == Role::kVar; }
bool is_const() const { return role_ == Role::kConst; }
bool is_data() const { return role_ == Role::kData; }
private: private:
std::shared_ptr<xtcl::xExpr> data_{nullptr};
PrecisionType precision_{PRECISION(kFloat)}; PrecisionType precision_{PRECISION(kFloat)};
DataLayoutType layout_{DATALAYOUT(kNCHW)}; DataLayoutType layout_{DATALAYOUT(kNCHW)};
bool persistable_{false}; Role role_{Role::kVar};
}; };
// Graph to collect all of converted XPU IR nodes
class Graph { class Graph {
public: public:
// Layer node int Add(const std::string& name, std::shared_ptr<Node> node);
std::shared_ptr<xtcl::xExpr> AddNode(
const std::string& name, // Variable node
const xtcl::xExpr& layer, std::shared_ptr<Node> Add(const std::string& name,
PrecisionType precision = PRECISION(kFloat), const xtcl::xExpr& layer,
DataLayoutType layout = DATALAYOUT(kNCHW)); PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW));
// Const or data node
std::shared_ptr<Node> Add(const std::string& name,
const Tensor& tensor,
std::vector<int64_t> shape,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW));
std::shared_ptr<Node> Add(const std::string& name,
const Tensor& tensor,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return Add(name, tensor, tensor.dims().Vectorize(), precision, layout);
}
// Const node std::shared_ptr<Node> Add(const std::string& name,
std::shared_ptr<xtcl::xExpr> AddNode( const Tensor& tensor,
const std::string& name, DDim dims,
const Tensor& tensor, PrecisionType precision = PRECISION(kFloat),
PrecisionType precision = PRECISION(kFloat), DataLayoutType layout = DATALAYOUT(kNCHW)) {
DataLayoutType layout = DATALAYOUT(kNCHW)); return Add(name, tensor, dims.Vectorize(), precision, layout);
std::shared_ptr<xtcl::xExpr> AddNode(
const std::string& name,
const Tensor& tensor,
std::vector<int64_t> shape,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW));
std::shared_ptr<xtcl::xExpr> AddNode(
const std::string& name,
const Tensor& tensor,
DDim dims,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return AddNode(name, tensor, dims.Vectorize(), precision, layout);
} }
// Const node
template <typename T> template <typename T>
std::shared_ptr<xtcl::xExpr> AddNode( std::shared_ptr<Node> Add(const std::string& name,
const std::string& name, const std::vector<T>& data,
const std::vector<T>& data, std::vector<int64_t> shape = {},
std::vector<int64_t> shape = {}, DataLayoutType layout = DATALAYOUT(kNCHW)) {
DataLayoutType layout = DATALAYOUT(kNCHW)) {
const std::type_info& info = typeid(T); const std::type_info& info = typeid(T);
PrecisionType precision = PRECISION(kFloat); PrecisionType precision = PRECISION(kFloat);
if (info == typeid(float)) { if (info == typeid(float)) {
...@@ -111,70 +125,61 @@ class Graph { ...@@ -111,70 +125,61 @@ class Graph {
} }
Tensor tensor; Tensor tensor;
tensor.Resize(shape); tensor.Resize(shape);
tensor.set_persistable(true);
std::memcpy(reinterpret_cast<uint8_t*>(tensor.mutable_data<T>()), std::memcpy(reinterpret_cast<uint8_t*>(tensor.mutable_data<T>()),
reinterpret_cast<const uint8_t*>(data.data()), reinterpret_cast<const uint8_t*>(data.data()),
data.size() * sizeof(T)); data.size() * sizeof(T));
return AddNode(name, tensor, precision, layout); return Add(name, tensor, precision, layout);
} }
template <typename T> template <typename T>
std::shared_ptr<xtcl::xExpr> AddNode( std::shared_ptr<Node> Add(const std::string& name,
const std::string& name, const std::vector<T>& data,
const std::vector<T>& data, DDim dims,
DDim dims, DataLayoutType layout = DATALAYOUT(kNCHW)) {
DataLayoutType layout = DATALAYOUT(kNCHW)) { return Add(name, data, dims.Vectorize(), layout);
return AddNode(name, data, dims.Vectorize(), layout);
} }
template <typename T> template <typename T>
std::shared_ptr<xtcl::xExpr> AddNode( std::shared_ptr<Node> Add(const std::string& name,
const std::string& name, T value,
T value, std::vector<int64_t> shape = {1},
std::vector<int64_t> shape = {1}, DataLayoutType layout = DATALAYOUT(kNCHW)) {
DataLayoutType layout = DATALAYOUT(kNCHW)) {
int64_t size = 1; int64_t size = 1;
for (auto i : shape) { for (auto i : shape) {
size *= i; size *= i;
} }
std::vector<T> data(size, value); std::vector<T> data(size, value);
return AddNode(name, data, shape, layout); return Add(name, data, shape, layout);
} }
template <typename T> template <typename T>
std::shared_ptr<xtcl::xExpr> AddNode( std::shared_ptr<Node> Add(const std::string& name,
const std::string& name, T value,
T value, DDim dims,
DDim dims, DataLayoutType layout = DATALAYOUT(kNCHW)) {
DataLayoutType layout = DATALAYOUT(kNCHW)) { return Add(name, value, dims.Vectorize(), layout);
return AddNode(name, value, dims.Vectorize(), layout);
} }
// Data node // Data node
std::shared_ptr<xtcl::xExpr> AddNode( std::shared_ptr<Node> Add(const std::string& name,
const std::string& name, std::vector<int64_t> shape,
std::vector<int64_t> shape, PrecisionType precision = PRECISION(kFloat),
PrecisionType precision = PRECISION(kFloat), DataLayoutType layout = DATALAYOUT(kNCHW));
DataLayoutType layout = DATALAYOUT(kNCHW));
std::shared_ptr<Node> Add(const std::string& name,
std::shared_ptr<xtcl::xExpr> AddNode( DDim dims,
const std::string& name, PrecisionType precision = PRECISION(kFloat),
DDim dims, DataLayoutType layout = DATALAYOUT(kNCHW)) {
PrecisionType precision = PRECISION(kFloat), return Add(name, dims.Vectorize(), precision, layout);
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return AddNode(name, dims.Vectorize(), precision, layout);
}
std::shared_ptr<xtcl::xExpr> GetNode(const std::string& name) {
CHECK(HasNode(name)) << "[XPU] Node " << name << " not found.";
return nodes_.at(name).first;
} }
const Type& GetType(const std::string& name) { std::shared_ptr<Node> Get(const std::string& name) {
CHECK(HasNode(name)) << "[XPU] Node " << name << " not found."; CHECK(Has(name)) << "[XPU] Node " << name << " not found.";
return nodes_.at(name).second; return nodes_.at(name).back();
} }
bool HasNode(const std::string& name) { bool Has(const std::string& name) {
return nodes_.find(name) != nodes_.end(); return nodes_.find(name) != nodes_.end();
} }
...@@ -184,9 +189,7 @@ class Graph { ...@@ -184,9 +189,7 @@ class Graph {
xtcl::network::xTensorCompiler::ParamNDArrayMap params_; xtcl::network::xTensorCompiler::ParamNDArrayMap params_;
private: private:
std::unordered_map<std::string, std::pair<std::shared_ptr<xtcl::xExpr>, Type>> std::unordered_map<std::string, std::vector<std::shared_ptr<Node>>> nodes_;
nodes_;
std::unordered_map<std::string, int> counts_;
}; };
} // namespace xpu } // namespace xpu
......
...@@ -51,23 +51,23 @@ int LayerNormConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -51,23 +51,23 @@ int LayerNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto x_inner_size = x_dims.Slice(axis, x_rank).production(); auto x_inner_size = x_dims.Slice(axis, x_rank).production();
// X node // X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
if (reshape) { if (reshape) {
auto reshaped_x_dims = x_dims.Slice(0, axis).Vectorize(); auto reshaped_x_dims = x_dims.Slice(0, axis).Vectorize();
reshaped_x_dims.push_back(x_inner_size); reshaped_x_dims.push_back(x_inner_size);
x_node = x_node = graph->Add(
graph->AddNode(x_name + "/reshape", x_name + "/reshape",
graph->builder_.CreateReshape( graph->builder_.CreateReshape(
*x_node, CvtShape<xtcl::Integer>(reshaped_x_dims))); *x_node->data(), CvtShape<xtcl::Integer>(reshaped_x_dims)));
} }
// Scale node // Scale node
std::shared_ptr<xtcl::xExpr> scale_const_node = nullptr; std::shared_ptr<Node> scale_node = nullptr;
if (HasInputArg(op_info, scope, "Scale")) { if (HasInputArg(op_info, scope, "Scale")) {
auto scale_name = op_info->Input("Scale").front(); auto scale_name = op_info->Input("Scale").front();
auto scale_type = kernel->GetInputDeclType("Scale"); auto scale_type = kernel->GetInputDeclType("Scale");
...@@ -77,14 +77,13 @@ int LayerNormConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -77,14 +77,13 @@ int LayerNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto scale_dims = scale->dims(); auto scale_dims = scale->dims();
CHECK_EQ(scale_dims.size(), 1); CHECK_EQ(scale_dims.size(), 1);
CHECK_EQ(scale_dims.production(), x_inner_size); CHECK_EQ(scale_dims.production(), x_inner_size);
scale_const_node = graph->AddNode(scale_name, *scale); scale_node = graph->Add(scale_name, *scale);
} else { } else {
scale_const_node = scale_node = graph->Add(y_name + "/scale_one", 1.0f, {x_inner_size});
graph->AddNode(y_name + "/scale_one", 1.0f, {x_inner_size});
} }
// Bias node // Bias node
std::shared_ptr<xtcl::xExpr> bias_const_node = nullptr; std::shared_ptr<Node> bias_node = nullptr;
if (HasInputArg(op_info, scope, "Bias")) { if (HasInputArg(op_info, scope, "Bias")) {
auto bias_name = op_info->Input("Bias").front(); auto bias_name = op_info->Input("Bias").front();
auto bias_type = kernel->GetInputDeclType("Bias"); auto bias_type = kernel->GetInputDeclType("Bias");
...@@ -94,26 +93,25 @@ int LayerNormConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -94,26 +93,25 @@ int LayerNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto bias_dims = bias->dims(); auto bias_dims = bias->dims();
CHECK_EQ(bias_dims.size(), 1); CHECK_EQ(bias_dims.size(), 1);
CHECK_EQ(bias_dims.production(), x_inner_size); CHECK_EQ(bias_dims.production(), x_inner_size);
bias_const_node = graph->AddNode(bias_name, *bias); bias_node = graph->Add(bias_name, *bias);
} else { } else {
bias_const_node = bias_node = graph->Add(y_name + "/bias_zero", 0.0f, {x_inner_size});
graph->AddNode(y_name + "/bias_zero", 0.0f, {x_inner_size});
} }
// Layer Norm node // Layer Norm node
auto layer_norm_node = auto layer_norm_node =
graph->AddNode(y_name, graph->Add(y_name,
graph->builder_.CreateLayerNorm(*x_node, graph->builder_.CreateLayerNorm(*x_node->data(),
*scale_const_node, *scale_node->data(),
*bias_const_node, *bias_node->data(),
axis, axis,
epsilon, epsilon,
true, true,
true)); true));
if (reshape) { if (reshape) {
graph->AddNode(y_name, graph->Add(y_name,
graph->builder_.CreateReshape( graph->builder_.CreateReshape(*layer_norm_node->data(),
*layer_norm_node, CvtShape<xtcl::Integer>(y_dims))); CvtShape<xtcl::Integer>(y_dims)));
} }
return REBUILD_WHEN_SHAPE_CHANGED; return REBUILD_WHEN_SHAPE_CHANGED;
} }
...@@ -123,6 +121,6 @@ int LayerNormConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -123,6 +121,6 @@ int LayerNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU, REGISTER_SUBGRAPH_BRIDGE(layer_norm,
layer_norm, kXPU,
paddle::lite::subgraph::xpu::LayerNormConverter); paddle::lite::subgraph::xpu::LayerNormConverter);
...@@ -57,30 +57,34 @@ int LookupTableConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -57,30 +57,34 @@ int LookupTableConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} }
// Ids node // Ids node
std::shared_ptr<xtcl::xExpr> ids_node = nullptr; std::shared_ptr<Node> ids_node = nullptr;
if (graph->HasNode(ids_name)) { if (graph->Has(ids_name)) {
ids_node = graph->GetNode(ids_name); ids_node = graph->Get(ids_name);
} else { } else {
ids_node = graph->AddNode( ids_node = graph->Add(
ids_name, ids_dims, ids_type->precision(), ids_type->layout()); ids_name, ids_dims, ids_type->precision(), ids_type->layout());
} }
// Flatten Ids node // Flatten Ids node
if (ids_dims.size() != 1) { if (ids_dims.size() != 1) {
ids_node = graph->AddNode(ids_name + "/reshape", ids_node =
graph->builder_.CreateReshape(*ids_node, {-1}), graph->Add(ids_name + "/reshape",
ids_type->precision(), graph->builder_.CreateReshape(*ids_node->data(), {-1}),
ids_type->layout()); ids_type->precision(),
ids_type->layout());
} }
auto w_const_node = graph->AddNode(w_name, *w);
// W node
auto w_node = graph->Add(w_name, *w);
// Reshape the gather node with the inferred shape as the output node // Reshape the gather node with the inferred shape as the output node
auto gather_node = graph->AddNode( auto gather_node =
out_name, graph->Add(out_name,
graph->builder_.CreateGather(*w_const_node, *ids_node, /* axis= */ 0)); graph->builder_.CreateGather(
*w_node->data(), *ids_node->data(), /* axis= */ 0));
if (out_dims.size() != 2) { if (out_dims.size() != 2) {
graph->AddNode(out_name, graph->Add(out_name,
graph->builder_.CreateReshape( graph->builder_.CreateReshape(
*gather_node, CvtShape<xtcl::Integer>(out_dims))); *gather_node->data(), CvtShape<xtcl::Integer>(out_dims)));
} }
return SUCCESS; return SUCCESS;
} }
...@@ -90,6 +94,6 @@ int LookupTableConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -90,6 +94,6 @@ int LookupTableConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU, REGISTER_SUBGRAPH_BRIDGE(lookup_table,
lookup_table, kXPU,
paddle::lite::subgraph::xpu::LookupTableConverter); paddle::lite::subgraph::xpu::LookupTableConverter);
...@@ -57,19 +57,19 @@ int MatmulConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -57,19 +57,19 @@ int MatmulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto alpha = op_info->GetAttr<float>("alpha"); auto alpha = op_info->GetAttr<float>("alpha");
// X node // X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
// Y node // Y node
std::shared_ptr<xtcl::xExpr> y_node = nullptr; std::shared_ptr<Node> y_node = nullptr;
if (graph->HasNode(y_name)) { if (graph->Has(y_name)) {
y_node = graph->GetNode(y_name); y_node = graph->Get(y_name);
} else { } else {
y_node = graph->AddNode(y_name, y_dims); y_node = graph->Add(y_name, *y);
} }
// Matmul node // Matmul node
...@@ -80,52 +80,55 @@ int MatmulConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -80,52 +80,55 @@ int MatmulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
if (x_dims.size() != 3) { if (x_dims.size() != 3) {
auto m = static_cast<int>(x_dims[x_dims.size() - 2]); auto m = static_cast<int>(x_dims[x_dims.size() - 2]);
auto k = static_cast<int>(x_dims[x_dims.size() - 1]); auto k = static_cast<int>(x_dims[x_dims.size() - 1]);
x_node = x_node = graph->Add(
graph->AddNode(x_name + "/reshape", x_name + "/reshape",
graph->builder_.CreateReshape(*x_node, {-1, m, k})); graph->builder_.CreateReshape(*x_node->data(), {-1, m, k}));
if (transpose_x) { if (transpose_x) {
x_node = x_node = graph->Add(
graph->AddNode(x_name + "/reshape/transpose", x_name + "/reshape/transpose",
graph->builder_.CreateTranspose(*x_node, {0, 2, 1})); graph->builder_.CreateTranspose(*x_node->data(), {0, 2, 1}));
} }
} }
// Reshape and transposed Y node // Reshape and transposed Y node
if (y_dims.size() != 3) { if (y_dims.size() != 3) {
auto k = static_cast<int>(y_dims[y_dims.size() - 2]); auto k = static_cast<int>(y_dims[y_dims.size() - 2]);
auto n = static_cast<int>(y_dims[y_dims.size() - 1]); auto n = static_cast<int>(y_dims[y_dims.size() - 1]);
y_node = y_node = graph->Add(
graph->AddNode(y_name + "/reshape", y_name + "/reshape",
graph->builder_.CreateReshape(*y_node, {-1, k, n})); graph->builder_.CreateReshape(*y_node->data(), {-1, k, n}));
if (!transpose_y) { if (!transpose_y) {
y_node = y_node = graph->Add(
graph->AddNode(y_name + "/reshape/transpose", y_name + "/reshape/transpose",
graph->builder_.CreateTranspose(*y_node, {0, 2, 1})); graph->builder_.CreateTranspose(*y_node->data(), {0, 2, 1}));
} }
} }
// Matmul node // Matmul node
auto matmul_node = graph->AddNode( auto matmul_node = graph->Add(
out_name, graph->builder_.CreateBatchMatmul(*x_node, *y_node)); out_name,
graph->builder_.CreateBatchMatmul(*x_node->data(), *y_node->data()));
if (fabs(alpha - 1) > 1e-6f) { if (fabs(alpha - 1) > 1e-6f) {
matmul_node = graph->AddNode( matmul_node = graph->Add(
out_name, graph->builder_.CreateScale(*matmul_node, alpha)); out_name, graph->builder_.CreateScale(*matmul_node->data(), alpha));
} }
if (out_dims.size() != 3) { if (out_dims.size() != 3) {
graph->AddNode(out_name, graph->Add(out_name,
graph->builder_.CreateReshape( graph->builder_.CreateReshape(
*matmul_node, CvtShape<xtcl::Integer>(out_dims))); *matmul_node->data(), CvtShape<xtcl::Integer>(out_dims)));
} }
} else if (x_dims.size() == 2 && y_dims.size() == 2) { } else if (x_dims.size() == 2 && y_dims.size() == 2) {
// x: [M, K], y: [K, N], out: [M, N] // x: [M, K], y: [K, N], out: [M, N]
if (transpose_x) { if (transpose_x) {
x_node = graph->AddNode(x_name + "/transpose", x_node =
graph->builder_.CreateTranspose(*x_node, {1, 0})); graph->Add(x_name + "/transpose",
graph->builder_.CreateTranspose(*x_node->data(), {1, 0}));
} }
auto matmul_node = graph->AddNode( auto matmul_node =
out_name, graph->Add(out_name,
graph->builder_.CreateMatmul2D(*x_node, *y_node, transpose_y)); graph->builder_.CreateMatmul2D(
*x_node->data(), *y_node->data(), transpose_y));
if (fabs(alpha - 1) > 1e-6f) { if (fabs(alpha - 1) > 1e-6f) {
matmul_node = graph->AddNode( matmul_node = graph->Add(
out_name, graph->builder_.CreateScale(*matmul_node, alpha)); out_name, graph->builder_.CreateScale(*matmul_node->data(), alpha));
} }
} else if (x_dims.size() == 1 && y_dims.size() == 1) { } else if (x_dims.size() == 1 && y_dims.size() == 1) {
// x: [K], y: [K], out: [1] // x: [K], y: [K], out: [1]
...@@ -141,6 +144,6 @@ int MatmulConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -141,6 +144,6 @@ int MatmulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU, REGISTER_SUBGRAPH_BRIDGE(matmul,
matmul, kXPU,
paddle::lite::subgraph::xpu::MatmulConverter); paddle::lite::subgraph::xpu::MatmulConverter);
...@@ -56,49 +56,50 @@ int MulConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -56,49 +56,50 @@ int MulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK_EQ(x_matrix_dims[1], y_matrix_dims[0]); CHECK_EQ(x_matrix_dims[1], y_matrix_dims[0]);
// X node // X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
// Flatten X node // Flatten X node
if (x_dims.size() != 2) { if (x_dims.size() != 2) {
x_node = x_node = graph->Add(
graph->AddNode(x_name + "/reshape", x_name + "/reshape",
graph->builder_.CreateReshape( graph->builder_.CreateReshape(
*x_node, {-1, static_cast<int>(x_matrix_dims[1])})); *x_node->data(), {-1, static_cast<int>(x_matrix_dims[1])}));
} }
// Y node // Y node
std::shared_ptr<xtcl::xExpr> y_node = nullptr; std::shared_ptr<Node> y_node = nullptr;
if (graph->HasNode(y_name)) { if (graph->Has(y_name)) {
y_node = graph->GetNode(y_name); y_node = graph->Get(y_name);
} else { } else {
y_node = graph->AddNode(y_name, y_dims); y_node = graph->Add(y_name, *y);
} }
// Flatten Y node // Flatten Y node
if (y_dims.size() != 2) { if (y_dims.size() != 2) {
y_node = y_node = graph->Add(
graph->AddNode(y_name + "/reshape", y_name + "/reshape",
graph->builder_.CreateReshape( graph->builder_.CreateReshape(
*y_node, {static_cast<int>(y_matrix_dims[0]), -1})); *y_node->data(), {static_cast<int>(y_matrix_dims[0]), -1}));
} }
// Reshape the matmul node with the inferred shape as the output node // Reshape the matmul node with the inferred shape as the output node
auto matmul_node = graph->AddNode( auto matmul_node = graph->Add(
out_name, graph->builder_.CreateMatmul2D(*x_node, *y_node, false)); out_name,
graph->builder_.CreateMatmul2D(*x_node->data(), *y_node->data(), false));
if (out_dims.size() != 2) { if (out_dims.size() != 2) {
graph->AddNode(out_name, graph->Add(out_name,
graph->builder_.CreateReshape( graph->builder_.CreateReshape(
*matmul_node, CvtShape<xtcl::Integer>(out_dims))); *matmul_node->data(), CvtShape<xtcl::Integer>(out_dims)));
} }
return REBUILD_WHEN_SHAPE_CHANGED; return REBUILD_WHEN_SHAPE_CHANGED;
} } // namespace xpu
} // namespace xpu } // namespace xpu
} // namespace subgraph } // namespace subgraph
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU, mul, paddle::lite::subgraph::xpu::MulConverter); REGISTER_SUBGRAPH_BRIDGE(mul, kXPU, paddle::lite::subgraph::xpu::MulConverter);
...@@ -14,25 +14,25 @@ ...@@ -14,25 +14,25 @@
#pragma once #pragma once
USE_SUBGRAPH_BRIDGE(XPU, relu); USE_SUBGRAPH_BRIDGE(relu, kXPU);
USE_SUBGRAPH_BRIDGE(XPU, tanh); USE_SUBGRAPH_BRIDGE(tanh, kXPU);
USE_SUBGRAPH_BRIDGE(XPU, conv2d); USE_SUBGRAPH_BRIDGE(conv2d, kXPU);
USE_SUBGRAPH_BRIDGE(XPU, depthwise_conv2d); USE_SUBGRAPH_BRIDGE(depthwise_conv2d, kXPU);
USE_SUBGRAPH_BRIDGE(XPU, elementwise_add); USE_SUBGRAPH_BRIDGE(elementwise_add, kXPU);
USE_SUBGRAPH_BRIDGE(XPU, pool2d); USE_SUBGRAPH_BRIDGE(pool2d, kXPU);
USE_SUBGRAPH_BRIDGE(XPU, softmax); USE_SUBGRAPH_BRIDGE(softmax, kXPU);
USE_SUBGRAPH_BRIDGE(XPU, mul); USE_SUBGRAPH_BRIDGE(mul, kXPU);
USE_SUBGRAPH_BRIDGE(XPU, batch_norm); USE_SUBGRAPH_BRIDGE(batch_norm, kXPU);
USE_SUBGRAPH_BRIDGE(XPU, stack); USE_SUBGRAPH_BRIDGE(stack, kXPU);
USE_SUBGRAPH_BRIDGE(XPU, gather); USE_SUBGRAPH_BRIDGE(gather, kXPU);
USE_SUBGRAPH_BRIDGE(XPU, scale); USE_SUBGRAPH_BRIDGE(scale, kXPU);
USE_SUBGRAPH_BRIDGE(XPU, lookup_table); USE_SUBGRAPH_BRIDGE(lookup_table, kXPU);
USE_SUBGRAPH_BRIDGE(XPU, slice); USE_SUBGRAPH_BRIDGE(slice, kXPU);
USE_SUBGRAPH_BRIDGE(XPU, transpose); USE_SUBGRAPH_BRIDGE(transpose, kXPU);
USE_SUBGRAPH_BRIDGE(XPU, transpose2); USE_SUBGRAPH_BRIDGE(transpose2, kXPU);
USE_SUBGRAPH_BRIDGE(XPU, reshape); USE_SUBGRAPH_BRIDGE(reshape, kXPU);
USE_SUBGRAPH_BRIDGE(XPU, reshape2); USE_SUBGRAPH_BRIDGE(reshape2, kXPU);
USE_SUBGRAPH_BRIDGE(XPU, layer_norm); USE_SUBGRAPH_BRIDGE(layer_norm, kXPU);
USE_SUBGRAPH_BRIDGE(XPU, gelu); USE_SUBGRAPH_BRIDGE(gelu, kXPU);
USE_SUBGRAPH_BRIDGE(XPU, dropout); USE_SUBGRAPH_BRIDGE(dropout, kXPU);
USE_SUBGRAPH_BRIDGE(XPU, matmul); USE_SUBGRAPH_BRIDGE(matmul, kXPU);
...@@ -50,21 +50,22 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -50,21 +50,22 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto exclusive = op_info->GetAttr<bool>("exclusive"); auto exclusive = op_info->GetAttr<bool>("exclusive");
// X node // X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
// Pool node // Pool node
if (pooling_type == "max") { if (pooling_type == "max") {
if (global_pooling) { if (global_pooling) {
graph->AddNode(out_name, graph->builder_.CreateGlobalMaxPool2D(*x_node)); graph->Add(out_name,
graph->builder_.CreateGlobalMaxPool2D(*x_node->data()));
} else { } else {
graph->AddNode( graph->Add(
out_name, out_name,
graph->builder_.CreateMaxPool2D(*x_node, graph->builder_.CreateMaxPool2D(*x_node->data(),
CvtShape<xtcl::xIndexExpr>(ksize), CvtShape<xtcl::xIndexExpr>(ksize),
CvtShape<xtcl::xIndexExpr>(strides), CvtShape<xtcl::xIndexExpr>(strides),
CvtShape<xtcl::xIndexExpr>(paddings), CvtShape<xtcl::xIndexExpr>(paddings),
...@@ -73,12 +74,13 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -73,12 +74,13 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} }
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
if (global_pooling) { if (global_pooling) {
graph->AddNode(out_name, graph->builder_.CreateGlobalAvgPool2D(*x_node)); graph->Add(out_name,
graph->builder_.CreateGlobalAvgPool2D(*x_node->data()));
} else { } else {
// !exclusive ---> count_include_pad // !exclusive ---> count_include_pad
graph->AddNode( graph->Add(
out_name, out_name,
graph->builder_.CreateAvgPool2D(*x_node, graph->builder_.CreateAvgPool2D(*x_node->data(),
CvtShape<xtcl::xIndexExpr>(ksize), CvtShape<xtcl::xIndexExpr>(ksize),
CvtShape<xtcl::xIndexExpr>(strides), CvtShape<xtcl::xIndexExpr>(strides),
CvtShape<xtcl::xIndexExpr>(paddings), CvtShape<xtcl::xIndexExpr>(paddings),
...@@ -98,6 +100,6 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -98,6 +100,6 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU, REGISTER_SUBGRAPH_BRIDGE(pool2d,
pool2d, kXPU,
paddle::lite::subgraph::xpu::PoolConverter); paddle::lite::subgraph::xpu::PoolConverter);
...@@ -44,11 +44,11 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -44,11 +44,11 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(out_type->layout() == DATALAYOUT(kNCHW)); CHECK(out_type->layout() == DATALAYOUT(kNCHW));
// X node // X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
std::vector<int> shape; std::vector<int> shape;
...@@ -59,6 +59,7 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -59,6 +59,7 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// CHECK(shape_tensor_type->layout() == DATALAYOUT(kNCHW)); // CHECK(shape_tensor_type->layout() == DATALAYOUT(kNCHW));
for (auto shape_tensor_name : shape_tensor_names) { for (auto shape_tensor_name : shape_tensor_names) {
auto shape_tensor = scope->FindMutableTensor(shape_tensor_name); auto shape_tensor = scope->FindMutableTensor(shape_tensor_name);
CHECK(shape_tensor->persistable());
auto shape_tensor_data = shape_tensor->mutable_data<int>(); auto shape_tensor_data = shape_tensor->mutable_data<int>();
shape.emplace_back(shape_tensor_data[0]); shape.emplace_back(shape_tensor_data[0]);
} }
...@@ -73,6 +74,7 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -73,6 +74,7 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// CHECK(actual_shape_type->precision() == PRECISION(kInt32)); // CHECK(actual_shape_type->precision() == PRECISION(kInt32));
// CHECK(actual_shape_type->layout() == DATALAYOUT(kNCHW)); // CHECK(actual_shape_type->layout() == DATALAYOUT(kNCHW));
auto actual_shape = scope->FindMutableTensor(actual_shape_name); auto actual_shape = scope->FindMutableTensor(actual_shape_name);
CHECK(actual_shape->persistable());
auto actual_shape_dims = actual_shape->dims(); auto actual_shape_dims = actual_shape->dims();
auto actual_shape_data = actual_shape->mutable_data<int>(); auto actual_shape_data = actual_shape->mutable_data<int>();
auto shape = std::vector<int>( auto shape = std::vector<int>(
...@@ -86,9 +88,9 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -86,9 +88,9 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto out_dims = operators::ValidateShape(shape, x_dims); auto out_dims = operators::ValidateShape(shape, x_dims);
// Reshape node // Reshape node
graph->AddNode(out_name, graph->Add(out_name,
graph->builder_.CreateReshape( graph->builder_.CreateReshape(*x_node->data(),
*x_node, CvtShape<xtcl::Integer>(out_dims))); CvtShape<xtcl::Integer>(out_dims)));
return REBUILD_WHEN_SHAPE_CHANGED; return REBUILD_WHEN_SHAPE_CHANGED;
} }
...@@ -97,9 +99,9 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -97,9 +99,9 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU, REGISTER_SUBGRAPH_BRIDGE(reshape2,
reshape2, kXPU,
paddle::lite::subgraph::xpu::ReshapeConverter); paddle::lite::subgraph::xpu::ReshapeConverter);
REGISTER_SUBGRAPH_BRIDGE(XPU, REGISTER_SUBGRAPH_BRIDGE(reshape,
reshape, kXPU,
paddle::lite::subgraph::xpu::ReshapeConverter); paddle::lite::subgraph::xpu::ReshapeConverter);
...@@ -46,17 +46,17 @@ int ScaleConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -46,17 +46,17 @@ int ScaleConverter(void* ctx, OpLite* op, KernelBase* kernel) {
float bias = op_info->GetAttr<float>("bias"); float bias = op_info->GetAttr<float>("bias");
// X node // X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
// Scale node // Scale node
graph->AddNode( graph->Add(out_name,
out_name, graph->builder_.CreateScale(
graph->builder_.CreateScale(*x_node, scale, bias, bias_after_scale)); *x_node->data(), scale, bias, bias_after_scale));
return SUCCESS; return SUCCESS;
} }
...@@ -65,6 +65,6 @@ int ScaleConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -65,6 +65,6 @@ int ScaleConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU, REGISTER_SUBGRAPH_BRIDGE(scale,
scale, kXPU,
paddle::lite::subgraph::xpu::ScaleConverter); paddle::lite::subgraph::xpu::ScaleConverter);
...@@ -46,11 +46,11 @@ int SliceConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -46,11 +46,11 @@ int SliceConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto ends = op_info->GetAttr<std::vector<int>>("ends"); auto ends = op_info->GetAttr<std::vector<int>>("ends");
// Input node // Input node
std::shared_ptr<xtcl::xExpr> input_node = nullptr; std::shared_ptr<Node> input_node = nullptr;
if (graph->HasNode(input_name)) { if (graph->Has(input_name)) {
input_node = graph->GetNode(input_name); input_node = graph->Get(input_name);
} else { } else {
input_node = graph->AddNode(input_name, input_dims); input_node = graph->Add(input_name, *input);
} }
// Calculate the begin and end of the slice in all of // Calculate the begin and end of the slice in all of
...@@ -74,9 +74,9 @@ int SliceConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -74,9 +74,9 @@ int SliceConverter(void* ctx, OpLite* op, KernelBase* kernel) {
strides.push_back(1); strides.push_back(1);
} }
} }
graph->AddNode( graph->Add(out_name,
out_name, graph->builder_.CreateStridedSlice(
graph->builder_.CreateStridedSlice(*input_node, begin, end, strides)); *input_node->data(), begin, end, strides));
return REBUILD_WHEN_SHAPE_CHANGED; return REBUILD_WHEN_SHAPE_CHANGED;
} }
...@@ -85,6 +85,6 @@ int SliceConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -85,6 +85,6 @@ int SliceConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU, REGISTER_SUBGRAPH_BRIDGE(slice,
slice, kXPU,
paddle::lite::subgraph::xpu::SliceConverter); paddle::lite::subgraph::xpu::SliceConverter);
...@@ -44,15 +44,15 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -44,15 +44,15 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto axis = op_info->GetAttr<int>("axis"); auto axis = op_info->GetAttr<int>("axis");
// X node // X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
// Softmax node // Softmax node
graph->AddNode(out_name, graph->builder_.CreateSoftmax(*x_node, axis)); graph->Add(out_name, graph->builder_.CreateSoftmax(*x_node->data(), axis));
return SUCCESS; return SUCCESS;
} }
...@@ -61,6 +61,6 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -61,6 +61,6 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU, REGISTER_SUBGRAPH_BRIDGE(softmax,
softmax, kXPU,
paddle::lite::subgraph::xpu::SoftmaxConverter); paddle::lite::subgraph::xpu::SoftmaxConverter);
...@@ -46,19 +46,19 @@ int StackConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -46,19 +46,19 @@ int StackConverter(void* ctx, OpLite* op, KernelBase* kernel) {
for (auto& x_name : x_names) { for (auto& x_name : x_names) {
auto x = scope->FindMutableTensor(x_name); auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims(); auto x_dims = x->dims();
std::shared_ptr<xtcl::xExpr> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
x_nodes.push_back(*x_node); x_nodes.push_back(*x_node->data());
} }
// Stack node // Stack node
graph->AddNode(y_name, graph->Add(y_name,
graph->builder_.CreateStack( graph->builder_.CreateStack(
xtcl::network::TupleNode::make(x_nodes), axis)); xtcl::network::TupleNode::make(x_nodes), axis));
return SUCCESS; return SUCCESS;
} }
...@@ -67,6 +67,6 @@ int StackConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -67,6 +67,6 @@ int StackConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU, REGISTER_SUBGRAPH_BRIDGE(stack,
stack, kXPU,
paddle::lite::subgraph::xpu::StackConverter); paddle::lite::subgraph::xpu::StackConverter);
...@@ -44,19 +44,19 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -44,19 +44,19 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto axis = op_info->GetAttr<std::vector<int>>("axis"); auto axis = op_info->GetAttr<std::vector<int>>("axis");
// X node // X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->HasNode(x_name)) { if (graph->Has(x_name)) {
x_node = graph->GetNode(x_name); x_node = graph->Get(x_name);
} else { } else {
x_node = graph->AddNode(x_name, x_dims); x_node = graph->Add(x_name, *x);
} }
// Transpose node // Transpose node
graph->AddNode(out_name, graph->Add(out_name,
graph->builder_.CreateTranspose( graph->builder_.CreateTranspose(
*x_node, *x_node->data(),
CvtShape<xtcl::Integer>( CvtShape<xtcl::Integer>(
std::vector<int64_t>(axis.begin(), axis.end())))); std::vector<int64_t>(axis.begin(), axis.end()))));
return SUCCESS; return SUCCESS;
} }
...@@ -66,9 +66,9 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -66,9 +66,9 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU, REGISTER_SUBGRAPH_BRIDGE(transpose,
transpose, kXPU,
paddle::lite::subgraph::xpu::TransposeConverter); paddle::lite::subgraph::xpu::TransposeConverter);
REGISTER_SUBGRAPH_BRIDGE(XPU, REGISTER_SUBGRAPH_BRIDGE(transpose2,
transpose2, kXPU,
paddle::lite::subgraph::xpu::TransposeConverter); paddle::lite::subgraph::xpu::TransposeConverter);
...@@ -39,13 +39,13 @@ int SubgraphEngine::BuildDeviceProgram() { ...@@ -39,13 +39,13 @@ int SubgraphEngine::BuildDeviceProgram() {
op->CheckShape(); op->CheckShape();
op->InferShape(); op->InferShape();
std::string op_type = op->op_info()->Type(); std::string op_type = op->op_info()->Type();
if (!bridges.Exists("XPU", op_type)) { if (!bridges.Exists(op_type, "kXPU")) {
return subgraph::FAILED; return subgraph::FAILED;
} }
auto kernel = inst.kernel(); auto kernel = inst.kernel();
status |= bridges.Select("XPU", op_type)(reinterpret_cast<void*>(&graph), status |= bridges.Select(op_type, "kXPU")(reinterpret_cast<void*>(&graph),
const_cast<OpLite*>(op), const_cast<OpLite*>(op),
const_cast<KernelBase*>(kernel)); const_cast<KernelBase*>(kernel));
if (subgraph::CHECK_FAILED(status)) { if (subgraph::CHECK_FAILED(status)) {
return subgraph::FAILED; return subgraph::FAILED;
} }
...@@ -57,26 +57,26 @@ int SubgraphEngine::BuildDeviceProgram() { ...@@ -57,26 +57,26 @@ int SubgraphEngine::BuildDeviceProgram() {
std::vector<xtcl::xExpr*> device_inodes; std::vector<xtcl::xExpr*> device_inodes;
std::vector<xtcl::xExpr*> device_onodes; std::vector<xtcl::xExpr*> device_onodes;
for (auto& input_name : input_names_) { for (auto& input_name : input_names_) {
if (graph.HasNode(input_name)) { if (graph.Has(input_name)) {
if (!graph.GetType(input_name).persistable()) { if (graph.Get(input_name)->is_data()) {
device_inodes.push_back(graph.GetNode(input_name).get()); device_inodes.push_back(graph.Get(input_name)->data().get());
device_inames_.push_back(input_name); device_inames_.push_back(input_name);
} else { } else {
LOG(WARNING) << "[XPU] Input node " << input_name LOG(WARNING) << "[XPU] Input node " << input_name
<< " is skipped because it is a persistable node."; << " is ignored because it is not a data node.";
} }
} else { } else {
LOG(WARNING) << "[XPU] Input node " << input_name LOG(WARNING) << "[XPU] Input node " << input_name
<< " is skipped because it does not exist."; << " is ignored because it does not exist.";
} }
} }
for (auto& output_name : output_names_) { for (auto& output_name : output_names_) {
if (graph.HasNode(output_name)) { if (graph.Has(output_name)) {
device_onodes.push_back(graph.GetNode(output_name).get()); device_onodes.push_back(graph.Get(output_name)->data().get());
device_onames_.push_back(output_name); device_onames_.push_back(output_name);
} else { } else {
LOG(WARNING) << "[XPU] Output node " << output_name LOG(WARNING) << "[XPU] Output node " << output_name
<< " is skipped because it does not exist."; << " is ignored because it does not exist.";
} }
} }
CHECK(!device_inames_.empty()) CHECK(!device_inames_.empty())
...@@ -98,14 +98,14 @@ int SubgraphEngine::BuildDeviceProgram() { ...@@ -98,14 +98,14 @@ int SubgraphEngine::BuildDeviceProgram() {
origin_otensors_.resize(device_onames_.size()); origin_otensors_.resize(device_onames_.size());
device_otensors_.resize(device_onames_.size()); device_otensors_.resize(device_onames_.size());
for (int i = 0; i < device_inames_.size(); i++) { for (int i = 0; i < device_inames_.size(); i++) {
auto type = graph.GetType(device_inames_[i]); auto node = graph.Get(device_inames_[i]);
auto precision = type.precision(); auto precision = node->precision();
auto layout = type.layout(); auto layout = node->layout();
origin_itensors_[i] = scope_->FindMutableTensor(device_inames_[i]); origin_itensors_[i] = scope_->FindMutableTensor(device_inames_[i]);
CHECK(origin_itensors_[i]); CHECK(origin_itensors_[i]);
origin_idims_[i] = origin_itensors_[i]->dims(); origin_idims_[i] = origin_itensors_[i]->dims();
VLOG(3) << "[XPU] Inputs[" << i VLOG(3) << "[XPU] Inputs[" << i << "] name: " << device_inames_[i]
<< "] precision: " << PrecisionToStr(precision) << " precision: " << PrecisionToStr(precision)
<< " layout: " << DataLayoutToStr(layout) << " layout: " << DataLayoutToStr(layout)
<< " dims: " << origin_idims_[i]; << " dims: " << origin_idims_[i];
// Prepare the device input tensors which share data with the origin input // Prepare the device input tensors which share data with the origin input
...@@ -122,14 +122,14 @@ int SubgraphEngine::BuildDeviceProgram() { ...@@ -122,14 +122,14 @@ int SubgraphEngine::BuildDeviceProgram() {
device_itensors_[i].byte_offset = 0; device_itensors_[i].byte_offset = 0;
} }
for (int i = 0; i < device_onames_.size(); i++) { for (int i = 0; i < device_onames_.size(); i++) {
auto type = graph.GetType(device_onames_[i]); auto node = graph.Get(device_onames_[i]);
auto precision = type.precision(); auto precision = node->precision();
auto layout = type.layout(); auto layout = node->layout();
origin_otensors_[i] = scope_->FindMutableTensor(device_onames_[i]); origin_otensors_[i] = scope_->FindMutableTensor(device_onames_[i]);
CHECK(origin_otensors_[i]); CHECK(origin_otensors_[i]);
origin_odims_[i] = origin_otensors_[i]->dims(); origin_odims_[i] = origin_otensors_[i]->dims();
VLOG(3) << "[XPU] Outputs[" << i VLOG(3) << "[XPU] Outputs[" << i << "] name: " << device_onames_[i]
<< "] precision: " << PrecisionToStr(precision) << " precision: " << PrecisionToStr(precision)
<< " layout: " << DataLayoutToStr(layout) << " layout: " << DataLayoutToStr(layout)
<< " dims: " << origin_odims_[i]; << " dims: " << origin_odims_[i];
// Prepare the device output tensors which share data with the origin output // Prepare the device output tensors which share data with the origin output
......
...@@ -29,7 +29,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH ...@@ -29,7 +29,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH
lite_cc_test(test_kernel_reshape_compute SRCS reshape_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_reshape_compute SRCS reshape_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_layer_norm_compute SRCS layer_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_layer_norm_compute SRCS layer_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_dropout_compute SRCS dropout_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_dropout_compute SRCS dropout_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_softmax_compute SRCS softmax_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_softmax_compute SRCS softmax_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_mul_compute SRCS mul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_mul_compute SRCS mul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
if(LITE_BUILD_EXTRA) if(LITE_BUILD_EXTRA)
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "lite/api/paddle_use_kernels.h" #include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h" #include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h" #include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -23,31 +24,33 @@ namespace lite { ...@@ -23,31 +24,33 @@ namespace lite {
class ScaleComputeTester : public arena::TestCase { class ScaleComputeTester : public arena::TestCase {
protected: protected:
// common attributes for this op. // common attributes for this op.
std::string input_ = "x"; std::string x_ = "x";
std::string output_ = "out"; std::string out_ = "out";
DDim x_dims_{{100, 20}};
float scale_ = 0.; float scale_ = 0.;
float bias_ = 0.; float bias_ = 0.;
DDim dims_{{100, 20}};
bool bias_after_scale_; bool bias_after_scale_;
public: public:
ScaleComputeTester(const Place& place, ScaleComputeTester(const Place& place,
const std::string& alias, const std::string& alias,
const DDim& x_dims,
float scale, float scale,
float bias, float bias,
bool bias_after_scale) bool bias_after_scale)
: TestCase(place, alias), : TestCase(place, alias),
x_dims_(x_dims),
scale_(scale), scale_(scale),
bias_(bias), bias_(bias),
bias_after_scale_(bias_after_scale) {} bias_after_scale_(bias_after_scale) {}
void RunBaseline(Scope* scope) override { void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_); auto* out = scope->NewTensor(out_);
CHECK(out); CHECK(out);
out->Resize(dims_); out->Resize(x_dims_);
auto* out_data = out->mutable_data<float>(); auto* out_data = out->mutable_data<float>();
auto* x = scope->FindTensor(input_); auto* x = scope->FindTensor(x_);
const auto* x_data = x->data<float>(); const auto* x_data = x->data<float>();
float bias = bias_; float bias = bias_;
...@@ -56,35 +59,34 @@ class ScaleComputeTester : public arena::TestCase { ...@@ -56,35 +59,34 @@ class ScaleComputeTester : public arena::TestCase {
bias *= scale_; bias *= scale_;
} }
for (int i = 0; i < dims_.production(); i++) { for (int i = 0; i < x_dims_.production(); i++) {
out_data[i] = x_data[i] * scale_ + bias; out_data[i] = x_data[i] * scale_ + bias;
} }
} }
void PrepareOpDesc(cpp::OpDesc* op_desc) { void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("scale"); op_desc->SetType("scale");
op_desc->SetInput("X", {input_}); op_desc->SetInput("X", {x_});
op_desc->SetOutput("Out", {output_}); op_desc->SetOutput("Out", {out_});
op_desc->SetAttr("scale", scale_); op_desc->SetAttr("scale", scale_);
op_desc->SetAttr("bias", bias_); op_desc->SetAttr("bias", bias_);
op_desc->SetAttr("bias_after_scale", bias_after_scale_); op_desc->SetAttr("bias_after_scale", bias_after_scale_);
} }
void PrepareData() override { void PrepareData() override {
std::vector<float> data(dims_.production()); std::vector<float> x(x_dims_.production());
fill_data_rand(x.data(), -1.f, 1.f, x_dims_.production());
for (int i = 0; i < dims_.production(); i++) { SetCommonTensor(x_, x_dims_, x.data());
data[i] = i * 1.1;
}
SetCommonTensor(input_, dims_, data.data());
} }
}; };
TEST(Scale, precision) { TEST(Scale, precision) {
Place place; Place place;
float abs_error = 2e-5; float abs_error = 2e-5;
#if defined(LITE_WITH_ARM) #if defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 4e-3; // Using fp16 in NPU
#elif defined(LITE_WITH_ARM)
place = TARGET(kARM); place = TARGET(kARM);
#elif defined(LITE_WITH_XPU) #elif defined(LITE_WITH_XPU)
place = TARGET(kXPU); place = TARGET(kXPU);
...@@ -95,13 +97,16 @@ TEST(Scale, precision) { ...@@ -95,13 +97,16 @@ TEST(Scale, precision) {
return; return;
#endif #endif
for (float scale : {0.123, 2., -1.2}) { for (auto x_dims :
for (float bias : {1., 0., -1.2331}) { std::vector<std::vector<int64_t>>{{5, 2, 3, 4}, {8, 3, 5}, {12, 3}}) {
for (bool bias_before : {true, false}) { for (float scale : {0.123, 2., -1.2}) {
std::unique_ptr<arena::TestCase> tester( for (float bias : {1., 0., -1.2331}) {
new ScaleComputeTester(place, "def", scale, bias, bias_before)); for (bool bias_after_scale : {true, false}) {
arena::Arena arena(std::move(tester), place, abs_error); std::unique_ptr<arena::TestCase> tester(new ScaleComputeTester(
arena.TestPrecision(); place, "def", DDim(x_dims), scale, bias, bias_after_scale));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
} }
} }
} }
...@@ -117,8 +122,8 @@ TEST(Scale, performance) { ...@@ -117,8 +122,8 @@ TEST(Scale, performance) {
return; return;
#endif #endif
std::unique_ptr<arena::TestCase> tester( std::unique_ptr<arena::TestCase> tester(new ScaleComputeTester(
new ScaleComputeTester(place, "def", 1.2, 1.1, true)); place, "def", DDim(std::vector<int64_t>{5, 2, 3, 4}), 1.2, 1.1, true));
// To modify the arm context, one can retrive the context as follows. // To modify the arm context, one can retrive the context as follows.
// #ifdef LITE_WITH_ARM // #ifdef LITE_WITH_ARM
......
...@@ -25,33 +25,33 @@ class SoftmaxComputeTest : public arena::TestCase { ...@@ -25,33 +25,33 @@ class SoftmaxComputeTest : public arena::TestCase {
protected: protected:
// common attributes for this op. // common attributes for this op.
std::string op_type_ = "softmax"; std::string op_type_ = "softmax";
std::string input_ = "x"; DDim x_dims_{{1, 2, 3, 4}};
std::string output_ = "out"; std::string x_ = "x";
DDim dims_{{1, 2, 3, 4}}; std::string out_ = "out";
int axis_ = 1; int axis_ = 1;
public: public:
SoftmaxComputeTest(const Place& place, SoftmaxComputeTest(const Place& place,
const std::string& alias, const std::string& alias,
DDim dims, DDim x_dims,
int axis) int axis)
: TestCase(place, alias), dims_(dims), axis_(axis) {} : TestCase(place, alias), x_dims_(x_dims), axis_(axis) {}
void RunBaseline(Scope* scope) override { void RunBaseline(Scope* scope) override {
auto x = scope->FindTensor(input_); auto x = scope->FindTensor(x_);
auto out = scope->NewTensor(output_); auto out = scope->NewTensor(out_);
CHECK(out); CHECK(out);
out->Resize(dims_); out->Resize(x_dims_);
auto x_data = x->data<float>(); auto x_data = x->data<float>();
auto out_data = out->mutable_data<float>(); auto out_data = out->mutable_data<float>();
auto x_rank = dims_.size(); auto x_rank = x_dims_.size();
if (axis_ < 0) { if (axis_ < 0) {
axis_ += x_rank; axis_ += x_rank;
} }
int axis_size = dims_[axis_]; int axis_size = x_dims_[axis_];
int outer_num = dims_.Slice(0, axis_).production(); int outer_num = x_dims_.Slice(0, axis_).production();
int inner_num = dims_.Slice(axis_ + 1, x_rank).production(); int inner_num = x_dims_.Slice(axis_ + 1, x_rank).production();
int compute_size = outer_num * inner_num; int compute_size = outer_num * inner_num;
for (int i = 0; i < compute_size; i++) { for (int i = 0; i < compute_size; i++) {
int idx_inner = i % inner_num; int idx_inner = i % inner_num;
...@@ -84,15 +84,15 @@ class SoftmaxComputeTest : public arena::TestCase { ...@@ -84,15 +84,15 @@ class SoftmaxComputeTest : public arena::TestCase {
void PrepareOpDesc(cpp::OpDesc* op_desc) { void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType(op_type_); op_desc->SetType(op_type_);
op_desc->SetInput("X", {input_}); op_desc->SetInput("X", {x_});
op_desc->SetOutput("Out", {output_}); op_desc->SetOutput("Out", {out_});
op_desc->SetAttr("axis", axis_); op_desc->SetAttr("axis", axis_);
} }
void PrepareData() override { void PrepareData() override {
std::vector<float> din(dims_.production()); std::vector<float> x(x_dims_.production());
fill_data_rand(din.data(), -1.f, 1.f, dims_.production()); fill_data_rand(x.data(), -1.f, 1.f, x_dims_.production());
SetCommonTensor(input_, dims_, din.data()); SetCommonTensor(x_, x_dims_, x.data());
} }
}; };
...@@ -100,18 +100,21 @@ TEST(Softmax, precision) { ...@@ -100,18 +100,21 @@ TEST(Softmax, precision) {
LOG(INFO) << "test softmax op"; LOG(INFO) << "test softmax op";
float abs_error = 2e-5; float abs_error = 2e-5;
Place place; Place place;
#if defined(LITE_WITH_XPU) #if defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 4e-3; // Using fp16 in NPU
#elif defined(LITE_WITH_XPU)
place = TARGET(kXPU); place = TARGET(kXPU);
#else #else
return; return;
#endif #endif
std::vector<std::vector<int64_t>> dims{{1, 2, 3, 4}, {2, 3, 4}, {3, 4}}; for (auto x_dims :
for (auto dim_in : dims) { std::vector<std::vector<int64_t>>{{1, 2, 3, 4}, {2, 3, 4}, {3, 4}}) {
for (auto axis : {-1, 0, 1, 2, 3}) { for (auto axis : {-1, 0, 1, 2, 3}) {
if (axis >= dim_in.size()) continue; if (axis >= x_dims.size()) continue;
std::unique_ptr<arena::TestCase> tester( std::unique_ptr<arena::TestCase> tester(
new SoftmaxComputeTest(place, "def", DDim(dim_in), axis)); new SoftmaxComputeTest(place, "def", DDim(x_dims), axis));
arena::Arena arena(std::move(tester), place, abs_error); arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision(); arena.TestPrecision();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册