未验证 提交 05da0c72 编写于 作者: H hong19860320 提交者: GitHub

[LITE][NPU][XPU] Support multiple types for XPU and NPU op bridges (#2646)

* Support multiple types for XPU and NPU op bridges

* Add lookup_table, gather, slice, stack and scale op bridges for supporting BERT

* Fix the definition of lookup_table kernel for X86
上级 e1c4adfd
......@@ -169,6 +169,10 @@ endif()
########################################################################################
if(LITE_WITH_XPU)
include(xpu)
endif()
include(external/mklml) # download mklml package
include(external/xbyak) # download xbyak package
include(external/libxsmm) # download, build, install libxsmm
......@@ -188,10 +192,6 @@ if(LITE_WITH_CUDA)
include(cuda)
endif()
if(LITE_WITH_XPU)
include(xpu)
endif()
include(generic) # simplify cmake module
include(ccache) # set ccache for compilation
include(util) # set unittest and link libs
......
......@@ -89,7 +89,7 @@ else()
endif()
find_library(XPU_SDK_LLVM_FILE NAMES LLVM-8
PATHS ${XPU_SDK_ROOT}/XTDK/shlib/gcc482)
PATHS ${XPU_SDK_ROOT}/XTDK/shlib)
if(NOT XPU_SDK_LLVM_FILE)
message(FATAL_ERROR "Can not find LLVM Library in ${XPU_SDK_ROOT}")
......@@ -99,7 +99,7 @@ else()
set_property(TARGET xpu_sdk_llvm PROPERTY IMPORTED_LOCATION ${XPU_SDK_LLVM_FILE})
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDMLC_USE_GLOG=1")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDMLC_USE_GLOG=1 -D_GLIBCXX_USE_CXX11_ABI=0")
set(xpu_runtime_libs xpu_sdk_xtcl xpu_sdk_tvm xpu_sdk_xpu_api xpu_sdk_xpu_rt xpu_sdk_xpu_jitc xpu_sdk_llvm CACHE INTERNAL "xpu runtime libs")
set(xpu_builder_libs xpu_sdk_xtcl xpu_sdk_tvm xpu_sdk_xpu_api xpu_sdk_xpu_rt xpu_sdk_xpu_jitc xpu_sdk_llvm CACHE INTERNAL "xpu builder libs")
......@@ -61,6 +61,7 @@ std::unique_ptr<hiai::AiModelMngerClient> Device::Build(
return nullptr;
}
ir_build.ReleaseModelBuff(om_model_buf);
VLOG(3) << "[NPU] Build done";
return model_client;
}
......
......@@ -28,8 +28,8 @@ std::unique_ptr<xtcl::network::xRuntimeInstance> Device::Build(
CHECK(outputs != nullptr);
CHECK_GT(outputs->size(), 0);
// The XPU compiler build the graph and fill all of the constant params, only
// one output is supported now.
// The XPU compiler build the graph and fill all of the constant params, and
// use TupleNode to support multiple outputs
xtcl::Array<xtcl::xExpr> all_outs;
for (size_t i = 0; i < outputs->size(); i++) {
all_outs.push_back(*outputs->at(i));
......@@ -40,6 +40,7 @@ std::unique_ptr<xtcl::network::xRuntimeInstance> Device::Build(
auto compiler = xtcl::network::xTensorCompiler(network, target);
compiler.SetParams(*params); // Set the data of constant tensors
compiler.Build();
VLOG(3) << "[XPU] Build done";
return std::unique_ptr<xtcl::network::xRuntimeInstance>(
new xtcl::network::xRuntimeInstance(compiler.CreateRuntimeInstance()));
}
......
......@@ -24,39 +24,56 @@
DEFINE_string(model_file, "", "model file path of combined protobuf model");
DEFINE_string(params_file, "", "params file path of combined protobuf model");
DEFINE_string(optimized_model_dir, "", "path of optimized naive buffer model");
DEFINE_string(input_tensor_shape, "1,3,224,224", "shapes of input tensors");
DEFINE_int32(output_tensor_num, 1, "number of output tensors");
DEFINE_string(input_tensor_shape, "1,3,224,224", "shape of input tensors");
DEFINE_string(input_tensor_type, "float32", "data type of input tensors");
DEFINE_string(output_tensor_type, "float32", "data type of output tensors");
namespace paddle {
namespace lite {
// The helper functions for loading and running model from command line and
// verifying output data
std::vector<std::vector<int64_t>> ShapeParsing(std::string txt) {
std::vector<std::vector<int64_t>> shape;
while (!txt.empty()) {
size_t idx = txt.find_first_of(":");
std::string dims = txt.substr(0, idx);
std::vector<int64_t> s;
while (!dims.empty()) {
size_t idx = dims.find_first_of(",");
int d = atoi(dims.substr(0, idx).c_str());
std::vector<std::string> TypeParsing(std::string text) {
std::vector<std::string> types;
while (!text.empty()) {
size_t index = text.find_first_of(":");
std::string type = text.substr(0, index);
VLOG(3) << type;
types.push_back(type);
if (index == std::string::npos) {
break;
} else {
text = text.substr(index + 1);
}
}
return types;
}
std::vector<std::vector<int64_t>> ShapeParsing(std::string text) {
std::vector<std::vector<int64_t>> shapes;
while (!text.empty()) {
size_t index = text.find_first_of(":");
std::string slice = text.substr(0, index);
std::vector<int64_t> shape;
while (!slice.empty()) {
size_t index = slice.find_first_of(",");
int d = atoi(slice.substr(0, index).c_str());
VLOG(3) << d;
s.push_back(d);
if (idx == std::string::npos) {
shape.push_back(d);
if (index == std::string::npos) {
break;
} else {
dims = dims.substr(idx + 1);
slice = slice.substr(index + 1);
}
}
shape.push_back(s);
if (idx == std::string::npos) {
shapes.push_back(shape);
if (index == std::string::npos) {
break;
} else {
txt = txt.substr(idx + 1);
text = text.substr(index + 1);
}
}
return shape;
return shapes;
}
int64_t ShapeProduction(std::vector<int64_t> shape) {
......@@ -70,40 +87,55 @@ int64_t ShapeProduction(std::vector<int64_t> shape) {
void FillInputTensors(
const std::shared_ptr<lite_api::PaddlePredictor>& predictor,
const std::vector<std::vector<int64_t>>& input_tensor_shape,
const std::vector<std::string>& input_tensor_type,
const float value) {
#define FILL_TENSOR_WITH_TYPE(type) \
auto input_tensor_data = input_tensor->mutable_data<type>(); \
for (int j = 0; j < input_tensor_size; j++) { \
input_tensor_data[i] = static_cast<type>(value); \
}
for (int i = 0; i < input_tensor_shape.size(); i++) {
auto input_tensor = predictor->GetInput(i);
input_tensor->Resize(input_tensor_shape[i]);
auto input_tensor_data = input_tensor->mutable_data<float>();
auto input_tensor_size = ShapeProduction(input_tensor->shape());
for (int j = 0; j < input_tensor_size; j++) {
input_tensor_data[i] = value;
if (input_tensor_type[i] == "float32") {
FILL_TENSOR_WITH_TYPE(float)
} else if (input_tensor_type[i] == "int64") {
FILL_TENSOR_WITH_TYPE(int64_t)
}
}
#undef FILL_TENSOR_WITH_TYPE
}
void CheckOutputTensors(
const std::shared_ptr<lite_api::PaddlePredictor>& tar_predictor,
const std::shared_ptr<lite_api::PaddlePredictor>& ref_predictor,
const int output_tensor_num) {
for (int i = 0; i < output_tensor_num; i++) {
const std::vector<std::string>& output_tensor_type) {
#define CHECK_TENSOR_WITH_TYPE(type) \
auto tar_output_tensor_data = tar_output_tensor->data<type>(); \
auto ref_output_tensor_data = ref_output_tensor->data<type>(); \
for (size_t j = 0; j < ref_output_tensor_size; j++) { \
auto abs_diff = \
std::fabs(tar_output_tensor_data[j] - ref_output_tensor_data[j]); \
auto rel_diff = abs_diff / (std::fabs(ref_output_tensor_data[j]) + 1e-6); \
VLOG(5) << "val: " << tar_output_tensor_data[j] \
<< " ref: " << ref_output_tensor_data[j] \
<< " abs_diff: " << abs_diff << " rel_diff: " << rel_diff; \
EXPECT_LT(rel_diff, 0.1); \
}
for (int i = 0; i < output_tensor_type.size(); i++) {
auto tar_output_tensor = tar_predictor->GetOutput(i);
auto ref_output_tensor = ref_predictor->GetOutput(i);
auto tar_output_tensor_data = tar_output_tensor->data<float>();
auto ref_output_tensor_data = ref_output_tensor->data<float>();
auto tar_output_tensor_size = ShapeProduction(tar_output_tensor->shape());
auto ref_output_tensor_size = ShapeProduction(ref_output_tensor->shape());
EXPECT_EQ(tar_output_tensor_size, ref_output_tensor_size);
for (size_t j = 0; j < ref_output_tensor_size; j++) {
auto abs_diff =
std::fabs(tar_output_tensor_data[j] - ref_output_tensor_data[j]);
auto rel_diff = abs_diff / (std::fabs(ref_output_tensor_data[j]) + 1e-6);
VLOG(5) << "val: " << tar_output_tensor_data[j]
<< " ref: " << ref_output_tensor_data[j]
<< " abs_diff: " << abs_diff << " rel_diff: " << rel_diff;
EXPECT_LT(rel_diff, 0.1);
if (output_tensor_type[i] == "float32") {
CHECK_TENSOR_WITH_TYPE(float)
} else if (output_tensor_type[i] == "int64") {
CHECK_TENSOR_WITH_TYPE(int64_t)
}
}
#undef CHECK_TENSOR_WITH_TYPE
}
std::shared_ptr<lite_api::PaddlePredictor> TestModel(
......@@ -112,6 +144,7 @@ std::shared_ptr<lite_api::PaddlePredictor> TestModel(
const std::string& params_file,
const std::vector<lite_api::Place>& valid_places,
const std::vector<std::vector<int64_t>>& input_tensor_shape,
const std::vector<std::string>& input_tensor_type,
const std::string& optimized_model_dir) {
// Generate optimized model
lite_api::CxxConfig cxx_config;
......@@ -128,7 +161,7 @@ std::shared_ptr<lite_api::PaddlePredictor> TestModel(
mobile_config.set_power_mode(lite_api::PowerMode::LITE_POWER_HIGH);
mobile_config.set_threads(1);
predictor = lite_api::CreatePaddlePredictor(mobile_config);
FillInputTensors(predictor, input_tensor_shape, 1);
FillInputTensors(predictor, input_tensor_shape, input_tensor_type, 1);
// Run optimized model
for (int i = 0; i < FLAGS_warmup; i++) {
predictor->Run();
......@@ -148,10 +181,13 @@ TEST(Subgraph, generate_model_and_check_precision) {
"the path of model files.";
return;
}
// Parsing the shapes of input tensors from strings, supported formats:
// Parsing the shape of input tensors from strings, supported formats:
// "1,3,224,224" and "1,3,224,224:1,80"
std::vector<std::vector<int64_t>> input_tensor_shape =
ShapeParsing(FLAGS_input_tensor_shape);
auto input_tensor_shape = ShapeParsing(FLAGS_input_tensor_shape);
// Parsing the data type of input and output tensors from strings, supported
// formats: "float32" and "float32:int64:int8"
auto input_tensor_type = TypeParsing(FLAGS_input_tensor_type);
auto output_tensor_type = TypeParsing(FLAGS_output_tensor_type);
std::vector<lite_api::Place> valid_places({
#ifdef LITE_WITH_ARM
lite_api::Place{TARGET(kARM), PRECISION(kFloat)},
......@@ -166,6 +202,7 @@ TEST(Subgraph, generate_model_and_check_precision) {
FLAGS_params_file,
valid_places,
input_tensor_shape,
input_tensor_type,
FLAGS_optimized_model_dir + "/ref_opt_model");
// Generate and run optimized model on NPU/XPU as the target predictor
#ifdef LITE_WITH_NPU
......@@ -179,10 +216,11 @@ TEST(Subgraph, generate_model_and_check_precision) {
FLAGS_params_file,
valid_places,
input_tensor_shape,
input_tensor_type,
FLAGS_optimized_model_dir + "/tar_opt_model");
// Check the difference of the output tensors between reference predictor and
// target predictor
CheckOutputTensors(tar_predictor, ref_predictor, FLAGS_output_tensor_num);
CheckOutputTensors(tar_predictor, ref_predictor, output_tensor_type);
}
} // namespace lite
......
......@@ -21,24 +21,41 @@ namespace lite {
namespace subgraph {
namespace npu {
int ActConverter(void* ctx, OpLite* op) {
int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
// Create act node and set input node which is obtained from the node map
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
auto act_node = graph->AddNode<ge::op::Activation>(out_var_name);
act_node->set_input_x(*graph->GetNode(x_var_name));
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
}
// Act node
auto act_node = graph->AddNode<ge::op::Activation>(out_name);
act_node->set_input_x(*x_node);
// TODO(hong19860320) set the coef value for act Ops, such as leaky_relu,
// clipped_relu etc.
act_node->set_attr_mode(CvtActMode(op_type));
if (op_type == "relu_clipped") {
auto Relu_clipped_coef = op_info->GetAttr<float>("Relu_clipped_coef");
act_node->set_attr_coef(Relu_clipped_coef);
......
......@@ -21,7 +21,7 @@ namespace lite {
namespace subgraph {
namespace npu {
int ArgmaxConverter(void* ctx, OpLite* op) {
int ArgmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
......@@ -30,15 +30,34 @@ int ArgmaxConverter(void* ctx, OpLite* op) {
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
int axis = op_info->GetAttr<int64_t>("axis");
auto argmax_node = graph->AddNode<ge::op::ArgMax>(out_var_name);
argmax_node->set_input_x1(*graph->GetNode(x_var_name));
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
}
auto x2 = graph->AddNode(out_var_name + "/axis", axis);
argmax_node->set_input_x2(*x2);
// Axis node
auto axis_const_node = graph->AddNode(out_name + "/axis", axis);
// Argmax node
auto argmax_node = graph->AddNode<ge::op::ArgMax>(out_name);
argmax_node->set_input_x1(*x_node);
argmax_node->set_input_x2(*axis_const_node);
return SUCCESS;
}
......
......@@ -21,7 +21,7 @@ namespace lite {
namespace subgraph {
namespace npu {
int BatchNormConverter(void* ctx, OpLite* op) {
int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
......@@ -30,32 +30,59 @@ int BatchNormConverter(void* ctx, OpLite* op) {
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
auto x_var_name = op_info->Input("X").front();
auto y_var_name = op_info->Output("Y").front();
auto batch_norm_node = graph->AddNode<ge::op::BatchNormExt2>(y_var_name);
batch_norm_node->set_input_x(*graph->GetNode(x_var_name));
auto scale_var_name = op_info->Input("Scale").front();
auto scale = scope->FindVar(scale_var_name)->GetMutable<Tensor>();
auto scale_const_node = graph->AddNode(scale_var_name, *scale);
auto bias_var_name = op_info->Input("Bias").front();
auto bias = scope->FindVar(bias_var_name)->GetMutable<Tensor>();
auto bias_const_node = graph->AddNode(bias_var_name, *bias);
auto mean_var_name = op_info->Input("Mean").front();
auto mean = scope->FindVar(mean_var_name)->GetMutable<Tensor>();
auto mean_const_node = graph->AddNode(mean_var_name, *mean);
auto variance_var_name = op_info->Input("Variance").front();
auto variance = scope->FindVar(variance_var_name)->GetMutable<Tensor>();
auto variance_const_node = graph->AddNode(variance_var_name, *variance);
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto scale_name = op_info->Input("Scale").front();
auto scale_type = kernel->GetInputDeclType("Scale");
CHECK(scale_type->precision() == PRECISION(kFloat));
CHECK(scale_type->layout() == DATALAYOUT(kNCHW));
auto scale = scope->FindMutableTensor(scale_name);
auto bias_name = op_info->Input("Bias").front();
auto bias_type = kernel->GetInputDeclType("Bias");
CHECK(bias_type->precision() == PRECISION(kFloat));
CHECK(bias_type->layout() == DATALAYOUT(kNCHW));
auto bias = scope->FindMutableTensor(bias_name);
auto mean_name = op_info->Input("Mean").front();
auto mean_type = kernel->GetInputDeclType("Mean");
CHECK(mean_type->precision() == PRECISION(kFloat));
CHECK(mean_type->layout() == DATALAYOUT(kNCHW));
auto mean = scope->FindMutableTensor(mean_name);
auto variance_name = op_info->Input("Variance").front();
auto variance_type = kernel->GetInputDeclType("Variance");
CHECK(variance_type->precision() == PRECISION(kFloat));
CHECK(variance_type->layout() == DATALAYOUT(kNCHW));
auto variance = scope->FindMutableTensor(variance_name);
auto y_name = op_info->Output("Y").front();
auto y_type = kernel->GetOutputDeclType("Y");
CHECK(y_type->precision() == PRECISION(kFloat));
CHECK(y_type->layout() == DATALAYOUT(kNCHW));
float momentum = op_info->GetAttr<float>("momentum");
float epsilon = op_info->GetAttr<float>("epsilon");
int mode = 1; // bnScale, bnBias tensor dims are 1xCx1x1
bool use_global_stats = op_info->GetAttr<bool>("use_global_stats");
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
}
// Scale, Bias, Mean, Variance node
auto scale_const_node = graph->AddNode(scale_name, *scale);
auto bias_const_node = graph->AddNode(bias_name, *bias);
auto mean_const_node = graph->AddNode(mean_name, *mean);
auto variance_const_node = graph->AddNode(variance_name, *variance);
// Batch Norm node
auto batch_norm_node = graph->AddNode<ge::op::BatchNormExt2>(y_name);
batch_norm_node->set_input_x(*x_node);
batch_norm_node->set_input_scale(*scale_const_node);
batch_norm_node->set_input_offset(*bias_const_node);
batch_norm_node->set_input_mean(*mean_const_node);
......
......@@ -21,7 +21,7 @@ namespace lite {
namespace subgraph {
namespace npu {
int ConcatConverter(void* ctx, OpLite* op) {
int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
......@@ -30,23 +30,35 @@ int ConcatConverter(void* ctx, OpLite* op) {
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " << op_type << " ... ";
auto x_var_names = op_info->Input("X");
auto out_var_name = op_info->Output("Out").front();
// Get input and output vars and op attributes
auto x_names = op_info->Input("X");
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
auto axis = op_info->GetAttr<int>("axis");
auto num = x_var_names.size();
auto concat_node = graph->AddNode<ge::op::Concat>(out_var_name);
auto num = x_names.size();
// Traverse all of input nodes which are added into the new created concat
// node
auto concat_node = graph->AddNode<ge::op::Concat>(out_name);
concat_node->set_attr_axis(axis);
concat_node->set_attr_N(num);
concat_node->create_dynamic_input_x(num);
int idx = 1;
for (auto& x_var_name : x_var_names) {
if (graph->HasNode(x_var_name)) {
concat_node->set_dynamic_input_x(idx, *graph->GetNode(x_var_name));
for (auto& x_name : x_names) {
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
auto x = scope->FindVar(x_var_name)->GetMutable<Tensor>();
auto x_const_node = graph->AddNode(x_var_name, *x);
concat_node->set_dynamic_input_x(idx, *x_const_node);
x_node = graph->AddNode(x_name, x_dims);
}
concat_node->set_dynamic_input_x(idx, *x_node);
idx++;
}
return SUCCESS;
......
......@@ -22,7 +22,7 @@ namespace lite {
namespace subgraph {
namespace npu {
int ConvConverter(void* ctx, OpLite* op) {
int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
......@@ -31,16 +31,25 @@ int ConvConverter(void* ctx, OpLite* op) {
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " << op_type << "... ";
// Get input, filter and op attributes
auto input_var_name = op_info->Input("Input").front();
auto input = scope->FindVar(input_var_name)->GetMutable<Tensor>();
// Get input and output vars and op attributes
auto input_name = op_info->Input("Input").front();
auto input_type = kernel->GetInputDeclType("Input");
CHECK(input_type->precision() == PRECISION(kFloat));
CHECK(input_type->layout() == DATALAYOUT(kNCHW));
auto input = scope->FindMutableTensor(input_name);
auto input_dims = input->dims();
auto output_var_name = op_info->Output("Output").front();
auto output = scope->FindVar(output_var_name)->GetMutable<Tensor>();
auto output_dims = output->dims();
auto filter_var_name = op_info->Input("Filter").front();
auto filter = scope->FindVar(filter_var_name)->GetMutable<Tensor>();
auto filter_name = op_info->Input("Filter").front();
auto filter_type = kernel->GetInputDeclType("Filter");
CHECK(filter_type->precision() == PRECISION(kFloat));
CHECK(filter_type->layout() == DATALAYOUT(kNCHW));
auto filter = scope->FindMutableTensor(filter_name);
auto filter_dims = filter->dims();
auto output_name = op_info->Output("Output").front();
auto output_type = kernel->GetOutputDeclType("Output");
CHECK(output_type->precision() == PRECISION(kFloat));
CHECK(output_type->layout() == DATALAYOUT(kNCHW));
auto output = scope->FindMutableTensor(output_name);
auto output_dims = output->dims();
auto bs = input_dims[0];
auto ic = input_dims[1];
auto oc = filter_dims[0];
......@@ -57,6 +66,14 @@ int ConvConverter(void* ctx, OpLite* op) {
CHECK_EQ(strides.size(), 2L);
CHECK_EQ(dilations.size(), 2L);
// Input node
std::shared_ptr<ge::Operator> input_node = nullptr;
if (graph->HasNode(input_name)) {
input_node = graph->GetNode(input_name);
} else {
input_node = graph->AddNode(input_name, input_dims);
}
if (paddings.size() == 2L) {
for (size_t i = 0; i < strides.size(); ++i) {
int copy_pad = *(paddings.begin() + 2 * i);
......@@ -91,10 +108,10 @@ int ConvConverter(void* ctx, OpLite* op) {
"performance.";
}
// Create filter node
auto filter_const_node = graph->AddNode(filter_var_name, *filter);
// Filter node
auto filter_const_node = graph->AddNode(filter_name, *filter);
// Create bias node if exists bias
// Add bias node if exists bias
// Supports the bias nodes with the following dimensions
// 0: {oc}
// 1: {1, oc, oh, ow}
......@@ -102,8 +119,11 @@ int ConvConverter(void* ctx, OpLite* op) {
std::shared_ptr<ge::Operator> bias_node = nullptr;
bool is_channel_bias = false;
if (HasInputArg(op_info, scope, "Bias")) {
auto bias_var_name = op_info->Input("Bias").front();
auto* bias = scope->FindVar(bias_var_name)->GetMutable<Tensor>();
auto bias_name = op_info->Input("Bias").front();
auto bias_type = kernel->GetInputDeclType("Bias");
CHECK(bias_type->precision() == PRECISION(kFloat));
CHECK(bias_type->layout() == DATALAYOUT(kNCHW));
auto bias = scope->FindMutableTensor(bias_name);
auto bias_dims = bias->dims();
auto bias_data_size = bias_dims.production();
auto output_data_size = output_dims.production();
......@@ -124,21 +144,21 @@ int ConvConverter(void* ctx, OpLite* op) {
<< output_dims;
return FAILED;
}
if (graph->HasNode(bias_var_name)) {
// Bias node from input map
bias_node = graph->GetNode(bias_var_name);
if (graph->HasNode(bias_name)) {
// Bias node from input node
bias_node = graph->GetNode(bias_name);
} else {
// Bias node with const data
bias_node = graph->AddNode(bias_var_name, *bias, bias_shape);
bias_node = graph->AddNode(bias_name, *bias, bias_shape);
}
}
// Create conv node and set input, filter, bias nodes and attributes
// Conv node
std::shared_ptr<ge::Operator> conv_node = nullptr;
if (use_depthwise_conv && is_depthwise_mode) {
auto depthwise_conv_node =
graph->AddNode<ge::op::ConvolutionDepthwise>(output_var_name);
depthwise_conv_node->set_input_x(*graph->GetNode(input_var_name));
graph->AddNode<ge::op::ConvolutionDepthwise>(output_name);
depthwise_conv_node->set_input_x(*input_node);
depthwise_conv_node->set_input_filter(*filter_const_node);
depthwise_conv_node->set_attr_mode(1);
depthwise_conv_node->set_attr_algo(0);
......@@ -157,15 +177,14 @@ int ConvConverter(void* ctx, OpLite* op) {
// ConvolutionDepthwise Op doesn't support bias, so append Add node to
// support bias
if (bias_node != nullptr) {
auto add_node = graph->AddNode<ge::op::Add>(output_var_name);
auto add_node = graph->AddNode<ge::op::Add>(output_name);
add_node->set_input_x1(*depthwise_conv_node);
add_node->set_input_x2(*bias_node);
conv_node = add_node;
}
} else {
auto common_conv_node =
graph->AddNode<ge::op::Convolution>(output_var_name);
common_conv_node->set_input_x(*graph->GetNode(input_var_name));
auto common_conv_node = graph->AddNode<ge::op::Convolution>(output_name);
common_conv_node->set_input_x(*input_node);
common_conv_node->set_input_w(*filter_const_node);
common_conv_node->set_attr_mode(1);
common_conv_node->set_attr_pad_mode(0); // NOTSET
......@@ -185,7 +204,7 @@ int ConvConverter(void* ctx, OpLite* op) {
if (is_channel_bias) {
common_conv_node->set_input_b(*bias_node);
} else {
auto add_node = graph->AddNode<ge::op::Add>(output_var_name);
auto add_node = graph->AddNode<ge::op::Add>(output_name);
add_node->set_input_x1(*common_conv_node);
add_node->set_input_x2(*bias_node);
conv_node = add_node;
......@@ -196,7 +215,7 @@ int ConvConverter(void* ctx, OpLite* op) {
if (fuse_relu) {
// Append relu node if fuse_relu is true
auto relu_node = graph->AddNode<ge::op::Activation>(output_var_name);
auto relu_node = graph->AddNode<ge::op::Activation>(output_name);
relu_node->set_input_x(*conv_node);
relu_node->set_attr_mode(CvtActMode("relu"));
}
......
......@@ -21,7 +21,7 @@ namespace lite {
namespace subgraph {
namespace npu {
int ConvTransposeConverter(void* ctx, OpLite* op) {
int ConvTransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
......@@ -31,15 +31,24 @@ int ConvTransposeConverter(void* ctx, OpLite* op) {
VLOG(3) << "[NPU] Converting " << op_type << "... ";
// Get input, output and op attributes
auto input_var_name = op_info->Input("Input").front();
auto input = scope->FindVar(input_var_name)->GetMutable<Tensor>();
auto input_shape = input->dims().Vectorize();
auto output_var_name = op_info->Output("Output").front();
auto filter_var_name = op_info->Input("Filter").front();
auto filter = scope->FindVar(filter_var_name)->GetMutable<Tensor>();
auto filter_shape = filter->dims().Vectorize();
CHECK_EQ(input_shape.size(), 4);
CHECK_EQ(filter_shape.size(), 4);
auto input_name = op_info->Input("Input").front();
auto input_type = kernel->GetInputDeclType("Input");
CHECK(input_type->precision() == PRECISION(kFloat));
CHECK(input_type->layout() == DATALAYOUT(kNCHW));
auto input = scope->FindMutableTensor(input_name);
auto input_dims = input->dims();
CHECK_EQ(input_dims.size(), 4);
auto filter_name = op_info->Input("Filter").front();
auto filter_type = kernel->GetInputDeclType("Filter");
CHECK(filter_type->precision() == PRECISION(kFloat));
CHECK(filter_type->layout() == DATALAYOUT(kNCHW));
auto filter = scope->FindMutableTensor(filter_name);
auto filter_dims = filter->dims();
CHECK_EQ(filter_dims.size(), 4);
auto output_name = op_info->Output("Output").front();
auto output_type = kernel->GetOutputDeclType("Output");
CHECK(output_type->precision() == PRECISION(kFloat));
CHECK(output_type->layout() == DATALAYOUT(kNCHW));
auto strides = op_info->GetAttr<std::vector<int>>("strides");
auto paddings = op_info->GetAttr<std::vector<int>>("paddings");
auto groups = op_info->GetAttr<int>("groups");
......@@ -48,6 +57,15 @@ int ConvTransposeConverter(void* ctx, OpLite* op) {
CHECK_EQ(strides.size(), 2L);
CHECK_EQ(dilations.size(), 2L);
// Input node
std::shared_ptr<ge::Operator> input_node = nullptr;
if (graph->HasNode(input_name)) {
input_node = graph->GetNode(input_name);
} else {
input_node = graph->AddNode(input_name, input_dims);
}
// Create input sizes node to describe the dimensions of input tensor
if (paddings.size() == 2L) {
for (size_t i = 0; i < 2L; ++i) {
int copy_pad = *(paddings.begin() + 2 * i);
......@@ -56,32 +74,26 @@ int ConvTransposeConverter(void* ctx, OpLite* op) {
}
CHECK_EQ(paddings.size(), 4L)
<< "[NPU] Paddings size should be the same or twice as the input size.";
// Create deconv node
auto conv_transpose_node =
graph->AddNode<ge::op::Deconvolution>(output_var_name);
// Create input sizes node to describe the dimensions of input tensor
std::vector<int32_t> input_sizes;
input_sizes.push_back(input_shape[0]);
input_sizes.push_back(filter_shape[1] * groups);
input_sizes.push_back(input_dims[0]);
input_sizes.push_back(filter_dims[1] * groups);
for (int i = 0; i < strides.size(); i++) {
int kernel_ext = dilations[i] * (filter_shape[i + 2] - 1) + 1;
int kernel_ext = dilations[i] * (filter_dims[i + 2] - 1) + 1;
int output_size =
(input_shape[i + 2] - 1) * strides[i] + kernel_ext - 2 * paddings[i];
(input_dims[i + 2] - 1) * strides[i] + kernel_ext - 2 * paddings[i];
input_sizes.push_back(output_size);
}
auto input_sizes_const_node =
graph->AddNode(output_var_name + "/input_sizes", input_sizes);
conv_transpose_node->set_input_input_sizes(*input_sizes_const_node);
// Create filter node
auto filter_const_node = graph->AddNode(filter_var_name, *filter);
conv_transpose_node->set_input_filter(*filter_const_node);
graph->AddNode(output_name + "/input_sizes", input_sizes);
// Set input node
conv_transpose_node->set_input_x(*graph->GetNode(input_var_name));
// Filter node
auto filter_const_node = graph->AddNode(filter_name, *filter);
// Deconv node
auto conv_transpose_node = graph->AddNode<ge::op::Deconvolution>(output_name);
conv_transpose_node->set_input_input_sizes(*input_sizes_const_node);
conv_transpose_node->set_input_filter(*filter_const_node);
conv_transpose_node->set_input_x(*input_node);
// Set attributes
conv_transpose_node->set_attr_format(0); // NCHW
conv_transpose_node->set_attr_pad_mode(0); // NOTSET
......@@ -93,21 +105,23 @@ int ConvTransposeConverter(void* ctx, OpLite* op) {
conv_transpose_node->set_attr_stride(
ge::AttrValue::LIST_INT({strides[0], strides[1]}));
conv_transpose_node->set_attr_kernel(
ge::AttrValue::LIST_INT({filter_shape[2], filter_shape[3]}));
ge::AttrValue::LIST_INT({filter_dims[2], filter_dims[3]}));
// Append add node to add bias if exists bias
std::shared_ptr<ge::Operator> output_node = conv_transpose_node;
if (HasInputArg(op_info, scope, "Bias")) {
// Create bias node
auto bias_var_name = op_info->Input("Bias").front();
CHECK(!graph->HasNode(bias_var_name));
auto* bias = scope->FindVar(bias_var_name)->GetMutable<Tensor>();
auto bias_name = op_info->Input("Bias").front();
auto bias_type = kernel->GetInputDeclType("Bias");
CHECK(bias_type->precision() == PRECISION(kFloat));
CHECK(bias_type->layout() == DATALAYOUT(kNCHW));
auto bias = scope->FindMutableTensor(bias_name);
auto channel_size = bias->dims().production();
CHECK_EQ(channel_size, filter_shape[1] * groups);
CHECK_EQ(channel_size, filter_dims[1] * groups);
auto bias_const_node =
graph->AddNode(bias_var_name, *bias, {1, channel_size, 1, 1});
graph->AddNode(bias_name, *bias, {1, channel_size, 1, 1});
// Append add node to add bias node
auto add_node = graph->AddNode<ge::op::Add>(output_var_name);
auto add_node = graph->AddNode<ge::op::Add>(output_name);
add_node->set_input_x1(*conv_transpose_node);
add_node->set_input_x2(*bias_const_node);
output_node = add_node;
......@@ -115,7 +129,7 @@ int ConvTransposeConverter(void* ctx, OpLite* op) {
if (fuse_relu) {
// Append relu node if fuse_relu is true
auto relu_node = graph->AddNode<ge::op::Activation>(output_var_name);
auto relu_node = graph->AddNode<ge::op::Activation>(output_name);
relu_node->set_input_x(*output_node);
relu_node->set_attr_mode(CvtActMode("relu"));
}
......
......@@ -21,10 +21,10 @@ namespace lite {
namespace subgraph {
namespace npu {
std::vector<int64_t> CvtYShape(const Tensor& x, Tensor* y, int axis) {
auto x_dims = x.dims();
std::vector<int64_t> CvtYShape(const DDim& x_dims,
const DDim& y_dims,
int axis) {
CHECK_EQ(x_dims.size(), 4UL) << "[NPU] Only support 4-dimension x";
auto y_dims = y->dims();
CHECK_GE(x_dims.size(), y_dims.size());
if (axis < 0) {
......@@ -45,7 +45,7 @@ std::vector<int64_t> CvtYShape(const Tensor& x, Tensor* y, int axis) {
return y_new_shape;
}
int ElementwiseConverter(void* ctx, OpLite* op) {
int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
......@@ -54,41 +54,62 @@ int ElementwiseConverter(void* ctx, OpLite* op) {
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
auto x_var_name = op_info->Input("X").front();
auto y_var_name = op_info->Input("Y").front();
auto out_var_name = op_info->Output("Out").front();
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto y_name = op_info->Input("Y").front();
auto y_type = kernel->GetInputDeclType("Y");
CHECK(y_type->precision() == PRECISION(kFloat));
CHECK(y_type->layout() == DATALAYOUT(kNCHW));
auto y = scope->FindMutableTensor(y_name);
auto y_dims = y->dims();
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
auto axis = op_info->GetAttr<int>("axis");
std::shared_ptr<ge::Operator> elementwise_node = nullptr;
std::shared_ptr<ge::Operator> x_node = graph->GetNode(x_var_name);
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
}
// Y node
std::shared_ptr<ge::Operator> y_node = nullptr;
if (graph->HasNode(y_var_name)) {
y_node = graph->GetNode(y_var_name);
if (graph->HasNode(y_name)) {
y_node = graph->GetNode(y_name);
} else {
auto x = scope->FindTensor(x_var_name);
auto y = scope->FindMutableTensor(y_var_name);
auto y_new_shape = CvtYShape(*x, y, axis);
y_node = graph->AddNode(y_var_name, y, y_new_shape);
auto y_new_shape = CvtYShape(x_dims, y_dims, axis);
y_node = graph->AddNode(y_name, y_new_shape);
}
// Elementwise node
std::shared_ptr<ge::Operator> elementwise_node = nullptr;
if (op_type == "elementwise_add" ||
op_type == "fusion_elementwise_add_activation") {
auto elt_node = graph->AddNode<ge::op::Add>(out_var_name);
auto elt_node = graph->AddNode<ge::op::Add>(out_name);
elt_node->set_input_x1(*x_node);
elt_node->set_input_x2(*y_node);
elementwise_node = elt_node;
} else if (op_type == "elementwise_sub") {
auto elt_node = graph->AddNode<ge::op::Sub>(out_var_name);
auto elt_node = graph->AddNode<ge::op::Sub>(out_name);
elt_node->set_input_x1(*x_node);
elt_node->set_input_x2(*y_node);
elementwise_node = elt_node;
} else if (op_type == "elementwise_mul") {
auto elt_node = graph->AddNode<ge::op::Mul>(out_var_name);
auto elt_node = graph->AddNode<ge::op::Mul>(out_name);
elt_node->set_input_x(*x_node);
elt_node->set_input_y(*y_node);
elementwise_node = elt_node;
} else if (op_type == "elementwise_div") {
auto elt_node = graph->AddNode<ge::op::RealDiv>(out_var_name);
auto elt_node = graph->AddNode<ge::op::RealDiv>(out_name);
elt_node->set_input_x1(*x_node);
elt_node->set_input_x2(*y_node);
elementwise_node = elt_node;
......@@ -97,9 +118,10 @@ int ElementwiseConverter(void* ctx, OpLite* op) {
return FAILED;
}
// Act node
if (op_type == "fusion_elementwise_add_activation") {
auto act_type = op_info->GetAttr<std::string>("act_type");
auto act_node = graph->AddNode<ge::op::Activation>(out_var_name);
auto act_node = graph->AddNode<ge::op::Activation>(out_name);
act_node->set_input_x(*elementwise_node);
// TODO(hong19860320) set the coef value for act Ops, such as leaky_relu,
// clipped_relu etc.
......
......@@ -21,7 +21,7 @@ namespace lite {
namespace subgraph {
namespace npu {
int FCConverter(void* ctx, OpLite* op) {
int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
......@@ -30,36 +30,44 @@ int FCConverter(void* ctx, OpLite* op) {
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
auto x_var_name = op_info->Input("Input").front();
auto w_var_name = op_info->Input("W").front();
auto out_var_name = op_info->Output("Out").front();
int in_num_col_dims = op_info->GetAttr<int>("in_num_col_dims");
auto x = scope->FindVar(x_var_name)->GetMutable<Tensor>();
auto w = scope->FindVar(w_var_name)->GetMutable<Tensor>();
auto x_dims = x->dims();
auto input_name = op_info->Input("Input").front();
auto input_type = kernel->GetInputDeclType("Input");
CHECK(input_type->precision() == PRECISION(kFloat));
CHECK(input_type->layout() == DATALAYOUT(kNCHW));
auto input = scope->FindMutableTensor(input_name);
auto input_dims = input->dims();
CHECK_GE(input_dims.size(), 2UL);
auto w_name = op_info->Input("W").front();
auto w_type = kernel->GetInputDeclType("W");
CHECK(w_type->precision() == PRECISION(kFloat));
CHECK(w_type->layout() == DATALAYOUT(kNCHW));
auto w = scope->FindMutableTensor(w_name);
auto w_dims = w->dims();
CHECK_GE(x_dims.size(), 2UL);
CHECK_EQ(w_dims.size(), 2UL);
int m = x_dims.Slice(0, in_num_col_dims).production();
int k = x_dims.Slice(in_num_col_dims, x_dims.size()).production();
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
int in_num_col_dims = op_info->GetAttr<int>("in_num_col_dims");
int m = input_dims.Slice(0, in_num_col_dims).production();
int k = input_dims.Slice(in_num_col_dims, input_dims.size()).production();
int n = w_dims[1];
CHECK_EQ(k * n, w_dims.production());
VLOG(3) << "[NPU] x dims: " << x_dims << " w dims: " << w_dims << " m: " << m
<< " k: " << k << " n: " << n;
VLOG(3) << "[NPU] input dims: " << input_dims << " w dims: " << w_dims
<< " m: " << m << " k: " << k << " n: " << n;
auto fc_node = graph->AddNode<ge::op::FullConnection>(out_var_name + "/fc");
CHECK(!graph->HasNode(w_var_name));
// Reshape x to (m, k, 1, 1)
auto reshaped_x_node =
graph->AddNode<ge::op::Reshape>(x_var_name + "/reshape");
reshaped_x_node->set_input_tensor(*graph->GetNode(x_var_name));
reshaped_x_node->set_attr_shape({m, k, 1, 1});
reshaped_x_node->set_attr_axis(0);
fc_node->set_input_x(*reshaped_x_node);
// Create input node and reshape it to (m, k, 1, 1)
std::shared_ptr<ge::Operator> input_node = nullptr;
if (graph->HasNode(input_name)) {
input_node = graph->GetNode(input_name);
} else {
input_node = graph->AddNode(input_name, input_dims);
}
auto reshaped_input_node =
graph->AddNode<ge::op::Reshape>(input_name + "/reshape");
reshaped_input_node->set_input_tensor(*input_node);
reshaped_input_node->set_attr_shape({m, k, 1, 1});
reshaped_input_node->set_attr_axis(0);
// Create w const node, set its shape to (n, k, 1, 1) and fill with
// the transposed w tensor
......@@ -72,23 +80,26 @@ int FCConverter(void* ctx, OpLite* op) {
transpose_w_data[j * k + i] = w_data[i * n + j];
}
}
auto w_const_node = graph->AddNode(w_var_name, transpose_w);
fc_node->set_input_w(*w_const_node);
auto trans_w_const_node = graph->AddNode(w_name, transpose_w);
// FC node
auto fc_node = graph->AddNode<ge::op::FullConnection>(out_name + "/fc");
fc_node->set_input_x(*reshaped_input_node);
fc_node->set_input_w(*trans_w_const_node);
// Add bias node if bias tensor exists
if (HasInputArg(op_info, scope, "Bias")) {
auto bias_var_name = op_info->Input("Bias").front();
auto bias = scope->FindVar(bias_var_name)->GetMutable<lite::Tensor>();
auto bias_name = op_info->Input("Bias").front();
auto bias_type = kernel->GetInputDeclType("Bias");
CHECK(bias_type->precision() == PRECISION(kFloat));
CHECK(bias_type->layout() == DATALAYOUT(kNCHW));
auto bias = scope->FindMutableTensor(bias_name);
auto bias_dims = bias->dims();
CHECK(!graph->HasNode(bias_var_name));
CHECK_EQ(bias_dims.production(), n);
auto bias_const_node = graph->AddNode(bias_var_name, *bias, {1, n, 1, 1});
auto bias_const_node = graph->AddNode(bias_name, *bias, {1, n, 1, 1});
fc_node->set_input_b(*bias_const_node);
}
// Reshape output of fc_node from (m, n, 1, 1) to (m, n)
auto reshaped_fc_node = graph->AddNode<ge::op::Reshape>(out_var_name);
// Reshape output of FC node from (m, n, 1, 1) to (m, n)
auto reshaped_fc_node = graph->AddNode<ge::op::Reshape>(out_name);
reshaped_fc_node->set_input_tensor(*fc_node);
reshaped_fc_node->set_attr_shape({m, n});
reshaped_fc_node->set_attr_axis(0);
......
......@@ -22,35 +22,25 @@ namespace subgraph {
namespace npu {
// Const node
std::shared_ptr<ge::op::Const> Graph::AddNode(const std::string& name,
const Tensor& tensor,
PrecisionType ptype,
DataLayoutType ltype) {
return AddNode(name, tensor, tensor.dims().Vectorize(), ptype, ltype);
}
std::shared_ptr<ge::op::Const> Graph::AddNode(const std::string& name,
const Tensor& tensor,
std::vector<int64_t> shape,
PrecisionType ptype,
DataLayoutType ltype) {
CHECK(!HasNode(name)) << "Node " << name << " redefined.";
auto node = AddNode<ge::op::Const>(name);
node->set_attr_value(CvtTensor(tensor, shape, ptype, ltype));
PrecisionType precision,
DataLayoutType layout) {
auto node = AddNode<ge::op::Const>(name, precision, layout);
node->set_attr_value(CvtTensor(tensor, shape, precision, layout));
return node;
}
// Data node
std::shared_ptr<ge::op::Data> Graph::AddNode(const std::string& name,
std::vector<int64_t> shape,
PrecisionType ptype,
DataLayoutType ltype) {
CHECK(!HasNode(name)) << "Node " << name << " redefined.";
PrecisionType precision,
DataLayoutType layout) {
auto node = AddNode<ge::op::Data>(name);
ge::TensorDesc desc(
ge::Shape(shape), CvtDataLayoutType(ltype), CvtPrecisionType(ptype));
ge::Shape(shape), CvtDataLayoutType(layout), CvtPrecisionType(precision));
node->update_input_desc_x(desc);
nodes_.insert(std::make_pair(name, node));
return node;
}
......
......@@ -28,11 +28,35 @@ namespace lite {
namespace subgraph {
namespace npu {
// Type and registers of converters for converting Paddle Ops to HiAI IR graph
// Type of graph nodes
class Type {
public:
Type(PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW),
bool persistable = false)
: precision_(precision), layout_(layout), persistable_(persistable) {}
void set_precision(PrecisionType precision) { precision_ = precision; }
void set_layout(DataLayoutType layout) { layout_ = layout; }
bool set_persistable(bool persistable) { persistable_ = persistable; }
PrecisionType precision() const { return precision_; }
DataLayoutType layout() const { return layout_; }
bool persistable() const { return persistable_; }
private:
PrecisionType precision_{PRECISION(kFloat)};
DataLayoutType layout_{DATALAYOUT(kNCHW)};
bool persistable_{false};
};
// Graph to collect all of converted HiAI IR nodes
class Graph {
public:
template <typename T>
std::shared_ptr<T> AddNode(const std::string& name) {
std::shared_ptr<T> AddNode(const std::string& name,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) {
auto unique_name = [&](const std::string& key) {
int idx = 1;
auto it = counts_.find(key);
......@@ -43,8 +67,12 @@ class Graph {
}
return key + "_" + std::to_string(idx);
};
bool persistable = typeid(T) == typeid(ge::op::Const);
auto it = nodes_.find(name);
if (it != nodes_.end()) {
// Only variable can rebind the name
CHECK(!it->second.second.persistable() && !persistable)
<< "[NPU] Node " << name << " redefined.";
// Generate a new unique name as the key to bind the origin node:
// new_name->node
nodes_.insert(std::make_pair(unique_name(name + "_var"), it->second));
......@@ -52,7 +80,8 @@ class Graph {
}
// Create a new node and bind with the name: name->new_node
auto node = std::make_shared<T>(unique_name(name + "_op"));
nodes_.insert(std::make_pair(name, node));
nodes_.insert(std::make_pair(
name, std::make_pair(node, Type(precision, layout, persistable))));
return node;
}
......@@ -60,30 +89,41 @@ class Graph {
std::shared_ptr<ge::op::Const> AddNode(
const std::string& name,
const Tensor& tensor,
PrecisionType ptype = PRECISION(kFloat),
DataLayoutType ltype = DATALAYOUT(kNCHW));
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return AddNode(name, tensor, tensor.dims().Vectorize(), precision, layout);
}
std::shared_ptr<ge::op::Const> AddNode(
const std::string& name,
const Tensor& tensor,
std::vector<int64_t> shape,
PrecisionType ptype = PRECISION(kFloat),
DataLayoutType ltype = DATALAYOUT(kNCHW));
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW));
std::shared_ptr<ge::op::Const> AddNode(
const std::string& name,
const Tensor& tensor,
DDim dims,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return AddNode(name, tensor, dims.Vectorize(), precision, layout);
}
template <typename T>
std::shared_ptr<ge::op::Const> AddNode(
const std::string& name,
const std::vector<T>& data,
std::vector<int64_t> shape = {},
DataLayoutType ltype = DATALAYOUT(kNCHW)) {
DataLayoutType layout = DATALAYOUT(kNCHW)) {
const std::type_info& info = typeid(T);
PrecisionType ptype = PRECISION(kFloat);
PrecisionType precision = PRECISION(kFloat);
if (info == typeid(float)) {
ptype = PRECISION(kFloat);
precision = PRECISION(kFloat);
} else if (info == typeid(int8_t)) {
ptype = PRECISION(kFloat);
precision = PRECISION(kFloat);
} else if (info == typeid(int32_t)) {
ptype = PRECISION(kInt32);
precision = PRECISION(kInt32);
} else {
LOG(FATAL) << "[NPU] Unknow data type " << info.name();
}
......@@ -101,7 +141,16 @@ class Graph {
std::memcpy(reinterpret_cast<uint8_t*>(tensor.mutable_data<T>()),
reinterpret_cast<const uint8_t*>(data.data()),
data.size() * sizeof(T));
return AddNode(name, tensor, ptype, ltype);
return AddNode(name, tensor, precision, layout);
}
template <typename T>
std::shared_ptr<ge::op::Const> AddNode(
const std::string& name,
const std::vector<T>& data,
DDim dims,
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return AddNode(name, data, dims.Vectorize(), layout);
}
template <typename T>
......@@ -109,25 +158,47 @@ class Graph {
const std::string& name,
T value,
std::vector<int64_t> shape = {1},
DataLayoutType ltype = DATALAYOUT(kNCHW)) {
DataLayoutType layout = DATALAYOUT(kNCHW)) {
int64_t size = 1;
for (auto i : shape) {
size *= i;
}
std::vector<T> data(size, value);
return AddNode(name, data, shape, ltype);
return AddNode(name, data, shape, layout);
}
template <typename T>
std::shared_ptr<ge::op::Const> AddNode(
const std::string& name,
T value,
DDim dims,
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return AddNode(name, value, dims.Vectorize(), layout);
}
// Data node
std::shared_ptr<ge::op::Data> AddNode(
const std::string& name,
std::vector<int64_t> shape,
PrecisionType ptype = PRECISION(kFloat),
DataLayoutType ltype = DATALAYOUT(kNCHW));
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW));
std::shared_ptr<ge::op::Data> AddNode(
const std::string& name,
DDim dims,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return AddNode(name, dims.Vectorize(), precision, layout);
}
std::shared_ptr<ge::Operator> GetNode(std::string name) {
CHECK(HasNode(name)) << "[NPU] Node " << name << " not found.";
return nodes_.at(name);
return nodes_.at(name).first;
}
const Type& GetType(const std::string& name) {
CHECK(HasNode(name)) << "[NPU] Node " << name << " not found.";
return nodes_.at(name).second;
}
bool HasNode(const std::string& name) {
......@@ -135,7 +206,9 @@ class Graph {
}
private:
std::unordered_map<std::string, std::shared_ptr<ge::Operator>> nodes_;
std::unordered_map<std::string,
std::pair<std::shared_ptr<ge::Operator>, Type>>
nodes_;
std::unordered_map<std::string, int> counts_;
};
......
......@@ -21,7 +21,7 @@ namespace lite {
namespace subgraph {
namespace npu {
int InterpolateConverter(void* ctx, OpLite* op) {
int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
......@@ -30,14 +30,20 @@ int InterpolateConverter(void* ctx, OpLite* op) {
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
// Get input, output and attributes from lite op
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<Tensor>();
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto x_h = x_dims[2];
auto x_w = x_dims[3];
CHECK_EQ(x_dims.size(), 4);
auto out_var_name = op_info->Output("Out").front();
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
auto scale = op_info->GetAttr<float>("scale");
auto out_w = op_info->GetAttr<int>("out_w");
auto out_h = op_info->GetAttr<int>("out_h");
......@@ -48,6 +54,14 @@ int InterpolateConverter(void* ctx, OpLite* op) {
"align_corners = false isn't "
"supported in HiAI DDK";
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
}
// Priority: OutSize > scale > out_h/out_w
if (scale > 0) {
out_h = static_cast<int>(x_h * scale);
......@@ -56,14 +70,17 @@ int InterpolateConverter(void* ctx, OpLite* op) {
out_w = out_w > 0 ? out_w : -1;
}
// Update out_h and out_w if has OutSize
// Update out_h and out_w and create out_size node if has OutSize
std::shared_ptr<ge::Operator> out_size_node = nullptr;
if (HasInputArg(op_info, scope, "OutSize")) {
auto out_size_var_name = op_info->Input("OutSize").front();
if (graph->HasNode(out_size_var_name)) {
out_size_node = graph->GetNode(out_size_var_name);
auto out_size_name = op_info->Input("OutSize").front();
auto out_size_type = kernel->GetInputDeclType("OutSize");
CHECK(out_size_type->precision() == PRECISION(kInt32));
CHECK(out_size_type->layout() == DATALAYOUT(kNCHW));
if (graph->HasNode(out_size_name)) {
out_size_node = graph->GetNode(out_size_name);
} else {
auto out_size = scope->FindVar(out_size_var_name)->GetMutable<Tensor>();
auto out_size = scope->FindMutableTensor(out_size_name);
CHECK_EQ(out_size->numel(), 2);
auto out_size_data = out_size->mutable_data<int>();
// Update out_h and out_w if has OutSize
......@@ -80,20 +97,20 @@ int InterpolateConverter(void* ctx, OpLite* op) {
<< " is too large, should not exceed " << largest_multiple
<< " in HiAI DDK";
}
out_size_node = graph->AddNode(out_var_name + "/out_size",
out_size_node = graph->AddNode(out_name + "/out_size",
std::vector<int>({out_h, out_w}));
}
if (interp_method == "bilinear") {
auto bilinear_interp_node =
graph->AddNode<ge::op::ResizeBilinear>(out_var_name);
bilinear_interp_node->set_input_x(*graph->GetNode(x_var_name));
graph->AddNode<ge::op::ResizeBilinear>(out_name);
bilinear_interp_node->set_input_x(*x_node);
bilinear_interp_node->set_input_size(*out_size_node);
bilinear_interp_node->set_attr_align_corners(align_corners);
} else if (interp_method == "nearest") {
auto nearest_interp_node =
graph->AddNode<ge::op::ResizeNearestNeighbor>(out_var_name);
nearest_interp_node->set_input_image(*graph->GetNode(x_var_name));
graph->AddNode<ge::op::ResizeNearestNeighbor>(out_name);
nearest_interp_node->set_input_image(*x_node);
nearest_interp_node->set_input_size(*out_size_node);
nearest_interp_node->set_attr_align_corners(align_corners);
} else {
......
......@@ -22,7 +22,7 @@ namespace subgraph {
namespace npu {
// Note: all of the input weight vars should be handled in this converter
int MulConverter(void* ctx, OpLite* op) {
int MulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
......@@ -31,13 +31,23 @@ int MulConverter(void* ctx, OpLite* op) {
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
auto x_var_name = op_info->Input("X").front();
auto y_var_name = op_info->Input("Y").front();
auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
auto y = scope->FindVar(y_var_name)->GetMutable<lite::Tensor>();
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto y_name = op_info->Input("Y").front();
auto y_type = kernel->GetInputDeclType("Y");
CHECK(y_type->precision() == PRECISION(kFloat));
CHECK(y_type->layout() == DATALAYOUT(kNCHW));
auto y = scope->FindMutableTensor(y_name);
auto y_dims = y->dims();
auto out_var_name = op_info->Output("Out").front();
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
int x_num_col_dims = op_info->GetAttr<int>("x_num_col_dims");
int y_num_col_dims = op_info->GetAttr<int>("y_num_col_dims");
int m = x_dims.Slice(0, x_num_col_dims).production();
......@@ -46,40 +56,45 @@ int MulConverter(void* ctx, OpLite* op) {
<< "[NPU] columns of X must be equal with rows of Y";
int n = y_dims.Slice(y_num_col_dims, y_dims.size()).production();
VLOG(3) << "m:" << m << ",n:" << n << ",k:" << k;
VLOG(3) << "x_var_name:" << x_var_name
<< ", is data: " << graph->HasNode(x_var_name);
VLOG(3) << "y_var_name:" << y_var_name
<< ", is data: " << graph->HasNode(y_var_name);
CHECK(graph->HasNode(x_var_name))
VLOG(3) << "x_name:" << x_name << ", is data: " << graph->HasNode(x_name);
VLOG(3) << "y_name:" << y_name << ", is data: " << graph->HasNode(y_name);
CHECK(graph->HasNode(x_name))
<< "[NPU] MatMul in HiAI DDK only support X is data, Y is const yet.";
auto mul_node = graph->AddNode<ge::op::MatMul>(out_var_name);
// Add input x node which supports persistable and non-persistable tensor, and
// X node which supports persistable and non-persistable tensor, and
// reshape to (m, k)
if (graph->HasNode(x_var_name)) {
auto reshaped_x_node =
graph->AddNode<ge::op::Reshape>(x_var_name + "/reshape");
reshaped_x_node->set_input_tensor(*graph->GetNode(x_var_name));
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
auto reshaped_x_node = graph->AddNode<ge::op::Reshape>(x_name + "/reshape");
reshaped_x_node->set_input_tensor(*x_node);
reshaped_x_node->set_attr_shape({m, k});
reshaped_x_node->set_attr_axis(0);
mul_node->set_input_x1(*reshaped_x_node);
x_node = reshaped_x_node;
} else {
auto x_const_node = graph->AddNode(x_var_name, *x, {m, k});
mul_node->set_input_x1(*x_const_node);
auto x_const_node = graph->AddNode(x_name, *x, {m, k});
x_node = x_const_node;
}
// Add input y node which only supports persistable tensor, and reshape to
// Y node which only supports persistable tensor, and reshape to
// (k,n)
if (graph->HasNode(y_var_name)) {
auto reshaped_y_node =
graph->AddNode<ge::op::Reshape>(y_var_name + "/reshape");
reshaped_y_node->set_input_tensor(*graph->GetNode(y_var_name));
std::shared_ptr<ge::Operator> y_node = nullptr;
if (graph->HasNode(y_name)) {
y_node = graph->GetNode(y_name);
auto reshaped_y_node = graph->AddNode<ge::op::Reshape>(y_name + "/reshape");
reshaped_y_node->set_input_tensor(*y_node);
reshaped_y_node->set_attr_shape({k, n});
reshaped_y_node->set_attr_axis(0);
mul_node->set_input_x2(*reshaped_y_node);
y_node = reshaped_y_node;
} else {
auto y_const_node = graph->AddNode(y_var_name, *y, {k, n});
mul_node->set_input_x2(*y_const_node);
auto y_const_node = graph->AddNode(y_name, *y, {k, n});
y_node = y_const_node;
}
// Matmul node
auto mul_node = graph->AddNode<ge::op::MatMul>(out_name);
mul_node->set_input_x1(*x_node);
mul_node->set_input_x2(*y_node);
return REBUILD_WHEN_SHAPE_CHANGED;
}
......
......@@ -21,7 +21,7 @@ namespace lite {
namespace subgraph {
namespace npu {
int Pad2dConverter(void* ctx, OpLite* op) {
int Pad2dConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
......@@ -30,38 +30,54 @@ int Pad2dConverter(void* ctx, OpLite* op) {
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
auto pad2d_node = graph->AddNode<ge::op::Pad>(out_var_name);
pad2d_node->set_input_x(*graph->GetNode(x_var_name));
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("Input");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
auto padding = op_info->GetAttr<std::vector<int>>("paddings");
CHECK_EQ(padding.size(), 4);
auto mode = op_info->GetAttr<std::string>("mode");
if (mode == "constant") {
pad2d_node->set_attr_mode(0);
} else if (mode == "reflect") {
LOG(WARNING) << "[NPU] pad mode " << mode << " isn't supported in HiAI DDK";
pad2d_node->set_attr_mode(1);
return FAILED;
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
LOG(WARNING) << "[NPU] pad mode " << mode << " isn't supported in HiAI DDK";
return FAILED;
x_node = graph->AddNode(x_name, x_dims);
}
auto x_dims = scope->FindTensor(x_var_name)->dims();
auto padding = op_info->GetAttr<std::vector<int>>("paddings");
CHECK_EQ(padding.size(), 4);
// Padding node
int xds = x_dims.size();
padding.insert(padding.begin(), xds * 2 - 4, 0);
auto padding_const_node =
graph->AddNode(out_var_name + "/padding", padding, {xds, 2});
pad2d_node->set_input_padding(*padding_const_node);
graph->AddNode(out_name + "/padding", padding, {xds, 2});
// Pad node
auto pad2d_node = graph->AddNode<ge::op::Pad>(out_name);
pad2d_node->set_input_x(*x_node);
pad2d_node->set_input_padding(*padding_const_node);
auto mode = op_info->GetAttr<std::string>("mode");
if (mode == "constant") {
// Pad value node
auto pad_value = op_info->GetAttr<float>("pad_value");
auto pad_value_const_node =
graph->AddNode(out_var_name + "/pad_value", pad_value);
graph->AddNode(out_name + "/pad_value", pad_value);
pad2d_node->set_input_constant_values(*pad_value_const_node);
pad2d_node->set_attr_T(0); // type of pad_value: 0:float 3:int32
pad2d_node->set_attr_mode(0);
} else if (mode == "reflect") {
LOG(WARNING) << "[NPU] pad mode " << mode << " isn't supported in HiAI DDK";
pad2d_node->set_attr_mode(1);
return FAILED;
} else {
LOG(WARNING) << "[NPU] pad mode " << mode << " isn't supported in HiAI DDK";
return FAILED;
}
return REBUILD_WHEN_SHAPE_CHANGED;
}
......
......@@ -22,7 +22,7 @@ namespace lite {
namespace subgraph {
namespace npu {
int PoolConverter(void* ctx, OpLite* op) {
int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
......@@ -31,14 +31,32 @@ int PoolConverter(void* ctx, OpLite* op) {
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindTensor(x_var_name);
auto out_var_name = op_info->Output("Out").front();
auto pool_node = graph->AddNode<ge::op::Pooling>(out_var_name);
pool_node->set_input_x(*graph->GetNode(x_var_name));
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
auto pooling_type = op_info->GetAttr<std::string>("pooling_type");
auto global_pooling = op_info->GetAttr<bool>("global_pooling");
auto ksize = op_info->GetAttr<std::vector<int>>("ksize");
auto paddings = op_info->GetAttr<std::vector<int>>("paddings");
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
}
// pool mode
int mode = 0;
auto pooling_type = op_info->GetAttr<std::string>("pooling_type");
if (pooling_type == "max") {
mode = 0;
} else if (pooling_type == "avg") {
......@@ -49,8 +67,8 @@ int PoolConverter(void* ctx, OpLite* op) {
LOG(WARNING) << "[NPU] Unsupported pooling type: " << pooling_type;
return FAILED;
}
pool_node->set_attr_mode(mode);
// pad mode
int pad_mode = 0;
std::string padding_algorithm("");
if (op_info->HasAttr("padding_algorithm")) {
......@@ -61,16 +79,8 @@ int PoolConverter(void* ctx, OpLite* op) {
} else if (padding_algorithm == "VALID") {
pad_mode = 5;
}
pool_node->set_attr_pad_mode(pad_mode);
bool global_pooling = op_info->GetAttr<bool>("global_pooling");
pool_node->set_attr_global_pooling(global_pooling);
auto ksize = op_info->GetAttr<std::vector<int>>("ksize");
pool_node->set_attr_window(
ge::AttrValue::LIST_INT(ksize.begin(), ksize.end()));
auto paddings = op_info->GetAttr<std::vector<int>>("paddings");
// paddings and strides
if (paddings.size() == 2L) {
for (size_t i = 0; i < 2L; ++i) {
int copy_pad = *(paddings.begin() + 2 * i);
......@@ -91,15 +101,25 @@ int PoolConverter(void* ctx, OpLite* op) {
x->dims(),
strides,
ksize);
pool_node->set_attr_pad(ge::AttrValue::LIST_INT{
paddings[0], paddings[1], paddings[2], paddings[3]});
pool_node->set_attr_stride(
ge::AttrValue::LIST_INT(strides.begin(), strides.end()));
// ceil mode
int ceil_mode = 0;
if (op_info->HasAttr("ceil_mode")) {
ceil_mode = op_info->GetAttr<bool>("ceil_mode") ? 1 : 0;
}
// Pooling node
auto pool_node = graph->AddNode<ge::op::Pooling>(out_name);
pool_node->set_input_x(*x_node);
pool_node->set_attr_mode(mode);
pool_node->set_attr_pad_mode(pad_mode);
pool_node->set_attr_global_pooling(global_pooling);
pool_node->set_attr_window(
ge::AttrValue::LIST_INT(ksize.begin(), ksize.end()));
pool_node->set_attr_pad(ge::AttrValue::LIST_INT{
paddings[0], paddings[1], paddings[2], paddings[3]});
pool_node->set_attr_stride(
ge::AttrValue::LIST_INT(strides.begin(), strides.end()));
pool_node->set_attr_ceil_mode(ceil_mode);
// pool_node->set_attr_data_mode(data_mode);
return REBUILD_WHEN_SHAPE_CHANGED;
......
......@@ -21,7 +21,7 @@ namespace lite {
namespace subgraph {
namespace npu {
int ReduceMeanConverter(void* ctx, OpLite* op) {
int ReduceMeanConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
......@@ -30,10 +30,17 @@ int ReduceMeanConverter(void* ctx, OpLite* op) {
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
// Get input and op attributes
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Input("Out").front();
auto x_dims = scope->FindTensor(x_var_name)->dims();
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Input("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
auto keep_dim = op_info->GetAttr<bool>("keep_dim");
auto dim = op_info->GetAttr<std::vector<int>>("dim");
CHECK(!dim.empty()) << "[NPU] \"dim\" of reduce_mean should not be empty.";
......@@ -44,21 +51,36 @@ int ReduceMeanConverter(void* ctx, OpLite* op) {
}
std::sort(dim.begin(), dim.end());
// Create reduce_mean(using reduce_sum + scale) node and set input node from
// node map
auto reduce_sum_node =
graph->AddNode<ge::op::ReduceSum>(out_var_name + "/reducesum");
reduce_sum_node->set_input_x(*graph->GetNode(x_var_name));
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
}
// Using ReduceSum + Scale to implement ReduceMean
auto dim_const_node = graph->AddNode(out_var_name + "/dim", dim);
// Dim node
auto dim_const_node = graph->AddNode(out_name + "/dim", dim);
// Reduce Sum node
auto reduce_sum_node =
graph->AddNode<ge::op::ReduceSum>(out_name + "/reducesum");
reduce_sum_node->set_input_x(*x_node);
reduce_sum_node->set_input_w(*dim_const_node);
reduce_sum_node->set_attr_keep_dims(keep_dim);
// Scale node
auto scale_node = graph->AddNode<ge::op::Scale>(out_name);
scale_node->set_input_x(*reduce_sum_node);
scale_node->set_attr_axis(1);
// Add filter node(fill with scale)
float scale = 1;
for (size_t i = 0; i < dim.size(); i++) {
scale /= x_dims[dim[i]];
}
std::vector<int64_t> scale_bias_shape = x_dims.Vectorize();
if (keep_dim) {
for (size_t i = 0; i < dim.size(); i++) {
......@@ -73,13 +95,9 @@ int ReduceMeanConverter(void* ctx, OpLite* op) {
remove(scale_bias_shape.begin(), scale_bias_shape.end(), kDelFlag),
scale_bias_shape.end());
}
auto filter_const_node =
graph->AddNode(out_var_name + "/filter", scale, scale_bias_shape);
auto scale_node = graph->AddNode<ge::op::Scale>(out_var_name);
scale_node->set_input_x(*reduce_sum_node);
graph->AddNode(out_name + "/filter", scale, scale_bias_shape);
scale_node->set_input_filter(*filter_const_node);
scale_node->set_attr_axis(1);
return REBUILD_WHEN_SHAPE_CHANGED;
}
......
......@@ -33,7 +33,8 @@ inline bool CHECK_REBUILD_WHEN_SHAPE_CHANGED(int status) {
return status & REBUILD_WHEN_SHAPE_CHANGED;
}
using cvt_func_type = std::function<int(void* ctx, OpLite* op)>;
using cvt_func_type =
std::function<int(void* ctx, OpLite* op, KernelBase* kernel)>;
using cvt_map_type =
std::unordered_map<std::string,
std::unordered_map<std::string, cvt_func_type>>;
......
......@@ -22,7 +22,7 @@ namespace lite {
namespace subgraph {
namespace npu {
int ReshapeConverter(void* ctx, OpLite* op) {
int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
......@@ -31,25 +31,44 @@ int ReshapeConverter(void* ctx, OpLite* op) {
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
// Get input, output and op attributes
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
auto x = scope->FindVar(x_var_name)->GetMutable<Tensor>();
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
// Create reshape node and set input node from inputs_map
auto reshape_node = graph->AddNode<ge::op::Reshape>(out_var_name);
reshape_node->set_input_tensor(*graph->GetNode(x_var_name));
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
}
// Reshape node
auto reshape_node = graph->AddNode<ge::op::Reshape>(out_name);
reshape_node->set_input_tensor(*x_node);
// Read shape from "ShapeTensor"(input), or "Shape"(input), or "shape"(attr)
if (HasInputArg(op_info, scope, "ShapeTensor")) {
LOG(WARNING) << "[NPU] not support \"Shape\" from more than one Tensor.";
return FAILED;
} else if (HasInputArg(op_info, scope, "Shape")) {
auto actual_shape_var_name = op_info->Input("Shape").front();
if (!graph->HasNode(actual_shape_var_name)) {
auto actual_shape =
scope->FindVar(actual_shape_var_name)->GetMutable<Tensor>();
auto actual_shape_name = op_info->Input("Shape").front();
// auto actual_shape_type = kernel->GetInputDeclType("Shape");
// CHECK(actual_shape_type->precision() == PRECISION(kInt32));
// CHECK(actual_shape_type->layout() == DATALAYOUT(kNCHW));
std::shared_ptr<ge::Operator> actual_shape_node = nullptr;
if (graph->HasNode(actual_shape_name)) {
actual_shape_node = graph->GetNode(actual_shape_name);
} else {
auto actual_shape = scope->FindMutableTensor(actual_shape_name);
auto actual_shape_dims = actual_shape->dims();
auto actual_shape_data = actual_shape->mutable_data<int>();
auto shape =
......@@ -63,12 +82,11 @@ int ReshapeConverter(void* ctx, OpLite* op) {
<< out_shape.size();
}
auto actual_shape_const_node =
graph->AddNode(actual_shape_var_name,
graph->AddNode(actual_shape_name,
std::vector<int>(out_shape.begin(), out_shape.end()));
reshape_node->set_input_w(*actual_shape_const_node);
} else {
reshape_node->set_input_w(*graph->GetNode(actual_shape_var_name));
actual_shape_node = actual_shape_const_node;
}
reshape_node->set_input_w(*actual_shape_node);
} else {
auto shape = op_info->GetAttr<std::vector<int>>("shape");
auto out_dims = lite::operators::ValidateShape(shape, x_dims);
......@@ -82,6 +100,7 @@ int ReshapeConverter(void* ctx, OpLite* op) {
ge::AttrValue::LIST_INT(out_shape.begin(), out_shape.end()));
}
// XShape node
if (op_type == "reshape2") {
// Append an extra reshape node to calc XShape
std::vector<int64_t> xshape_dims(x_dims.size() + 1, 1);
......@@ -92,10 +111,14 @@ int ReshapeConverter(void* ctx, OpLite* op) {
LOG(WARNING) << "[NPU] HiAI DDK only supports less than 4 dimensions, "
"but XShape has "
<< xshape_dims.size();
return FAILED;
}
auto xshape_var_name = op_info->Output("XShape").front();
auto xshape_node = graph->AddNode<ge::op::Reshape>(xshape_var_name);
xshape_node->set_input_tensor(*graph->GetNode(x_var_name));
auto xshape_name = op_info->Output("XShape").front();
// auto xshape_type = kernel->GetOutputDeclType("XShape");
// CHECK(xshape_type->precision() == PRECISION(kFloat));
// CHECK(xshape_type->layout() == DATALAYOUT(kNCHW));
auto xshape_node = graph->AddNode<ge::op::Reshape>(xshape_name);
xshape_node->set_input_tensor(*x_node);
xshape_node->set_attr_shape(
ge::AttrValue::LIST_INT(xshape_dims.begin(), xshape_dims.end()));
}
......
......@@ -21,7 +21,7 @@ namespace lite {
namespace subgraph {
namespace npu {
int ScaleConverter(void* ctx, OpLite* op) {
int ScaleConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
......@@ -31,11 +31,17 @@ int ScaleConverter(void* ctx, OpLite* op) {
VLOG(3) << "[NPU] Converting " + op_type + "...";
// Get input, output and op attributes
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
auto x_dims = x->dims().Vectorize();
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
CHECK_GE(x_dims.size(), 2);
auto out_var_name = op_info->Output("Out").front();
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
std::vector<int64_t> scale_bias_shape = {x_dims[1]};
float scale = op_info->GetAttr<float>("scale");
float bias = op_info->GetAttr<float>("bias");
......@@ -44,23 +50,31 @@ int ScaleConverter(void* ctx, OpLite* op) {
bias *= scale;
}
// Create scale node and set input node from inputs_map
auto scale_node = graph->AddNode<ge::op::Scale>(out_var_name);
scale_node->set_input_x(*graph->GetNode(x_var_name));
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
}
// Scale node
auto scale_node = graph->AddNode<ge::op::Scale>(out_name);
scale_node->set_input_x(*x_node);
scale_node->set_attr_axis(1);
// Add filter node(fill with scale)
auto filter_const_node =
graph->AddNode(out_var_name + "/filter", scale, scale_bias_shape);
graph->AddNode(out_name + "/filter", scale, scale_bias_shape);
scale_node->set_input_filter(*filter_const_node);
// Add bias node(fill with bias)
if (fabs(bias) > 1e-6f) {
auto bias_const_node =
graph->AddNode(out_var_name + "/bias", bias, scale_bias_shape);
graph->AddNode(out_name + "/bias", bias, scale_bias_shape);
scale_node->set_input_bias(*bias_const_node);
scale_node->set_attr_has_bias_value(true);
}
scale_node->set_attr_axis(1);
return REBUILD_WHEN_SHAPE_CHANGED;
}
......
......@@ -21,7 +21,7 @@ namespace lite {
namespace subgraph {
namespace npu {
int ShuffleChannelConverter(void* ctx, OpLite* op) {
int ShuffleChannelConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
......@@ -30,13 +30,31 @@ int ShuffleChannelConverter(void* ctx, OpLite* op) {
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
auto shuffle_channel_node =
graph->AddNode<ge::op::ShuffleChannel>(out_var_name);
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
auto group = op_info->GetAttr<int>("group");
shuffle_channel_node->set_input_x(*graph->GetNode(x_var_name));
shuffle_channel_node->set_attr_group(op_info->GetAttr<int>("group"));
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
}
// Shuffle Channel node
auto shuffle_channel_node = graph->AddNode<ge::op::ShuffleChannel>(out_name);
shuffle_channel_node->set_input_x(*x_node);
shuffle_channel_node->set_attr_group(group);
return SUCCESS;
}
......
......@@ -21,7 +21,7 @@ namespace lite {
namespace subgraph {
namespace npu {
int SoftmaxConverter(void* ctx, OpLite* op) {
int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
......@@ -30,9 +30,17 @@ int SoftmaxConverter(void* ctx, OpLite* op) {
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
auto x_dims = scope->FindVar(x_var_name)->GetMutable<Tensor>()->dims();
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
auto axis = op_info->GetAttr<int>("axis");
if (x_dims.size() > 3) {
CHECK(!(axis == 2 && x_dims[3] > 1))
......@@ -40,8 +48,17 @@ int SoftmaxConverter(void* ctx, OpLite* op) {
<< " :x_w = " << x_dims[3];
}
auto softmax_node = graph->AddNode<ge::op::Softmax>(out_var_name);
softmax_node->set_input_x(*graph->GetNode(x_var_name));
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
}
// Softmax node
auto softmax_node = graph->AddNode<ge::op::Softmax>(out_name);
softmax_node->set_input_x(*x_node);
softmax_node->set_attr_axis(axis);
return REBUILD_WHEN_SHAPE_CHANGED;
}
......
......@@ -21,7 +21,7 @@ namespace lite {
namespace subgraph {
namespace npu {
int SplitConverter(void* ctx, OpLite* op) {
int SplitConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
......@@ -30,15 +30,33 @@ int SplitConverter(void* ctx, OpLite* op) {
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " << op_type << " ... ";
auto x_var_name = op_info->Input("X").front();
auto out_var_names = op_info->Output("Out");
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_names = op_info->Output("Out");
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
auto axis = op_info->GetAttr<int>("axis");
auto num = op_info->GetAttr<int>("num");
auto sections = op_info->GetAttr<std::vector<int>>("sections");
int64_t sections_num = static_cast<int64_t>(sections.size());
auto split_node = graph->AddNode<ge::op::Split>(op_type + "/" + x_var_name);
split_node->set_input_x(*graph->GetNode(x_var_name));
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
}
// Split node
auto split_node = graph->AddNode<ge::op::Split>(op_type + "/" + x_name);
split_node->set_input_x(*x_node);
split_node->set_attr_axis(static_cast<int64_t>(axis));
if (num > 0) {
split_node->set_attr_output_num(static_cast<int64_t>(num));
......@@ -48,12 +66,12 @@ int SplitConverter(void* ctx, OpLite* op) {
split_node->set_attr_size_split(size_split);
}
split_node->create_dynamic_output_y(out_var_names.size());
split_node->create_dynamic_output_y(out_names.size());
int idx = 1;
for (auto& out_var_name : out_var_names) {
for (auto& out_name : out_names) {
auto zero_const_node =
graph->AddNode(out_var_name + "/zero" + std::to_string(idx), 0);
auto add_node = graph->AddNode<ge::op::Add>(out_var_name);
graph->AddNode(out_name + "/zero" + std::to_string(idx), 0);
auto add_node = graph->AddNode<ge::op::Add>(out_name);
add_node->set_input_x1(*split_node, "y" + std::to_string(idx));
add_node->set_input_x2(*zero_const_node);
idx++;
......
......@@ -21,18 +21,38 @@ namespace lite {
namespace subgraph {
namespace npu {
int SqrtConverter(void* ctx, OpLite* op) {
int SqrtConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
auto sqrt_node = graph->AddNode<ge::op::Sqrt>(out_var_name);
sqrt_node->set_input_x(*graph->GetNode(x_var_name));
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
}
// Sqrt node
auto sqrt_node = graph->AddNode<ge::op::Sqrt>(out_name);
sqrt_node->set_input_x(*x_node);
return SUCCESS;
}
......
......@@ -21,18 +21,38 @@ namespace lite {
namespace subgraph {
namespace npu {
int SquareConverter(void* ctx, OpLite* op) {
int SquareConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
auto square_node = graph->AddNode<ge::op::Square>(out_var_name);
square_node->set_input_x(*graph->GetNode(x_var_name));
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
}
// Square node
auto square_node = graph->AddNode<ge::op::Square>(out_name);
square_node->set_input_x(*x_node);
return SUCCESS;
}
......
......@@ -21,7 +21,7 @@ namespace lite {
namespace subgraph {
namespace npu {
int TransposeConverter(void* ctx, OpLite* op) {
int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
......@@ -30,13 +30,28 @@ int TransposeConverter(void* ctx, OpLite* op) {
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Input("Out").front();
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Input("Out").front();
auto axis = op_info->GetAttr<std::vector<int>>("axis");
auto transpose_node = graph->AddNode<ge::op::Permute>(out_var_name);
transpose_node->set_input_x(*graph->GetNode(x_var_name));
auto w_const_node = graph->AddNode(out_var_name + "/w", 1.0f);
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
}
// Transpose node
auto transpose_node = graph->AddNode<ge::op::Permute>(out_name);
transpose_node->set_input_x(*x_node);
auto w_const_node = graph->AddNode(out_name + "/w", 1.0f);
transpose_node->set_input_w(*w_const_node);
transpose_node->set_attr_order(
ge::AttrValue::LIST_INT(axis.begin(), axis.end()));
......
......@@ -21,7 +21,7 @@ namespace lite {
namespace subgraph {
namespace npu {
int UnsqueezeConverter(void* ctx, OpLite* op) {
int UnsqueezeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
......@@ -30,14 +30,31 @@ int UnsqueezeConverter(void* ctx, OpLite* op) {
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " << op_type << "... ";
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
auto out_shape = scope->FindTensor(out_var_name)->dims().Vectorize();
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
auto out_shape = scope->FindTensor(out_name)->dims().Vectorize();
CHECK(op_info->HasAttr("axes"))
<< "[NPU] unsqueeze not support axes from tensor now";
auto unsqueeze_node = graph->AddNode<ge::op::Reshape>(out_var_name);
unsqueeze_node->set_input_tensor(*graph->GetNode(x_var_name));
// X node
std::shared_ptr<ge::Operator> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
}
// Unsqueeze node
auto unsqueeze_node = graph->AddNode<ge::op::Reshape>(out_name);
unsqueeze_node->set_input_tensor(*x_node);
unsqueeze_node->set_attr_shape(
ge::AttrValue::LIST_INT(out_shape.begin(), out_shape.end()));
return REBUILD_WHEN_SHAPE_CHANGED;
......
......@@ -44,12 +44,21 @@ ge::DataType CvtPrecisionType(PrecisionType itype) {
case PRECISION(kFloat):
otype = ge::DT_FLOAT;
break;
case PRECISION(kFP16):
otype = ge::DT_FLOAT16;
break;
case PRECISION(kInt8):
otype = ge::DT_INT8;
break;
case PRECISION(kInt16):
otype = ge::DT_INT16;
break;
case PRECISION(kInt32):
otype = ge::DT_INT32;
break;
case PRECISION(kInt64):
otype = ge::DT_INT64;
break;
default:
LOG(FATAL) << "[NPU] Can not convert precision type("
<< PrecisionToStr(itype) << ") from Lite to NPU";
......@@ -64,6 +73,9 @@ ge::Format CvtDataLayoutType(DataLayoutType itype) {
case DATALAYOUT(kNCHW):
otype = ge::FORMAT_NCHW;
break;
case DATALAYOUT(kNHWC):
otype = ge::FORMAT_NHWC;
break;
// TODO(hong19860320) support more data layout type
default:
LOG(FATAL) << "[NPU] Can not convert data layout type("
......@@ -75,39 +87,22 @@ ge::Format CvtDataLayoutType(DataLayoutType itype) {
ge::TensorPtr CvtTensor(const Tensor& in_tensor,
std::vector<int64_t> out_shape,
PrecisionType in_ptype,
DataLayoutType in_ltype) {
const uint8_t* in_data = nullptr;
PrecisionType in_precision,
DataLayoutType in_layout) {
auto in_size = in_tensor.dims().production();
auto in_shape = in_tensor.dims().Vectorize();
if (out_shape.empty()) {
out_shape = in_shape;
}
int in_bytes;
if (in_ptype == PRECISION(kFloat)) {
in_data = reinterpret_cast<const uint8_t*>(in_tensor.data<float>());
in_bytes = in_size * sizeof(float);
} else if (in_ptype == PRECISION(kInt32)) {
in_data = reinterpret_cast<const uint8_t*>(in_tensor.data<int32_t>());
in_bytes = in_size * sizeof(int32_t);
} else if (in_ptype == PRECISION(kInt8)) {
in_data = reinterpret_cast<const uint8_t*>(in_tensor.data<int8_t>());
in_bytes = in_size * sizeof(int8_t);
} else {
LOG(FATAL) << "[NPU] Unknow precision type " << PrecisionToStr(in_ptype);
}
ge::DataType out_ptype = CvtPrecisionType(in_ptype);
ge::Format out_ltype = CvtDataLayoutType(in_ltype);
ge::TensorDesc out_desc(ge::Shape(out_shape), out_ltype, out_ptype);
CHECK_EQ(out_ltype, ge::FORMAT_NCHW);
ge::TensorDesc out_desc(ge::Shape(out_shape),
CvtDataLayoutType(in_layout),
CvtPrecisionType(in_precision));
auto out_size = out_desc.GetShape().GetShapeSize();
CHECK_EQ(out_size, in_size);
ge::TensorPtr out_tensor = std::make_shared<ge::Tensor>();
out_tensor->SetTensorDesc(out_desc);
out_tensor->SetData(in_data, in_bytes);
out_tensor->SetData(reinterpret_cast<const uint8_t*>(in_tensor.raw_data()),
in_tensor.memory_size());
return out_tensor;
}
......
......@@ -72,8 +72,8 @@ ge::Format CvtDataLayoutType(DataLayoutType itype);
ge::TensorPtr CvtTensor(const Tensor& in_tensor,
std::vector<int64_t> out_shape = {},
PrecisionType in_ptype = PRECISION(kFloat),
DataLayoutType in_ltype = DATALAYOUT(kNCHW));
PrecisionType in_precision = PRECISION(kFloat),
DataLayoutType in_layout = DATALAYOUT(kNCHW));
template <typename T>
ge::TensorPtr CreateTensorAndFillData(const std::vector<T>& data,
......@@ -85,8 +85,12 @@ ge::TensorPtr CreateTensorAndFillData(const std::vector<T>& data,
type = ge::DT_FLOAT;
} else if (info == typeid(int8_t)) {
type = ge::DT_INT8;
} else if (info == typeid(int16_t)) {
type = ge::DT_INT16;
} else if (info == typeid(int32_t)) {
type = ge::DT_INT32;
} else if (info == typeid(int64_t)) {
type = ge::DT_INT64;
} else {
LOG(FATAL) << "[NPU] Unknow value type " << info.name();
}
......
......@@ -29,19 +29,9 @@ namespace npu {
int SubgraphEngine::BuildDeviceProgram() {
int status = 0;
// Convert all of input data vars and added into the HiAI IR graph
// Convert all of ops and their input vars and weights and added into the NPU
// HiAI IR graph
subgraph::npu::Graph graph;
for (auto& input_name : input_names_) {
auto input_tensor = scope_->FindMutableTensor(input_name);
CHECK(input_tensor);
auto input_node =
graph.AddNode(input_name, input_tensor->dims().Vectorize());
CHECK(input_node);
// HiAI DDK doesn't support dynamic dimensions/shapes, so need to rebuild
// the program when the shape of any input tensor is changed.
status |= subgraph::REBUILD_WHEN_SHAPE_CHANGED;
}
// Convert all of ops and its weights and added into the HiAI IR graph
const auto& bridges = subgraph::Registry::Instance();
for (auto& inst : origin_program_) {
auto op = inst.op();
......@@ -52,29 +42,56 @@ int SubgraphEngine::BuildDeviceProgram() {
if (!bridges.Exists("NPU", op_type)) {
return subgraph::FAILED;
}
auto kernel = inst.kernel();
status |= bridges.Select("NPU", op_type)(reinterpret_cast<void*>(&graph),
const_cast<OpLite*>(op));
const_cast<OpLite*>(op),
const_cast<KernelBase*>(kernel));
if (subgraph::CHECK_FAILED(status)) {
return subgraph::FAILED;
}
}
// Set the input and output nodes of the HiAI IR graph
std::vector<ge::Operator> input_nodes, output_nodes;
// Collect the valid input and output nodes in the HiAI IR graph and update
// the input and output names
device_inames_.clear();
device_onames_.clear();
std::vector<ge::Operator> device_inodes;
std::vector<ge::Operator> device_onodes;
for (auto& input_name : input_names_) {
input_nodes.push_back(*graph.GetNode(input_name));
if (graph.HasNode(input_name)) {
if (!graph.GetType(input_name).persistable()) {
device_inodes.push_back(*graph.GetNode(input_name));
device_inames_.push_back(input_name);
} else {
LOG(WARNING) << "[NPU] Input node " << input_name
<< " is skipped because it is a persistable node.";
}
} else {
LOG(WARNING) << "[NPU] Input node " << input_name
<< " is skipped because it does not exist.";
}
}
for (auto& output_name : output_names_) {
output_nodes.push_back(*graph.GetNode(output_name));
if (graph.HasNode(output_name)) {
device_onodes.push_back(*graph.GetNode(output_name));
device_onames_.push_back(output_name);
} else {
LOG(WARNING) << "[NPU] Output node " << output_name
<< " is skipped because it does not exist.";
}
}
// Build the HiAI IR graph to HiAI om model
device_program_ =
lite::npu::Device::Global().Build(model_name_, input_nodes, output_nodes);
CHECK(!device_inames_.empty())
<< "[NPU] No input nodes found for building NPU model";
CHECK(!device_onames_.empty())
<< "[NPU] No output nodes found for building NPU model";
// Build the HiAI IR graph to HiAI om model as the device program
device_program_ = lite::npu::Device::Global().Build(
model_name_, device_inodes, device_onodes);
if (device_program_ == nullptr) {
LOG(WARNING) << "[NPU] Build model failed!";
return subgraph::FAILED;
}
// Query and check the dimensions of input and output tensors
// Query and check the dimensions of valid input and output tensors
std::vector<hiai::TensorDimension> device_idims, device_odims;
if (device_program_->GetModelIOTensorDim(
model_name_, device_idims, device_odims) != hiai::AI_SUCCESS) {
......@@ -82,44 +99,75 @@ int SubgraphEngine::BuildDeviceProgram() {
<< "[NPU] Get the dimensions of input and output tensors failed!";
return subgraph::FAILED;
}
CHECK_EQ(device_idims.size(), input_names_.size());
CHECK_EQ(device_odims.size(), output_names_.size());
origin_idims_.resize(input_names_.size());
origin_itensors_.resize(input_names_.size());
device_idatasizes_.resize(input_names_.size());
device_itensors_.resize(input_names_.size());
origin_odims_.resize(output_names_.size());
origin_otensors_.resize(output_names_.size());
device_odatasizes_.resize(output_names_.size());
device_otensors_.resize(output_names_.size());
for (int i = 0; i < input_names_.size(); i++) {
origin_itensors_[i] = scope_->FindMutableTensor(input_names_[i]);
CHECK_EQ(device_idims.size(), device_inames_.size());
CHECK_EQ(device_odims.size(), device_onames_.size());
origin_idims_.resize(device_inames_.size());
origin_itensors_.resize(device_inames_.size());
device_itensors_.resize(device_inames_.size());
origin_odims_.resize(device_onames_.size());
origin_otensors_.resize(device_onames_.size());
device_otensors_.resize(device_onames_.size());
for (int i = 0; i < device_inames_.size(); i++) {
auto type = graph.GetType(device_inames_[i]);
auto precision = type.precision();
auto layout = type.layout();
origin_itensors_[i] = scope_->FindMutableTensor(device_inames_[i]);
CHECK(origin_itensors_[i]);
origin_idims_[i] = origin_itensors_[i]->dims();
VLOG(3) << "[NPU] Input dims[" << i << "]: {" << device_idims[i].GetNumber()
<< "," << device_idims[i].GetChannel() << ","
VLOG(3) << "[NPU] Inputs[" << i
<< "] precision: " << PrecisionToStr(precision)
<< " layout: " << DataLayoutToStr(layout) << " dims: {"
<< device_idims[i].GetNumber() << ","
<< device_idims[i].GetChannel() << ","
<< device_idims[i].GetHeight() << "," << device_idims[i].GetWidth()
<< "}";
device_idatasizes_[i] =
// Prepare the device input tensors
CHECK_EQ(origin_idims_[i].production(),
device_idims[i].GetNumber() * device_idims[i].GetChannel() *
device_idims[i].GetHeight() * device_idims[i].GetWidth();
CHECK_EQ(device_idatasizes_[i], origin_idims_[i].production());
device_idims[i].GetHeight() * device_idims[i].GetWidth());
device_itensors_[i].reset(new hiai::AiTensor);
device_itensors_[i]->Init(&(device_idims[i]));
}
for (int i = 0; i < output_names_.size(); i++) {
origin_otensors_[i] = scope_->FindMutableTensor(output_names_[i]);
for (int i = 0; i < device_onames_.size(); i++) {
auto type = graph.GetType(device_onames_[i]);
auto precision = type.precision();
auto layout = type.layout();
origin_otensors_[i] = scope_->FindMutableTensor(device_onames_[i]);
CHECK(origin_otensors_[i]);
origin_odims_[i] = origin_otensors_[i]->dims();
VLOG(3) << "[NPU] Output dims[" << i << "]: {"
VLOG(3) << "[NPU] Outputs[" << i
<< "] precision: " << PrecisionToStr(precision)
<< " layout: " << DataLayoutToStr(layout) << " dims: {"
<< device_odims[i].GetNumber() << ","
<< device_odims[i].GetChannel() << ","
<< device_odims[i].GetHeight() << "," << device_odims[i].GetWidth()
<< "}";
device_odatasizes_[i] =
// Prepare the device output tensors
switch (precision) {
case PRECISION(kFloat):
origin_otensors_[i]->mutable_data<float>();
break;
case PRECISION(kInt8):
origin_otensors_[i]->mutable_data<int8_t>();
break;
case PRECISION(kInt16):
origin_otensors_[i]->mutable_data<int16_t>();
break;
case PRECISION(kInt32):
origin_otensors_[i]->mutable_data<int32_t>();
break;
case PRECISION(kInt64):
origin_otensors_[i]->mutable_data<int64_t>();
break;
default:
LOG(FATAL) << "[NPU] " << device_onames_[i]
<< " can't mutable data with precision type "
<< PrecisionToStr(precision);
break;
}
CHECK_EQ(origin_odims_[i].production(),
device_odims[i].GetNumber() * device_odims[i].GetChannel() *
device_odims[i].GetHeight() * device_odims[i].GetWidth();
CHECK_EQ(device_odatasizes_[i], origin_odims_[i].production());
device_odims[i].GetHeight() * device_odims[i].GetWidth());
device_otensors_[i].reset(new hiai::AiTensor);
device_otensors_[i]->Init(&(device_odims[i]));
}
......@@ -128,10 +176,10 @@ int SubgraphEngine::BuildDeviceProgram() {
int SubgraphEngine::LaunchDeviceProgram() {
// Copy the data of origin input tensors to the buffer of input HiAI tensors
for (size_t i = 0; i < input_names_.size(); i++) {
std::memcpy(static_cast<float*>(device_itensors_[i]->GetBuffer()),
origin_itensors_[i]->mutable_data<float>(),
sizeof(float) * static_cast<size_t>(device_idatasizes_[i]));
for (size_t i = 0; i < device_itensors_.size(); i++) {
std::memcpy(device_itensors_[i]->GetBuffer(),
origin_itensors_[i]->raw_data(),
origin_itensors_[i]->memory_size());
}
// Run the HiAI model by name
std::string key = "model_name"; // Note: key seems must be model_name
......@@ -149,10 +197,10 @@ int SubgraphEngine::LaunchDeviceProgram() {
hiai::AI_SUCCESS);
VLOG(3) << "[NPU] Process cost " << GetCurrentUS() - start_time << " us";
// Copy the data of output HiAI tensor to the buffer of origin output tensors
for (size_t i = 0; i < output_names_.size(); i++) {
std::memcpy(origin_otensors_[i]->mutable_data<float>(),
static_cast<float*>(device_otensors_[i]->GetBuffer()),
sizeof(float) * static_cast<size_t>(device_odatasizes_[i]));
for (size_t i = 0; i < device_otensors_.size(); i++) {
std::memcpy(const_cast<void*>(origin_otensors_[i]->raw_data()),
device_otensors_[i]->GetBuffer(),
device_otensors_[i]->GetSize());
}
return 0;
}
......
......@@ -43,8 +43,8 @@ class SubgraphEngine : public subgraph::Engine {
std::string model_name_;
hiai::AiContext model_context_;
std::vector<int64_t> device_idatasizes_;
std::vector<int64_t> device_odatasizes_;
std::vector<std::string> device_inames_;
std::vector<std::string> device_onames_;
std::vector<std::shared_ptr<hiai::AiTensor>> device_itensors_;
std::vector<std::shared_ptr<hiai::AiTensor>> device_otensors_;
std::unique_ptr<hiai::AiModelMngerClient> device_program_{nullptr};
......
......@@ -24,7 +24,7 @@
//,
REGISTER_LITE_KERNEL(lookup_table,
kX86,
kInt64,
kFloat,
kNCHW,
paddle::lite::kernels::x86::LookupTableCompute<float>,
def)
......@@ -34,7 +34,7 @@ REGISTER_LITE_KERNEL(lookup_table,
.Finalize();
REGISTER_LITE_KERNEL(lookup_table_v2,
kX86,
kInt64,
kFloat,
kNCHW,
paddle::lite::kernels::x86::LookupTableCompute<float>,
def)
......
......@@ -24,7 +24,7 @@ namespace kernels {
namespace x86 {
template <typename T>
class LookupTableCompute : public KernelLite<TARGET(kX86), PRECISION(kInt64)> {
class LookupTableCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::LookupTableParam;
......
......@@ -79,4 +79,4 @@ TEST(lookup_table_x86, compute) {
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(lookup_table, kX86, kInt64, kNCHW, def);
USE_LITE_KERNEL(lookup_table, kX86, kFloat, kNCHW, def);
......@@ -21,5 +21,5 @@ REGISTER_LITE_KERNEL(stack,
paddle::lite::kernels::x86::StackCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Y", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
......@@ -14,6 +14,11 @@ lite_cc_library(subgraph_bridge_pool_op_xpu SRCS pool_op.cc DEPS ${subgraph_brid
lite_cc_library(subgraph_bridge_softmax_op_xpu SRCS softmax_op.cc DEPS ${subgraph_bridge_deps_xpu})
lite_cc_library(subgraph_bridge_mul_op_xpu SRCS mul_op.cc DEPS ${xpu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_batch_norm_op_xpu SRCS batch_norm_op.cc DEPS ${xpu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_stack_op_xpu SRCS stack_op.cc DEPS ${xpu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_gather_op_xpu SRCS gather_op.cc DEPS ${xpu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_scale_op_xpu SRCS scale_op.cc DEPS ${xpu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_lookup_table_op_xpu SRCS lookup_table_op.cc DEPS ${xpu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_slice_op_xpu SRCS slice_op.cc DEPS ${xpu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_transpose_op_xpu SRCS transpose_op.cc DEPS ${xpu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_reshape_op_xpu SRCS reshape_op.cc DEPS ${xpu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_layer_norm_op_xpu SRCS layer_norm_op.cc DEPS ${xpu_subgraph_bridge_deps})
......@@ -30,6 +35,11 @@ set(xpu_subgraph_bridges
subgraph_bridge_softmax_op_xpu
subgraph_bridge_mul_op_xpu
subgraph_bridge_batch_norm_op_xpu
subgraph_bridge_stack_op_xpu
subgraph_bridge_gather_op_xpu
subgraph_bridge_scale_op_xpu
subgraph_bridge_lookup_table_op_xpu
subgraph_bridge_slice_op_xpu
subgraph_bridge_transpose_op_xpu
subgraph_bridge_reshape_op_xpu
subgraph_bridge_layer_norm_op_xpu
......
......@@ -21,21 +21,42 @@ namespace lite {
namespace subgraph {
namespace xpu {
int ActConverter(void* ctx, OpLite* op) {
int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[XPU] Converting " + op_type + "...";
// Create act node and set params from op
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
CHECK(graph->HasNode(x_var_name));
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
// X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
}
// Act node
if (op_type == "relu") {
graph->AddNode(out_var_name,
graph->builder_.CreateRelu(*graph->GetNode(x_var_name)));
graph->AddNode(out_name, graph->builder_.CreateRelu(*x_node));
} else if (op_type == "tanh") {
graph->AddNode(out_name, graph->builder_.CreateUnaryOp("tanh", *x_node));
} else if (op_type == "gelu") {
graph->AddNode(out_name, graph->builder_.CreateGelu(*x_node));
} else {
// TODO(hong19860320) supports more activation ops
LOG(WARNING) << "[XPU] Unsupported activation type " << op_type;
......@@ -50,3 +71,5 @@ int ActConverter(void* ctx, OpLite* op) {
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU, relu, paddle::lite::subgraph::xpu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(XPU, tanh, paddle::lite::subgraph::xpu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(XPU, gelu, paddle::lite::subgraph::xpu::ActConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <random>
#include "lite/core/op_registry.h"
#include "lite/kernels/xpu/bridges/registry.h"
#include "lite/kernels/xpu/bridges/test_helper.h"
#include "lite/operators/activation_ops.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
namespace bridges {
void relu_ref(const std::shared_ptr<operators::ActivationOp> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x = scope->FindVar(op_info->Input("X").front())->GetMutable<Tensor>();
auto out =
scope->FindVar(op_info->Output("Out").front())->GetMutable<Tensor>();
auto x_data = x->data<float>();
auto out_data = out->mutable_data<float>();
DDim x_dims = x->dims();
DDim out_dims = out->dims();
CHECK_EQ(x_dims.production(), out_dims.production());
for (int i = 0; i < out_dims.production(); i++) {
out_data[i] = std::max(0.f, x_data[i]);
}
}
void test_relu(int bs, int ic, int ih, int iw) {
// prepare input&output variables
Scope scope;
std::string x_var_name("x");
std::string out_var_name("out");
std::string out_ref_var_name("out_ref");
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
x->Resize({bs, ic, ih, iw});
// initialize input&output data
FillTensor<float, int>(x);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("relu");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
// create and convert op to XPU model, and run it on XPU
auto op = CreateOp<operators::ActivationOp>(opdesc, &scope);
LauchOp(op, {x_var_name}, {out_var_name});
out_ref->CopyDataFrom(*out);
// execute reference implementation and save to output tensor
relu_ref(op);
// compare results
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
VLOG(5) << i;
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5);
}
}
TEST(NPUBridges, relu) {
for (auto bs : {1, 3}) {
for (auto ic : {3, 4}) {
for (auto ih : {2, 5}) {
for (auto iw : {5, 9}) {
VLOG(3) << "bs: " << bs << " ic: " << ic << " ih: " << ih
<< " iw: " << iw;
test_relu(bs, ic, ih, iw);
}
}
}
}
}
} // namespace bridges
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_OP(relu);
USE_XPU_BRIDGE(relu);
......@@ -21,7 +21,7 @@ namespace lite {
namespace subgraph {
namespace xpu {
int BatchNormConverter(void* ctx, OpLite* op) {
int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
......@@ -30,35 +30,62 @@ int BatchNormConverter(void* ctx, OpLite* op) {
auto scope = op->scope();
VLOG(3) << "[XPU] Converting " + op_type + "...";
// Get input vars and op attributes
auto x_var_name = op_info->Input("X").front();
auto scale_var_name = op_info->Input("Scale").front();
auto* scale = scope->FindMutableTensor(scale_var_name);
auto bias_var_name = op_info->Input("Bias").front();
auto* bias = scope->FindMutableTensor(bias_var_name);
auto mean_var_name = op_info->Input("Mean").front();
auto* mean = scope->FindMutableTensor(mean_var_name);
auto variance_var_name = op_info->Input("Variance").front();
auto* variance = scope->FindMutableTensor(variance_var_name);
auto y_var_name = op_info->Output("Y").front();
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto scale_name = op_info->Input("Scale").front();
auto scale_type = kernel->GetInputDeclType("Scale");
CHECK(scale_type->precision() == PRECISION(kFloat));
CHECK(scale_type->layout() == DATALAYOUT(kNCHW));
auto scale = scope->FindMutableTensor(scale_name);
auto bias_name = op_info->Input("Bias").front();
auto bias_type = kernel->GetInputDeclType("Bias");
CHECK(bias_type->precision() == PRECISION(kFloat));
CHECK(bias_type->layout() == DATALAYOUT(kNCHW));
auto bias = scope->FindMutableTensor(bias_name);
auto mean_name = op_info->Input("Mean").front();
auto mean_type = kernel->GetInputDeclType("Mean");
CHECK(mean_type->precision() == PRECISION(kFloat));
CHECK(mean_type->layout() == DATALAYOUT(kNCHW));
auto mean = scope->FindMutableTensor(mean_name);
auto variance_name = op_info->Input("Variance").front();
auto variance_type = kernel->GetInputDeclType("Variance");
CHECK(variance_type->precision() == PRECISION(kFloat));
CHECK(variance_type->layout() == DATALAYOUT(kNCHW));
auto variance = scope->FindMutableTensor(variance_name);
auto y_name = op_info->Output("Y").front();
auto y_type = kernel->GetOutputDeclType("Y");
CHECK(y_type->precision() == PRECISION(kFloat));
CHECK(y_type->layout() == DATALAYOUT(kNCHW));
auto epsilon = op_info->GetAttr<float>("epsilon");
// Create scale, bias, mean, variance nodes
auto scale_const_node = graph->AddNode(scale_var_name, *scale);
auto bias_const_node = graph->AddNode(bias_var_name, *bias);
auto mean_const_node = graph->AddNode(mean_var_name, *mean);
auto variance_const_node = graph->AddNode(variance_var_name, *variance);
// X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
}
// Create batch_norm node and set params from op
auto batch_norm_node =
graph->builder_.CreateBatchNorm(*graph->GetNode(x_var_name),
// Scale, Bias, Mean, Variance node
auto scale_const_node = graph->AddNode(scale_name, *scale);
auto bias_const_node = graph->AddNode(bias_name, *bias);
auto mean_const_node = graph->AddNode(mean_name, *mean);
auto variance_const_node = graph->AddNode(variance_name, *variance);
// Batch Norm node and extract the first field as the output node
auto batch_norm_node = graph->builder_.CreateBatchNorm(*x_node,
*scale_const_node,
*bias_const_node,
*mean_const_node,
*variance_const_node,
1,
epsilon);
graph->AddNode(y_var_name, graph->builder_.GetField(batch_norm_node, 0));
graph->AddNode(y_name, graph->builder_.GetField(batch_norm_node, 0));
return SUCCESS;
}
......
......@@ -22,7 +22,7 @@ namespace lite {
namespace subgraph {
namespace xpu {
int ConvConverter(void* ctx, OpLite* op) {
int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
......@@ -31,14 +31,23 @@ int ConvConverter(void* ctx, OpLite* op) {
auto scope = op->scope();
VLOG(3) << "[XPU] Converting " << op_type << "... ";
// Get input, filter and op attributes
auto input_var_name = op_info->Input("Input").front();
auto input = scope->FindVar(input_var_name)->GetMutable<Tensor>();
// Get input and output vars and op attributes
auto input_name = op_info->Input("Input").front();
auto input_type = kernel->GetInputDeclType("Input");
CHECK(input_type->precision() == PRECISION(kFloat));
CHECK(input_type->layout() == DATALAYOUT(kNCHW));
auto input = scope->FindMutableTensor(input_name);
auto input_dims = input->dims();
auto filter_var_name = op_info->Input("Filter").front();
auto filter = scope->FindVar(filter_var_name)->GetMutable<Tensor>();
auto filter_name = op_info->Input("Filter").front();
auto filter_type = kernel->GetInputDeclType("Filter");
CHECK(filter_type->precision() == PRECISION(kFloat));
CHECK(filter_type->layout() == DATALAYOUT(kNCHW));
auto filter = scope->FindMutableTensor(filter_name);
auto filter_dims = filter->dims();
auto output_var_name = op_info->Output("Output").front();
auto output_name = op_info->Output("Output").front();
auto output_type = kernel->GetOutputDeclType("Output");
CHECK(output_type->precision() == PRECISION(kFloat));
CHECK(output_type->layout() == DATALAYOUT(kNCHW));
auto bs = input_dims[0];
auto oc = filter_dims[0];
CHECK_EQ(input_dims.size(), 4);
......@@ -51,6 +60,14 @@ int ConvConverter(void* ctx, OpLite* op) {
CHECK_EQ(strides.size(), 2L);
CHECK_EQ(dilations.size(), 2L);
// Input node
std::shared_ptr<xtcl::xExpr> input_node = nullptr;
if (graph->HasNode(input_name)) {
input_node = graph->GetNode(input_name);
} else {
input_node = graph->AddNode(input_name, input_dims);
}
if (paddings.size() == 2L) {
for (size_t i = 0; i < strides.size(); ++i) {
int copy_pad = *(paddings.begin() + 2 * i);
......@@ -81,14 +98,14 @@ int ConvConverter(void* ctx, OpLite* op) {
}
DDim output_dims(output_shape);
// Create filter node
auto filter_const_node = graph->AddNode(filter_var_name, *filter);
// Filter node
auto filter_const_node = graph->AddNode(filter_name, *filter);
// Create conv node and set input, filter, bias nodes and attributes
// Conv node
auto conv_attrs = xtcl::make_node<xtcl::network::Conv2DAttrs>();
conv_attrs->strides = std::move(CvtShape(strides));
conv_attrs->padding = std::move(CvtShape(paddings));
conv_attrs->dilation = std::move(CvtShape(dilations));
conv_attrs->strides = std::move(CvtShape<xtcl::xIndexExpr>(strides));
conv_attrs->padding = std::move(CvtShape<xtcl::xIndexExpr>(paddings));
conv_attrs->dilation = std::move(CvtShape<xtcl::xIndexExpr>(dilations));
conv_attrs->groups = groups;
// conv_attrs->channels = nullptr;
conv_attrs->kernel_size = std::move(xtcl::Array<xtcl::xIndexExpr>(nullptr));
......@@ -96,19 +113,22 @@ int ConvConverter(void* ctx, OpLite* op) {
conv_attrs->kernel_layout = "OIHW";
conv_attrs->out_layout = "";
// conv_attrs->out_dtype = "";
auto conv_node = graph->AddNode(
output_var_name,
auto conv_node =
graph->AddNode(output_name,
graph->builder_.CreateConv2D(
*graph->GetNode(input_var_name), *filter_const_node, conv_attrs));
*input_node, *filter_const_node, conv_attrs));
// Create bias node if exists bias
// Add bias node if exists bias
// supports the bias nodes with the following dimensions
// 0: {oc}
// 1: {1, oc, oh, ow}
// 2: {n, oc, oh, ow}
if (HasInputArg(op_info, scope, "Bias")) {
auto bias_var_name = op_info->Input("Bias").front();
auto* bias = scope->FindVar(bias_var_name)->GetMutable<Tensor>();
auto bias_name = op_info->Input("Bias").front();
auto bias_type = kernel->GetInputDeclType("Bias");
CHECK(bias_type->precision() == PRECISION(kFloat));
CHECK(bias_type->layout() == DATALAYOUT(kNCHW));
auto bias = scope->FindMutableTensor(bias_name);
auto bias_dims = bias->dims();
auto bias_data_size = bias_dims.production();
auto output_data_size = output_dims.production();
......@@ -130,21 +150,21 @@ int ConvConverter(void* ctx, OpLite* op) {
<< output_dims;
}
std::shared_ptr<xtcl::xExpr> bias_node = nullptr;
if (graph->HasNode(bias_var_name)) {
if (graph->HasNode(bias_name)) {
// Bias node from input node
bias_node = graph->GetNode(bias_var_name);
bias_node = graph->GetNode(bias_name);
} else {
// Bias node with const tensor
bias_node = graph->AddNode(bias_var_name, *bias, bias_shape);
// Bias node with const data
bias_node = graph->AddNode(bias_name, *bias, bias_shape);
}
std::shared_ptr<xtcl::xExpr> add_node = nullptr;
if (is_channel_bias) {
add_node = graph->AddNode(
output_var_name,
output_name,
graph->builder_.CreateBiasAdd(*conv_node, 1, *bias_node));
} else {
add_node = graph->AddNode(
output_var_name,
output_name,
graph->builder_.CreateBinaryOp("add", *conv_node, *bias_node));
}
conv_node = add_node;
......@@ -152,7 +172,7 @@ int ConvConverter(void* ctx, OpLite* op) {
if (fuse_relu) {
// Append relu node if fuse_relu is true
graph->AddNode(output_var_name, graph->builder_.CreateRelu(*conv_node));
graph->AddNode(output_name, graph->builder_.CreateRelu(*conv_node));
}
return REBUILD_WHEN_SHAPE_CHANGED;
}
......
......@@ -21,7 +21,7 @@ namespace lite {
namespace subgraph {
namespace xpu {
int ElementwiseConverter(void* ctx, OpLite* op) {
int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(op != nullptr);
CHECK(ctx != nullptr);
auto graph = static_cast<Graph*>(ctx);
......@@ -30,39 +30,49 @@ int ElementwiseConverter(void* ctx, OpLite* op) {
auto scope = op->scope();
VLOG(3) << "[XPU] Converting " + op_type + "...";
// Get input, and attributes
auto x_var_name = op_info->Input("X").front();
auto y_var_name = op_info->Input("Y").front();
auto out_var_name = op_info->Output("Out").front();
auto axis = op_info->GetAttr<int>("axis");
auto x = scope->FindMutableTensor(x_var_name);
auto y = scope->FindMutableTensor(y_var_name);
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto y_name = op_info->Input("Y").front();
auto y_type = kernel->GetInputDeclType("Y");
CHECK(y_type->precision() == PRECISION(kFloat));
CHECK(y_type->layout() == DATALAYOUT(kNCHW));
auto y = scope->FindMutableTensor(y_name);
auto y_dims = y->dims();
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
auto axis = op_info->GetAttr<int>("axis");
// Create x and y node
// X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr;
if (graph->HasNode(x_var_name)) {
x_node = graph->GetNode(x_var_name);
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_var_name, *x);
x_node = graph->AddNode(x_name, x_dims);
}
// Y node
std::shared_ptr<xtcl::xExpr> y_node = nullptr;
if (graph->HasNode(y_var_name)) {
y_node = graph->GetNode(y_var_name);
if (graph->HasNode(y_name)) {
y_node = graph->GetNode(y_name);
} else {
y_node = graph->AddNode(y_var_name, *y);
y_node = graph->AddNode(y_name, y_dims);
}
// Create elementwise node and set input, attributes
// Elementwise node
std::shared_ptr<xtcl::xExpr> elementwise_node = nullptr;
if (y_dims.size() == 1) {
elementwise_node = graph->AddNode(
out_var_name, graph->builder_.CreateBiasAdd(*x_node, axis, *y_node));
out_name, graph->builder_.CreateBiasAdd(*x_node, axis, *y_node));
} else if (x_dims.size() == y_dims.size()) {
elementwise_node = graph->AddNode(
out_var_name, graph->builder_.CreateBinaryOp("add", *x_node, *y_node));
out_name, graph->builder_.CreateBinaryOp("add", *x_node, *y_node));
} else {
LOG(WARNING)
<< "[XPU] elementwise_add only support y of one dimension, or x "
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/xpu/bridges/graph.h"
#include "lite/kernels/xpu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace xpu {
int GatherConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[XPU] Converting " + op_type + "...";
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto index_name = op_info->Input("Index").front();
auto index_type = kernel->GetInputDeclType("Index");
CHECK(index_type->precision() == PRECISION(kInt32) ||
index_type->precision() == PRECISION(kInt64));
CHECK(index_type->layout() == DATALAYOUT(kNCHW));
auto index = scope->FindMutableTensor(index_name);
auto index_dims = index->dims();
CHECK(index_dims.size() == 1 ||
(index_dims.size() == 2 && index_dims[1] == 1));
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
auto out = scope->FindMutableTensor(out_name);
auto out_dims = out->dims();
// X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
}
// Index node
std::shared_ptr<xtcl::xExpr> index_node = nullptr;
if (graph->HasNode(index_name)) {
index_node = graph->GetNode(index_name);
} else {
index_node = graph->AddNode(
index_name, index_dims, index_type->precision(), index_type->layout());
}
// Flatten index node
if (index_dims.size() != 1) {
index_node =
graph->AddNode(index_name + "/reshape",
graph->builder_.CreateReshape(*index_node, {-1}),
index_type->precision(),
index_type->layout());
}
// Reshape the gather node with the inferred shape as the output node
auto gather_node = graph->AddNode(
out_name,
graph->builder_.CreateGather(*x_node, *index_node, /* axis= */ 0));
if (out_dims.size() != 2) {
graph->AddNode(out_name,
graph->builder_.CreateReshape(
*gather_node, CvtShape<xtcl::Integer>(out_dims)));
}
return SUCCESS;
}
} // namespace xpu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU,
gather,
paddle::lite::subgraph::xpu::GatherConverter);
......@@ -22,7 +22,9 @@ namespace subgraph {
namespace xpu {
std::shared_ptr<xtcl::xExpr> Graph::AddNode(const std::string& name,
const xtcl::xExpr& layer) {
const xtcl::xExpr& layer,
PrecisionType precision,
DataLayoutType layout) {
auto unique_name = [&](const std::string& key) {
int idx = 1;
auto it = counts_.find(key);
......@@ -35,7 +37,8 @@ std::shared_ptr<xtcl::xExpr> Graph::AddNode(const std::string& name,
};
auto it = nodes_.find(name);
if (it != nodes_.end()) {
CHECK(params_.find(name) == params_.end()) << "[XPU] Node " << name
// Only variable can rebind the name
CHECK(!it->second.second.persistable()) << "[XPU] Node " << name
<< " redefined.";
// Generate a new unique name as the key to bind the origin node if the
// origin node isn't a const node: new_name->node
......@@ -44,7 +47,8 @@ std::shared_ptr<xtcl::xExpr> Graph::AddNode(const std::string& name,
}
// Create a new node and bind with the name: name->new_node
auto node = std::make_shared<xtcl::xExpr>(layer);
nodes_.insert(std::make_pair(name, node));
nodes_.insert(std::make_pair(
name, std::make_pair(node, Type(precision, layout, false))));
builder_.SetLayer(unique_name(name + "_op"));
return node;
}
......@@ -52,31 +56,36 @@ std::shared_ptr<xtcl::xExpr> Graph::AddNode(const std::string& name,
// Const node
std::shared_ptr<xtcl::xExpr> Graph::AddNode(const std::string& name,
const Tensor& tensor,
PrecisionType ptype,
DataLayoutType ltype) {
return AddNode(name, tensor, tensor.dims().Vectorize(), ptype, ltype);
PrecisionType precision,
DataLayoutType layout) {
return AddNode(name, tensor, tensor.dims().Vectorize(), precision, layout);
}
std::shared_ptr<xtcl::xExpr> Graph::AddNode(const std::string& name,
const Tensor& tensor,
std::vector<int64_t> shape,
PrecisionType ptype,
DataLayoutType ltype) {
auto node = AddNode(name, shape, ptype, ltype);
PrecisionType precision,
DataLayoutType layout) {
CHECK(!HasNode(name)) << "[NPU] Node " << name << " redefined.";
auto node = std::make_shared<xtcl::xExpr>(builder_.CreateTensor(
name, CvtShape<xtcl::xIndexExpr>(shape), CvtPrecisionType(precision)));
nodes_.insert(std::make_pair(
name, std::make_pair(node, Type(precision, layout, true))));
params_.emplace(
std::make_pair(name, *CvtTensor(tensor, shape, ptype, ltype)));
std::make_pair(name, *CvtTensor(tensor, shape, precision, layout)));
return node;
}
// Data node
std::shared_ptr<xtcl::xExpr> Graph::AddNode(const std::string& name,
std::vector<int64_t> shape,
PrecisionType ptype,
DataLayoutType ltype) {
CHECK(!HasNode(name));
auto node = std::make_shared<xtcl::xExpr>(
builder_.CreateTensor(name, CvtShape(shape), CvtPrecisionType(ptype)));
nodes_.insert(std::make_pair(name, node));
PrecisionType precision,
DataLayoutType layout) {
CHECK(!HasNode(name)) << "[NPU] Node " << name << " redefined.";
auto node = std::make_shared<xtcl::xExpr>(builder_.CreateTensor(
name, CvtShape<xtcl::xIndexExpr>(shape), CvtPrecisionType(precision)));
nodes_.insert(std::make_pair(
name, std::make_pair(node, Type(precision, layout, false))));
return node;
}
......
......@@ -18,6 +18,7 @@
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/tensor.h"
......@@ -27,42 +28,75 @@ namespace lite {
namespace subgraph {
namespace xpu {
// The Context of the converters which used for converting the ops of subgraph
// to the XPU IR graph
// Type of graph nodes
class Type {
public:
Type(PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW),
bool persistable = false)
: precision_(precision), layout_(layout), persistable_(persistable) {}
void set_precision(PrecisionType precision) { precision_ = precision; }
void set_layout(DataLayoutType layout) { layout_ = layout; }
void set_persistable(bool persistable) { persistable_ = persistable; }
PrecisionType precision() const { return precision_; }
DataLayoutType layout() const { return layout_; }
bool persistable() const { return persistable_; }
private:
PrecisionType precision_{PRECISION(kFloat)};
DataLayoutType layout_{DATALAYOUT(kNCHW)};
bool persistable_{false};
};
// Graph to collect all of converted XPU IR nodes
class Graph {
public:
// Layer node
std::shared_ptr<xtcl::xExpr> AddNode(const std::string& name,
const xtcl::xExpr& layer);
std::shared_ptr<xtcl::xExpr> AddNode(
const std::string& name,
const xtcl::xExpr& layer,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW));
// Const node
std::shared_ptr<xtcl::xExpr> AddNode(
const std::string& name,
const Tensor& tensor,
PrecisionType ptype = PRECISION(kFloat),
DataLayoutType ltype = DATALAYOUT(kNCHW));
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW));
std::shared_ptr<xtcl::xExpr> AddNode(
const std::string& name,
const Tensor& tensor,
std::vector<int64_t> shape,
PrecisionType ptype = PRECISION(kFloat),
DataLayoutType ltype = DATALAYOUT(kNCHW));
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW));
std::shared_ptr<xtcl::xExpr> AddNode(
const std::string& name,
const Tensor& tensor,
DDim dims,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return AddNode(name, tensor, dims.Vectorize(), precision, layout);
}
template <typename T>
std::shared_ptr<xtcl::xExpr> AddNode(
const std::string& name,
const std::vector<T>& data,
std::vector<int64_t> shape = {},
DataLayoutType ltype = DATALAYOUT(kNCHW)) {
DataLayoutType layout = DATALAYOUT(kNCHW)) {
const std::type_info& info = typeid(T);
PrecisionType ptype = PRECISION(kFloat);
PrecisionType precision = PRECISION(kFloat);
if (info == typeid(float)) {
ptype = PRECISION(kFloat);
precision = PRECISION(kFloat);
} else if (info == typeid(int8_t)) {
ptype = PRECISION(kFloat);
precision = PRECISION(kFloat);
} else if (info == typeid(int32_t)) {
ptype = PRECISION(kInt32);
precision = PRECISION(kInt32);
} else {
LOG(FATAL) << "[XPU] Unknow data type " << info.name();
}
......@@ -80,7 +114,16 @@ class Graph {
std::memcpy(reinterpret_cast<uint8_t*>(tensor.mutable_data<T>()),
reinterpret_cast<const uint8_t*>(data.data()),
data.size() * sizeof(T));
return AddNode(name, tensor, ptype, ltype);
return AddNode(name, tensor, precision, layout);
}
template <typename T>
std::shared_ptr<xtcl::xExpr> AddNode(
const std::string& name,
const std::vector<T>& data,
DDim dims,
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return AddNode(name, data, dims.Vectorize(), layout);
}
template <typename T>
......@@ -88,25 +131,47 @@ class Graph {
const std::string& name,
T value,
std::vector<int64_t> shape = {1},
DataLayoutType ltype = DATALAYOUT(kNCHW)) {
DataLayoutType layout = DATALAYOUT(kNCHW)) {
int64_t size = 1;
for (auto i : shape) {
size *= i;
}
std::vector<T> data(size, value);
return AddNode(name, data, shape, ltype);
return AddNode(name, data, shape, layout);
}
template <typename T>
std::shared_ptr<xtcl::xExpr> AddNode(
const std::string& name,
T value,
DDim dims,
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return AddNode(name, value, dims.Vectorize(), layout);
}
// Data node
std::shared_ptr<xtcl::xExpr> AddNode(
const std::string& name,
std::vector<int64_t> shape,
PrecisionType ptype = PRECISION(kFloat),
DataLayoutType ltype = DATALAYOUT(kNCHW));
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW));
std::shared_ptr<xtcl::xExpr> AddNode(
const std::string& name,
DDim dims,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return AddNode(name, dims.Vectorize(), precision, layout);
}
std::shared_ptr<xtcl::xExpr> GetNode(const std::string& name) {
CHECK(HasNode(name)) << "[XPU] Node " << name << " not found.";
return nodes_.at(name);
return nodes_.at(name).first;
}
const Type& GetType(const std::string& name) {
CHECK(HasNode(name)) << "[XPU] Node " << name << " not found.";
return nodes_.at(name).second;
}
bool HasNode(const std::string& name) {
......@@ -119,7 +184,8 @@ class Graph {
xtcl::network::xTensorCompiler::ParamNDArrayMap params_;
private:
std::unordered_map<std::string, std::shared_ptr<xtcl::xExpr>> nodes_;
std::unordered_map<std::string, std::pair<std::shared_ptr<xtcl::xExpr>, Type>>
nodes_;
std::unordered_map<std::string, int> counts_;
};
......
......@@ -21,7 +21,7 @@ namespace lite {
namespace subgraph {
namespace xpu {
int LayerNormConverter(void* ctx, OpLite* op) {
int LayerNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
......@@ -30,33 +30,92 @@ int LayerNormConverter(void* ctx, OpLite* op) {
auto scope = op->scope();
VLOG(3) << "[XPU] Converting " + op_type + "...";
// Get input vars and op attributes
auto x_var_name = op_info->Input("X").front();
auto scale_var_name = op_info->Input("Scale").front();
auto* scale = scope->FindMutableTensor(scale_var_name);
auto bias_var_name = op_info->Input("Bias").front();
auto* bias = scope->FindMutableTensor(bias_var_name);
auto y_var_name = op_info->Output("Y").front();
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto y_name = op_info->Output("Y").front();
auto y_type = kernel->GetOutputDeclType("Y");
CHECK(y_type->precision() == PRECISION(kFloat));
CHECK(y_type->layout() == DATALAYOUT(kNCHW));
auto y = scope->FindMutableTensor(y_name);
auto y_dims = y->dims();
auto epsilon = op_info->GetAttr<float>("epsilon");
auto axis = op_info->GetAttr<int>("begin_norm_axis");
auto x_rank = static_cast<int>(x_dims.size());
axis = axis < 0 ? (x_rank + axis) : axis;
bool reshape = axis != (x_rank - 1); // XPU only support the last dimension
auto x_inner_size = x_dims.Slice(axis, x_rank).production();
// X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
}
if (reshape) {
auto reshaped_x_dims = x_dims.Slice(0, axis).Vectorize();
reshaped_x_dims.push_back(x_inner_size);
x_node =
graph->AddNode(x_name + "/reshape",
graph->builder_.CreateReshape(
*x_node, CvtShape<xtcl::Integer>(reshaped_x_dims)));
}
// Scale node
std::shared_ptr<xtcl::xExpr> scale_const_node = nullptr;
if (HasInputArg(op_info, scope, "Scale")) {
auto scale_name = op_info->Input("Scale").front();
auto scale_type = kernel->GetInputDeclType("Scale");
CHECK(scale_type->precision() == PRECISION(kFloat));
CHECK(scale_type->layout() == DATALAYOUT(kNCHW));
auto scale = scope->FindMutableTensor(scale_name);
auto scale_dims = scale->dims();
CHECK_EQ(scale_dims.size(), 1);
CHECK_EQ(scale_dims.production(), x_inner_size);
scale_const_node = graph->AddNode(scale_name, *scale);
} else {
scale_const_node =
graph->AddNode(y_name + "/scale_one", 1.0f, {x_inner_size});
}
// Create scale, bias nodes
auto scale_const_node = graph->AddNode(scale_var_name, *scale);
auto bias_const_node = graph->AddNode(bias_var_name, *bias);
// Bias node
std::shared_ptr<xtcl::xExpr> bias_const_node = nullptr;
if (HasInputArg(op_info, scope, "Bias")) {
auto bias_name = op_info->Input("Bias").front();
auto bias_type = kernel->GetInputDeclType("Bias");
CHECK(bias_type->precision() == PRECISION(kFloat));
CHECK(bias_type->layout() == DATALAYOUT(kNCHW));
auto bias = scope->FindMutableTensor(bias_name);
auto bias_dims = bias->dims();
CHECK_EQ(bias_dims.size(), 1);
CHECK_EQ(bias_dims.production(), x_inner_size);
bias_const_node = graph->AddNode(bias_name, *bias);
} else {
bias_const_node =
graph->AddNode(y_name + "/bias_zero", 0.0f, {x_inner_size});
}
// Create node and set params from op
// Layer Norm node
auto layer_norm_node =
graph->builder_.CreateLayerNorm(*graph->GetNode(x_var_name),
graph->AddNode(y_name,
graph->builder_.CreateLayerNorm(*x_node,
*scale_const_node,
*bias_const_node,
axis,
epsilon,
true,
true);
graph->AddNode(y_var_name, graph->builder_.GetField(layer_norm_node, 0));
return SUCCESS;
true));
if (reshape) {
graph->AddNode(y_name,
graph->builder_.CreateReshape(
*layer_norm_node, CvtShape<xtcl::Integer>(y_dims)));
}
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace xpu
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/xpu/bridges/graph.h"
#include "lite/kernels/xpu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace xpu {
int LookupTableConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[XPU] Converting " + op_type + "...";
// Get input and output vars and op attributes
auto ids_name = op_info->Input("Ids").front();
auto ids_type = kernel->GetInputDeclType("Ids");
CHECK(ids_type->precision() == PRECISION(kInt64));
CHECK(ids_type->layout() == DATALAYOUT(kNCHW));
auto ids = scope->FindMutableTensor(ids_name);
auto ids_dims = ids->dims();
auto w_name = op_info->Input("W").front();
auto w_type = kernel->GetInputDeclType("W");
CHECK(w_type->precision() == PRECISION(kFloat));
CHECK(w_type->layout() == DATALAYOUT(kNCHW));
auto w = scope->FindMutableTensor(w_name);
auto w_dims = w->dims();
CHECK_EQ(w_dims.size(), 2);
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
auto out = scope->FindMutableTensor(out_name);
auto out_dims = out->dims();
auto padding_idx = op_info->GetAttr<int64_t>("padding_idx");
if (padding_idx != -1) {
LOG(WARNING) << "[XPU] Only padding_idx=-1 is supported.";
return FAILED;
}
// Ids node
std::shared_ptr<xtcl::xExpr> ids_node = nullptr;
if (graph->HasNode(ids_name)) {
ids_node = graph->GetNode(ids_name);
} else {
ids_node = graph->AddNode(
ids_name, ids_dims, ids_type->precision(), ids_type->layout());
}
// Flatten Ids node
if (ids_dims.size() != 1) {
ids_node = graph->AddNode(ids_name + "/reshape",
graph->builder_.CreateReshape(*ids_node, {-1}),
ids_type->precision(),
ids_type->layout());
}
auto w_const_node = graph->AddNode(w_name, *w);
// Reshape the gather node with the inferred shape as the output node
auto gather_node = graph->AddNode(
out_name,
graph->builder_.CreateGather(*w_const_node, *ids_node, /* axis= */ 0));
if (out_dims.size() != 2) {
graph->AddNode(out_name,
graph->builder_.CreateReshape(
*gather_node, CvtShape<xtcl::Integer>(out_dims)));
}
return SUCCESS;
}
} // namespace xpu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU,
lookup_table,
paddle::lite::subgraph::xpu::LookupTableConverter);
......@@ -21,7 +21,7 @@ namespace lite {
namespace subgraph {
namespace xpu {
int MulConverter(void* ctx, OpLite* op) {
int MulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
......@@ -30,45 +30,57 @@ int MulConverter(void* ctx, OpLite* op) {
auto scope = op->scope();
VLOG(3) << "[XPU] Converting " + op_type + "...";
// Get input, and attributes
auto x_var_name = op_info->Input("X").front();
auto y_var_name = op_info->Input("Y").front();
auto out_var_name = op_info->Output("Out").front();
auto y = scope->FindMutableTensor(y_var_name);
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto y_name = op_info->Input("Y").front();
auto y_type = kernel->GetInputDeclType("Y");
CHECK(y_type->precision() == PRECISION(kFloat));
CHECK(y_type->layout() == DATALAYOUT(kNCHW));
auto y = scope->FindMutableTensor(y_name);
auto y_dims = y->dims();
CHECK_EQ(y_dims.size(), 2) << "xpu now only support y_dims.size() == 2";
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
auto out = scope->FindMutableTensor(out_name);
auto out_dims = out->dims();
auto x_num_col_dims = op_info->GetAttr<int>("x_num_col_dims");
CHECK_EQ(x_num_col_dims, 1) << "xpu now only support x_num_col_dims == 1";
auto y_num_col_dims = op_info->GetAttr<int>("x_num_col_dims");
CHECK_EQ(y_num_col_dims, 1) << "xpu now only support y_num_col_dims == 1";
// Flatten x node
auto x_node = graph->AddNode(
x_var_name + "/flatten",
graph->builder_.CreateBatchFlatten(*graph->GetNode(x_var_name)));
auto x_matrix_dims = x_dims.Flatten2D(x_num_col_dims);
auto y_num_col_dims = op_info->GetAttr<int>("y_num_col_dims");
auto y_matrix_dims = y_dims.Flatten2D(y_num_col_dims);
CHECK_EQ(x_matrix_dims[1], y_matrix_dims[0]);
// Transpose y data and create y node
Tensor transpose_y;
DDim transpose_y_dims(std::vector<int64_t>{y_dims[1], y_dims[0]});
transpose_y.Resize(transpose_y_dims);
auto transpose_y_data = transpose_y.mutable_data<float>();
auto y_data = y->mutable_data<float>();
for (int i = 0; i < transpose_y_dims[0]; i++) {
for (int j = 0; j < transpose_y_dims[1]; j++) {
transpose_y_data[i * transpose_y_dims[1] + j] =
y_data[j * transpose_y_dims[0] + i];
// X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
}
// Flatten X node
if (x_dims.size() != 2) {
x_node =
graph->AddNode(x_name + "/reshape",
graph->builder_.CreateReshape(
*x_node, {-1, static_cast<int>(y_matrix_dims[0])}));
}
auto y_const_node = graph->AddNode(y_var_name + "/transpose", transpose_y);
// Create mul node and set params from op
graph->AddNode(
out_var_name,
graph->builder_.CreateDense(*x_node,
static_cast<int>(y_dims[1]),
::xtcl::NullValue<::xtcl::DataType>(),
*y_const_node));
// Y node
auto y_const_node = graph->AddNode(y_name, *y, y_matrix_dims);
// Reshape the matmul node with the inferred shape as the output node
auto matmul_node = graph->AddNode(
out_name, graph->builder_.CreateMatmul2D(*x_node, *y_const_node, false));
if (out_dims.size() != 2) {
graph->AddNode(out_name,
graph->builder_.CreateReshape(
*matmul_node, CvtShape<xtcl::Integer>(out_dims)));
}
return REBUILD_WHEN_SHAPE_CHANGED;
}
......
......@@ -15,6 +15,7 @@
#pragma once
USE_SUBGRAPH_BRIDGE(XPU, relu);
USE_SUBGRAPH_BRIDGE(XPU, tanh);
USE_SUBGRAPH_BRIDGE(XPU, conv2d);
USE_SUBGRAPH_BRIDGE(XPU, depthwise_conv2d);
USE_SUBGRAPH_BRIDGE(XPU, elementwise_add);
......@@ -22,8 +23,15 @@ USE_SUBGRAPH_BRIDGE(XPU, pool2d);
USE_SUBGRAPH_BRIDGE(XPU, softmax);
USE_SUBGRAPH_BRIDGE(XPU, mul);
USE_SUBGRAPH_BRIDGE(XPU, batch_norm);
USE_SUBGRAPH_BRIDGE(XPU, stack);
USE_SUBGRAPH_BRIDGE(XPU, gather);
USE_SUBGRAPH_BRIDGE(XPU, scale);
USE_SUBGRAPH_BRIDGE(XPU, lookup_table);
USE_SUBGRAPH_BRIDGE(XPU, slice);
USE_SUBGRAPH_BRIDGE(XPU, transpose);
USE_SUBGRAPH_BRIDGE(XPU, transpose2);
USE_SUBGRAPH_BRIDGE(XPU, reshape);
USE_SUBGRAPH_BRIDGE(XPU, reshape2);
USE_SUBGRAPH_BRIDGE(XPU, layer_norm);
USE_SUBGRAPH_BRIDGE(XPU, gelu);
USE_SUBGRAPH_BRIDGE(XPU, dropout);
......@@ -21,17 +21,26 @@ namespace lite {
namespace subgraph {
namespace xpu {
int PoolConverter(void* ctx, OpLite* op) {
int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[XPU] Converting " + op_type + "...";
// Get input, and attributes
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
auto pooling_type = op_info->GetAttr<std::string>("pooling_type");
auto ceil_mode = op_info->GetAttr<bool>("ceil_mode");
auto paddings = op_info->GetAttr<std::vector<int>>("paddings");
......@@ -40,35 +49,39 @@ int PoolConverter(void* ctx, OpLite* op) {
auto strides = op_info->GetAttr<std::vector<int>>("strides");
auto exclusive = op_info->GetAttr<bool>("exclusive");
// Create pool node and set params from op
// X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
}
// Pool node
if (pooling_type == "max") {
if (global_pooling) {
graph->AddNode(
out_var_name,
graph->builder_.CreateGlobalMaxPool2D(*graph->GetNode(x_var_name)));
graph->AddNode(out_name, graph->builder_.CreateGlobalMaxPool2D(*x_node));
} else {
graph->AddNode(
out_var_name,
graph->builder_.CreateMaxPool2D(*graph->GetNode(x_var_name),
CvtShape(ksize),
CvtShape(strides),
CvtShape(paddings),
out_name,
graph->builder_.CreateMaxPool2D(*x_node,
CvtShape<xtcl::xIndexExpr>(ksize),
CvtShape<xtcl::xIndexExpr>(strides),
CvtShape<xtcl::xIndexExpr>(paddings),
"NCHW",
ceil_mode));
}
} else if (pooling_type == "avg") {
if (global_pooling) {
graph->AddNode(
out_var_name,
graph->builder_.CreateGlobalAvgPool2D(*graph->GetNode(x_var_name)));
graph->AddNode(out_name, graph->builder_.CreateGlobalAvgPool2D(*x_node));
} else {
// !exclusive ---> count_include_pad
graph->AddNode(
out_var_name,
graph->builder_.CreateAvgPool2D(*graph->GetNode(x_var_name),
CvtShape(ksize),
CvtShape(strides),
CvtShape(paddings),
out_name,
graph->builder_.CreateAvgPool2D(*x_node,
CvtShape<xtcl::xIndexExpr>(ksize),
CvtShape<xtcl::xIndexExpr>(strides),
CvtShape<xtcl::xIndexExpr>(paddings),
"NCHW",
ceil_mode,
!exclusive));
......
......@@ -22,7 +22,7 @@ namespace lite {
namespace subgraph {
namespace xpu {
int ReshapeConverter(void* ctx, OpLite* op) {
int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
......@@ -31,40 +31,65 @@ int ReshapeConverter(void* ctx, OpLite* op) {
auto op_type = op_info->Type();
VLOG(3) << "[XPU] Converting " + op_type + "...";
// Create node and set params from op
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
// X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
}
std::vector<int> shape;
if (op_info->HasInput("ShapeTensor") &&
!op_info->Input("ShapeTensor").empty()) {
for (auto var_name : op_info->Input("ShapeTensor")) {
shape.emplace_back(scope->FindMutableTensor(var_name)->data<int>()[0]);
if (HasInputArg(op_info, scope, "ShapeTensor")) {
auto shape_tensor_names = op_info->Input("ShapeTensor");
// auto shape_tensor_type = kernel->GetInputDeclType("ShapeTensor");
// CHECK(shape_tensor_type->precision() == PRECISION(kInt32));
// CHECK(shape_tensor_type->layout() == DATALAYOUT(kNCHW));
for (auto shape_tensor_name : shape_tensor_names) {
auto shape_tensor = scope->FindMutableTensor(shape_tensor_name);
auto shape_tensor_data = shape_tensor->mutable_data<int>();
shape.emplace_back(shape_tensor_data[0]);
}
CHECK_GT(shape.size(), 0)
<< "ShapeError: When `shape` in ReshapeOp is a list or tuple "
<< "[XPU] ShapeError: When `shape` in ReshapeOp is a list or tuple "
"which contains Tensor, the shape's size can't be zero. "
"But received shape's size is "
<< shape.size();
} else if (op_info->HasInput("Shape") && !op_info->Input("Shape").empty()) {
auto shape_tensor =
scope->FindMutableTensor(op_info->Input("Shape").front());
auto shape_data = shape_tensor->data<int>();
shape = std::vector<int>(shape_data, shape_data + shape_tensor->numel());
} else if (HasInputArg(op_info, scope, "Shape")) {
auto actual_shape_name = op_info->Input("Shape").front();
// auto actual_shape_type = kernel->GetInputDeclType("Shape");
// CHECK(actual_shape_type->precision() == PRECISION(kInt32));
// CHECK(actual_shape_type->layout() == DATALAYOUT(kNCHW));
auto actual_shape = scope->FindMutableTensor(actual_shape_name);
auto actual_shape_dims = actual_shape->dims();
auto actual_shape_data = actual_shape->mutable_data<int>();
auto shape = std::vector<int>(
actual_shape_data, actual_shape_data + actual_shape_dims.production());
} else if (op_info->HasAttr("shape")) {
shape = op_info->GetAttr<std::vector<int>>("shape");
} else {
LOG(FATAL) << "no new shape for reshape op";
LOG(WARNING) << "[XPU] No new shape for reshape op";
return FAILED;
}
auto out_dims =
operators::ValidateShape(shape, scope->FindTensor(x_var_name)->dims());
CHECK(graph->HasNode(x_var_name));
graph->AddNode(out_var_name,
graph->builder_.CreateReshape(*graph->GetNode(x_var_name),
Cvt2ArrayInt(out_dims)));
auto out_dims = operators::ValidateShape(shape, x_dims);
return SUCCESS;
// Reshape node
graph->AddNode(out_name,
graph->builder_.CreateReshape(
*x_node, CvtShape<xtcl::Integer>(out_dims)));
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace xpu
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/xpu/bridges/graph.h"
#include "lite/kernels/xpu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace xpu {
int ScaleConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[XPU] Converting " + op_type + "...";
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
float scale = op_info->GetAttr<float>("scale");
bool bias_after_scale = op_info->GetAttr<bool>("bias_after_scale");
float bias = op_info->GetAttr<float>("bias");
// X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
}
// Scale node
graph->AddNode(
out_name,
graph->builder_.CreateScale(*x_node, scale, bias, bias_after_scale));
return SUCCESS;
}
} // namespace xpu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU,
scale,
paddle::lite::subgraph::xpu::ScaleConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/xpu/bridges/graph.h"
#include "lite/kernels/xpu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace xpu {
int SliceConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[XPU] Converting " + op_type + "...";
// Get input, output and op attributes
auto input_name = op_info->Input("Input").front();
auto input_type = kernel->GetInputDeclType("Input");
CHECK(input_type->precision() == PRECISION(kFloat));
CHECK(input_type->layout() == DATALAYOUT(kNCHW));
auto input = scope->FindMutableTensor(input_name);
auto input_dims = input->dims();
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
auto axes = op_info->GetAttr<std::vector<int>>("axes");
auto starts = op_info->GetAttr<std::vector<int>>("starts");
auto ends = op_info->GetAttr<std::vector<int>>("ends");
// Input node
std::shared_ptr<xtcl::xExpr> input_node = nullptr;
if (graph->HasNode(input_name)) {
input_node = graph->GetNode(input_name);
} else {
input_node = graph->AddNode(input_name, input_dims);
}
// Calculate the begin and end of the slice in all of
// dimensions and Create slice node as the output node
xtcl::Array<xtcl::Integer> begin, end, strides;
for (size_t i = 0; i < input_dims.size(); ++i) {
auto it = std::find(axes.cbegin(), axes.cend(), i);
if (it == axes.cend()) {
// If not found, don't slice this axis
int s = 0;
int e = input_dims[i];
begin.push_back(s);
end.push_back(e);
strides.push_back(1);
} else {
int offset = it - axes.cbegin();
int s = starts[offset];
int e = ends[offset];
begin.push_back(s);
end.push_back(e);
strides.push_back(1);
}
}
graph->AddNode(
out_name,
graph->builder_.CreateStridedSlice(*input_node, begin, end, strides));
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace xpu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU,
slice,
paddle::lite::subgraph::xpu::SliceConverter);
......@@ -21,23 +21,38 @@ namespace lite {
namespace subgraph {
namespace xpu {
int SoftmaxConverter(void* ctx, OpLite* op) {
int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[XPU] Converting " + op_type + "...";
// Get op's attributes
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
auto axis = op_info->GetAttr<int>("axis");
// Create softmax node and set params from ops
graph->AddNode(
out_var_name,
graph->builder_.CreateSoftmax(*graph->GetNode(x_var_name), axis));
// X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
}
// Softmax node
graph->AddNode(out_name, graph->builder_.CreateSoftmax(*x_node, axis));
return SUCCESS;
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/xpu/bridges/graph.h"
#include "lite/kernels/xpu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace xpu {
int StackConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[XPU] Converting " + op_type + "...";
// Get input and output vars and op attributes
auto x_names = op_info->Input("X");
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto y_name = op_info->Output("Y").front();
auto y_type = kernel->GetOutputDeclType("Y");
CHECK(y_type->precision() == PRECISION(kFloat));
CHECK(y_type->layout() == DATALAYOUT(kNCHW));
int axis = op_info->GetAttr<int>("axis");
// X nodes
xtcl::Array<xtcl::xExpr> x_nodes;
for (auto& x_name : x_names) {
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
std::shared_ptr<xtcl::xExpr> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
}
x_nodes.push_back(*x_node);
}
// Stack node
graph->AddNode(y_name,
graph->builder_.CreateStack(
xtcl::network::TupleNode::make(x_nodes), axis));
return SUCCESS;
}
} // namespace xpu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU,
stack,
paddle::lite::subgraph::xpu::StackConverter);
......@@ -21,26 +21,42 @@ namespace lite {
namespace subgraph {
namespace xpu {
int TransposeConverter(void* ctx, OpLite* op) {
int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[XPU] Converting " + op_type + "...";
// Create node and set params from op
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
auto axis = op_info->GetAttr<std::vector<int>>("axis");
CHECK(graph->HasNode(x_var_name));
graph->AddNode(
out_var_name,
// X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
}
// Transpose node
graph->AddNode(out_name,
graph->builder_.CreateTranspose(
*graph->GetNode(x_var_name),
Cvt2ArrayInt(std::vector<int64_t>(axis.begin(), axis.end()))));
*x_node,
CvtShape<xtcl::Integer>(
std::vector<int64_t>(axis.begin(), axis.end()))));
return SUCCESS;
}
......
......@@ -47,9 +47,15 @@ xtcl::DataType CvtPrecisionType(PrecisionType in_type) {
case PRECISION(kInt8):
out_type = ::xtcl::Int(8);
break;
case PRECISION(kInt16):
out_type = ::xtcl::Int(16);
break;
case PRECISION(kInt32):
out_type = ::xtcl::Int(32);
break;
case PRECISION(kInt64):
out_type = ::xtcl::Int(64);
break;
default:
LOG(FATAL) << "[XPU] Can not convert precision type("
<< PrecisionToStr(in_type) << ") from Lite to XPU";
......@@ -58,7 +64,7 @@ xtcl::DataType CvtPrecisionType(PrecisionType in_type) {
return out_type;
}
DLDataType CvtDataType(PrecisionType in_type) {
DLDataType CvtDLDataType(PrecisionType in_type) {
DLDataType out_type = {kDLFloat, 32, 1};
switch (in_type) {
case PRECISION(kFloat):
......@@ -67,76 +73,64 @@ DLDataType CvtDataType(PrecisionType in_type) {
case PRECISION(kInt8):
out_type = {kDLInt, 8, 1};
break;
case PRECISION(kInt16):
out_type = {kDLInt, 16, 1};
break;
case PRECISION(kInt32):
out_type = {kDLInt, 32, 1};
break;
case PRECISION(kInt64):
out_type = {kDLInt, 64, 1};
break;
default:
LOG(FATAL) << "[XPU] Can not convert data type("
<< PrecisionToStr(in_type) << ") from Lite to XPU";
LOG(FATAL) << "[XPU] Can not convert precision type("
<< PrecisionToStr(in_type) << ") from Lite to XPU DLDataType";
break;
}
return out_type;
}
xtcl::Array<xtcl::xIndexExpr> CvtShape(const std::vector<int>& in_shape) {
xtcl::Array<xtcl::xIndexExpr> out_shape;
for (auto dim : in_shape) {
out_shape.push_back(dim);
DLDeviceType CvtDLDeviceType(TargetType in_type) {
DLDeviceType out_type = kDLCPU;
switch (in_type) {
case TARGET(kX86):
out_type = kDLCPU;
break;
case TARGET(kHost):
out_type = kDLCPU;
break;
case TARGET(kCUDA):
out_type = kDLGPU;
break;
case TARGET(kXPU):
out_type = kDLCPU;
break;
default:
LOG(FATAL) << "[XPU] Can not convert target type(" << TargetToStr(in_type)
<< ") from Lite to XPU DLDeviceType";
break;
}
return out_shape;
}
xtcl::Array<xtcl::xIndexExpr> CvtShape(const std::vector<int64_t>& in_shape) {
return CvtShape(std::vector<int>(in_shape.begin(), in_shape.end()));
}
xtcl::Array<xtcl::xIndexExpr> CvtShape(const DDim& in_dims) {
return CvtShape(in_dims.Vectorize());
return out_type;
}
std::shared_ptr<xtcl::xNDArray> CvtTensor(const Tensor& in_tensor,
std::vector<int64_t> out_shape,
PrecisionType in_ptype,
DataLayoutType in_ltype) {
const uint8_t* in_data = nullptr;
auto in_size = in_tensor.dims().production();
PrecisionType in_precision,
DataLayoutType in_layout) {
auto in_shape = in_tensor.dims().Vectorize();
if (out_shape.empty()) {
out_shape = in_shape;
}
int in_bytes;
if (in_ptype == PRECISION(kFloat)) {
in_data = reinterpret_cast<const uint8_t*>(in_tensor.data<float>());
in_bytes = in_size * sizeof(float);
} else if (in_ptype == PRECISION(kInt32)) {
in_data = reinterpret_cast<const uint8_t*>(in_tensor.data<int32_t>());
in_bytes = in_size * sizeof(int32_t);
} else if (in_ptype == PRECISION(kInt8)) {
in_data = reinterpret_cast<const uint8_t*>(in_tensor.data<int8_t>());
in_bytes = in_size * sizeof(int8_t);
} else {
LOG(FATAL) << "[XPU] Unknow precision type " << PrecisionToStr(in_ptype);
}
auto out_tensor = std::make_shared<xtcl::xNDArray>(
xtcl::xNDArray::Empty(out_shape, CvtDataType(in_ptype), {kDLCPU, 0}));
xtcl::xNDArray::Empty(out_shape,
CvtDLDataType(in_precision),
{CvtDLDeviceType(TARGET(kHost)), 0}));
auto out_data =
reinterpret_cast<uint8_t*>(out_tensor->ToDLPack()->dl_tensor.data);
std::memcpy(out_data, in_data, in_bytes);
std::memcpy(out_data, in_tensor.raw_data(), in_tensor.memory_size());
return out_tensor;
}
xtcl::Array<xtcl::Integer> Cvt2ArrayInt(const std::vector<int64_t>& input) {
xtcl::Array<xtcl::Integer> output;
for (auto i : input) {
output.push_back(i);
}
return output;
}
xtcl::Array<xtcl::Integer> Cvt2ArrayInt(const DDim& input) {
return Cvt2ArrayInt(input.Vectorize());
}
} // namespace xpu
} // namespace subgraph
} // namespace lite
......
......@@ -33,22 +33,33 @@ bool HasInputArg(const OpInfo* op_info,
xtcl::DataType CvtPrecisionType(PrecisionType in_type);
DLDataType CvtDataType(PrecisionType in_type);
DLDataType CvtDLDataType(PrecisionType in_type);
DLDeviceType CvtDLDeviceType(TargetType in_type);
xtcl::Array<xtcl::xIndexExpr> CvtShape(const std::vector<int>& in_shape);
template <typename T>
xtcl::Array<T> CvtShape(const std::vector<int>& in_shape) {
xtcl::Array<T> out_shape;
for (auto dim : in_shape) {
out_shape.push_back(dim);
}
return out_shape;
}
xtcl::Array<xtcl::xIndexExpr> CvtShape(const std::vector<int64_t>& in_shape);
template <typename T>
xtcl::Array<T> CvtShape(const std::vector<int64_t>& in_shape) {
return CvtShape<T>(std::vector<int>(in_shape.begin(), in_shape.end()));
}
xtcl::Array<xtcl::xIndexExpr> CvtShape(const DDim& in_dims);
template <typename T>
xtcl::Array<T> CvtShape(const DDim& in_dims) {
return CvtShape<T>(in_dims.Vectorize());
}
std::shared_ptr<xtcl::xNDArray> CvtTensor(
const Tensor& in_tensor,
std::vector<int64_t> out_shape = {},
PrecisionType in_ptype = PRECISION(kFloat),
DataLayoutType in_ltype = DATALAYOUT(kNCHW));
xtcl::Array<xtcl::Integer> Cvt2ArrayInt(const std::vector<int64_t>& input);
xtcl::Array<xtcl::Integer> Cvt2ArrayInt(const DDim& input);
PrecisionType in_precision = PRECISION(kFloat),
DataLayoutType in_layout = DATALAYOUT(kNCHW));
} // namespace xpu
} // namespace subgraph
......
......@@ -20,6 +20,7 @@
#include "lite/core/op_registry.h"
#include "lite/kernels/xpu/bridges/graph.h"
#include "lite/kernels/xpu/bridges/paddle_use_bridges.h"
#include "lite/kernels/xpu/bridges/utility.h"
namespace paddle {
namespace lite {
......@@ -28,19 +29,9 @@ namespace xpu {
int SubgraphEngine::BuildDeviceProgram() {
int status = 0;
// Convert all of input data vars and added into the XPU IR graph
// Convert all of ops and their input vars and weights and added into the XPU
// IR graph
subgraph::xpu::Graph graph;
for (auto& input_name : input_names_) {
auto input_tensor = scope_->FindMutableTensor(input_name);
CHECK(input_tensor);
auto input_node =
graph.AddNode(input_name, input_tensor->dims().Vectorize());
CHECK(input_node);
// XTCL doesn't support dynamic dimensions/shapes, so need to rebuild
// the program when the shape of any input tensor is changed.
status |= subgraph::REBUILD_WHEN_SHAPE_CHANGED;
}
// Convert all of ops and its weights and added into the XPU IR graph
const auto& bridges = subgraph::Registry::Instance();
for (auto& inst : origin_program_) {
auto op = inst.op();
......@@ -51,62 +42,140 @@ int SubgraphEngine::BuildDeviceProgram() {
if (!bridges.Exists("XPU", op_type)) {
return subgraph::FAILED;
}
auto kernel = inst.kernel();
status |= bridges.Select("XPU", op_type)(reinterpret_cast<void*>(&graph),
const_cast<OpLite*>(op));
const_cast<OpLite*>(op),
const_cast<KernelBase*>(kernel));
if (subgraph::CHECK_FAILED(status)) {
return subgraph::FAILED;
}
}
// Obtain the output nodes of the XPU IR graph and build the graph to XPU
// Obtain the output nodes of the XPU IR graph and build the graph to the XPU
// runtime
std::vector<xtcl::xExpr*> output_nodes;
std::vector<std::string> valid_output_names;
device_inames_.clear();
device_onames_.clear();
std::vector<xtcl::xExpr*> device_inodes;
std::vector<xtcl::xExpr*> device_onodes;
for (auto& input_name : input_names_) {
if (graph.HasNode(input_name)) {
if (!graph.GetType(input_name).persistable()) {
device_inodes.push_back(graph.GetNode(input_name).get());
device_inames_.push_back(input_name);
} else {
LOG(WARNING) << "[XPU] Input node " << input_name
<< " is skipped because it is a persistable node.";
}
} else {
LOG(WARNING) << "[XPU] Input node " << input_name
<< " is skipped because it does not exist.";
}
}
for (auto& output_name : output_names_) {
if (graph.HasNode(output_name)) {
output_nodes.push_back(graph.GetNode(output_name).get());
valid_output_names.push_back(output_name);
device_onodes.push_back(graph.GetNode(output_name).get());
device_onames_.push_back(output_name);
} else {
LOG(WARNING) << "[XPU] Output node " << output_name
<< " is skipped because it does not exist.";
}
}
CHECK(!valid_output_names.empty()) << "[XPU] no valid output names";
CHECK(!device_inames_.empty())
<< "[XPU] No input nodes found for building XPU model";
CHECK(!device_onames_.empty())
<< "[XPU] No output nodes found for building XPU model";
device_program_ = lite::xpu::Device::Global().Build(
&graph.builder_, &graph.params_, &output_nodes);
&graph.builder_, &graph.params_, &device_onodes);
if (device_program_ == nullptr) {
LOG(WARNING) << "[XPU] Build model failed!";
return subgraph::FAILED;
}
// Query and check the dimensions of input and output tensors
origin_idims_.resize(input_names_.size());
origin_itensors_.resize(input_names_.size());
origin_odims_.resize(valid_output_names.size());
origin_otensors_.resize(valid_output_names.size());
for (int i = 0; i < input_names_.size(); i++) {
origin_itensors_[i] = scope_->FindMutableTensor(input_names_[i]);
origin_idims_.resize(device_inames_.size());
origin_itensors_.resize(device_inames_.size());
device_itensors_.resize(device_inames_.size());
origin_odims_.resize(device_onames_.size());
origin_otensors_.resize(device_onames_.size());
device_otensors_.resize(device_onames_.size());
for (int i = 0; i < device_inames_.size(); i++) {
auto type = graph.GetType(device_inames_[i]);
auto precision = type.precision();
auto layout = type.layout();
origin_itensors_[i] = scope_->FindMutableTensor(device_inames_[i]);
CHECK(origin_itensors_[i]);
origin_idims_[i] = origin_itensors_[i]->dims();
VLOG(3) << "[XPU] Input dims[" << i << "]: " << origin_idims_[i];
VLOG(3) << "[XPU] Inputs[" << i
<< "] precision: " << PrecisionToStr(precision)
<< " layout: " << DataLayoutToStr(layout)
<< " dims: " << origin_idims_[i];
// Prepare the device input tensors which share data with the origin input
// tensors
device_itensors_[i].data = nullptr;
device_itensors_[i].ctx.device_type =
subgraph::xpu::CvtDLDeviceType(TARGET(kHost));
device_itensors_[i].ctx.device_id = 0;
device_itensors_[i].ndim = origin_idims_[i].size();
device_itensors_[i].dtype = subgraph::xpu::CvtDLDataType(precision);
device_itensors_[i].shape = const_cast<int64_t*>(
static_cast<const int64_t*>(origin_idims_[i].data().data()));
device_itensors_[i].strides = nullptr;
device_itensors_[i].byte_offset = 0;
}
for (int i = 0; i < valid_output_names.size(); i++) {
origin_otensors_[i] = scope_->FindMutableTensor(valid_output_names[i]);
for (int i = 0; i < device_onames_.size(); i++) {
auto type = graph.GetType(device_onames_[i]);
auto precision = type.precision();
auto layout = type.layout();
origin_otensors_[i] = scope_->FindMutableTensor(device_onames_[i]);
CHECK(origin_otensors_[i]);
origin_odims_[i] = origin_otensors_[i]->dims();
VLOG(3) << "[XPU] Output dims[" << i << "]: " << origin_odims_[i];
VLOG(3) << "[XPU] Outputs[" << i
<< "] precision: " << PrecisionToStr(precision)
<< " layout: " << DataLayoutToStr(layout)
<< " dims: " << origin_odims_[i];
// Prepare the device output tensors which share data with the origin output
// tensors
switch (precision) {
case PRECISION(kFloat):
origin_otensors_[i]->mutable_data<float>();
break;
case PRECISION(kInt8):
origin_otensors_[i]->mutable_data<int8_t>();
break;
case PRECISION(kInt16):
origin_otensors_[i]->mutable_data<int16_t>();
break;
case PRECISION(kInt32):
origin_otensors_[i]->mutable_data<int32_t>();
break;
case PRECISION(kInt64):
origin_otensors_[i]->mutable_data<int64_t>();
break;
default:
LOG(FATAL) << "[XPU] " << device_onames_[i]
<< " can't mutable data with precision type "
<< PrecisionToStr(precision);
break;
}
device_otensors_[i].data = nullptr;
device_otensors_[i].ctx.device_type =
subgraph::xpu::CvtDLDeviceType(TARGET(kHost));
device_otensors_[i].ctx.device_id = 0;
device_otensors_[i].ndim = origin_odims_[i].size();
device_otensors_[i].dtype = subgraph::xpu::CvtDLDataType(precision);
device_otensors_[i].shape = const_cast<int64_t*>(
static_cast<const int64_t*>(origin_odims_[i].data().data()));
device_otensors_[i].strides = nullptr;
device_otensors_[i].byte_offset = 0;
}
return status;
}
int SubgraphEngine::LaunchDeviceProgram() {
// Copy the data of origin input tensors to the buffer of input XPU tensors
for (size_t i = 0; i < input_names_.size(); i++) {
auto input_ndarray =
xtcl::xNDArray::Empty(origin_itensors_[i]->dims().Vectorize(),
{kDLFloat, 32, 1},
{kDLCPU, 0});
std::memcpy(static_cast<float*>(input_ndarray.ToDLPack()->dl_tensor.data),
origin_itensors_[i]->mutable_data<float>(),
sizeof(float) * origin_itensors_[i]->dims().production());
device_program_->SetInputZeroCopy(input_names_[i],
&input_ndarray.ToDLPack()->dl_tensor);
for (size_t i = 0; i < device_itensors_.size(); i++) {
// Update the data pointer of DLTensor to track the origin input tensors
device_itensors_[i].data =
const_cast<void*>(origin_itensors_[i]->raw_data());
device_program_->SetInputZeroCopy(device_inames_[i], &device_itensors_[i]);
}
// Run the XPU model
auto GetCurrentUS = []() -> double {
......@@ -117,12 +186,11 @@ int SubgraphEngine::LaunchDeviceProgram() {
auto start_time = GetCurrentUS();
device_program_->Run();
VLOG(3) << "[XPU] Process cost " << GetCurrentUS() - start_time << " us";
// Copy the data of output XPU tensor to the buffer of origin output tensors
for (size_t i = 0; i < origin_otensors_.size(); i++) {
auto output_ndarray = device_program_->GetOutput(i);
std::memcpy(origin_otensors_[i]->mutable_data<float>(),
static_cast<float*>(output_ndarray.ToDLPack()->dl_tensor.data),
sizeof(float) * origin_otensors_[i]->dims().production());
for (size_t i = 0; i < device_otensors_.size(); i++) {
// Update the data pointer of DLTensor to track the origin output tensors
device_otensors_[i].data =
const_cast<void*>(origin_otensors_[i]->raw_data());
device_program_->CopyOutputTo(i, &device_otensors_[i]);
}
return 0;
}
......
......@@ -41,6 +41,10 @@ class SubgraphEngine : public subgraph::Engine {
int BuildDeviceProgram() override;
int LaunchDeviceProgram() override;
std::vector<std::string> device_inames_;
std::vector<std::string> device_onames_;
std::vector<DLTensor> device_itensors_;
std::vector<DLTensor> device_otensors_;
std::unique_ptr<xtcl::network::xRuntimeInstance> device_program_{nullptr};
};
......
......@@ -120,6 +120,7 @@ REGISTER_LITE_OP(hard_sigmoid, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(sqrt, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(rsqrt, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(softsign, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(gelu, paddle::lite::operators::ActivationOp);
#ifdef LITE_WITH_TRAIN
REGISTER_LITE_OP(square_grad, paddle::lite::operators::ActivationGradOp);
......
if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH_ARM))
lite_cc_test(test_kernel_scale_compute SRCS scale_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_scale_compute SRCS scale_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_power_compute SRCS power_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_shuffle_channel_compute SRCS shuffle_channel_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_yolo_box_compute SRCS yolo_box_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_fc SRCS fc_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_elementwise_compute SRCS elementwise_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_fc_compute SRCS fc_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_elementwise_compute SRCS elementwise_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_lrn_compute SRCS lrn_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_decode_bboxes_compute SRCS decode_bboxes_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_box_coder_compute SRCS box_coder_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......@@ -41,7 +41,7 @@ if(LITE_BUILD_EXTRA)
lite_cc_test(test_kernel_box_clip_compute SRCS box_clip_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_mean_compute SRCS reduce_mean_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_prod_compute SRCS reduce_prod_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_stack_compute SRCS stack_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_stack_compute SRCS stack_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_range_compute SRCS range_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_affine_channel_compute SRCS affine_channel_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_anchor_generator_compute SRCS anchor_generator_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......@@ -49,6 +49,8 @@ if(LITE_BUILD_EXTRA)
#lite_cc_test(test_kernel_roi_align_compute SRCS roi_align_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_search_aligned_mat_mul_compute SRCS search_aligned_mat_mul_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_search_seq_fc_compute SRCS search_seq_fc_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_lookup_table_compute SRCS lookup_table_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_gather_compute SRCS gather_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
endif()
lite_cc_test(test_kernel_pad2d_compute SRCS pad2d_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_prior_box_compute SRCS prior_box_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
......@@ -34,7 +34,8 @@ enum activation_type_test {
LOG,
EXP,
FLOOR,
RSQRT
RSQRT,
GELU
};
class ActivationComputeTester : public arena::TestCase {
......@@ -184,6 +185,13 @@ class ActivationComputeTester : public arena::TestCase {
}
break;
}
case GELU: {
for (int i = 0; i < dims_.production(); i++) {
output_data[i] = x_data[i] * 0.5 *
(1.0 + std::erf(x_data[i] * 0.70710678118654752440));
}
break;
}
default:
LOG(INFO) << "the type of activation is unknow.";
}
......@@ -243,8 +251,8 @@ class ActivationComputeTester : public arena::TestCase {
TEST(Activation_relu, precision) {
LOG(INFO) << "test relu op";
float abs_error = 2e-5;
Place place;
float abs_error = 2e-5;
#if defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 1e-2; // Using fp16 in NPU
......@@ -280,8 +288,8 @@ TEST(Activation_relu, precision) {
TEST(Activation_leaky_relu, precision) {
LOG(INFO) << "test leaky_relu op";
float abs_error = 2e-5;
Place place;
float abs_error = 2e-5;
#if defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 1e-2; // Using fp16 in NPU
......@@ -317,8 +325,8 @@ TEST(Activation_leaky_relu, precision) {
TEST(Activation_relu_clipped, precision) {
LOG(INFO) << "test relu clipped op";
float abs_error = 2e-5;
Place place;
float abs_error = 2e-5;
#if defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 1e-2; // Using fp16 in NPU
......@@ -384,8 +392,8 @@ TEST(Activation_prelu, precision) {
TEST(Activation_sigmoid, precision) {
LOG(INFO) << "test sigmoid op";
float abs_error = 2e-5;
Place place;
float abs_error = 2e-5;
#if defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 1e-2; // Using fp16 in NPU
......@@ -419,13 +427,15 @@ TEST(Activation_sigmoid, precision) {
TEST(Activation_tanh, precision) {
LOG(INFO) << "test tanh op";
float abs_error = 2e-5;
Place place;
float abs_error = 2e-5;
#if defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 1e-2; // Using fp16 in NPU
#elif defined(LITE_WITH_ARM)
place = TARGET(kARM);
#elif defined(LITE_WITH_XPU)
place = TARGET(kXPU);
#else
return;
#endif
......@@ -621,5 +631,25 @@ TEST(Activation_rsqrt, precision) {
}
#endif
}
TEST(Activation_gelu, precision) {
LOG(INFO) << "test gelu op";
Place place;
float abs_error = 2e-5;
#if defined(LITE_WITH_XPU)
place = TARGET(kXPU);
#else
return;
#endif
for (auto dims : std::vector<std::vector<int64_t>>{
{1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place, "def", 0.01, 6., "all", 0., DDim(dims), "gelu", GELU));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
} // namespace lite
} // namespace paddle
......@@ -161,7 +161,7 @@ class FcOPTest : public arena::TestCase {
}
};
void test_fc(Place place) {
void test_fc(Place place, float abs_error) {
for (auto& m : {1, 3, 16}) {
for (auto& n : {1, 4, 16, 128, 256, 1024}) {
for (auto& k : {1, 16, 128, 1024}) {
......@@ -172,10 +172,12 @@ void test_fc(Place place) {
std::unique_ptr<arena::TestCase> tester(
new FcOPTest(place, "def", dim_in, wdim, bdim, 1));
#ifdef LITE_WITH_ARM
if (place == TARGET(kARM)) {
auto& ctx = tester->context()->As<ARMContext>();
ctx.SetRunMode(lite_api::LITE_POWER_HIGH, 1);
}
#endif
arena::Arena arena(std::move(tester), place, 6e-5);
arena::Arena arena(std::move(tester), place, abs_error);
if (!arena.TestPrecision()) {
LOG(ERROR) << "run m: " << m << ", n: " << n << ", k: " << k
<< ", bias: " << (bflag ? "true" : "false") << " failed";
......@@ -188,13 +190,17 @@ void test_fc(Place place) {
}
TEST(FcOP, precision) {
#ifdef LITE_WITH_X86
Place place(TARGET(kX86));
#endif
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_fc(place);
Place place;
float abs_error = 6e-5;
#if defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 2e-1; // Using fp16 in NPU
#elif defined(LITE_WITH_ARM)
place = TARGET(kARM);
#else
return;
#endif
test_fc(place, abs_error);
}
} // namespace lite
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle {
namespace lite {
class GatherComputeTest : public arena::TestCase {
protected:
// common attributes for this op.
std::string op_type_ = "gather";
std::string x_ = "x";
std::string index_ = "index";
std::string out_ = "out";
DDim x_dims_{{5, 4, 2, 3}};
DDim index_dims_{{2, 1}};
public:
GatherComputeTest(const Place& place,
const std::string& alias,
const DDim& x_dims,
const DDim& index_dims)
: TestCase(place, alias), x_dims_(x_dims), index_dims_(index_dims) {}
void RunBaseline(Scope* scope) override {
auto x = scope->FindTensor(x_);
auto index = scope->FindTensor(index_);
auto x_dims = x->dims();
auto index_dims = index->dims();
CHECK(index_dims.size() == 1 ||
(index_dims.size() == 2 && index_dims[1] == 1));
auto out = scope->NewTensor(out_);
CHECK(out);
int batch_size = index_dims[0];
DDim out_dims = x_dims;
out_dims[0] = batch_size;
out->Resize(out_dims);
auto x_data = x->data<float>();
auto index_data = index->data<int>();
auto out_data = out->mutable_data<float>();
auto slice_num = x_dims[0];
auto slice_size = x_dims.Slice(1, x_dims.size()).production();
for (int i = 0; i < batch_size; i++) {
auto index = index_data[i];
CHECK_LT(index, slice_num) << "gather index[i] expected < " << slice_num
<< " but got " << index;
CHECK_GE(index, 0) << "gather ids[i] expected >= 0 but got " << index;
memcpy(out_data + i * slice_size,
x_data + index * slice_size,
slice_size * sizeof(float));
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType(op_type_);
op_desc->SetInput("X", {x_});
op_desc->SetInput("Index", {index_});
op_desc->SetOutput("Out", {out_});
}
void PrepareData() override {
std::vector<float> x(x_dims_.production());
fill_data_rand(x.data(), -1.f, 1.f, x_dims_.production());
std::vector<int32_t> index(index_dims_.production());
fill_data_rand<int32_t>(
index.data(), 0, x_dims_[0] - 1, index_dims_.production());
SetCommonTensor(x_, x_dims_, x.data());
SetCommonTensor(index_, index_dims_, index.data());
}
};
TEST(Gather, precision) {
LOG(INFO) << "test gather op";
float abs_error = 2e-5;
Place place;
#if defined(LITE_WITH_XPU)
place = TARGET(kXPU);
#else
return;
#endif
for (auto x_dims :
std::vector<std::vector<int64_t>>{{5, 2, 3, 4}, {8, 3, 5}, {12, 3}}) {
for (auto index_dims :
std::vector<std::vector<int64_t>>{{3, 1}, {7, 1}, {10, 1}}) {
std::unique_ptr<arena::TestCase> tester(
new GatherComputeTest(place, "def", DDim(x_dims), DDim(index_dims)));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
}
} // namespace lite
} // namespace paddle
此差异已折叠。
......@@ -82,11 +82,17 @@ class ScaleComputeTester : public arena::TestCase {
};
TEST(Scale, precision) {
#ifdef LITE_WITH_X86
Place place(TARGET(kX86));
#endif
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
Place place;
float abs_error = 2e-5;
#if defined(LITE_WITH_ARM)
place = TARGET(kARM);
#elif defined(LITE_WITH_XPU)
place = TARGET(kXPU);
abs_error = 3e-4; // Some operations use fp16 in XPU
#elif defined(LITE_WITH_X86)
place = TARGET(kX86);
#else
return;
#endif
for (float scale : {0.123, 2., -1.2}) {
......@@ -94,7 +100,7 @@ TEST(Scale, precision) {
for (bool bias_before : {true, false}) {
std::unique_ptr<arena::TestCase> tester(
new ScaleComputeTester(place, "def", scale, bias, bias_before));
arena::Arena arena(std::move(tester), place, 2e-5);
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
......@@ -102,11 +108,13 @@ TEST(Scale, precision) {
}
TEST(Scale, performance) {
#ifdef LITE_WITH_X86
Place place(TARGET(kX86));
#endif
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
Place place;
#if defined(LITE_WITH_ARM)
place = TARGET(kARM);
#elif defined(LITE_WITH_X86)
place = TARGET(kX86);
#else
return;
#endif
std::unique_ptr<arena::TestCase> tester(
......
......@@ -267,14 +267,14 @@ void test_slice_tensor_list(Place place) {
}
TEST(Slice, precision) {
#ifdef LITE_WITH_X86
Place place(TARGET(kX86));
#endif
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_slice(place);
test_slice_tensor(place);
test_slice_tensor_list(place);
#elif defined(LITE_WITH_XPU)
Place place(TARGET(kXPU));
test_slice(place);
#endif
}
......
......@@ -103,13 +103,15 @@ void test_stack(Place place) {
}
TEST(Stack, precision) {
// #ifdef LITE_WITH_X86
// Place place(TARGET(kX86));
// #endif
Place place;
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_stack(place);
place = TARGET(kARM);
#elif defined(LITE_WITH_XPU)
place = TARGET(kXPU);
#else
return;
#endif
test_stack(place);
}
} // namespace lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册