提交 1dbcd51d 编写于 作者: H hong19860320 提交者: GitHub

[LITE][NPU][XPU] Refine subgraph pass, and support NPU/XPU model generation at...

[LITE][NPU][XPU] Refine subgraph pass, and support NPU/XPU model generation at execution time (#2576)
上级 ead74728
......@@ -118,7 +118,7 @@ file(WRITE ${offline_lib_registry_file} "") # clean
function(lite_cc_library TARGET)
set(options SHARED shared STATIC static MODULE module)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS NPU_DEPS XPU_DEPS ARM_DEPS FPGA_DEPS PROFILE_DEPS LIGHT_DEPS
set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS NPU_DEPS XPU_DEPS PROFILE_DEPS LIGHT_DEPS
HVY_DEPS EXCLUDE_COMPILE_DEPS ARGS)
cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
......@@ -128,10 +128,10 @@ function(lite_cc_library TARGET)
X86_DEPS ${args_X86_DEPS}
CUDA_DEPS ${args_CUDA_DEPS}
CL_DEPS ${args_CL_DEPS}
NPU_DEPS ${args_NPU_DEPS}
XPU_DEPS ${args_XPU_DEPS}
ARM_DEPS ${args_ARM_DEPS}
FPGA_DEPS ${args_FPGA_DEPS}
NPU_DEPS ${args_NPU_DEPS}
XPU_DEPS ${args_XPU_DEPS}
PROFILE_DEPS ${args_PROFILE_DEPS}
LIGHT_DEPS ${args_LIGHT_DEPS}
HVY_DEPS ${args_HVY_DEPS}
......@@ -161,7 +161,7 @@ function(lite_cc_binary TARGET)
set(options " -g ")
endif()
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS PROFILE_DEPS
set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS NPU_DEPS XPU_DEPS PROFILE_DEPS
LIGHT_DEPS HVY_DEPS EXCLUDE_COMPILE_DEPS ARGS)
cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
......@@ -173,6 +173,8 @@ function(lite_cc_binary TARGET)
CL_DEPS ${args_CL_DEPS}
ARM_DEPS ${args_ARM_DEPS}
FPGA_DEPS ${args_FPGA_DEPS}
NPU_DEPS ${args_NPU_DEPS}
XPU_DEPS ${args_XPU_DEPS}
PROFILE_DEPS ${args_PROFILE_DEPS}
LIGHT_DEPS ${args_LIGHT_DEPS}
HVY_DEPS ${args_HVY_DEPS}
......@@ -205,7 +207,7 @@ function(lite_cc_test TARGET)
endif()
set(options "")
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS PROFILE_DEPS
set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS NPU_DEPS XPU_DEPS PROFILE_DEPS
LIGHT_DEPS HVY_DEPS EXCLUDE_COMPILE_DEPS
ARGS
COMPILE_LEVEL # (basic|extra)
......@@ -225,6 +227,8 @@ function(lite_cc_test TARGET)
CL_DEPS ${args_CL_DEPS}
ARM_DEPS ${args_ARM_DEPS}
FPGA_DEPS ${args_FPGA_DEPS}
NPU_DEPS ${args_NPU_DEPS}
XPU_DEPS ${args_XPU_DEPS}
PROFILE_DEPS ${args_PROFILE_DEPS}
LIGHT_DEPS ${args_LIGHT_DEPS}
HVY_DEPS ${args_HVY_DEPS}
......@@ -267,7 +271,7 @@ endif()
function(add_kernel TARGET device level)
set(options "")
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS PROFILE_DEPS
set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS NPU_DEPS XPU_DEPS PROFILE_DEPS
LIGHT_DEPS HVY_DEPS EXCLUDE_COMPILE_DEPS
ARGS)
cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
......@@ -360,11 +364,12 @@ function(add_kernel TARGET device level)
lite_cc_library(${TARGET} SRCS ${args_SRCS}
DEPS ${args_DEPS}
X86_DEPS ${args_X86_DEPS}
XPU_DEPS ${args_XPU_DEPS}
CUDA_DEPS ${args_CUDA_DEPS}
CL_DEPS ${args_CL_DEPS}
ARM_DEPS ${args_ARM_DEPS}
FPGA_DEPS ${args_FPGA_DEPS}
NPU_DEPS ${args_NPU_DEPS}
XPU_DEPS ${args_XPU_DEPS}
PROFILE_DEPS ${args_PROFILE_DEPS}
LIGHT_DEPS ${args_LIGHT_DEPS}
HVY_DEPS ${args_HVY_DEPS}
......@@ -383,7 +388,7 @@ endif()
function(add_operator TARGET level)
set(options "")
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS PROFILE_DEPS
set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS NPU_DEPS XPU_DEPS PROFILE_DEPS
LIGHT_DEPS HVY_DEPS EXCLUDE_COMPILE_DEPS
ARGS)
cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
......@@ -409,11 +414,12 @@ function(add_operator TARGET level)
lite_cc_library(${TARGET} SRCS ${args_SRCS}
DEPS ${args_DEPS}
X86_DEPS ${args_X86_DEPS}
XPU_DEPS ${args_XPU_DEPS}
CUDA_DEPS ${args_CUDA_DEPS}
CL_DEPS ${args_CL_DEPS}
ARM_DEPS ${args_ARM_DEPS}
FPGA_DEPS ${args_FPGA_DEPS}
NPU_DEPS ${args_NPU_DEPS}
XPU_DEPS ${args_XPU_DEPS}
PROFILE_DEPS ${args_PROFILE_DEPS}
LIGHT_DEPS ${args_LIGHT_DEPS}
HVY_DEPS ${args_HVY_DEPS}
......
......@@ -89,7 +89,7 @@ else()
endif()
find_library(XPU_SDK_LLVM_FILE NAMES LLVM-8
PATHS ${XPU_SDK_ROOT}/XTDK/shlib)
PATHS ${XPU_SDK_ROOT}/XTDK/shlib/gcc482)
if(NOT XPU_SDK_LLVM_FILE)
message(FATAL_ERROR "Can not find LLVM Library in ${XPU_SDK_ROOT}")
......
......@@ -42,7 +42,7 @@ else()
add_dependencies(paddle_light_api_shared op_list_h kernel_list_h)
if (LITE_WITH_NPU)
# Need to add HIAI runtime libs (libhiai.so) dependency
target_link_libraries(paddle_light_api_shared ${npu_runtime_libs})
target_link_libraries(paddle_light_api_shared ${npu_builder_libs} ${npu_runtime_libs})
endif()
endif()
endif()
......@@ -78,8 +78,8 @@ if (NOT LITE_ON_TINY_PUBLISH)
DEPS ${cxx_api_deps} ${ops} ${host_kernels} program
X86_DEPS ${x86_kernels}
ARM_DEPS ${arm_kernels}
NPU_DEPS ${npu_kernels} ${npu_bridges} npu_pass
XPU_DEPS ${xpu_kernels} ${xpu_bridges} xpu_pass
NPU_DEPS ${npu_kernels}
XPU_DEPS ${xpu_kernels}
CL_DEPS ${opencl_kernels}
FPGA_DEPS ${fpga_kernels})
endif()
......
......@@ -108,7 +108,7 @@ USE_LITE_OP(while)
USE_LITE_OP(lod_reset)
USE_LITE_OP(lookup_table)
USE_LITE_OP(multiclass_nms)
USE_LITE_OP(graph_op)
USE_LITE_OP(subgraph)
USE_LITE_OP(sequence_expand)
USE_LITE_OP(sequence_pool)
USE_LITE_OP(reduce_max)
......
......@@ -30,7 +30,7 @@ else()
add_dependencies(paddle_lite_jni op_list_h kernel_list_h)
if (LITE_WITH_NPU)
# Need to add HIAI runtime libs (libhiai.so) dependency
target_link_libraries(paddle_lite_jni ${npu_runtime_libs})
target_link_libraries(paddle_lite_jni ${npu_builder_libs} ${npu_runtime_libs})
endif()
endif()
......
......@@ -139,22 +139,15 @@ std::vector<std::string> Predictor::GetOutputNames() { return output_names_; }
// append the names of inputs and outputs into input_names_ and output_names_
void Predictor::PrepareFeedFetch() {
std::vector<const cpp::OpDesc *> feeds;
std::vector<const cpp::OpDesc *> fetchs;
#if defined(LITE_WITH_NPU) || defined(LITE_WITH_XPU)
// The shape of input tensors must be determined before generating NPU and XPU
// program.
auto current_block = program_desc_.GetBlock<cpp::BlockDesc>(0);
for (size_t i = 0; i < current_block->OpsSize(); i++) {
auto op = current_block->GetOp<cpp::OpDesc>(i);
#else
if (!program_) {
GenRuntimeProgram();
}
std::vector<const cpp::OpDesc *> feeds;
std::vector<const cpp::OpDesc *> fetchs;
const auto &insts = program_->instructions();
for (size_t i = 0; i < program_->num_instructions(); i++) {
const auto &op = insts[i].op()->op_info();
#endif
if (op->Type() == "feed") {
feeds.push_back(op);
} else if (op->Type() == "fetch") {
......
......@@ -90,6 +90,10 @@ std::vector<Place> ParserValidPlaces() {
TARGET(kARM)); // enable kARM CPU kernel when no opencl kernel
} else if (target_repr == "x86") {
valid_places.emplace_back(TARGET(kX86));
} else if (target_repr == "npu") {
valid_places.emplace_back(TARGET(kNPU));
} else if (target_repr == "xpu") {
valid_places.emplace_back(TARGET(kXPU));
} else {
LOG(FATAL) << lite::string_format(
"Wrong target '%s' found, please check the command flag "
......
......@@ -20,12 +20,6 @@ USE_MIR_PASS(static_kernel_pick_pass);
USE_MIR_PASS(variable_place_inference_pass);
USE_MIR_PASS(type_target_cast_pass);
USE_MIR_PASS(generate_program_pass);
#ifdef LITE_WITH_NPU
USE_MIR_PASS(generate_npu_program_pass);
#endif
#ifdef LITE_WITH_XPU
USE_MIR_PASS(generate_xpu_program_pass);
#endif
USE_MIR_PASS(io_copy_kernel_pick_pass);
USE_MIR_PASS(argument_type_display_pass);
......@@ -45,3 +39,5 @@ USE_MIR_PASS(lite_quant_dequant_fuse_pass);
USE_MIR_PASS(type_precision_cast_pass);
USE_MIR_PASS(type_layout_cast_pass);
USE_MIR_PASS(memory_optimize_pass);
USE_MIR_PASS(npu_subgraph_pass);
USE_MIR_PASS(xpu_subgraph_pass);
......@@ -2,5 +2,4 @@ if(NOT LITE_WITH_NPU)
return()
endif()
lite_cc_library(npu_runtime SRCS runtime.cc DEPS ${npu_runtime_libs})
lite_cc_library(npu_builder SRCS builder.cc DEPS ${npu_builder_libs} npu_runtime tensor op scope)
lite_cc_library(device_npu SRCS device.cc DEPS ${npu_builder_libs} ${npu_runtime_libs})
......@@ -12,47 +12,56 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/npu/runtime.h"
#include <string>
#include <vector>
#include "lite/backends/npu/device.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
namespace npu {
// Create hiai model manager to load om model from lite tensor, and return the
// manager and an unique model name
bool LoadModel(const lite::Tensor &model_data,
std::shared_ptr<hiai::AiModelMngerClient> *model_client,
std::string *model_name) {
LOG(INFO) << "[NPU] Load model.";
auto model_data_ptr = model_data.data<int8_t>();
auto model_data_size = model_data.numel() * sizeof(int8_t);
if (model_data_ptr == nullptr || model_data_size == 0) {
return false;
std::unique_ptr<hiai::AiModelMngerClient> Device::Build(
std::string& model_name, // NOLINT
std::vector<ge::Operator>& input_nodes, // NOLINT
std::vector<ge::Operator>& output_nodes // NOLINT
) {
VLOG(3) << "[NPU] Build model";
// Build the HiAI IR graph to the HiAI om model
ge::Graph ir_graph("graph");
ir_graph.SetInputs(input_nodes).SetOutputs(output_nodes);
ge::Model om_model("model", "model");
om_model.SetGraph(ir_graph);
domi::HiaiIrBuild ir_build;
domi::ModelBufferData om_model_buf;
if (!ir_build.CreateModelBuff(om_model, om_model_buf)) {
LOG(WARNING) << "[NPU] CreateModelBuff failed!";
return nullptr;
}
*model_client = std::make_shared<hiai::AiModelMngerClient>();
int ret = (*model_client)->Init(nullptr);
if (ret != hiai::AI_SUCCESS) {
LOG(WARNING) << "[NPU] AiModelMngerClient init failed(" << ret << ")!";
return false;
if (!ir_build.BuildIRModel(om_model, om_model_buf)) {
LOG(WARNING) << "[NPU] BuildIRModel failed!";
ir_build.ReleaseModelBuff(om_model_buf);
return nullptr;
}
*model_name = "model.om";
// Create a HiAI model manager client to load the HiAI om model
std::unique_ptr<hiai::AiModelMngerClient> model_client(
new hiai::AiModelMngerClient());
if (model_client->Init(nullptr) != hiai::AI_SUCCESS) {
LOG(WARNING) << "[NPU] AiModelMngerClient init failed)!";
ir_build.ReleaseModelBuff(om_model_buf);
return nullptr;
}
model_name = "model_" + std::to_string(model_count_++) + ".om";
auto model_desc = std::make_shared<hiai::AiModelDescription>(
*model_name,
DeviceInfo::Global().freq_level(),
DeviceInfo::Global().framework_type(),
DeviceInfo::Global().model_type(),
DeviceInfo::Global().device_type());
model_desc->SetModelBuffer(model_data_ptr, model_data_size);
model_name, freq_level(), framework_type(), model_type(), device_type());
model_desc->SetModelBuffer(om_model_buf.data, om_model_buf.length);
std::vector<std::shared_ptr<hiai::AiModelDescription>> model_descs;
model_descs.push_back(model_desc);
if ((*model_client)->Load(model_descs) != hiai::AI_SUCCESS) {
if (model_client->Load(model_descs) != hiai::AI_SUCCESS) {
LOG(WARNING) << "[NPU] AiModelMngerClient load model failed!";
return false;
ir_build.ReleaseModelBuff(om_model_buf);
return nullptr;
}
return true;
ir_build.ReleaseModelBuff(om_model_buf);
return model_client;
}
} // namespace npu
......
......@@ -13,38 +13,47 @@
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "ai_ddk_lib/include/HiAiModelManagerService.h"
#include "lite/core/tensor.h"
#include "ai_ddk_lib/include/hiai_ir_build.h"
namespace paddle {
namespace lite {
namespace npu {
class DeviceInfo {
class Device {
public:
static DeviceInfo &Global() {
static DeviceInfo x;
static Device& Global() {
static Device x;
return x;
}
DeviceInfo() {}
Device() {}
int freq_level() { return freq_level_; }
int framework_type() { return framework_type_; }
int model_type() { return model_type_; }
int device_type() { return device_type_; }
// Build the HiAI IR graph to om model, return HiAI model manager client to
// load om model and run inference.
std::unique_ptr<hiai::AiModelMngerClient> Build(
std::string& model_name, // NOLINT
std::vector<ge::Operator>& input_nodes, // NOLINT
std::vector<ge::Operator>& output_nodes // NOLINT
); // NOLINT
private:
int freq_level_{3};
int framework_type_{0};
int model_type_{0};
int device_type_{0};
int model_count_{0};
};
bool LoadModel(const lite::Tensor &model_data,
std::shared_ptr<hiai::AiModelMngerClient> *model_client,
std::string *model_name);
} // namespace npu
} // namespace lite
} // namespace paddle
......@@ -2,5 +2,4 @@ if(NOT LITE_WITH_XPU)
return()
endif()
lite_cc_library(xpu_runtime SRCS runtime.cc DEPS ${xpu_runtime_libs})
lite_cc_library(xpu_builder SRCS builder.cc DEPS ${xpu_builder_libs} xpu_runtime tensor op scope)
lite_cc_library(device_xpu SRCS device.cc DEPS ${xpu_builder_libs} ${xpu_runtime_libs})
......@@ -12,33 +12,31 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/xpu/runtime.h"
#include <vector>
#include "lite/backends/xpu/device.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
namespace xpu {
// Extract the model data and recover the XPU model for inference, the function
// is called by the graph computing kernel when the graph op is executed.
// Due to the lack of XPU APIs for loading and recovering the XPU model from
// memory, the key name is obtained from the weight tensor of graph op, to get
// the runtime object for inference from the global variable 'DeviceInfo'.
// TODO(hong19860320) Recover the XPU model from the weight tensor of graph op.
bool LoadModel(const lite::Tensor &model,
std::shared_ptr<xtcl::network::xRuntimeInstance> *runtime) {
LOG(INFO) << "[XPU] Load Model.";
CHECK_GT(model.dims().production(), 0);
std::string name(reinterpret_cast<const char *>(model.data<int8_t>()));
LOG(INFO) << "[XPU] Model Name: " << name;
CHECK(runtime != nullptr);
*runtime = DeviceInfo::Global().Find(name);
if (*runtime == nullptr) {
LOG(WARNING) << "[XPU] Load Model failed!";
return false;
}
return true;
std::unique_ptr<xtcl::network::xRuntimeInstance> Device::Build(
xtcl::network::xNetworkBuilder* builder,
xtcl::network::xTensorCompiler::ParamNDArrayMap* params,
std::vector<xtcl::xExpr*>* outputs) {
VLOG(3) << "[XPU] Build model";
CHECK(builder != nullptr);
CHECK(outputs != nullptr);
CHECK_GT(outputs->size(), 0);
// The XPU compiler build the graph and fill all of the constant params, only
// one output is supported now.
xtcl::xNetwork network = builder->FinalizeNetwork(*((*outputs)[0]));
auto target = xtcl::Target::Create(device_name_);
auto compiler = xtcl::network::xTensorCompiler(network, target);
compiler.SetParams(*params); // Set the data of constant tensors
compiler.Build();
return std::unique_ptr<xtcl::network::xRuntimeInstance>(
new xtcl::network::xRuntimeInstance(compiler.CreateRuntimeInstance()));
}
} // namespace xpu
......
......@@ -17,31 +17,34 @@
#include <xtcl/xtcl.h>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/core/types.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class GraphCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
class Device {
public:
using param_t = operators::GraphParam;
void PrepareForRun() override;
void Run() override;
virtual ~GraphCompute() = default;
static Device& Global() {
static Device x;
return x;
}
Device() {}
// Build the XPU graph to the XPU runtime, return the XPU runtime which can be
// used to run inference.
std::unique_ptr<xtcl::network::xRuntimeInstance> Build(
xtcl::network::xNetworkBuilder* builder,
xtcl::network::xTensorCompiler::ParamNDArrayMap* params,
std::vector<xtcl::xExpr*>* outputs);
private:
std::shared_ptr<xtcl::network::xRuntimeInstance> runtime_{nullptr};
// Keep reserved fields
int device_id_{0};
std::string device_name_{"llvm"};
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -33,9 +33,9 @@ lite_cc_library(scope SRCS scope.cc DEPS tensor)
lite_cc_library(device_info SRCS device_info.cc DEPS tensor)
if (LITE_WITH_ARM)
lite_cc_library(context SRCS context.cc DEPS tensor any device_info CL_DEPS cl_context gflags NPU_DEPS npu_runtime)
lite_cc_library(context SRCS context.cc DEPS tensor any device_info CL_DEPS cl_context gflags)
else()
lite_cc_library(context SRCS context.cc DEPS tensor any device_info eigen3 CL_DEPS cl_context gflags XPU_DEPS xpu_runtime)
lite_cc_library(context SRCS context.cc DEPS tensor any device_info eigen3 CL_DEPS cl_context gflags)
endif()
#-------------------------------------------- GET CODE META INFO ------------------------------------------
......
......@@ -5,6 +5,6 @@ endif()
lite_cc_library(arena_framework SRCS framework.cc DEPS program gtest)
if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_XPU) AND (LITE_WITH_X86 OR LITE_WITH_ARM))
lite_cc_test(test_arena_framework SRCS framework_test.cc DEPS arena_framework ${x86_kernels} ${fpga_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
if((NOT LITE_WITH_OPENCL) AND (LITE_WITH_X86 OR LITE_WITH_ARM))
lite_cc_test(test_arena_framework SRCS framework_test.cc DEPS arena_framework ${npu_kernels} ${xpu_kernels} ${x86_kernels} ${fpga_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
endif()
......@@ -14,13 +14,38 @@
#include "lite/core/arena/framework.h"
#include "lite/core/context.h"
#include "lite/operators/subgraph_op.h"
namespace paddle {
namespace lite {
namespace arena {
void TestCase::CreateInstruction() {
auto op = LiteOpRegistry::Global().Create(op_desc().Type());
std::shared_ptr<lite::OpLite> op = nullptr;
if (place_.target == TARGET(kNPU) || place_.target == TARGET(kXPU)) {
// Create a new block desc to wrap the original op desc
int sub_block_idx = 0;
auto sub_block_desc = new cpp::BlockDesc();
sub_block_desc->ClearOps();
sub_block_desc->ClearVars();
auto sub_block_op_desc = sub_block_desc->AddOp<cpp::OpDesc>();
*sub_block_op_desc = *op_desc_;
// Add the block desc into the subgraph op which used to replace the
// original op
op_desc_.reset(new cpp::OpDesc());
op_desc_->SetType("subgraph");
op_desc_->SetAttr<int32_t>("sub_block", sub_block_idx);
op_desc_->SetInput("Inputs", op_desc_->input_vars());
op_desc_->SetOutput("Outputs", op_desc_->output_vars());
op_desc_->SetAttr<std::vector<std::string>>(
"input_data_names", sub_block_op_desc->input_vars());
op_desc_->SetAttr<std::vector<std::string>>(
"output_data_names", sub_block_op_desc->output_vars());
op = LiteOpRegistry::Global().Create(op_desc().Type());
static_cast<operators::SubgraphOp*>(op.get())->SetSubBlock(sub_block_desc);
} else {
op = LiteOpRegistry::Global().Create(op_desc().Type());
}
CHECK(op) << "no op for " << op_desc().Type();
op->Attach(*op_desc_, inst_scope_);
auto kernels = op->CreateKernels({place_});
......@@ -68,6 +93,19 @@ void TestCase::PrepareInputsForInstruction() {
}
}
TestCase::~TestCase() {
if (op_desc_->Type() == "subgraph") {
// Release the subblock desc of Subgraph op
auto subgraph_op = const_cast<operators::SubgraphOp*>(
static_cast<const operators::SubgraphOp*>(instruction_->op()));
CHECK(subgraph_op);
auto sub_block_desc = subgraph_op->GetSubBlock();
if (sub_block_desc) {
delete sub_block_desc;
}
}
}
} // namespace arena
} // namespace lite
} // namespace paddle
......@@ -42,7 +42,7 @@ class TestCase {
: place_(place), scope_(new Scope), alias_(alias) {
ctx_ = ContextScheduler::Global().NewContext(place_.target);
}
virtual ~TestCase() {}
virtual ~TestCase();
void Prepare() {
PrepareScopes();
......
......@@ -25,12 +25,6 @@
#include "lite/backends/opencl/cl_context.h"
#include "lite/backends/opencl/cl_runtime.h"
#endif
#ifdef LITE_WITH_NPU
#include "lite/backends/npu/runtime.h"
#endif
#ifdef LITE_WITH_XPU
#include "lite/backends/xpu/runtime.h"
#endif
#include <map>
#include <memory>
......@@ -93,7 +87,7 @@ template <>
class Context<TargetType::kXPU> {
public:
Context() {}
explicit Context(const NPUContext& ctx);
explicit Context(const XPUContext& ctx);
// NOTE: InitOnce should only be used by ContextScheduler
void InitOnce() {}
void CopySharedTo(XPUContext* ctx) {}
......
......@@ -32,7 +32,7 @@ lite_cc_library(mir_passes
demo_pass.cc
runtime_context_assign_pass.cc
memory_optimize_pass.cc
DEPS mir_pass types context ${mir_fusers} ${subgraph_passes})
DEPS mir_pass types context ${mir_fusers} ${mir_subgraphs})
# lite_cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS
#mir_ssa_graph scope op
......
......@@ -36,15 +36,6 @@ std::string Visualize(mir::SSAGraph* graph) {
int id = 0;
std::set<std::string> exists_args;
std::map<int, std::string> graph_col; // Different colors of subgraphs
graph_col.insert({{1, "red"},
{2, "green"},
{3, "cyan"},
{4, "bisque3"},
{5, "coral"},
{6, "darkseagreen1"},
{7, "goldenrod1"},
{8, "darkorchid"}});
for (auto& node : graph->mutable_nodes()) {
std::string key;
if (node.IsArg()) {
......@@ -52,24 +43,12 @@ std::string Visualize(mir::SSAGraph* graph) {
} else {
key = string_format("%s%d", node.AsStmt().op_type().c_str(), id++);
}
if (node.IsStmt()) {
auto& stmt = node.AsStmt();
auto sub_id = stmt.subgraph_id();
auto it = graph_col.find(sub_id);
if (sub_id > 0 && it != graph_col.end()) {
dot.AddNode(key,
{Dot::Attr("shape", "box"),
Dot::Attr("style", "filled"),
Dot::Attr("color", "black"),
Dot::Attr("fillcolor", it->second)});
} else {
dot.AddNode(key,
{Dot::Attr("shape", "box"),
Dot::Attr("style", "filled"),
Dot::Attr("color", "black"),
Dot::Attr("fillcolor", "yellow")});
}
dot.AddNode(key,
{Dot::Attr("shape", "box"),
Dot::Attr("style", "filled"),
Dot::Attr("color", "black"),
Dot::Attr("fillcolor", "yellow")});
for (auto& x : node.inlinks) {
auto name = x->AsArg().name;
if (!exists_args.count(name)) {
......
......@@ -50,7 +50,7 @@ void MemoryOptimizePass::CollectLifeCycleByDevice(
"lod_reset",
"concat",
"yolo_box",
"graph_op",
"subgraph",
"feed",
"fetch"};
for (auto* tmp : node->inlinks) {
......
......@@ -64,9 +64,6 @@ class Node {
return valid_kernels_;
}
void ClearSubgraphID() { subgraph_id_ = -1 /* note: not 0 */; }
void SetSubgraphID(int id) { subgraph_id_ = id; }
int subgraph_id() const { return subgraph_id_; }
void SetOp(const std::shared_ptr<OpLite>& op) { op_ = op; }
const std::shared_ptr<OpLite> op() const { return op_; }
......@@ -82,11 +79,6 @@ class Node {
// Description.
std::string desc;
protected:
// -1 means not in subgraph, 0 means supported but not one id, id started
// from 1
int subgraph_id_{-1};
};
struct Arg {
......
lite_cc_library(subgraph_detector
SRCS subgraph_detector.cc
DEPS mir_pass types subgraph_op)
lite_cc_library(subgraph_pass
SRCS subgraph_program_pass.cc
DEPS mir_pass types ${mir_fusers})
lite_cc_test(test_subgraph_pass SRCS subgraph_program_pass_test.cc
DEPS subgraph_pass mir_passes gflags model_parser cxx_api
ARGS --model_dir=${LITE_MODEL_DIR}/mobilenet_v1 SERIAL)
SRCS subgraph_pass.cc
DEPS mir_pass types context ${mir_fusers} subgraph_detector)
if (WITH_TESTING)
add_dependencies(test_subgraph_pass extern_lite_download_mobilenet_v1_tar_gz)
add_dependencies(test_subgraph_pass extern_lite_download_mobilenet_v2_relu_tar_gz)
set(LINK_FLAGS "-Wl,--version-script ${PADDLE_SOURCE_DIR}/lite/core/lite.map")
set_target_properties(test_subgraph_pass PROPERTIES LINK_FLAGS "${LINK_FLAGS}")
endif()
set(subgraph_passes subgraph_pass)
if(LITE_WITH_NPU)
lite_cc_library(npu_pass SRCS generate_npu_program_pass.cc
DEPS mir_pass types context ${mir_fusers} ${npu_bridges} graph_op subgraph_pass)
list(APPEND subgraph_passes npu_pass)
lite_cc_test(test_npu_pass SRCS generate_npu_program_pass_test.cc
DEPS npu_pass mir_passes paddle_api_full paddle_api_light gflags
ARGS --model_dir=${LITE_MODEL_DIR}/mobilenet_v1
--optimized_model=${LITE_MODEL_DIR}/lite_npu_model_opt SERIAL)
if (WITH_TESTING)
add_dependencies(test_npu_pass extern_lite_download_mobilenet_v1_tar_gz)
add_dependencies(test_subgraph_pass extern_lite_download_mobilenet_v2_relu_tar_gz)
lite_cc_test(test_subgraph_detector
SRCS subgraph_detector_test.cc
DEPS subgraph_detector mir_passes gflags model_parser cxx_api
ARGS --model_dir=${LITE_MODEL_DIR}/mobilenet_v1 SERIAL)
add_dependencies(test_subgraph_detector
extern_lite_download_mobilenet_v1_tar_gz
extern_lite_download_mobilenet_v2_relu_tar_gz)
set(LINK_FLAGS "-Wl,--version-script ${PADDLE_SOURCE_DIR}/lite/core/lite.map")
set_target_properties(test_npu_pass PROPERTIES LINK_FLAGS "${LINK_FLAGS}")
endif()
endif()
if(LITE_WITH_XPU)
lite_cc_library(xpu_pass SRCS generate_xpu_program_pass.cc
DEPS mir_pass types context ${mir_fusers} ${xpu_bridges} ${xpu_builder_libs} graph_op subgraph_pass)
list(APPEND subgraph_passes xpu_pass)
lite_cc_test(test_xpu_pass SRCS generate_xpu_program_pass_test.cc
DEPS xpu_pass mir_passes paddle_api_full gflags
ARGS --model_dir=${LITE_MODEL_DIR}/mobilenet_v1
--optimized_model=${LITE_MODEL_DIR}/lite_npu_model_opt SERIAL)
if (WITH_TESTING)
add_dependencies(test_xpu_pass extern_lite_download_mobilenet_v1_tar_gz)
add_dependencies(test_subgraph_pass extern_lite_download_mobilenet_v2_relu_tar_gz)
set_target_properties(test_subgraph_detector PROPERTIES LINK_FLAGS "${LINK_FLAGS}")
lite_cc_test(test_subgraph_pass
SRCS subgraph_pass_test.cc
DEPS mir_passes paddle_api_full paddle_api_light gflags
ARGS --model_dir=${LITE_MODEL_DIR}/mobilenet_v1
--optimized_model_dir=${LITE_MODEL_DIR}/lite_model_opt SERIAL)
add_dependencies(test_subgraph_pass
extern_lite_download_mobilenet_v1_tar_gz
extern_lite_download_mobilenet_v2_relu_tar_gz)
set(LINK_FLAGS "-Wl,--version-script ${PADDLE_SOURCE_DIR}/lite/core/lite.map")
set_target_properties(test_xpu_pass PROPERTIES LINK_FLAGS "${LINK_FLAGS}")
endif()
set_target_properties(test_subgraph_pass PROPERTIES LINK_FLAGS "${LINK_FLAGS}")
endif()
set(subgraph_passes ${subgraph_passes} CACHE INTERNAL "subgraph_passes")
message(STATUS "----> subgraph_passes: ${subgraph_passes}")
set(mir_subgraphs subgraph_pass CACHE INTERNAL "mir_subgraphs")
message(STATUS "----> mir_subgraphs: ${mir_subgraphs}")
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/subgraph/generate_npu_program_pass.h"
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/pattern_matcher.h"
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/paddle_use_npu_bridges.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace mir {
namespace subgraph {
std::shared_ptr<ge::Operator> GenerateNPUProgramPass::CvtVarNode(
lite::mir::Node* var_node, const Scope* scope) {
CHECK(var_node->IsArg());
const auto& arg = var_node->AsArg();
VLOG(4) << "[NPU] Convert var node " << arg.name;
auto* var = scope->FindVar(arg.name);
CHECK(var);
auto* tensor = var->GetMutable<lite::Tensor>();
CHECK(tensor);
auto dims = tensor->dims();
if (arg.is_weight) {
auto wgt = std::make_shared<ge::op::Const>(arg.name);
LOG(INFO) << "[NPU] Convert const var node " << arg.name;
VLOG(4) << dims;
wgt->set_attr_value(lite::npu::CvtTensor(tensor));
return wgt;
} else {
CHECK_EQ(dims.size(), 4);
LOG(INFO) << "[NPU] Convert data var node " << arg.name;
LOG(INFO) << dims;
// TODO(xxx): support more types and dims size
ge::TensorDesc desc(ge::Shape(dims.Vectorize()),
ge::Format::FORMAT_NCHW,
ge::DataType::DT_FLOAT);
// auto size = desc.GetShape().GetShapeSize();
// ge::TensorUtils::SetSize(desc, size*sizeof(float));
// ge::TensorUtils::SetRealDimCnt(desc, 4);
auto data = std::make_shared<ge::op::Data>(arg.name);
data->update_input_desc_x(desc);
return data;
}
return nullptr;
}
void GenerateNPUProgramPass::CvtAllOpNodes(
const std::vector<Node*>& nodes2cvt,
lite::kernels::npu::bridges::node_map_type* converted_vars) {
const auto& bridges = lite::kernels::npu::bridges::Factory::Instance();
const auto& cvtfunc_map = bridges.AllFunctions();
// return record all converted vars
// op node's inputs must be found in converted_vars
for (auto& node : nodes2cvt) {
lite::kernels::npu::bridges::node_map_type node_inputs;
auto& stmt = node->AsStmt();
for (auto& var_node : node->inlinks) {
auto& arg = var_node->AsArg();
// weight should be handled in the converter, so skip here
if (arg.is_weight) {
continue;
}
auto var_name = arg.name;
if (!converted_vars->count(var_name)) {
converted_vars->insert(
std::make_pair(var_name, CvtVarNode(var_node, stmt.op()->scope())));
}
node_inputs.insert(*converted_vars->find(var_name));
}
auto node_outputs = cvtfunc_map.at(stmt.op_type())(stmt.op(), node_inputs);
converted_vars->insert(node_outputs.begin(), node_outputs.end());
}
}
std::string GenerateNPUProgramPass::BuildNPUGraph(
const std::unordered_set<Node*>& op_nodes,
const std::unordered_set<Node*>& in_data_vars,
const std::unordered_set<Node*>& out_data_vars,
int sub_id) {
auto ordered_nodes = GetTopologicalOrder(op_nodes);
lite::kernels::npu::bridges::node_map_type converted_vars;
CvtAllOpNodes(ordered_nodes, &converted_vars);
std::vector<std::string> in_var_names;
std::vector<std::string> out_var_names;
std::vector<ge::Operator> inputs;
std::vector<ge::Operator> outputs;
for (auto i : in_data_vars) {
auto argname = i->AsArg().name;
in_var_names.push_back(argname);
inputs.push_back(*converted_vars.at(argname));
}
for (auto i : out_data_vars) {
auto argname = i->AsArg().name;
out_var_names.push_back(argname);
outputs.push_back(*converted_vars.at(argname));
}
std::string weight_var_name = "graph" + std::to_string(sub_id) + "_weights";
auto any_op = (*op_nodes.begin())->AsStmt().op();
auto weight = any_op->scope()->Var(weight_var_name)->GetMutable<Tensor>();
weight->set_persistable(true);
weight->set_precision(PRECISION(kInt8));
// Compiling IR graph to NPU model and store mode data into weight tensor with
// persistable=true, Sothat the model parser can recognize it and save it to
// param files
if (!lite::npu::BuildModel(inputs, outputs, weight)) {
LOG(FATAL) << "[NPU] Build NPU graph failed (subgraph=" << sub_id << ")";
} else {
LOG(INFO) << "[NPU] Build NPU graph success (subgraph=" << sub_id << ")";
}
return weight_var_name;
}
void GenerateNPUProgramPass::GenNPUSubgraph(
const std::unique_ptr<SSAGraph>& graph,
const std::unordered_set<Node*>& op_nodes,
int sub_id) {
std::unordered_set<Node*> in_data_vars;
std::unordered_set<Node*> in_wgt_vars;
std::unordered_set<Node*> out_data_vars;
std::unordered_set<Node*> out_unused_vars;
FindInputOutputVars(
op_nodes, &in_data_vars, &in_wgt_vars, &out_data_vars, &out_unused_vars);
auto weight_var_name =
BuildNPUGraph(op_nodes, in_data_vars, out_data_vars, sub_id);
auto any_op = (*op_nodes.begin())->AsStmt().op();
InsertNewNode(graph,
weight_var_name,
any_op->scope(),
any_op->valid_places(),
in_data_vars,
in_wgt_vars,
out_data_vars,
out_unused_vars);
auto nodes2rm = GetNode2rm(
op_nodes, {in_data_vars, in_wgt_vars, out_data_vars, out_unused_vars});
GraphSafeRemoveNodes(graph.get(), nodes2rm);
}
void GenerateNPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
LOG(INFO) << "[NPU] Before NPU Pass \n" << Visualize(graph.get());
const auto& bridges = lite::kernels::npu::bridges::Factory::Instance();
const auto& op_map = bridges.AllFunctions();
std::vector<std::string> supported_op_types;
for (auto& i : op_map) {
LOG(INFO) << "[NPU] Supported type: " << i.first;
supported_op_types.push_back(i.first);
}
int num_subgraph = FuseSubgraph(graph, supported_op_types);
InferOnce(graph);
auto op_nodes_all = ClassifySubgraph(graph);
CHECK_EQ(op_nodes_all.size(), num_subgraph);
int id = 1;
for (auto& op_nodes : op_nodes_all) {
LOG(INFO) << "[NPU] Converting Subgraph " << id;
GenNPUSubgraph(graph, op_nodes.second, id);
LOG(INFO) << "[NPU] After NPU Pass Subgraph " << id << "\n"
<< Visualize(graph.get());
id++;
}
}
} // namespace subgraph
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(generate_npu_program_pass,
paddle::lite::mir::subgraph::GenerateNPUProgramPass)
.BindTargets({TARGET(kNPU)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/subgraph/generate_xpu_program_pass.h"
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/pattern_matcher.h"
#include "lite/backends/xpu/builder.h"
#include "lite/kernels/xpu/bridges/paddle_use_xpu_bridges.h"
#include "lite/kernels/xpu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace mir {
namespace subgraph {
std::shared_ptr<xtcl::xExpr> GenerateXPUProgramPass::CvtVarNode(
lite::kernels::xpu::bridges::graph_ctx_type* graph_ctx,
lite::mir::Node* var_node,
const Scope* scope) {
CHECK(var_node->IsArg());
const auto& arg = var_node->AsArg();
auto var_name = arg.name;
VLOG(4) << "[XPU] Convert var node " << var_name;
auto* var = scope->FindVar(var_name);
CHECK(var);
auto* tensor = var->GetMutable<lite::Tensor>();
CHECK(tensor);
auto dims = tensor->dims();
auto cvted_var_node =
std::make_shared<xtcl::xExpr>(graph_ctx->builder->CreateTensor(
var_name, lite::xpu::CvtShape(dims), ::xtcl::Float(32)));
if (arg.is_weight) {
auto cvted_var_tensor = lite::xpu::CvtTensor(tensor);
graph_ctx->params->emplace(std::make_pair(var_name, *cvted_var_tensor));
}
return cvted_var_node;
}
void GenerateXPUProgramPass::CvtAllOpNodes(
const std::vector<Node*>& op_nodes,
lite::kernels::xpu::bridges::graph_ctx_type* graph_ctx,
lite::kernels::xpu::bridges::node_map_type* cvted_var_nodes) {
const auto& bridges = lite::kernels::xpu::bridges::Factory::Instance();
const auto& supported_lists = bridges.AllFunctions();
// return record all converted vars
// op node's inputs must be found in converted_vars
for (auto& node : op_nodes) {
lite::kernels::xpu::bridges::node_map_type input_nodes;
auto& stmt = node->AsStmt();
for (auto& var_node : node->inlinks) {
auto& arg = var_node->AsArg();
// weight should be handled in the converter, so skip here
if (arg.is_weight) {
continue;
}
auto var_name = arg.name;
if (!cvted_var_nodes->count(var_name)) {
cvted_var_nodes->insert(std::make_pair(
var_name, CvtVarNode(graph_ctx, var_node, stmt.op()->scope())));
}
input_nodes.insert(*cvted_var_nodes->find(var_name));
}
auto output_nodes =
supported_lists.at(stmt.op_type())(stmt.op(), graph_ctx, input_nodes);
cvted_var_nodes->insert(output_nodes.begin(), output_nodes.end());
}
}
std::string GenerateXPUProgramPass::BuildXPUGraph(
const std::unordered_set<Node*>& op_nodes,
const std::unordered_set<Node*>& in_data_vars,
const std::unordered_set<Node*>& out_data_vars,
int sub_id) {
auto ordered_op_nodes = GetTopologicalOrder(op_nodes);
lite::kernels::xpu::bridges::graph_ctx_type graph_ctx;
graph_ctx.builder = std::make_shared<xtcl::network::xNetworkBuilder>();
graph_ctx.params =
std::make_shared<xtcl::network::xTensorCompiler::ParamNDArrayMap>();
lite::kernels::xpu::bridges::node_map_type cvted_var_nodes;
CvtAllOpNodes(ordered_op_nodes, &graph_ctx, &cvted_var_nodes);
std::string weight_var_name = "graph" + std::to_string(sub_id) + "_weights";
auto any_op = (*op_nodes.begin())->AsStmt().op();
auto weight = any_op->scope()->Var(weight_var_name)->GetMutable<Tensor>();
weight->set_persistable(true);
weight->set_precision(PRECISION(kInt8));
// Compiling graph to XPU model and store mode data into weight tensor with
// persistable=true, Sothat the model parser can recognize it and save it to
// param files
std::vector<std::shared_ptr<xtcl::xExpr>> ordered_cvted_var_nodes;
for (auto out_data_var : out_data_vars) {
auto var_name = out_data_var->AsArg().name;
ordered_cvted_var_nodes.push_back(cvted_var_nodes[var_name]);
}
if (!lite::xpu::BuildModel(graph_ctx.builder,
graph_ctx.params,
&ordered_cvted_var_nodes,
weight)) {
LOG(FATAL) << "[XPU] Build XPU graph failed (subgraph=" << sub_id << ")";
} else {
LOG(INFO) << "[XPU] Build XPU graph success (subgraph=" << sub_id << ")";
}
return weight_var_name;
}
void GenerateXPUProgramPass::GenXPUSubgraph(
const std::unique_ptr<SSAGraph>& graph,
const std::unordered_set<Node*>& op_nodes,
int sub_id) {
std::unordered_set<Node*> in_data_vars;
std::unordered_set<Node*> in_wgt_vars;
std::unordered_set<Node*> out_data_vars;
std::unordered_set<Node*> out_unused_vars;
FindInputOutputVars(
op_nodes, &in_data_vars, &in_wgt_vars, &out_data_vars, &out_unused_vars);
auto weight_var_name =
BuildXPUGraph(op_nodes, in_data_vars, out_data_vars, sub_id);
auto any_op = (*op_nodes.begin())->AsStmt().op();
InsertNewNode(graph,
weight_var_name,
any_op->scope(),
any_op->valid_places(),
in_data_vars,
in_wgt_vars,
out_data_vars,
out_unused_vars);
auto nodes2rm = GetNode2rm(
op_nodes, {in_data_vars, in_wgt_vars, out_data_vars, out_unused_vars});
GraphSafeRemoveNodes(graph.get(), nodes2rm);
}
void GenerateXPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
LOG(INFO) << "[XPU] Before XPU Pass \n" << Visualize(graph.get());
const auto& bridges = lite::kernels::xpu::bridges::Factory::Instance();
const auto& op_map = bridges.AllFunctions();
std::vector<std::string> supported_op_types;
for (auto& i : op_map) {
LOG(INFO) << "[XPU] Supported type: " << i.first;
supported_op_types.push_back(i.first);
}
int num_subgraph = FuseSubgraph(graph, supported_op_types);
InferOnce(graph);
auto op_nodes_all = ClassifySubgraph(graph);
CHECK_EQ(op_nodes_all.size(), num_subgraph);
int id = 1;
for (auto& op_nodes : op_nodes_all) {
LOG(INFO) << "[XPU] Converting Subgraph " << id;
GenXPUSubgraph(graph, op_nodes.second, id);
LOG(INFO) << "[XPU] After XPU Pass Subgraph " << id << "\n"
<< Visualize(graph.get());
id++;
}
}
} // namespace subgraph
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(generate_xpu_program_pass,
paddle::lite::mir::subgraph::GenerateXPUProgramPass)
.BindTargets({TARGET(kXPU)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <cmath>
#include "lite/api/paddle_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/utils/cp_logging.h"
DEFINE_string(model_file, "", "model file path of combined protobuf model");
DEFINE_string(params_file, "", "params file path of combined protobuf model");
DEFINE_string(optimized_model_dir, "", "path of optimized naive buffer model");
DEFINE_string(input_tensor_shape, "1,3,224,224", "shapes of input tensors");
DEFINE_int32(output_tensor_num, 1, "number of output tensors");
namespace paddle {
namespace lite {
std::vector<std::vector<int64_t>> ParseShape(std::string txt) {
std::vector<std::vector<int64_t>> shape;
while (!txt.empty()) {
size_t idx = txt.find_first_of(":");
std::string dims = txt.substr(0, idx);
std::vector<int64_t> s;
while (!dims.empty()) {
size_t idx = dims.find_first_of(",");
int d = atoi(dims.substr(0, idx).c_str());
VLOG(3) << d;
s.push_back(d);
if (idx == std::string::npos) {
break;
} else {
dims = dims.substr(idx + 1);
}
}
shape.push_back(s);
if (idx == std::string::npos) {
break;
} else {
txt = txt.substr(idx + 1);
}
}
return shape;
}
int64_t ShapeProduction(std::vector<int64_t> shape) {
int64_t s = 1;
for (int64_t dim : shape) {
s *= dim;
}
return s;
}
void FillInputTensor(
const std::shared_ptr<lite_api::PaddlePredictor>& predictor,
const std::vector<std::vector<int64_t>>& input_tensor_shape,
const float value) {
for (int i = 0; i < input_tensor_shape.size(); i++) {
auto input_tensor = predictor->GetInput(i);
input_tensor->Resize(input_tensor_shape[i]);
auto input_tensor_data = input_tensor->mutable_data<float>();
auto input_tensor_size = ShapeProduction(input_tensor->shape());
for (int j = 0; j < input_tensor_size; j++) {
input_tensor_data[j] = value;
}
}
}
void CompareOutputTensor(
const std::shared_ptr<lite_api::PaddlePredictor>& tar_predictor,
const std::shared_ptr<lite_api::PaddlePredictor>& ref_predictor,
const int output_tensor_num) {
for (int i = 0; i < output_tensor_num; i++) {
auto tar_output_tensor = tar_predictor->GetOutput(i);
auto ref_output_tensor = ref_predictor->GetOutput(i);
auto tar_output_tensor_data = tar_output_tensor->data<float>();
auto ref_output_tensor_data = ref_output_tensor->data<float>();
auto tar_output_tensor_size = ShapeProduction(tar_output_tensor->shape());
auto ref_output_tensor_size = ShapeProduction(ref_output_tensor->shape());
EXPECT_EQ(tar_output_tensor_size, ref_output_tensor_size);
for (size_t j = 0; j < ref_output_tensor_size; j++) {
auto diff =
std::fabs(tar_output_tensor_data[j] - ref_output_tensor_data[j]) /
(std::fabs(ref_output_tensor_data[j]) + 1e-6);
VLOG(3) << diff;
EXPECT_LT(diff, 0.1);
}
}
}
std::shared_ptr<lite_api::PaddlePredictor> TestModel(
const std::string& model_dir,
const std::string& model_file,
const std::string& params_file,
const std::vector<lite_api::Place>& valid_places,
const std::vector<std::vector<int64_t>>& input_tensor_shape,
const std::string& optimized_model_dir) {
// generate optimized model
lite_api::CxxConfig cxx_config;
cxx_config.set_model_dir(model_dir);
cxx_config.set_model_file(model_file);
cxx_config.set_param_file(params_file);
cxx_config.set_valid_places(valid_places);
auto predictor = lite_api::CreatePaddlePredictor(cxx_config);
FillInputTensor(predictor, input_tensor_shape, -1);
predictor->SaveOptimizedModel(optimized_model_dir,
lite_api::LiteModelType::kNaiveBuffer);
#if 0 // TODO(hong19860320) supports light api for XPU
// load optimized model
lite_api::MobileConfig mobile_config;
mobile_config.set_model_dir(optimized_model_dir);
mobile_config.set_power_mode(lite_api::PowerMode::LITE_POWER_HIGH);
mobile_config.set_threads(1);
predictor = lite_api::CreatePaddlePredictor(mobile_config);
FillInputTensor(predictor, input_tensor_shape, 1);
#endif
// run optimized model
for (int i = 0; i < FLAGS_warmup; i++) {
predictor->Run();
}
for (int i = 0; i < FLAGS_repeats; i++) {
auto start = GetCurrentUS();
predictor->Run();
LOG(INFO) << i << ", " << GetCurrentUS() - start << "us";
}
return predictor;
}
TEST(XPUSubgraph, compare) {
// parsing input tensor shape, supported formats: "1,3,224,224"
// "1,3,224,224:1,80"
std::vector<std::vector<int64_t>> input_tensor_shape =
ParseShape(FLAGS_input_tensor_shape);
// generate and run optimized CPU model
LOG(INFO) << " ================ CPU ================== ";
auto cpu_predictor =
TestModel(FLAGS_model_dir,
FLAGS_model_file,
FLAGS_params_file,
{lite_api::Place{TARGET(kX86), PRECISION(kFloat)}},
input_tensor_shape,
FLAGS_optimized_model_dir + "/CPU");
// generate and run optimized XPU model
LOG(INFO) << " ================ XPU ================== ";
auto xpu_predictor =
TestModel(FLAGS_model_dir,
FLAGS_model_file,
FLAGS_params_file,
{lite_api::Place{TARGET(kXPU), PRECISION(kFloat)},
lite_api::Place{TARGET(kX86), PRECISION(kFloat)}},
input_tensor_shape,
FLAGS_optimized_model_dir + "/XPU");
// verify results
CompareOutputTensor(xpu_predictor, cpu_predictor, FLAGS_output_tensor_num);
}
} // namespace lite
} // namespace paddle
此差异已折叠。
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
using SubgraphTeller = std::function<bool(Node*)>;
class SubgraphVisualizer {
public:
SubgraphVisualizer(SSAGraph* graph,
const std::vector<std::vector<Node*>>& subgraphs)
: graph_(graph), subgraphs_(subgraphs) {}
std::string operator()();
protected:
SSAGraph* graph_{nullptr};
std::vector<std::vector<Node*>> subgraphs_;
};
/*
* Divide the graph into subgraphs according to the specified conditions.
* Return the divided clusters, a cluster is consisted of the op nodes in the
* subgraph.
*/
class SubgraphDetector {
public:
// This is a simple representation of a graph. The SDNode hold the
// pointer of the Node. This is to avoid changing the original graph in the
// process of graph analysis.
struct node_dat_t;
using node_map_t = std::unordered_map<Node*, node_dat_t*>;
using node_set_t = std::vector<node_dat_t*>;
struct node_dat_t {
explicit node_dat_t(Node* _node) : node(_node) {}
Node* node;
bool marked{false};
node_dat_t* union_find_parent{this};
node_set_t inlinks{};
node_set_t outlinks{};
node_dat_t* UnionFindAncestor();
void UnionFindCombine(node_dat_t* candidate);
};
SubgraphDetector(SSAGraph* graph, const SubgraphTeller& teller)
: graph_(graph), teller_(teller) {}
std::vector<std::vector<Node*>> operator()();
void FlexibleDFS(const node_set_t& source,
bool reverse,
const std::function<bool(const node_dat_t*)>& enter,
const std::function<bool(const node_dat_t*)>& leave);
void InitNodes(node_map_t* nodes);
std::vector<std::vector<Node*>> ExtractSubgraphs(node_map_t* nodes);
protected:
SSAGraph* graph_{nullptr};
SubgraphTeller teller_;
};
/*
* Replace all of subgraphs with the subgraph ops, a block desc is added into
* the subgraph op to wrap the original op nodes, keep all of var nodes of the
* original ops nodes as the inputs and outputs of the subgraph op
*/
class SubgraphFuser {
public:
SubgraphFuser(SSAGraph* graph,
const SubgraphTeller& teller,
int min_subgraph_size)
: graph_(graph), teller_(teller), min_subgraph_size_{min_subgraph_size} {}
void operator()();
// Remove the op nodes of the subgraphs and replace with the subgraph ops.
void ReplaceNodesWithSubgraphs(SSAGraph* graph,
const SubgraphTeller& teller,
int min_subgraph_size);
// Create a subgraph node with a block desc to wrap the original op nodes of
// the subgraph
void InsertNewNode(SSAGraph* graph,
int subgraph_idx,
const std::vector<Node*>& subgraph_nodes);
protected:
SSAGraph* graph_{nullptr};
SubgraphTeller teller_;
int min_subgraph_size_;
};
void ExtractInputsOutputs(const std::vector<Node*>& op_nodes,
std::unordered_set<Node*>* input_var_nodes,
std::unordered_set<Node*>* weight_var_nodes,
std::unordered_set<Node*>* output_var_nodes,
std::unordered_set<Node*>* local_var_nodes,
std::unordered_set<Node*>* unused_var_nodes);
std::unordered_set<const Node*> GetNodes2RM(
const std::vector<Node*>& op_nodes,
const std::vector<std::unordered_set<Node*>>& excluded_var_nodes);
std::vector<Node*> GetTopologicalOrder(
const std::unordered_set<Node*>& unordered_nodes);
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -12,68 +12,25 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/subgraph/subgraph_program_pass.h"
#include "lite/core/mir/subgraph/subgraph_detector.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/ssa_graph.h"
#include "lite/core/program.h"
#include "lite/model_parser/cpp/program_desc.h"
#include "lite/model_parser/model_parser.h"
DEFINE_string(model_dir, "", "model_dir");
DEFINE_string(model_file, "", "model file path of combined protobuf model");
DEFINE_string(params_file, "", "params file path of combined protobuf model");
namespace paddle {
namespace lite {
TEST(SubgraphTest, models) {
cpp::ProgramDesc program_desc;
auto scope = std::make_shared<Scope>();
// LoadModelPb(FLAGS_model_dir,
// FLAGS_model_dir + "/model",
// FLAGS_model_dir + "/params",
// scope.get(),
// &program_desc,
// true);
LoadModelPb(FLAGS_model_dir, "", "", scope.get(), &program_desc);
std::vector<Place> valid_places({
Place{TARGET(kHost), PRECISION(kFloat)},
#ifdef LITE_WITH_ARM
Place{TARGET(kARM), PRECISION(kFloat)},
#endif
#ifdef LITE_WITH_NPU
Place{TARGET(kNPU), PRECISION(kFloat)},
#endif
#ifdef LITE_WITH_XPU
Place{TARGET(kXPU), PRECISION(kFloat)},
#endif
});
lite::Program program(program_desc, scope, valid_places);
auto graph = std::unique_ptr<mir::SSAGraph>(new mir::SSAGraph());
graph->Build(program, valid_places);
std::vector<std::string> supported_op_types{"concat",
"conv2d",
"depthwise_conv2d",
"batch_norm",
"scale",
"pool2d",
"mul",
"elementwise_add",
"softmax",
"split",
"relu",
"reshape2",
"transpose2"};
auto* pass = new mir::subgraph::SubgraphProgramPass;
ASSERT_EQ(pass->FuseSubgraph(graph, supported_op_types), 1);
LOG(INFO) << "After NPU Pass \n" << Visualize(graph.get());
}
// return output_var_names
// The helper functions for building model manually
std::vector<std::string> AddFCDesc(
cpp::BlockDesc* block_desc,
const std::shared_ptr<Scope>& scope,
......@@ -87,20 +44,20 @@ std::vector<std::string> AddFCDesc(
auto* wgt = block_desc->AddVar<cpp::VarDesc>();
wgt->SetName(prefix + "_W");
auto* wtensor = scope->Var(prefix + "_W")->GetMutable<lite::Tensor>();
auto* wtensor = scope->Var(prefix + "_W")->GetMutable<Tensor>();
wtensor->Resize(wshape);
wtensor->mutable_data<float>();
auto* bias = block_desc->AddVar<cpp::VarDesc>();
bias->SetName(prefix + "_Bias");
auto* btensor = scope->Var(prefix + "_Bias")->GetMutable<lite::Tensor>();
auto* btensor = scope->Var(prefix + "_Bias")->GetMutable<Tensor>();
btensor->Resize({wshape[1]});
btensor->mutable_data<float>();
auto* out = block_desc->AddVar<cpp::VarDesc>();
out->SetName(prefix + "_Out");
std::vector<std::string> out_var_names{prefix + "_Out"};
scope->Var(prefix + "_Out")->GetMutable<lite::Tensor>();
scope->Var(prefix + "_Out")->GetMutable<Tensor>();
op_desc->SetType("fc");
op_desc->SetInput("Input", input_var_names);
......@@ -126,7 +83,7 @@ std::vector<std::string> AddElementwiseAddDesc(
out->SetName(prefix + "_Out");
std::vector<std::string> out_var_names{prefix + "_Out"};
scope->Var(prefix + "_Out")->GetMutable<lite::Tensor>();
scope->Var(prefix + "_Out")->GetMutable<Tensor>();
op_desc->SetType("elementwise_add");
op_desc->SetInput("X", input_X_names);
......@@ -150,7 +107,7 @@ std::vector<std::string> AddFeedDesc(
out->SetName(prefix + "_Out");
std::vector<std::string> out_var_names{prefix + "_Out"};
scope->Var(prefix + "_Out")->GetMutable<lite::Tensor>();
scope->Var(prefix + "_Out")->GetMutable<Tensor>();
op_desc->SetType("feed");
op_desc->SetInput("X", input_X_names);
......@@ -173,7 +130,7 @@ std::vector<std::string> AddFetchDesc(
out->SetName(prefix + "_Out");
std::vector<std::string> out_var_names{prefix + "_Out"};
scope->Var(prefix + "_Out")->GetMutable<lite::Tensor>();
scope->Var(prefix + "_Out")->GetMutable<Tensor>();
op_desc->SetType("fetch");
op_desc->SetInput("X", input_X_names);
......@@ -183,40 +140,88 @@ std::vector<std::string> AddFetchDesc(
return out_var_names;
}
std::unique_ptr<mir::SSAGraph> BuildSimpleNet(
cpp::ProgramDesc* program_desc,
const std::shared_ptr<Scope>& scope,
const std::vector<Place>& valid_places) {
program_desc->ClearBlocks();
auto* block_desc = program_desc->AddBlock<cpp::BlockDesc>();
TEST(Subgraph, detect_simple_model) {
cpp::ProgramDesc program_desc;
std::vector<Place> valid_places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>();
// Build a simple network
program_desc.ClearBlocks();
auto* block_desc = program_desc.AddBlock<cpp::BlockDesc>();
block_desc->ClearOps();
block_desc->ClearVars();
auto* var_desc = block_desc->AddVar<cpp::VarDesc>();
var_desc->SetName("feed_var");
auto* feed_var = scope->Var("feed_var")->GetMutable<lite::Tensor>();
auto* feed_var = scope->Var("feed_var")->GetMutable<Tensor>();
feed_var->Resize({1, 4});
auto fc1_out = AddFCDesc(block_desc, scope, {"feed_var"}, {4, 5});
auto fc2_out = AddFCDesc(block_desc, scope, fc1_out, {5, 2});
lite::Program program(*program_desc, scope, valid_places);
Program program(program_desc, scope, valid_places);
auto graph = std::unique_ptr<mir::SSAGraph>(new mir::SSAGraph());
graph->Build(program, valid_places);
return graph;
// Apply subgraph detector and check results
auto teller = [](mir::Node* node) {
if (!node->IsStmt()) return false;
auto& stmt = node->AsStmt();
auto op_type = stmt.op_type();
const std::vector<std::string> supported_types = {"fc"};
return std::find(supported_types.begin(), supported_types.end(), op_type) !=
supported_types.end();
};
std::vector<std::vector<mir::Node*>> subgraphs =
mir::SubgraphDetector(graph.get(), teller)();
ASSERT_EQ(subgraphs.size(), 1);
ASSERT_EQ(graph->nodes().size(), 9);
mir::SubgraphVisualizer(graph.get(), subgraphs)();
}
TEST(SubGraphTest, SimpleNet) {
TEST(Subgraph, detect_custom_model) {
if (FLAGS_model_dir.empty() && FLAGS_model_file.empty() &&
FLAGS_params_file.empty()) {
LOG(INFO) << "Using --model_dir, or --model_file and --params_file to set "
"the path of model files.";
return;
}
cpp::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>();
auto graph = BuildSimpleNet(&program_desc, scope, places);
std::vector<std::string> supported_op_types{"fc"};
auto* pass = new mir::subgraph::SubgraphProgramPass;
ASSERT_EQ(pass->FuseSubgraph(graph, supported_op_types), 1);
ASSERT_EQ(graph->nodes().size(), 9);
// LOG(INFO) << "After NPU Pass \n" << Visualize(graph.get());
LoadModelPb(FLAGS_model_dir,
FLAGS_model_file,
FLAGS_params_file,
scope.get(),
&program_desc,
!FLAGS_model_file.empty() && !FLAGS_params_file.empty(),
false);
std::vector<Place> valid_places({
#ifdef LITE_WITH_ARM
Place{TARGET(kARM), PRECISION(kFloat)},
#endif
#ifdef LITE_WITH_X86
Place{TARGET(kX86), PRECISION(kFloat)},
#endif
#ifdef LITE_WITH_NPU
Place{TARGET(kNPU), PRECISION(kFloat)},
#endif
#ifdef LITE_WITH_XPU
Place{TARGET(kXPU), PRECISION(kFloat)},
#endif
});
Program program(program_desc, scope, valid_places);
auto graph = std::unique_ptr<mir::SSAGraph>(new mir::SSAGraph());
graph->Build(program, valid_places);
// Apply subgraph detector and check results
auto teller = [](mir::Node* node) {
if (!node->IsStmt()) return false;
auto& stmt = node->AsStmt();
auto op_type = stmt.op_type();
const std::vector<std::string> unsupported_types = {
"feed", "fetch", "subgraph"};
return std::find(unsupported_types.begin(),
unsupported_types.end(),
op_type) == unsupported_types.end();
};
std::vector<std::vector<mir::Node*>> subgraphs =
mir::SubgraphDetector(graph.get(), teller)();
ASSERT_EQ(subgraphs.size(), 1);
mir::SubgraphVisualizer(graph.get(), subgraphs)();
}
} // namespace lite
......
......@@ -12,58 +12,52 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <xtcl/xtcl.h>
#include "lite/core/mir/subgraph/subgraph_pass.h"
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "lite/core/tensor.h"
#include <vector>
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/subgraph/subgraph_detector.h"
namespace paddle {
namespace lite {
namespace xpu {
class DeviceInfo {
public:
static DeviceInfo& Global() {
static DeviceInfo x;
return x;
}
DeviceInfo() {}
void Insert(const std::string& name,
std::shared_ptr<xtcl::network::xRuntimeInstance> runtime) {
if (runtimes_.find(name) != runtimes_.end()) {
LOG(WARNING) << "[XPU] Model " << name << " already exists.";
return;
}
runtimes_.emplace(std::make_pair(name, runtime));
}
void Clear() { runtimes_.clear(); }
std::shared_ptr<xtcl::network::xRuntimeInstance> Find(
const std::string& name) const {
if (runtimes_.find(name) != runtimes_.end()) {
return runtimes_.at(name);
} else {
return nullptr;
}
}
private:
int device_id_{0};
std::string device_name_{"default"};
std::unordered_map<std::string,
std::shared_ptr<xtcl::network::xRuntimeInstance>>
runtimes_;
};
bool LoadModel(const lite::Tensor& model,
std::shared_ptr<xtcl::network::xRuntimeInstance>* runtime);
} // namespace xpu
namespace mir {
void NPUSubgraphPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
std::unordered_set<std::string> supported_lists;
#define USE_SUBGRAPH_BRIDGE(dev_type, op_type) supported_lists.insert(#op_type);
#include "lite/kernels/npu/bridges/paddle_use_bridges.h"
#undef USE_SUBGRAPH_BRIDGE
auto teller = [&](Node* node) {
if (!node->IsStmt()) return false;
auto& stmt = node->AsStmt();
return supported_lists.count(stmt.op_type()) != 0;
};
SubgraphFuser fuser(graph.get(), teller, 1 /* min_subgraph_size */);
fuser();
}
void XPUSubgraphPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
std::unordered_set<std::string> supported_lists;
#define USE_SUBGRAPH_BRIDGE(dev_type, op_type) supported_lists.insert(#op_type);
#include "lite/kernels/xpu/bridges/paddle_use_bridges.h"
#undef USE_SUBGRAPH_BRIDGE
auto teller = [&](Node* node) {
if (!node->IsStmt()) return false;
auto& stmt = node->AsStmt();
return supported_lists.count(stmt.op_type()) != 0;
};
SubgraphFuser fuser(graph.get(), teller, 1 /* min_subgraph_size */);
fuser();
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(npu_subgraph_pass, paddle::lite::mir::NPUSubgraphPass)
.BindTargets({TARGET(kNPU)});
REGISTER_MIR_PASS(xpu_subgraph_pass, paddle::lite::mir::XPUSubgraphPass)
.BindTargets({TARGET(kXPU)});
......@@ -12,30 +12,26 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/bridges/registry.h"
#include <utility>
#pragma once
#include <memory>
#include <vector>
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
namespace bridges {
Factory& Factory::Instance() {
static Factory g_xpu_bridge;
return g_xpu_bridge;
}
namespace mir {
bool Factory::HasType(const std::string& op_type) const {
return map_.count(op_type);
}
class NPUSubgraphPass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
void Factory::Insert(const std::string& op_type, const func_type& func_name) {
map_.insert(std::make_pair(op_type, func_name));
}
class XPUSubgraphPass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace bridges
} // namespace xpu
} // namespace kernels
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -30,7 +30,9 @@ DEFINE_int32(output_tensor_num, 1, "number of output tensors");
namespace paddle {
namespace lite {
std::vector<std::vector<int64_t>> ParseShape(std::string txt) {
// The helper functions for loading and running model from command line and
// verifying output data
std::vector<std::vector<int64_t>> ShapeParsing(std::string txt) {
std::vector<std::vector<int64_t>> shape;
while (!txt.empty()) {
size_t idx = txt.find_first_of(":");
......@@ -65,7 +67,7 @@ int64_t ShapeProduction(std::vector<int64_t> shape) {
return s;
}
void FillInputTensor(
void FillInputTensors(
const std::shared_ptr<lite_api::PaddlePredictor>& predictor,
const std::vector<std::vector<int64_t>>& input_tensor_shape,
const float value) {
......@@ -80,7 +82,7 @@ void FillInputTensor(
}
}
void CompareOutputTensor(
void CheckOutputTensors(
const std::shared_ptr<lite_api::PaddlePredictor>& tar_predictor,
const std::shared_ptr<lite_api::PaddlePredictor>& ref_predictor,
const int output_tensor_num) {
......@@ -96,7 +98,7 @@ void CompareOutputTensor(
auto abs_diff =
std::fabs(tar_output_tensor_data[j] - ref_output_tensor_data[j]);
auto rel_diff = abs_diff / (std::fabs(ref_output_tensor_data[j]) + 1e-6);
VLOG(3) << "val: " << tar_output_tensor_data[j]
VLOG(5) << "val: " << tar_output_tensor_data[j]
<< " ref: " << ref_output_tensor_data[j]
<< " abs_diff: " << abs_diff << " rel_diff: " << rel_diff;
EXPECT_LT(rel_diff, 0.1);
......@@ -111,24 +113,23 @@ std::shared_ptr<lite_api::PaddlePredictor> TestModel(
const std::vector<lite_api::Place>& valid_places,
const std::vector<std::vector<int64_t>>& input_tensor_shape,
const std::string& optimized_model_dir) {
// generate optimized model
// Generate optimized model
lite_api::CxxConfig cxx_config;
cxx_config.set_model_dir(model_dir);
cxx_config.set_model_file(model_file);
cxx_config.set_param_file(params_file);
cxx_config.set_valid_places(valid_places);
auto predictor = lite_api::CreatePaddlePredictor(cxx_config);
FillInputTensor(predictor, input_tensor_shape, 1);
predictor->SaveOptimizedModel(optimized_model_dir,
lite_api::LiteModelType::kNaiveBuffer);
// load optimized model
// Load optimized model
lite_api::MobileConfig mobile_config;
mobile_config.set_model_dir(optimized_model_dir);
mobile_config.set_power_mode(lite_api::PowerMode::LITE_POWER_HIGH);
mobile_config.set_threads(1);
predictor = lite_api::CreatePaddlePredictor(mobile_config);
FillInputTensor(predictor, input_tensor_shape, 1);
// run optimized model
FillInputTensors(predictor, input_tensor_shape, 1);
// Run optimized model
for (int i = 0; i < FLAGS_warmup; i++) {
predictor->Run();
}
......@@ -140,32 +141,48 @@ std::shared_ptr<lite_api::PaddlePredictor> TestModel(
return predictor;
}
TEST(NPUSubgraph, compare) {
// parsing input tensor shape, supported formats: "1,3,224,224"
// "1,3,224,224:1,80"
TEST(Subgraph, generate_model_and_check_precision) {
if (FLAGS_model_dir.empty() && FLAGS_model_file.empty() &&
FLAGS_params_file.empty()) {
LOG(INFO) << "Using --model_dir, or --model_file and --params_file to set "
"the path of model files.";
return;
}
// Parsing the shapes of input tensors from strings, supported formats:
// "1,3,224,224" and "1,3,224,224:1,80"
std::vector<std::vector<int64_t>> input_tensor_shape =
ParseShape(FLAGS_input_tensor_shape);
// generate and run optimized CPU model
LOG(INFO) << " ================ CPU ================== ";
auto cpu_predictor =
TestModel(FLAGS_model_dir,
FLAGS_model_file,
FLAGS_params_file,
{lite_api::Place{TARGET(kARM), PRECISION(kFloat)}},
input_tensor_shape,
FLAGS_optimized_model_dir + "/CPU");
// generate and run optimized NPU model
LOG(INFO) << " ================ NPU ================== ";
auto npu_predictor =
TestModel(FLAGS_model_dir,
FLAGS_model_file,
FLAGS_params_file,
{lite_api::Place{TARGET(kNPU), PRECISION(kFloat)},
lite_api::Place{TARGET(kARM), PRECISION(kFloat)}},
input_tensor_shape,
FLAGS_optimized_model_dir + "/NPU");
// verify results
CompareOutputTensor(npu_predictor, cpu_predictor, FLAGS_output_tensor_num);
ShapeParsing(FLAGS_input_tensor_shape);
std::vector<lite_api::Place> valid_places({
#ifdef LITE_WITH_ARM
lite_api::Place{TARGET(kARM), PRECISION(kFloat)},
#endif
#ifdef LITE_WITH_X86
lite_api::Place{TARGET(kX86), PRECISION(kFloat)},
#endif
});
// Generate and run optimized model on CPU as the reference predictor
auto ref_predictor = TestModel(FLAGS_model_dir,
FLAGS_model_file,
FLAGS_params_file,
valid_places,
input_tensor_shape,
FLAGS_optimized_model_dir + "/ref_opt_model");
// Generate and run optimized model on NPU/XPU as the target predictor
#ifdef LITE_WITH_NPU
valid_places.push_back(lite_api::Place{TARGET(kNPU), PRECISION(kFloat)});
#endif
#ifdef LITE_WITH_XPU
valid_places.push_back(lite_api::Place{TARGET(kXPU), PRECISION(kFloat)});
#endif
auto tar_predictor = TestModel(FLAGS_model_dir,
FLAGS_model_file,
FLAGS_params_file,
valid_places,
input_tensor_shape,
FLAGS_optimized_model_dir + "/tar_opt_model");
// Check the difference of the output tensors between reference predictor and
// target predictor
CheckOutputTensors(tar_predictor, ref_predictor, FLAGS_output_tensor_num);
}
} // namespace lite
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/subgraph/subgraph_program_pass.h"
#include <memory>
#include <unordered_set>
#include <utility>
#include <vector>
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/pattern_matcher.h"
namespace paddle {
namespace lite {
namespace mir {
namespace subgraph {
std::unordered_map<int, std::unordered_set<Node*>>
SubgraphProgramPass::ClassifySubgraph(const std::unique_ptr<SSAGraph>& graph) {
std::unordered_map<int, std::unordered_set<Node*>> op_nodes;
for (auto& item : graph->StmtTopologicalOrder()) {
if (!item->IsStmt()) continue;
auto& stmt = item->AsStmt();
int sub_id = stmt.subgraph_id();
if (sub_id < 1) continue;
if (!op_nodes.count(sub_id)) {
op_nodes[sub_id] = std::unordered_set<Node*>();
}
op_nodes.at(sub_id).insert(item);
}
return op_nodes;
}
cpp::OpDesc SubgraphProgramPass::GenGraphOpDesc(
const std::string& weight_var_name,
const std::vector<std::string>& in_var_names,
const std::vector<std::string>& out_var_names) {
cpp::OpDesc op_desc;
op_desc.SetType("graph_op");
op_desc.SetInput("Inputs", in_var_names);
op_desc.SetInput("Weight", {weight_var_name});
op_desc.SetOutput("Outputs", out_var_names);
return op_desc;
}
void SubgraphProgramPass::InsertNewNode(
const std::unique_ptr<SSAGraph>& graph,
const std::string& weight_var_name,
Scope* scope,
const std::vector<Place>& valid_places,
std::unordered_set<Node*> in_data_vars,
std::unordered_set<Node*> in_wgt_vars,
std::unordered_set<Node*> out_data_vars,
std::unordered_set<Node*> out_unused_vars) {
std::vector<std::string> in_var_names;
std::vector<std::string> out_var_names;
for (auto i : in_data_vars) {
in_var_names.push_back(i->AsArg().name);
}
for (auto i : out_data_vars) {
out_var_names.push_back(i->AsArg().name);
}
auto op_desc = GenGraphOpDesc(weight_var_name, in_var_names, out_var_names);
auto graph_op = LiteOpRegistry::Global().Create("graph_op");
graph_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(graph_op, valid_places);
for (auto& in_var : in_data_vars) {
IR_NODE_LINK_TO(in_var, new_op_node);
}
for (auto& in_var : in_wgt_vars) {
IR_NODE_LINK_TO(in_var, new_op_node);
}
for (auto& out_var : out_data_vars) {
IR_OP_VAR_LINK(new_op_node, out_var);
}
for (auto& out_var : out_unused_vars) {
IR_OP_VAR_LINK(new_op_node, out_var);
}
// add weight node to store pre-compilied NPU model
auto new_weight_node = graph->NewArgumentNode(weight_var_name);
new_weight_node->AsArg().is_weight = true;
new_weight_node->AsArg().is_persist = true;
DirectedLink(new_weight_node, new_op_node);
// assign context
auto& inst = new_op_node->AsStmt();
inst.picked_kernel().SetContext(
ContextScheduler::Global().NewContext(inst.picked_kernel().target()));
}
void SubgraphProgramPass::SortHelper(
Node* node,
const std::unordered_set<Node*>& nodes_all,
std::unordered_set<const Node*>* visited_nodes,
std::vector<Node*>* ret) {
for (auto& var_node : node->inlinks) {
if (var_node->inlinks.empty()) continue;
auto* op_node = var_node->inlinks.front();
if (nodes_all.count(op_node) && !visited_nodes->count(op_node)) {
SortHelper(op_node, nodes_all, visited_nodes, ret);
}
}
ret->push_back(node);
visited_nodes->insert(node);
}
std::vector<Node*> SubgraphProgramPass::GetTopologicalOrder(
const std::unordered_set<Node*>& nodes) {
std::unordered_set<const Node*> visited;
std::vector<Node*> ret;
for (auto& node : nodes) {
if (!node->IsStmt()) continue;
if (visited.count(node)) continue;
SortHelper(node, nodes, &visited, &ret);
}
return ret;
}
void SubgraphProgramPass::FindInputOutputVars(
const std::unordered_set<Node*>& op_nodes,
std::unordered_set<Node*>* in_data_vars,
std::unordered_set<Node*>* in_wgt_vars,
std::unordered_set<Node*>* out_data_vars,
std::unordered_set<Node*>* out_unused_vars) {
for (auto& op_node : op_nodes) {
for (auto& in_var : op_node->inlinks) {
if (in_var->AsArg().is_weight) {
in_wgt_vars->insert(in_var);
continue;
}
if (!in_var->inlinks.empty()) {
// var can only come from one op node, so use front
auto* pre_op_node = in_var->inlinks.front();
if (op_nodes.count(pre_op_node)) {
continue;
}
}
in_data_vars->insert(in_var);
}
for (auto& out_var : op_node->outlinks) {
if (out_var->outlinks.empty()) {
// the next op is empty so this var is actually unused
out_unused_vars->insert(out_var);
continue;
}
// var can have more than one next op node
// so, if any one in the op_nodes then continue
bool next_op_in_nodes = false;
for (auto& next_op_node : out_var->outlinks) {
if (op_nodes.count(next_op_node)) {
next_op_in_nodes = true;
}
}
if (next_op_in_nodes) {
continue;
}
out_data_vars->insert(out_var);
}
}
}
std::unordered_set<const Node*> SubgraphProgramPass::GetNode2rm(
const std::unordered_set<Node*>& op_nodes,
const std::vector<std::unordered_set<Node*>>& excluded_nodes) {
std::unordered_set<const Node*> nodes2rm(op_nodes.begin(), op_nodes.end());
for (auto& op_node : op_nodes) {
for (auto& in_var : op_node->inlinks) {
if (!nodes2rm.count(in_var)) {
nodes2rm.insert(in_var);
}
}
for (auto& out_var : op_node->outlinks) {
if (!nodes2rm.count(out_var)) {
nodes2rm.insert(out_var);
}
}
}
// some nodes should not be removed
for (auto& e : excluded_nodes) {
for (auto& i : e) {
if (nodes2rm.count(i)) {
nodes2rm.erase(i);
}
}
}
return nodes2rm;
}
void SubgraphProgramPass::InferOnce(const std::unique_ptr<SSAGraph>& graph) {
for (auto& item : graph->StmtTopologicalOrder()) {
if (!item->IsStmt()) continue;
auto& stmt = item->AsStmt();
auto& op = stmt.op();
auto scope = op->scope();
std::string op_type = op->op_info()->Type();
// check the dimension of input variables in the scope, must not be empty !
if (op_type == "feed") {
auto input_var_names = op->op_info()->output_names();
CHECK_GE(input_var_names.size(), 1);
for (auto input_var_name : input_var_names) {
auto input_var = scope->FindVar(input_var_name);
CHECK(input_var) << "No input variable '" << input_var_name
<< "' found in scope " << scope;
auto input = input_var->GetMutable<lite::Tensor>();
CHECK(!input->dims().empty()) << "The dimension of input variable '"
<< input_var_name
<< "' can not be empty.";
}
continue;
}
if (op_type == "fetch") {
continue;
}
op->CheckShape();
op->InferShape();
#ifndef LITH_WITH_XPU
// TOOD(xxx): remove Launch() at last
auto& kkks = stmt.kernels();
if (!kkks.empty()) {
auto& kk = stmt.kernels().front();
if (kk) {
kk->Launch();
}
}
#endif
}
}
void SubgraphProgramPass::InitSubgraphID(
const std::unique_ptr<SSAGraph>& graph,
const std::vector<std::string>& supported_op_types) {
for (auto& item : graph->StmtTopologicalOrder()) {
if (!item->IsStmt()) continue;
auto& stmt = item->AsStmt();
stmt.ClearSubgraphID();
if (std::find(supported_op_types.begin(),
supported_op_types.end(),
stmt.op_type()) != supported_op_types.end()) {
stmt.SetSubgraphID(0);
LOG(INFO) << "supported " << stmt.op_type();
} else {
LOG(INFO) << "======= not supported " << stmt.op_type();
}
}
}
// mark current and all output supported nodes
void SubgraphProgramPass::ChangeAllOutConnectedID(Node* node,
int to_id,
int from_id) {
if (!node) return;
if (node->IsStmt()) {
auto& stmt = node->AsStmt();
if (stmt.subgraph_id() == from_id) {
stmt.SetSubgraphID(to_id);
for (auto& i : node->outlinks) {
ChangeAllOutConnectedID(i, to_id, from_id);
}
} else {
LOG(INFO) << "failed op type:" << stmt.op_type();
return;
}
} else {
// this it arg node
bool all_out_op_supported = true;
for (auto& i : node->outlinks) {
if (!i->IsStmt()) return;
auto& stmt = i->AsStmt();
if (stmt.subgraph_id() < from_id) {
all_out_op_supported = false;
}
}
if (!all_out_op_supported) {
return;
}
for (auto& i : node->outlinks) {
CHECK(i->IsStmt());
auto& stmt = i->AsStmt();
if (stmt.subgraph_id() == from_id) {
stmt.SetSubgraphID(to_id);
for (auto& o : i->outlinks) {
ChangeAllOutConnectedID(o, to_id, from_id);
}
}
}
}
}
int SubgraphProgramPass::FuseSubgraphID(
const std::unique_ptr<SSAGraph>& graph) {
int sub_id = 1; // id start from 1 not 0
for (auto& item : graph->StmtTopologicalOrder()) {
// bool inputvar = false;
if (!item->IsStmt()) continue;
auto& stmt = item->AsStmt();
/*
if (stmt.subgraph_id() == -1) {
for (auto& i : item->outlinks) {
for (auto& j : i->outlinks) {
if (j->IsStmt()) {
auto& jstmt = j->AsStmt();
if (jstmt.subgraph_id() == 0) inputvar = true;
}
}
}
}
*/
if (stmt.subgraph_id() != 0) continue;
ChangeAllOutConnectedID(item, sub_id);
sub_id++;
}
return sub_id - 1;
}
int SubgraphProgramPass::FuseSubgraph(
const std::unique_ptr<SSAGraph>& graph,
const std::vector<std::string>& supported_op_types) {
InitSubgraphID(graph, supported_op_types);
return FuseSubgraphID(graph);
}
} // namespace subgraph
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(subgraph_program_pass,
paddle::lite::mir::subgraph::SubgraphProgramPass)
.BindTargets({TARGET(kAny)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
namespace subgraph {
class SubgraphProgramPass : public ProgramPass {
public:
using key2nodes_t = std::map<std::string, Node*>;
// make all the linked ops in subgraph with same subgraph_id
// return the fused subgraph numbers
int FuseSubgraph(const std::unique_ptr<SSAGraph>& graph,
const std::vector<std::string>& supported_op_types);
void Apply(const std::unique_ptr<SSAGraph>& graph) override{};
protected:
void InferOnce(const std::unique_ptr<SSAGraph>& graph);
// clear all subgraph id and mark all ops, which could be fuse, as id zero
void InitSubgraphID(const std::unique_ptr<SSAGraph>& graph,
const std::vector<std::string>& supported_op_types);
// make all the linked ops in subgraph with same subgraph_id
// return the fused subgraph numbers
int FuseSubgraphID(const std::unique_ptr<SSAGraph>& graph);
// // GenerateFusedGraph:
// std::unique_ptr<SSAGraph> GenerateFusedGraph(const
// std::unique_ptr<SSAGraph>& graph, int sub_num);
void ChangeAllOutConnectedID(Node* node, int to_id, int from_id = 0);
// Below function cloud be useful in child classes //
// classify node by subgraph id
std::unordered_map<int, std::unordered_set<Node*>> ClassifySubgraph(
const std::unique_ptr<SSAGraph>& graph);
// generate the graph op desc
cpp::OpDesc GenGraphOpDesc(const std::string& weight_var_name,
const std::vector<std::string>& in_var_names,
const std::vector<std::string>& out_var_names);
// insert a new graph op node
void InsertNewNode(const std::unique_ptr<SSAGraph>& graph,
const std::string& weight_var_name,
Scope* scope,
const std::vector<Place>& valid_places,
std::unordered_set<Node*> in_data_vars,
std::unordered_set<Node*> in_wgt_vars,
std::unordered_set<Node*> out_data_vars,
std::unordered_set<Node*> out_unused_vars);
// Sort and return the topology order of nodes set
std::vector<Node*> GetTopologicalOrder(
const std::unordered_set<Node*>& nodes);
// find all input data vars, input weight vars,
// output data vars and output vars from the nodes
void FindInputOutputVars(const std::unordered_set<Node*>& op_nodes,
std::unordered_set<Node*>* in_data_vars,
std::unordered_set<Node*>* in_wgt_vars,
std::unordered_set<Node*>* out_data_vars,
std::unordered_set<Node*>* out_unused_vars);
// return the node to remove in the subgraph
std::unordered_set<const Node*> GetNode2rm(
const std::unordered_set<Node*>& op_nodes,
const std::vector<std::unordered_set<Node*>>& excluded_nodes);
private:
// sort nodes to operational sequence
void SortHelper(Node* node,
const std::unordered_set<Node*>& nodes_all,
std::unordered_set<const Node*>* visited_nodes,
std::vector<Node*>* ret);
};
} // namespace subgraph
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -27,12 +27,6 @@
#include "lite/core/program.h"
#include "lite/core/types.h"
#include "lite/model_parser/model_parser.h"
#ifdef LITE_WITH_NPU
#include "lite/core/mir/subgraph/generate_npu_program_pass.h"
#endif
#ifdef LITE_WITH_XPU
#include "lite/core/mir/subgraph/generate_xpu_program_pass.h"
#endif
namespace paddle {
namespace lite {
......@@ -109,7 +103,9 @@ class Optimizer {
"runtime_context_assign_pass",
"argument_type_display_pass",
"memory_optimize_pass"}};
"memory_optimize_pass",
"npu_subgraph_pass",
"xpu_subgraph_pass"}};
RunPasses(passes_local);
} else {
RunPasses(passes);
......@@ -121,13 +117,6 @@ class Optimizer {
// Generate a new program based on the mir graph.
std::unique_ptr<RuntimeProgram> GenRuntimeProgram() {
// Extra passes are applied for NPU and XPU, they depends on the shapes
// of input tensors. so GenRuntimeProgram() must be called after the shapes
// of input tensors are determined.
std::vector<std::string> subgraph_passes{"generate_npu_program_pass",
"generate_xpu_program_pass"};
RunPasses(subgraph_passes);
auto pass = mir::PassManager::Global().LookUp<mir::GenerateProgramPass>(
"generate_program_pass");
pass->Apply(graph_);
......
......@@ -18,6 +18,7 @@
#include "lite/model_parser/cpp/op_desc.h"
#include "lite/model_parser/cpp/var_desc.h"
#include "lite/operators/conditional_block_op.h"
#include "lite/operators/subgraph_op.h"
#include "lite/operators/while_op.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/precision_profiler.h"
......@@ -31,10 +32,32 @@ void RuntimeProgram::SaveOpInfosToProgram(cpp::ProgramDesc* desc) {
// NOTE: RuntimeProgram do not has all meta info, so save model just update
// upon origin model
CHECK(desc->BlocksSize());
auto& main_block = *desc->GetBlock<cpp::BlockDesc>(0);
main_block.ClearOps();
auto main_block = desc->GetBlock<cpp::BlockDesc>(0);
main_block->ClearOps();
for (auto& node : instructions_) {
auto* op = main_block.AddOp<cpp::OpDesc>();
auto op_type = node.op()->op_info()->Type();
if (op_type == "subgraph") {
auto subgraph_op = const_cast<operators::SubgraphOp*>(
static_cast<const operators::SubgraphOp*>(node.op()));
int sub_block_idx = subgraph_op->op_info()->GetAttr<int32_t>("sub_block");
if (sub_block_idx < 0) {
// It's a new subgraph op when its sub_block_idx < 0, Now we add its
// subblock desc to the program desc, Then update its sub_block_idx to
// the index of block desc of the program desc.
sub_block_idx = desc->BlocksSize();
auto sub_block_desc = subgraph_op->GetSubBlock();
CHECK(sub_block_desc);
auto new_block_desc = desc->AddBlock<cpp::BlockDesc>();
*new_block_desc = *sub_block_desc;
delete sub_block_desc;
subgraph_op->mutable_op_info()->SetAttr<int32_t>("sub_block",
sub_block_idx);
subgraph_op->SetSubBlock(new_block_desc);
// Update main block desc after a new subblock desc is added
main_block = desc->GetBlock<cpp::BlockDesc>(0);
}
}
auto op = main_block->AddOp<cpp::OpDesc>();
*op = *node.op()->op_info();
op->SetAttr(kKernelTypeAttr, node.kernel()->SerializedKernelType());
}
......@@ -142,16 +165,25 @@ void Program::Build(const cpp::ProgramDesc& prog) {
VLOG(4) << "create Op [" << op_type << "]";
auto op = LiteOpRegistry::Global().Create(op_type);
CHECK(op) << "no Op found for " << op_type;
if (op_type == "while" || op_type == "conditional_block") {
if (op_type == "while" || op_type == "conditional_block" ||
op_type == "subgraph") {
auto sub_block_idx = op_desc.GetAttr<int32_t>("sub_block");
auto sub_block =
CHECK(sub_block_idx >= 0 && sub_block_idx < program.BlocksSize())
<< "Invalid attribute sub_block(" << sub_block_idx << ") for "
<< op_type;
auto sub_block_desc =
const_cast<cpp::ProgramDesc&>(prog).GetBlock<cpp::BlockDesc>(
sub_block_idx);
CHECK(sub_block_desc);
if (op_type == "while") {
static_cast<operators::WhileOpLite*>(op.get())->SetSubBlock(sub_block);
static_cast<operators::WhileOpLite*>(op.get())->SetSubBlock(
sub_block_desc);
} else if (op_type == "conditional_block") {
static_cast<operators::ConditionalBlockOpLite*>(op.get())->SetSubBlock(
sub_block);
sub_block_desc);
} else if (op_type == "subgraph") {
static_cast<operators::SubgraphOp*>(op.get())->SetSubBlock(
sub_block_desc);
}
}
ops_.emplace_back(std::move(op));
......
add_subdirectory(bridges)
if(NOT LITE_WITH_NPU)
return ()
endif()
message(STATUS "compile with lite NPU kernels")
add_kernel(graph_compute_npu NPU basic SRCS graph_compute.cc DEPS ${lite_kernel_deps} npu_runtime)
# lite_cc_test(test_graph_compute_npu SRCS graph_compute_test.cc DEPS graph_compute_npu)
if(NOT LITE_ON_TINY_PUBLISH)
add_subdirectory(bridges)
endif()
add_kernel(subgraph_compute_npu NPU basic SRCS subgraph_compute.cc DEPS ${lite_kernel_deps} device_npu subgraph_bridge_engine ${npu_subgraph_bridges})
lite_cc_library(npu_bridge_registry SRCS registry.cc)
if(NOT LITE_WITH_NPU AND NOT LITE_WITH_XPU)
return()
endif()
set(npu_bridge_deps npu_bridge_registry npu_builder op)
lite_cc_library(subgraph_bridge_registry
SRCS registry.cc
DEPS op)
lite_cc_library(subgraph_bridge_engine
SRCS engine.cc
DEPS tensor op scope program)
lite_cc_library(npu_bridge_fc_op SRCS fc_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_conv_op SRCS conv_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_mul_op SRCS mul_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_act_op SRCS act_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_scale_op SRCS scale_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_softmax_op SRCS softmax_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_pool_op SRCS pool_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_batch_norm_op SRCS batch_norm_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_elementwise_ops SRCS elementwise_ops.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_reshape_op SRCS reshape_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_conv_transpose_op SRCS conv_transpose_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_interpolate_op SRCS interpolate_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_transpose_op SRCS transpose_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_split_op SRCS split_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_concat_op SRCS concat_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_shuffle_channel_op SRCS shuffle_channel_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_pad2d_op SRCS pad2d_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_square_op SRCS square_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_sqrt_op SRCS sqrt_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_reduce_mean_op SRCS reduce_mean_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_unsqueeze_op SRCS unsqueeze_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_argmax_op SRCS argmax_op.cc DEPS ${npu_bridge_deps})
if(NOT LITE_WITH_NPU)
return()
endif()
set(npu_bridges
npu_bridge_registry
npu_bridge_fc_op
npu_bridge_conv_op
npu_bridge_mul_op
npu_bridge_act_op
npu_bridge_scale_op
npu_bridge_softmax_op
npu_bridge_pool_op
npu_bridge_batch_norm_op
npu_bridge_elementwise_ops
npu_bridge_reshape_op
npu_bridge_conv_transpose_op
npu_bridge_interpolate_op
npu_bridge_transpose_op
npu_bridge_split_op
npu_bridge_concat_op
npu_bridge_shuffle_channel_op
npu_bridge_pad2d_op
npu_bridge_square_op
npu_bridge_sqrt_op
npu_bridge_reduce_mean_op
npu_bridge_unsqueeze_op
npu_bridge_argmax_op
CACHE INTERNAL "npu_bridges")
lite_cc_library(subgraph_bridge_utility_npu SRCS utility.cc DEPS ${npu_builder_libs} tensor)
lite_cc_library(subgraph_bridge_graph_npu SRCS graph.cc DEPS subgraph_bridge_utility_npu)
set(npu_bridge_test_deps ${npu_bridges} ${npu_kernels} ${ops})
set(npu_subgraph_bridge_deps subgraph_bridge_registry subgraph_bridge_utility_npu subgraph_bridge_graph_npu)
lite_cc_test(test_npu_bridge_fc_op SRCS fc_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_conv_op SRCS conv_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_mul_op SRCS mul_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_act_op SRCS act_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_scale_op SRCS scale_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_softmax_op SRCS softmax_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_pool_op SRCS pool_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_batch_norm_op SRCS batch_norm_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_elementwise_ops SRCS elementwise_ops_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_reshape_op SRCS reshape_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_conv_transpose_op SRCS conv_transpose_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_interpolate_op SRCS interpolate_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_transpose_op SRCS transpose_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_split_op SRCS split_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_concat_op SRCS concat_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_shuffle_channel_op SRCS shuffle_channel_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_pad2d_op SRCS pad2d_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_square_op SRCS square_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_sqrt_op SRCS sqrt_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_reduce_mean_op SRCS reduce_mean_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_unsqueeze_op SRCS unsqueeze_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_argmax_op SRCS argmax_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_library(subgraph_bridge_fc_op_npu SRCS fc_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_conv_op_npu SRCS conv_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_mul_op_npu SRCS mul_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_act_op_npu SRCS act_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_scale_op_npu SRCS scale_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_softmax_op_npu SRCS softmax_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_pool_op_npu SRCS pool_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_batch_norm_op_npu SRCS batch_norm_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_elementwise_ops_npu SRCS elementwise_ops.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_reshape_op_npu SRCS reshape_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_conv_transpose_op_npu SRCS conv_transpose_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_interpolate_op_npu SRCS interpolate_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_transpose_op_npu SRCS transpose_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_split_op_npu SRCS split_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_concat_op_npu SRCS concat_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_shuffle_channel_op_npu SRCS shuffle_channel_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_pad2d_op_npu SRCS pad2d_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_square_op_npu SRCS square_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_sqrt_op_npu SRCS sqrt_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_reduce_mean_op_npu SRCS reduce_mean_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_unsqueeze_op_npu SRCS unsqueeze_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_argmax_op_npu SRCS argmax_op.cc DEPS ${npu_subgraph_bridge_deps})
message(STATUS "+++++ npu_bridges: ${npu_bridges}")
set(npu_subgraph_bridges
subgraph_bridge_registry
subgraph_bridge_utility_npu
subgraph_bridge_graph_npu
subgraph_bridge_fc_op_npu
subgraph_bridge_conv_op_npu
subgraph_bridge_mul_op_npu
subgraph_bridge_act_op_npu
subgraph_bridge_scale_op_npu
subgraph_bridge_softmax_op_npu
subgraph_bridge_pool_op_npu
subgraph_bridge_batch_norm_op_npu
subgraph_bridge_elementwise_ops_npu
subgraph_bridge_reshape_op_npu
subgraph_bridge_conv_transpose_op_npu
subgraph_bridge_interpolate_op_npu
subgraph_bridge_transpose_op_npu
subgraph_bridge_split_op_npu
subgraph_bridge_concat_op_npu
subgraph_bridge_shuffle_channel_op_npu
subgraph_bridge_pad2d_op_npu
subgraph_bridge_square_op_npu
subgraph_bridge_sqrt_op_npu
subgraph_bridge_reduce_mean_op_npu
subgraph_bridge_unsqueeze_op_npu
subgraph_bridge_argmax_op_npu
CACHE INTERNAL "npu_subgraph_bridges")
message(STATUS "+++++ npu_subgraph_bridges: ${npu_subgraph_bridges}")
......@@ -12,34 +12,32 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/graph.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace subgraph {
namespace npu {
namespace bridges {
node_map_type ActConverter(const std::shared_ptr<lite::OpLite> act_op,
const node_map_type& inputs_map) {
auto scope = act_op->scope();
auto op_info = act_op->op_info();
int ActConverter(void* ctx, OpLite* op) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "[NPU] Converting " + op_type + "...";
VLOG(3) << "[NPU] Converting " + op_type + "...";
// create act node and set input node from inputs_map
// Create act node and set input node which is obtained from the node map
auto x_var_name = op_info->Input("X").front();
auto act_node = std::make_shared<ge::op::Activation>(unique_op_type);
CHECK(inputs_map.count(x_var_name));
act_node->set_input_x(*inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(act_node);
auto out_var_name = op_info->Output("Out").front();
auto act_node = graph->AddNode<ge::op::Activation>(out_var_name);
act_node->set_input_x(*graph->GetNode(x_var_name));
// TODO(hong19860320) set the coef value for act Ops, such as leaky_relu,
// clipped_relu etc.
act_node->set_attr_mode(lite::npu::CvtActMode(op_type));
act_node->set_attr_mode(CvtActMode(op_type));
if (op_type == "relu_clipped") {
auto Relu_clipped_coef = op_info->GetAttr<float>("Relu_clipped_coef");
......@@ -56,31 +54,33 @@ node_map_type ActConverter(const std::shared_ptr<lite::OpLite> act_op,
act_node->set_attr_negative_slope(slope);
act_node->set_attr_coef(offset);
}
node_map_type outputs_map;
outputs_map[op_info->Output("Out").front()] = act_node;
return outputs_map;
return SUCCESS;
}
} // namespace bridges
} // namespace npu
} // namespace kernels
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_NPU_BRIDGE(sigmoid, paddle::lite::kernels::npu::bridges::ActConverter);
REGISTER_NPU_BRIDGE(relu, paddle::lite::kernels::npu::bridges::ActConverter);
REGISTER_NPU_BRIDGE(tanh, paddle::lite::kernels::npu::bridges::ActConverter);
REGISTER_NPU_BRIDGE(relu_clipped,
paddle::lite::kernels::npu::bridges::ActConverter);
REGISTER_NPU_BRIDGE(relu6, paddle::lite::kernels::npu::bridges::ActConverter);
// REGISTER_NPU_BRIDGE(elu, paddle::lite::kernels::npu::bridges::ActConverter);
REGISTER_NPU_BRIDGE(leaky_relu,
paddle::lite::kernels::npu::bridges::ActConverter);
REGISTER_NPU_BRIDGE(abs, paddle::lite::kernels::npu::bridges::ActConverter);
REGISTER_NPU_BRIDGE(softsign,
paddle::lite::kernels::npu::bridges::ActConverter);
REGISTER_NPU_BRIDGE(softplus,
paddle::lite::kernels::npu::bridges::ActConverter);
REGISTER_NPU_BRIDGE(hard_sigmoid,
paddle::lite::kernels::npu::bridges::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
sigmoid,
paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU, relu, paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU, tanh, paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
relu_clipped,
paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU, relu6, paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
leaky_relu,
paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU, abs, paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
softsign,
paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
softplus,
paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
hard_sigmoid,
paddle::lite::subgraph::npu::ActConverter);
......@@ -12,59 +12,41 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/graph.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace subgraph {
namespace npu {
namespace bridges {
node_map_type ArgmaxConverter(const std::shared_ptr<lite::OpLite> argmax_op,
const node_map_type& inputs_map) {
auto scope = argmax_op->scope();
auto op_info = argmax_op->op_info();
int ArgmaxConverter(void* ctx, OpLite* op) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "[NPU] Converting " + op_type + "...";
int axis = op_info->GetAttr<int64_t>("axis");
std::shared_ptr<ge::op::ArgMax> argmax_node =
std::make_shared<ge::op::ArgMax>(unique_op_type);
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
int axis = op_info->GetAttr<int64_t>("axis");
CHECK(inputs_map.count(x_var_name));
argmax_node->set_input_x1(*inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(argmax_node);
Tensor x2_t;
x2_t.Resize(std::vector<int64_t>{1});
auto x2_t_data = x2_t.mutable_data<int>();
x2_t_data[0] = axis;
auto argmax_node = graph->AddNode<ge::op::ArgMax>(out_var_name);
argmax_node->set_input_x1(*graph->GetNode(x_var_name));
auto x2 = std::make_shared<ge::op::Const>(unique_op_type + "/axis");
x2->set_attr_value(lite::npu::CvtTensor(&x2_t));
auto x2 = graph->AddNode(out_var_name + "/axis", axis);
argmax_node->set_input_x2(*x2);
lite::npu::OpList::Global().add(x2);
// argmax_node->set_attr_axis(axis);
// argmax only support output_type==int32
// argmax_node->set_attr_output_type(3);
node_map_type outputs_map;
outputs_map[op_info->Output("Out").front()] = argmax_node;
return outputs_map;
return SUCCESS;
}
} // namespace bridges
} // namespace npu
} // namespace kernels
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_NPU_BRIDGE(arg_max,
paddle::lite::kernels::npu::bridges::ArgmaxConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
arg_max,
paddle::lite::subgraph::npu::ArgmaxConverter);
......@@ -12,81 +12,66 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/graph.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace subgraph {
namespace npu {
namespace bridges {
node_map_type BatchNormConverter(
const std::shared_ptr<lite::OpLite> batch_norm_op,
const node_map_type& inputs_map) {
auto scope = batch_norm_op->scope();
auto op_info = batch_norm_op->op_info();
int BatchNormConverter(void* ctx, OpLite* op) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "[NPU] Converting " + op_type + "...";
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
std::shared_ptr<ge::op::BatchNormExt2> batch_norm_node =
std::make_shared<ge::op::BatchNormExt2>(unique_op_type);
auto x_var_name = op_info->Input("X").front();
auto y_var_name = op_info->Output("Y").front();
auto batch_norm_node = graph->AddNode<ge::op::BatchNormExt2>(y_var_name);
batch_norm_node->set_input_x(*graph->GetNode(x_var_name));
auto scale_var_name = op_info->Input("Scale").front();
lite::Tensor* scale = scope->FindVar(scale_var_name)->GetMutable<Tensor>();
auto npu_scale = std::make_shared<ge::op::Const>(scale_var_name);
npu_scale->set_attr_value(lite::npu::CvtTensor(scale));
lite::npu::OpList::Global().add(npu_scale);
auto scale = scope->FindVar(scale_var_name)->GetMutable<Tensor>();
auto scale_const_node = graph->AddNode(scale_var_name, *scale);
auto bias_var_name = op_info->Input("Bias").front();
lite::Tensor* bias = scope->FindVar(bias_var_name)->GetMutable<Tensor>();
auto npu_bias = std::make_shared<ge::op::Const>(bias_var_name);
npu_bias->set_attr_value(lite::npu::CvtTensor(bias));
lite::npu::OpList::Global().add(npu_bias);
auto bias = scope->FindVar(bias_var_name)->GetMutable<Tensor>();
auto bias_const_node = graph->AddNode(bias_var_name, *bias);
auto mean_var_name = op_info->Input("Mean").front();
lite::Tensor* mean = scope->FindVar(mean_var_name)->GetMutable<Tensor>();
auto npu_mean = std::make_shared<ge::op::Const>(mean_var_name);
npu_mean->set_attr_value(lite::npu::CvtTensor(mean));
lite::npu::OpList::Global().add(npu_mean);
auto mean = scope->FindVar(mean_var_name)->GetMutable<Tensor>();
auto mean_const_node = graph->AddNode(mean_var_name, *mean);
auto variance_var_name = op_info->Input("Variance").front();
lite::Tensor* variance =
scope->FindVar(variance_var_name)->GetMutable<Tensor>();
auto npu_variance = std::make_shared<ge::op::Const>(variance_var_name);
npu_variance->set_attr_value(lite::npu::CvtTensor(variance));
lite::npu::OpList::Global().add(npu_variance);
auto variance = scope->FindVar(variance_var_name)->GetMutable<Tensor>();
auto variance_const_node = graph->AddNode(variance_var_name, *variance);
float npu_momentum = op_info->GetAttr<float>("momentum");
float npu_epsilon = op_info->GetAttr<float>("epsilon");
int npu_mode = 1; // bnScale, bnBias tensor dims are 1xCx1x1
bool npu_use_global_stats = op_info->GetAttr<bool>("use_global_stats");
float momentum = op_info->GetAttr<float>("momentum");
float epsilon = op_info->GetAttr<float>("epsilon");
int mode = 1; // bnScale, bnBias tensor dims are 1xCx1x1
bool use_global_stats = op_info->GetAttr<bool>("use_global_stats");
batch_norm_node->set_input_x(*inputs_map.at(x_var_name));
batch_norm_node->set_input_scale(*npu_scale);
batch_norm_node->set_input_offset(*npu_bias);
batch_norm_node->set_input_mean(*npu_mean);
batch_norm_node->set_input_variance(*npu_variance);
batch_norm_node->set_attr_momentum(npu_momentum);
batch_norm_node->set_attr_epsilon(npu_epsilon);
batch_norm_node->set_attr_mode(npu_mode);
batch_norm_node->set_attr_use_global_stats(npu_use_global_stats);
lite::npu::OpList::Global().add(inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(batch_norm_node);
node_map_type outputs_map;
outputs_map[op_info->Output("Y").front()] = batch_norm_node;
return outputs_map;
batch_norm_node->set_input_scale(*scale_const_node);
batch_norm_node->set_input_offset(*bias_const_node);
batch_norm_node->set_input_mean(*mean_const_node);
batch_norm_node->set_input_variance(*variance_const_node);
batch_norm_node->set_attr_momentum(momentum);
batch_norm_node->set_attr_epsilon(epsilon);
batch_norm_node->set_attr_mode(mode);
batch_norm_node->set_attr_use_global_stats(use_global_stats);
return SUCCESS;
}
} // namespace bridges
} // namespace npu
} // namespace kernels
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_NPU_BRIDGE(batch_norm,
paddle::lite::kernels::npu::bridges::BatchNormConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
batch_norm,
paddle::lite::subgraph::npu::BatchNormConverter);
......@@ -12,58 +12,51 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/graph.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace subgraph {
namespace npu {
namespace bridges {
node_map_type ConcatConverter(const std::shared_ptr<lite::OpLite> concat_op,
const node_map_type& inputs_map) {
lite::Scope* scope = concat_op->scope();
const lite::OpInfo* op_info = concat_op->op_info();
int ConcatConverter(void* ctx, OpLite* op) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "[NPU] Converting " << op_type << " ... ";
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " << op_type << " ... ";
auto x_var_names = op_info->Input("X");
auto out_var_name = op_info->Output("Out").front();
auto axis = op_info->GetAttr<int>("axis");
int num = x_var_names.size();
int index = 0;
std::shared_ptr<ge::op::Concat> output_node =
std::make_shared<ge::op::Concat>(unique_op_type);
output_node->set_attr_axis(axis);
output_node->set_attr_N(num);
output_node->create_dynamic_input_x(num);
for (auto x_var_name : x_var_names) {
if (inputs_map.find(x_var_name) != inputs_map.end()) {
output_node->set_dynamic_input_x(index + 1, *inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(inputs_map.at(x_var_name));
auto num = x_var_names.size();
auto concat_node = graph->AddNode<ge::op::Concat>(out_var_name);
concat_node->set_attr_axis(axis);
concat_node->set_attr_N(num);
concat_node->create_dynamic_input_x(num);
int idx = 1;
for (auto& x_var_name : x_var_names) {
if (graph->HasNode(x_var_name)) {
concat_node->set_dynamic_input_x(idx, *graph->GetNode(x_var_name));
} else {
auto consty = std::make_shared<ge::op::Const>(x_var_name);
auto* x = scope->FindVar(x_var_name)->GetMutable<Tensor>();
consty->set_attr_value(lite::npu::CvtTensor(x));
output_node->set_dynamic_input_x(index + 1, *consty);
lite::npu::OpList::Global().add(consty);
auto x = scope->FindVar(x_var_name)->GetMutable<Tensor>();
auto x_const_node = graph->AddNode(x_var_name, *x);
concat_node->set_dynamic_input_x(idx, *x_const_node);
}
index++;
idx++;
}
lite::npu::OpList::Global().add(output_node);
node_map_type outputs_map;
outputs_map[op_info->Output("Out").front()] = output_node;
return outputs_map;
return SUCCESS;
}
} // namespace bridges
} // namespace npu
} // namespace kernels
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_NPU_BRIDGE(concat,
paddle::lite::kernels::npu::bridges::ConcatConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
concat,
paddle::lite::subgraph::npu::ConcatConverter);
......@@ -13,32 +13,33 @@
// limitations under the License.
#include "lite/operators/conv_op.h"
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/graph.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace subgraph {
namespace npu {
namespace bridges {
node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
const node_map_type& inputs_map) {
auto scope = conv_op->scope();
auto op_info = conv_op->op_info();
int ConvConverter(void* ctx, OpLite* op) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "[NPU] Converting " << op_type << "... ";
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " << op_type << "... ";
// get input, filter and op attributes
// Get input, filter and op attributes
auto input_var_name = op_info->Input("Input").front();
auto input = scope->FindVar(input_var_name)->GetMutable<lite::Tensor>();
auto input = scope->FindVar(input_var_name)->GetMutable<Tensor>();
auto input_dims = input->dims();
auto output_var_name = op_info->Output("Output").front();
auto output = scope->FindVar(output_var_name)->GetMutable<lite::Tensor>();
auto output = scope->FindVar(output_var_name)->GetMutable<Tensor>();
auto output_dims = output->dims();
auto filter_var_name = op_info->Input("Filter").front();
auto filter = scope->FindVar(filter_var_name)->GetMutable<lite::Tensor>();
auto filter = scope->FindVar(filter_var_name)->GetMutable<Tensor>();
auto filter_dims = filter->dims();
auto bs = input_dims[0];
auto ic = input_dims[1];
......@@ -63,7 +64,7 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
}
}
CHECK_EQ(paddings.size(), 4L)
<< "Paddings size should be the same or twice as the input size.";
<< "[NPU] Paddings size should be the same or twice as the input size.";
std::string padding_algorithm("");
if (op_info->HasAttr("padding_algorithm")) {
......@@ -76,9 +77,9 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
input_dims,
filter_dims);
// check depthwise mode, and decide whether use ConvolutionDepthwise Op
// Check depthwise mode, and decide whether use ConvolutionDepthwise Op
bool use_depthwise_conv =
false; // whether use ge::op::ConvolutionDepthwise ?
false; // Whether use ge::op::ConvolutionDepthwise ?
bool is_depthwise_mode = ic == groups && oc == groups;
if (is_depthwise_mode &&
!((groups == 1 || groups >= 5) && dilations[0] == 1 &&
......@@ -90,26 +91,19 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
"performance.";
}
// check input
CHECK(inputs_map.count(input_var_name));
lite::npu::OpList::Global().add(inputs_map.at(input_var_name));
// Create filter node
auto filter_const_node = graph->AddNode(filter_var_name, *filter);
// create filter node
CHECK(!inputs_map.count(filter_var_name));
auto filter_const_node = std::make_shared<ge::op::Const>(filter_var_name);
filter_const_node->set_attr_value(lite::npu::CvtTensor(filter));
lite::npu::OpList::Global().add(filter_const_node);
// create bias node if has bias
// supports the bias nodes with the following dimensions
// Create bias node if exists bias
// Supports the bias nodes with the following dimensions
// 0: {oc}
// 1: {1, oc, oh, ow}
// 2: {n, oc, oh, ow}
std::shared_ptr<ge::Operator> bias_node = nullptr;
bool is_channel_bias = false;
if (lite::npu::HasInputArg(op_info, scope, "Bias")) {
if (HasInputArg(op_info, scope, "Bias")) {
auto bias_var_name = op_info->Input("Bias").front();
auto* bias = scope->FindVar(bias_var_name)->GetMutable<lite::Tensor>();
auto* bias = scope->FindVar(bias_var_name)->GetMutable<Tensor>();
auto bias_dims = bias->dims();
auto bias_data_size = bias_dims.production();
auto output_data_size = output_dims.production();
......@@ -125,28 +119,26 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
// 2: {n, oc, oh, ow}
bias_shape = output_dims.Vectorize();
} else {
LOG(ERROR) << "bias dimension " << bias_dims
<< " isn't supported in conv2d Op when output dimension is "
<< output_dims;
LOG(WARNING) << "[NPU] Bias dimension " << bias_dims
<< " isn't supported in conv2d Op when output dimension is "
<< output_dims;
return FAILED;
}
if (inputs_map.count(bias_var_name)) {
// bias node from input map
bias_node = inputs_map.at(bias_var_name);
if (graph->HasNode(bias_var_name)) {
// Bias node from input map
bias_node = graph->GetNode(bias_var_name);
} else {
// bias node with const data
auto bias_const_node = std::make_shared<ge::op::Const>(bias_var_name);
bias_const_node->set_attr_value(lite::npu::CvtTensor(bias, bias_shape));
bias_node = bias_const_node;
// Bias node with const data
bias_node = graph->AddNode(bias_var_name, *bias, bias_shape);
}
lite::npu::OpList::Global().add(bias_node);
}
// create conv node and set input, filter, bias nodes and attributes
// Create conv node and set input, filter, bias nodes and attributes
std::shared_ptr<ge::Operator> conv_node = nullptr;
if (use_depthwise_conv && is_depthwise_mode) {
auto depthwise_conv_node =
std::make_shared<ge::op::ConvolutionDepthwise>(unique_op_type);
depthwise_conv_node->set_input_x(*inputs_map.at(input_var_name));
graph->AddNode<ge::op::ConvolutionDepthwise>(output_var_name);
depthwise_conv_node->set_input_x(*graph->GetNode(input_var_name));
depthwise_conv_node->set_input_filter(*filter_const_node);
depthwise_conv_node->set_attr_mode(1);
depthwise_conv_node->set_attr_algo(0);
......@@ -161,21 +153,19 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
ge::AttrValue::LIST_INT({strides[0], strides[1]}));
depthwise_conv_node->set_attr_kernel(
ge::AttrValue::LIST_INT({filter_dims[2], filter_dims[3]}));
lite::npu::OpList::Global().add(depthwise_conv_node);
conv_node = depthwise_conv_node;
// ConvolutionDepthwise Op doesn't support bias, so append Add node to
// support bias
if (bias_node != nullptr) {
auto add_node = std::make_shared<ge::op::Add>(unique_op_type + "/add");
auto add_node = graph->AddNode<ge::op::Add>(output_var_name);
add_node->set_input_x1(*depthwise_conv_node);
add_node->set_input_x2(*bias_node);
lite::npu::OpList::Global().add(add_node);
conv_node = add_node;
}
} else {
auto common_conv_node =
std::make_shared<ge::op::Convolution>(unique_op_type);
common_conv_node->set_input_x(*inputs_map.at(input_var_name));
graph->AddNode<ge::op::Convolution>(output_var_name);
common_conv_node->set_input_x(*graph->GetNode(input_var_name));
common_conv_node->set_input_w(*filter_const_node);
common_conv_node->set_attr_mode(1);
common_conv_node->set_attr_pad_mode(0); // NOTSET
......@@ -188,7 +178,6 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
ge::AttrValue::LIST_INT({strides[0], strides[1]}));
common_conv_node->set_attr_kernel(
ge::AttrValue::LIST_INT({filter_dims[2], filter_dims[3]}));
lite::npu::OpList::Global().add(common_conv_node);
conv_node = common_conv_node;
// Convolution Op only support bias with dimension {1, oc, 1, 1},
// so append Add node if dimension is {1, oc, oh, ow} or (n, oc, oh, ow)
......@@ -196,37 +185,32 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
if (is_channel_bias) {
common_conv_node->set_input_b(*bias_node);
} else {
auto add_node = std::make_shared<ge::op::Add>(unique_op_type + "/add");
auto add_node = graph->AddNode<ge::op::Add>(output_var_name);
add_node->set_input_x1(*common_conv_node);
add_node->set_input_x2(*bias_node);
lite::npu::OpList::Global().add(add_node);
conv_node = add_node;
}
}
}
CHECK(conv_node);
node_map_type outputs_map;
if (fuse_relu) {
// append relu node if fuse_relu is true
auto relu_node =
std::make_shared<ge::op::Activation>(unique_op_type + "/relu");
// Append relu node if fuse_relu is true
auto relu_node = graph->AddNode<ge::op::Activation>(output_var_name);
relu_node->set_input_x(*conv_node);
relu_node->set_attr_mode(lite::npu::CvtActMode("relu"));
lite::npu::OpList::Global().add(relu_node);
outputs_map[op_info->Output("Output").front()] = relu_node;
} else {
outputs_map[op_info->Output("Output").front()] = conv_node;
relu_node->set_attr_mode(CvtActMode("relu"));
}
return outputs_map;
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace bridges
} // namespace npu
} // namespace kernels
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_NPU_BRIDGE(conv2d, paddle::lite::kernels::npu::bridges::ConvConverter);
REGISTER_NPU_BRIDGE(depthwise_conv2d,
paddle::lite::kernels::npu::bridges::ConvConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
conv2d,
paddle::lite::subgraph::npu::ConvConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
depthwise_conv2d,
paddle::lite::subgraph::npu::ConvConverter);
......@@ -12,30 +12,31 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/graph.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace subgraph {
namespace npu {
namespace bridges {
node_map_type ConvTransposeConverter(
const std::shared_ptr<lite::OpLite> conv_transpose_op,
const node_map_type& inputs_map) {
auto scope = conv_transpose_op->scope();
auto op_info = conv_transpose_op->op_info();
int ConvTransposeConverter(void* ctx, OpLite* op) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "[NPU] Converting " << op_type << "... ";
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " << op_type << "... ";
// get input, output and op attributes
// Get input, output and op attributes
auto input_var_name = op_info->Input("Input").front();
auto input = scope->FindVar(input_var_name)->GetMutable<lite::Tensor>();
auto input = scope->FindVar(input_var_name)->GetMutable<Tensor>();
auto input_shape = input->dims().Vectorize();
auto output_var_name = op_info->Output("Output").front();
auto filter_var_name = op_info->Input("Filter").front();
auto filter = scope->FindVar(filter_var_name)->GetMutable<lite::Tensor>();
auto filter = scope->FindVar(filter_var_name)->GetMutable<Tensor>();
auto filter_shape = filter->dims().Vectorize();
CHECK_EQ(input_shape.size(), 4);
CHECK_EQ(filter_shape.size(), 4);
......@@ -54,42 +55,34 @@ node_map_type ConvTransposeConverter(
}
}
CHECK_EQ(paddings.size(), 4L)
<< "Paddings size should be the same or twice as the input size.";
<< "[NPU] Paddings size should be the same or twice as the input size.";
// create deconv node
// Create deconv node
auto conv_transpose_node =
std::make_shared<ge::op::Deconvolution>(unique_op_type);
graph->AddNode<ge::op::Deconvolution>(output_var_name);
// create input sizes node to describe the dimensions of input tensor
std::vector<int32_t> output_shape;
output_shape.push_back(input_shape[0]);
output_shape.push_back(filter_shape[1] * groups);
// Create input sizes node to describe the dimensions of input tensor
std::vector<int32_t> input_sizes;
input_sizes.push_back(input_shape[0]);
input_sizes.push_back(filter_shape[1] * groups);
for (int i = 0; i < strides.size(); i++) {
int kernel_ext = dilations[i] * (filter_shape[i + 2] - 1) + 1;
int output_size =
(input_shape[i + 2] - 1) * strides[i] + kernel_ext - 2 * paddings[i];
output_shape.push_back(output_size);
input_sizes.push_back(output_size);
}
auto input_sizes_const_node =
std::make_shared<ge::op::Const>(unique_op_type + "/input_size");
input_sizes_const_node->set_attr_value(
lite::npu::CreateTensorAndFillData(output_shape));
graph->AddNode(output_var_name + "/input_sizes", input_sizes);
conv_transpose_node->set_input_input_sizes(*input_sizes_const_node);
lite::npu::OpList::Global().add(input_sizes_const_node);
// create filter node
CHECK(!inputs_map.count(filter_var_name));
auto filter_const_node = std::make_shared<ge::op::Const>(filter_var_name);
filter_const_node->set_attr_value(lite::npu::CvtTensor(filter));
// Create filter node
auto filter_const_node = graph->AddNode(filter_var_name, *filter);
conv_transpose_node->set_input_filter(*filter_const_node);
lite::npu::OpList::Global().add(filter_const_node);
// set input node
CHECK(inputs_map.count(input_var_name));
conv_transpose_node->set_input_x(*inputs_map.at(input_var_name));
lite::npu::OpList::Global().add(inputs_map.at(input_var_name));
// Set input node
conv_transpose_node->set_input_x(*graph->GetNode(input_var_name));
// set attributes
// Set attributes
conv_transpose_node->set_attr_format(0); // NCHW
conv_transpose_node->set_attr_pad_mode(0); // NOTSET
conv_transpose_node->set_attr_group(groups);
......@@ -101,50 +94,39 @@ node_map_type ConvTransposeConverter(
ge::AttrValue::LIST_INT({strides[0], strides[1]}));
conv_transpose_node->set_attr_kernel(
ge::AttrValue::LIST_INT({filter_shape[2], filter_shape[3]}));
lite::npu::OpList::Global().add(conv_transpose_node);
// append add node to add bias if has bias
// Append add node to add bias if exists bias
std::shared_ptr<ge::Operator> output_node = conv_transpose_node;
if (lite::npu::HasInputArg(op_info, scope, "Bias")) {
// create bias node
if (HasInputArg(op_info, scope, "Bias")) {
// Create bias node
auto bias_var_name = op_info->Input("Bias").front();
CHECK(!inputs_map.count(bias_var_name));
auto* bias = scope->FindVar(bias_var_name)->GetMutable<lite::Tensor>();
CHECK(!graph->HasNode(bias_var_name));
auto* bias = scope->FindVar(bias_var_name)->GetMutable<Tensor>();
auto channel_size = bias->dims().production();
CHECK_EQ(channel_size, filter_shape[1] * groups);
auto bias_const_node = std::make_shared<ge::op::Const>(bias_var_name);
bias_const_node->set_attr_value(
lite::npu::CvtTensor(bias, {1, channel_size, 1, 1}));
lite::npu::OpList::Global().add(bias_const_node);
// append add node to add bias node
auto add_node = std::make_shared<ge::op::Add>(unique_op_type + "/add");
auto bias_const_node =
graph->AddNode(bias_var_name, *bias, {1, channel_size, 1, 1});
// Append add node to add bias node
auto add_node = graph->AddNode<ge::op::Add>(output_var_name);
add_node->set_input_x1(*conv_transpose_node);
add_node->set_input_x2(*bias_const_node);
lite::npu::OpList::Global().add(add_node);
output_node = add_node;
}
node_map_type outputs_map;
if (fuse_relu) {
// append relu node if fuse_relu is true
auto relu_node =
std::make_shared<ge::op::Activation>(unique_op_type + "/relu");
// Append relu node if fuse_relu is true
auto relu_node = graph->AddNode<ge::op::Activation>(output_var_name);
relu_node->set_input_x(*output_node);
relu_node->set_attr_mode(lite::npu::CvtActMode("relu"));
lite::npu::OpList::Global().add(relu_node);
outputs_map[op_info->Output("Output").front()] = relu_node;
} else {
outputs_map[op_info->Output("Output").front()] = output_node;
relu_node->set_attr_mode(CvtActMode("relu"));
}
return outputs_map;
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace bridges
} // namespace npu
} // namespace kernels
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_NPU_BRIDGE(
conv2d_transpose,
paddle::lite::kernels::npu::bridges::ConvTransposeConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
conv2d_transpose,
paddle::lite::subgraph::npu::ConvTransposeConverter);
......@@ -12,18 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/graph.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace subgraph {
namespace npu {
namespace bridges {
std::vector<int64_t> CvtYShape(const Tensor& x, Tensor* y, int axis) {
auto x_dims = x.dims();
CHECK_EQ(x_dims.size(), 4UL) << "[NPU] only support 4-dimension x";
CHECK_EQ(x_dims.size(), 4UL) << "[NPU] Only support 4-dimension x";
auto y_dims = y->dims();
CHECK_GE(x_dims.size(), y_dims.size());
......@@ -45,93 +45,86 @@ std::vector<int64_t> CvtYShape(const Tensor& x, Tensor* y, int axis) {
return y_new_shape;
}
node_map_type ElementwiseConverter(
const std::shared_ptr<lite::OpLite> elementwise_op,
const node_map_type& inputs_map) {
auto scope = elementwise_op->scope();
auto op_info = elementwise_op->op_info();
int ElementwiseConverter(void* ctx, OpLite* op) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "[NPU] Converting " + op_type + "...";
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
auto x_var_name = op_info->Input("X").front();
auto y_var_name = op_info->Input("Y").front();
CHECK(inputs_map.find(x_var_name) != inputs_map.end());
auto out_var_name = op_info->Output("Out").front();
auto axis = op_info->GetAttr<int>("axis");
std::shared_ptr<ge::Operator> elementwise_node = nullptr;
std::shared_ptr<ge::Operator> x_node = inputs_map.at(x_var_name);
std::shared_ptr<ge::Operator> x_node = graph->GetNode(x_var_name);
std::shared_ptr<ge::Operator> y_node = nullptr;
if (inputs_map.find(y_var_name) != inputs_map.end()) {
y_node = inputs_map.at(y_var_name);
if (graph->HasNode(y_var_name)) {
y_node = graph->GetNode(y_var_name);
} else {
auto y_const_node = std::make_shared<ge::op::Const>(y_var_name);
auto x = scope->FindTensor(x_var_name);
auto y = scope->FindMutableTensor(y_var_name);
auto y_new_shape = CvtYShape(*x, y, axis);
y_const_node->set_attr_value(lite::npu::CvtTensor(y, y_new_shape));
y_node = y_const_node;
y_node = graph->AddNode(y_var_name, y, y_new_shape);
}
lite::npu::OpList::Global().add(x_node);
lite::npu::OpList::Global().add(y_node);
if (op_type == "elementwise_add" ||
op_type == "fusion_elementwise_add_activation") {
auto elt_node = std::make_shared<ge::op::Add>(unique_op_type);
auto elt_node = graph->AddNode<ge::op::Add>(out_var_name);
elt_node->set_input_x1(*x_node);
elt_node->set_input_x2(*y_node);
elementwise_node = elt_node;
} else if (op_type == "elementwise_sub") {
auto elt_node = std::make_shared<ge::op::Sub>(unique_op_type);
auto elt_node = graph->AddNode<ge::op::Sub>(out_var_name);
elt_node->set_input_x1(*x_node);
elt_node->set_input_x2(*y_node);
elementwise_node = elt_node;
} else if (op_type == "elementwise_mul") {
auto elt_node = std::make_shared<ge::op::Mul>(unique_op_type);
auto elt_node = graph->AddNode<ge::op::Mul>(out_var_name);
elt_node->set_input_x(*x_node);
elt_node->set_input_y(*y_node);
elementwise_node = elt_node;
} else if (op_type == "elementwise_div") {
auto elt_node = std::make_shared<ge::op::RealDiv>(unique_op_type);
auto elt_node = graph->AddNode<ge::op::RealDiv>(out_var_name);
elt_node->set_input_x1(*x_node);
elt_node->set_input_x2(*y_node);
elementwise_node = elt_node;
} else {
LOG(FATAL) << "unsupported op type: " << op_type;
LOG(WARNING) << "[NPU] Unsupported op type: " << op_type;
return FAILED;
}
lite::npu::OpList::Global().add(elementwise_node);
node_map_type outputs_map;
if (op_type == "fusion_elementwise_add_activation") {
auto act_type = op_info->GetAttr<std::string>("act_type");
auto act_node =
std::make_shared<ge::op::Activation>(unique_op_type + "/act");
auto act_node = graph->AddNode<ge::op::Activation>(out_var_name);
act_node->set_input_x(*elementwise_node);
// TODO(hong19860320) set the coef value for act Ops, such as leaky_relu,
// clipped_relu etc.
act_node->set_attr_mode(lite::npu::CvtActMode(act_type));
lite::npu::OpList::Global().add(act_node);
outputs_map[op_info->Output("Out").front()] = act_node;
} else {
outputs_map[op_info->Output("Out").front()] = elementwise_node;
act_node->set_attr_mode(CvtActMode(act_type));
}
return outputs_map;
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace bridges
} // namespace npu
} // namespace kernels
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_NPU_BRIDGE(elementwise_add,
paddle::lite::kernels::npu::bridges::ElementwiseConverter);
REGISTER_NPU_BRIDGE(fusion_elementwise_add_activation,
paddle::lite::kernels::npu::bridges::ElementwiseConverter);
REGISTER_NPU_BRIDGE(elementwise_sub,
paddle::lite::kernels::npu::bridges::ElementwiseConverter);
REGISTER_NPU_BRIDGE(elementwise_mul,
paddle::lite::kernels::npu::bridges::ElementwiseConverter);
REGISTER_NPU_BRIDGE(elementwise_div,
paddle::lite::kernels::npu::bridges::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
elementwise_add,
paddle::lite::subgraph::npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
fusion_elementwise_add_activation,
paddle::lite::subgraph::npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
elementwise_sub,
paddle::lite::subgraph::npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
elementwise_mul,
paddle::lite::subgraph::npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
elementwise_div,
paddle::lite::subgraph::npu::ElementwiseConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/npu/bridges/engine.h"
#include <sys/time.h>
#include <time.h>
#include <utility>
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
int Engine::BuildDeviceProgram() { return FAILED; }
int Engine::LaunchDeviceProgram() { return 0; }
int Engine::BuildOriginProgram() {
// TODO(hong19860320) The block_desc need to be divided into subgraphs during
// the exection time. But only see them as a subgraph now.
origin_program_.clear();
for (int op_idx = 0; op_idx < block_desc_->OpsSize(); op_idx++) {
auto op_desc = block_desc_->GetOp<cpp::OpDesc>(op_idx);
CHECK(op_desc);
std::string op_type = op_desc->Type();
auto op = LiteOpRegistry::Global().Create(op_desc->Type());
op->Attach(*op_desc, scope_);
std::unique_ptr<KernelBase> picked_kernel;
if (op_desc->HasAttr(kKernelTypeAttr)) {
// Create op and pick up kernel according to the kKernelTypeAttr attribute
auto kernel_type = op_desc->GetAttr<std::string>(kKernelTypeAttr);
std::string alias;
Place place;
KernelBase::ParseKernelType(kernel_type, &op_type, &alias, &place);
VLOG(3) << "Found the attr '" << kKernelTypeAttr << "': " << kernel_type
<< " for " << op_type;
auto kernels = op->CreateKernels({place});
CHECK_GT(kernels.size(), 0) << "No kernels found for " << op_type;
auto it = std::find_if(
kernels.begin(), kernels.end(), [&](std::unique_ptr<KernelBase>& it) {
return it->alias() == alias;
});
CHECK(it != kernels.end());
picked_kernel = std::move(*it);
} else {
VLOG(3) << "The attr '" << kKernelTypeAttr
<< "' not found, pick the first kernel for " << op_type;
#if defined(LITE_WITH_ARM)
auto kernels = op->CreateKernels({Place{TARGET(kARM)}});
#elif defined(LITE_WITH_X86)
auto kernels = op->CreateKernels({Place{TARGET(kX86)}});
#endif
CHECK_GT(kernels.size(), 0) << "No kernels found for " << op_type;
picked_kernel = std::move(kernels.front());
}
picked_kernel->SetContext(
ContextScheduler::Global().NewContext(picked_kernel->target()));
origin_program_.emplace_back(std::move(op), std::move(picked_kernel));
}
return 0;
}
int Engine::LaunchOriginProgram() {
for (auto& inst : origin_program_) {
auto op_type = inst.op()->op_info()->Type();
if (op_type == "feed" || op_type == "fetch") continue;
inst.Run();
}
return 0;
}
int Engine::Build() {
// In order to attach all of the ops of the block desc, we need to build the
// original program firstly.
BuildOriginProgram();
// Run InferShape() of all of ops, and convert Paddle ops to NPU/XPU IR graph
build_device_program_status_ = BuildDeviceProgram();
return build_device_program_status_;
}
bool Engine::InputShapeChanged() {
for (int i = 0; i < origin_itensors_.size(); i++) {
if (origin_itensors_[i]->dims() != origin_idims_[i]) {
return true;
}
}
return false;
}
int Engine::Launch() {
// Rebuild device program when the shapes of input tensors have been changed.
if (CHECK_SUCCESS(build_device_program_status_) &&
CHECK_REBUILD_WHEN_SHAPE_CHANGED(build_device_program_status_) &&
InputShapeChanged()) {
Build();
}
if (CHECK_FAILED(build_device_program_status_)) {
LaunchOriginProgram();
} else {
LaunchDeviceProgram();
}
return 0;
}
} // namespace subgraph
} // namespace lite
} // namespace paddle
......@@ -14,52 +14,63 @@
#pragma once
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "lite/backends/xpu/builder.h"
#include "lite/core/mir/pass.h"
#include "lite/core/mir/subgraph/subgraph_program_pass.h"
#include "lite/kernels/xpu/bridges/registry.h"
#include "lite/core/op_lite.h"
#include "lite/core/program.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace mir {
namespace subgraph {
class GenerateXPUProgramPass : public SubgraphProgramPass {
class Engine {
public:
using key2nodes_t = std::map<std::string, Node*>;
Engine(int block_idx,
cpp::BlockDesc *block_desc,
const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names,
lite::Scope *scope)
: block_idx_(block_idx),
block_desc_(block_desc),
input_names_(input_names),
output_names_(output_names),
scope_(scope) {}
virtual ~Engine() = default;
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
virtual int Build();
virtual int Launch();
private:
Engine(const Engine &) = delete;
protected:
// nodes2cvt: op nodes to convert
// return cvted_vars: converted var nodes
void CvtAllOpNodes(
const std::vector<Node*>& op_nodes,
lite::kernels::xpu::bridges::graph_ctx_type* graph_ctx,
lite::kernels::xpu::bridges::node_map_type* cvted_var_nodes);
virtual int BuildDeviceProgram();
virtual int LaunchDeviceProgram();
std::shared_ptr<xtcl::xExpr> CvtVarNode(
lite::kernels::xpu::bridges::graph_ctx_type* graph_ctx,
lite::mir::Node* var_node,
const Scope* scope);
virtual int BuildOriginProgram();
virtual int LaunchOriginProgram();
std::string BuildXPUGraph(const std::unordered_set<Node*>& op_nodes,
const std::unordered_set<Node*>& in_data_vars,
const std::unordered_set<Node*>& out_data_vars,
int sub_id);
virtual bool InputShapeChanged();
void GenXPUSubgraph(const std::unique_ptr<SSAGraph>& graph,
const std::unordered_set<Node*>& op_nodes,
int sub_id);
int block_idx_;
cpp::BlockDesc *block_desc_;
std::vector<std::string> input_names_;
std::vector<std::string> output_names_;
Scope *scope_{nullptr};
// SUCCESS: device program build successed. FAILED: device program build
// failed. REBUILD_WHEN_SHAPE_CHANGED: device program build successed but need
// to rebuild when input shape changed.
int build_device_program_status_{0};
std::vector<DDim> origin_idims_;
std::vector<DDim> origin_odims_;
std::vector<Tensor *> origin_itensors_;
std::vector<Tensor *> origin_otensors_;
std::vector<Instruction> origin_program_;
};
} // namespace subgraph
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -12,31 +12,31 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/graph.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace subgraph {
namespace npu {
namespace bridges {
node_map_type FCConverter(const std::shared_ptr<lite::OpLite> fc_op,
const node_map_type& inputs_map) {
auto scope = fc_op->scope();
auto op_info = fc_op->op_info();
int FCConverter(void* ctx, OpLite* op) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "[NPU] Converting " + op_type + "...";
auto fc_node = std::make_shared<ge::op::FullConnection>(unique_op_type);
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
auto x_var_name = op_info->Input("Input").front();
auto w_var_name = op_info->Input("W").front();
auto out_var_name = op_info->Output("Out").front();
int in_num_col_dims = op_info->GetAttr<int>("in_num_col_dims");
auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
auto w = scope->FindVar(w_var_name)->GetMutable<lite::Tensor>();
auto x = scope->FindVar(x_var_name)->GetMutable<Tensor>();
auto w = scope->FindVar(w_var_name)->GetMutable<Tensor>();
auto x_dims = x->dims();
auto w_dims = w->dims();
......@@ -50,71 +50,54 @@ node_map_type FCConverter(const std::shared_ptr<lite::OpLite> fc_op,
VLOG(3) << "[NPU] x dims: " << x_dims << " w dims: " << w_dims << " m: " << m
<< " k: " << k << " n: " << n;
CHECK(inputs_map.count(x_var_name));
CHECK(!inputs_map.count(w_var_name));
auto fc_node = graph->AddNode<ge::op::FullConnection>(out_var_name + "/fc");
CHECK(!graph->HasNode(w_var_name));
// reshape x to (m, k, 1, 1)
// Reshape x to (m, k, 1, 1)
auto reshaped_x_node =
std::make_shared<ge::op::Reshape>(x_var_name + "_reshape");
reshaped_x_node->set_input_tensor(*inputs_map.at(x_var_name));
graph->AddNode<ge::op::Reshape>(x_var_name + "/reshape");
reshaped_x_node->set_input_tensor(*graph->GetNode(x_var_name));
reshaped_x_node->set_attr_shape({m, k, 1, 1});
reshaped_x_node->set_attr_axis(0);
fc_node->set_input_x(*reshaped_x_node);
lite::npu::OpList::Global().add(inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(reshaped_x_node);
// create w const node, set its shape to (k, n, 1, 1) and fill with
// Create w const node, set its shape to (n, k, 1, 1) and fill with
// the transposed w tensor
auto w_const_node = std::make_shared<ge::op::Const>(w_var_name);
ge::TensorDesc w_const_desc(
ge::Shape({n, k, 1, 1}), ge::FORMAT_NCHW, ge::DT_FLOAT);
ge::TensorPtr w_const_tensor = std::make_shared<ge::Tensor>();
w_const_tensor->SetTensorDesc(w_const_desc);
Tensor transpose_w;
transpose_w.Resize({n, k, 1, 1});
auto transpose_w_data = transpose_w.mutable_data<float>();
auto w_data = w->mutable_data<float>();
std::vector<float> transposed_w_data(w_dims.production());
for (int i = 0; i < k; i++) {
for (int j = 0; j < n; j++) {
transposed_w_data[j * k + i] = w_data[i * n + j];
transpose_w_data[j * k + i] = w_data[i * n + j];
}
}
w_const_tensor->SetData(reinterpret_cast<uint8_t*>(transposed_w_data.data()),
transposed_w_data.size() * sizeof(float));
w_const_node->set_attr_value(w_const_tensor);
auto w_const_node = graph->AddNode(w_var_name, transpose_w);
fc_node->set_input_w(*w_const_node);
lite::npu::OpList::Global().add(w_const_node);
// add bias node if bias tensor exists
if (lite::npu::HasInputArg(op_info, scope, "Bias")) {
// Add bias node if bias tensor exists
if (HasInputArg(op_info, scope, "Bias")) {
auto bias_var_name = op_info->Input("Bias").front();
auto bias = scope->FindVar(bias_var_name)->GetMutable<lite::Tensor>();
auto bias_dims = bias->dims();
CHECK(!inputs_map.count(bias_var_name));
CHECK(!graph->HasNode(bias_var_name));
CHECK_EQ(bias_dims.production(), n);
auto bias_const_node = std::make_shared<ge::op::Const>(bias_var_name);
bias_const_node->set_attr_value(lite::npu::CvtTensor(bias, {1, n, 1, 1}));
auto bias_const_node = graph->AddNode(bias_var_name, *bias, {1, n, 1, 1});
fc_node->set_input_b(*bias_const_node);
lite::npu::OpList::Global().add(bias_const_node);
}
lite::npu::OpList::Global().add(fc_node);
// reshape output of fc_node from (m, n, 1, 1) to (m, n)
auto reshaped_fc_node =
std::make_shared<ge::op::Reshape>(unique_op_type + "_reshape");
// Reshape output of fc_node from (m, n, 1, 1) to (m, n)
auto reshaped_fc_node = graph->AddNode<ge::op::Reshape>(out_var_name);
reshaped_fc_node->set_input_tensor(*fc_node);
reshaped_fc_node->set_attr_shape({m, n});
reshaped_fc_node->set_attr_axis(0);
lite::npu::OpList::Global().add(reshaped_fc_node);
node_map_type outputs_map;
outputs_map[op_info->Output("Out").front()] = reshaped_fc_node;
return outputs_map;
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace bridges
} // namespace npu
} // namespace kernels
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_NPU_BRIDGE(fc, paddle::lite::kernels::npu::bridges::FCConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU, fc, paddle::lite::subgraph::npu::FCConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/npu/bridges/graph.h"
#include <utility>
#include "lite/kernels/npu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace npu {
// Const node
std::shared_ptr<ge::op::Const> Graph::AddNode(const std::string& name,
const Tensor& tensor,
PrecisionType ptype,
DataLayoutType ltype) {
return AddNode(name, tensor, tensor.dims().Vectorize(), ptype, ltype);
}
std::shared_ptr<ge::op::Const> Graph::AddNode(const std::string& name,
const Tensor& tensor,
std::vector<int64_t> shape,
PrecisionType ptype,
DataLayoutType ltype) {
CHECK(!HasNode(name)) << "Node " << name << " redefined.";
auto node = AddNode<ge::op::Const>(name);
node->set_attr_value(CvtTensor(tensor, shape, ptype, ltype));
return node;
}
// Data node
std::shared_ptr<ge::op::Data> Graph::AddNode(const std::string& name,
std::vector<int64_t> shape,
PrecisionType ptype,
DataLayoutType ltype) {
CHECK(!HasNode(name)) << "Node " << name << " redefined.";
auto node = AddNode<ge::op::Data>(name);
ge::TensorDesc desc(
ge::Shape(shape), CvtDataLayoutType(ltype), CvtPrecisionType(ptype));
node->update_input_desc_x(desc);
nodes_.insert(std::make_pair(name, node));
return node;
}
} // namespace npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "ai_ddk_lib/include/graph/op/all_ops.h"
#include "lite/core/op_lite.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace npu {
// Type and registers of converters for converting Paddle Ops to HiAI IR graph
class Graph {
public:
template <typename T>
std::shared_ptr<T> AddNode(const std::string& name) {
auto unique_name = [&](const std::string& key) {
int idx = 1;
auto it = counts_.find(key);
if (it == counts_.end()) {
counts_.insert(std::make_pair(key, idx));
} else {
idx = ++(it->second);
}
return key + "_" + std::to_string(idx);
};
auto it = nodes_.find(name);
if (it != nodes_.end()) {
// Generate a new unique name as the key to bind the origin node:
// new_name->node
nodes_.insert(std::make_pair(unique_name(name + "_var"), it->second));
nodes_.erase(it);
}
// Create a new node and bind with the name: name->new_node
auto node = std::make_shared<T>(unique_name(name + "_op"));
nodes_.insert(std::make_pair(name, node));
return node;
}
// Const node
std::shared_ptr<ge::op::Const> AddNode(
const std::string& name,
const Tensor& tensor,
PrecisionType ptype = PRECISION(kFloat),
DataLayoutType ltype = DATALAYOUT(kNCHW));
std::shared_ptr<ge::op::Const> AddNode(
const std::string& name,
const Tensor& tensor,
std::vector<int64_t> shape,
PrecisionType ptype = PRECISION(kFloat),
DataLayoutType ltype = DATALAYOUT(kNCHW));
template <typename T>
std::shared_ptr<ge::op::Const> AddNode(
const std::string& name,
const std::vector<T>& data,
std::vector<int64_t> shape = {},
DataLayoutType ltype = DATALAYOUT(kNCHW)) {
const std::type_info& info = typeid(T);
PrecisionType ptype = PRECISION(kFloat);
if (info == typeid(float)) {
ptype = PRECISION(kFloat);
} else if (info == typeid(int8_t)) {
ptype = PRECISION(kFloat);
} else if (info == typeid(int32_t)) {
ptype = PRECISION(kInt32);
} else {
LOG(FATAL) << "[NPU] Unknow data type " << info.name();
}
if (shape.empty()) {
shape = {static_cast<int64_t>(data.size())};
} else {
int size = 1;
for (auto i : shape) {
size *= i;
}
CHECK_EQ(data.size(), size);
}
Tensor tensor;
tensor.Resize(shape);
std::memcpy(reinterpret_cast<uint8_t*>(tensor.mutable_data<T>()),
reinterpret_cast<const uint8_t*>(data.data()),
data.size() * sizeof(T));
return AddNode(name, tensor, ptype, ltype);
}
template <typename T>
std::shared_ptr<ge::op::Const> AddNode(
const std::string& name,
T value,
std::vector<int64_t> shape = {1},
DataLayoutType ltype = DATALAYOUT(kNCHW)) {
int64_t size = 1;
for (auto i : shape) {
size *= i;
}
std::vector<T> data(size, value);
return AddNode(name, data, shape, ltype);
}
// Data node
std::shared_ptr<ge::op::Data> AddNode(
const std::string& name,
std::vector<int64_t> shape,
PrecisionType ptype = PRECISION(kFloat),
DataLayoutType ltype = DATALAYOUT(kNCHW));
std::shared_ptr<ge::Operator> GetNode(std::string name) {
CHECK(HasNode(name)) << "[NPU] Node " << name << " not found.";
return nodes_.at(name);
}
bool HasNode(const std::string& name) {
return nodes_.find(name) != nodes_.end();
}
private:
std::unordered_map<std::string, std::shared_ptr<ge::Operator>> nodes_;
std::unordered_map<std::string, int> counts_;
};
} // namespace npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
......@@ -12,34 +12,32 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/graph.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace subgraph {
namespace npu {
namespace bridges {
node_map_type InterpolateConverter(
const std::shared_ptr<lite::OpLite> interpolate_op,
const node_map_type& inputs_map) {
auto scope = interpolate_op->scope();
auto op_info = interpolate_op->op_info();
int InterpolateConverter(void* ctx, OpLite* op) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "[NPU] Converting " + op_type + "...";
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
// get input, output and attributes from lite op
// Get input, output and attributes from lite op
auto x_var_name = op_info->Input("X").front();
CHECK(inputs_map.count(x_var_name));
lite::npu::OpList::Global().add(inputs_map.at(x_var_name));
auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
auto x = scope->FindVar(x_var_name)->GetMutable<Tensor>();
auto x_dims = x->dims();
auto x_h = x_dims[2];
auto x_w = x_dims[3];
CHECK_EQ(x_dims.size(), 4);
auto out_var_name = op_info->Output("Out").front();
auto scale = op_info->GetAttr<float>("scale");
auto out_w = op_info->GetAttr<int>("out_w");
auto out_h = op_info->GetAttr<int>("out_h");
......@@ -50,7 +48,7 @@ node_map_type InterpolateConverter(
"align_corners = false isn't "
"supported in HiAI DDK";
// priority: OutSize > scale > out_h/out_w
// Priority: OutSize > scale > out_h/out_w
if (scale > 0) {
out_h = static_cast<int>(x_h * scale);
out_w = static_cast<int>(x_w * scale);
......@@ -58,18 +56,17 @@ node_map_type InterpolateConverter(
out_w = out_w > 0 ? out_w : -1;
}
// update out_h and out_w if has OutSize
// Update out_h and out_w if has OutSize
std::shared_ptr<ge::Operator> out_size_node = nullptr;
if (lite::npu::HasInputArg(op_info, scope, "OutSize")) {
if (HasInputArg(op_info, scope, "OutSize")) {
auto out_size_var_name = op_info->Input("OutSize").front();
if (inputs_map.count(out_size_var_name)) {
out_size_node = inputs_map.at(out_size_var_name);
if (graph->HasNode(out_size_var_name)) {
out_size_node = graph->GetNode(out_size_var_name);
} else {
auto out_size =
scope->FindVar(out_size_var_name)->GetMutable<lite::Tensor>();
auto out_size = scope->FindVar(out_size_var_name)->GetMutable<Tensor>();
CHECK_EQ(out_size->numel(), 2);
auto out_size_data = out_size->mutable_data<int>();
// update out_h and out_w if has OutSize
// Update out_h and out_w if has OutSize
out_h = out_size_data[0];
out_w = out_size_data[1];
}
......@@ -83,46 +80,37 @@ node_map_type InterpolateConverter(
<< " is too large, should not exceed " << largest_multiple
<< " in HiAI DDK";
}
auto out_size_const_node =
std::make_shared<ge::op::Const>(unique_op_type + "/out_size");
out_size_const_node->set_attr_value(
lite::npu::CreateTensorAndFillData(std::vector<int>({out_h, out_w})));
out_size_node = out_size_const_node;
out_size_node = graph->AddNode(out_var_name + "/out_size",
std::vector<int>({out_h, out_w}));
}
lite::npu::OpList::Global().add(out_size_node);
std::shared_ptr<ge::Operator> interp_node = nullptr;
if (interp_method == "bilinear") {
auto bilinear_interp_node =
std::make_shared<ge::op::ResizeBilinear>(unique_op_type);
bilinear_interp_node->set_input_x(*inputs_map.at(x_var_name));
graph->AddNode<ge::op::ResizeBilinear>(out_var_name);
bilinear_interp_node->set_input_x(*graph->GetNode(x_var_name));
bilinear_interp_node->set_input_size(*out_size_node);
bilinear_interp_node->set_attr_align_corners(align_corners);
interp_node = bilinear_interp_node;
} else if (interp_method == "nearest") {
auto nearest_interp_node =
std::make_shared<ge::op::ResizeNearestNeighbor>(unique_op_type);
nearest_interp_node->set_input_image(*inputs_map.at(x_var_name));
graph->AddNode<ge::op::ResizeNearestNeighbor>(out_var_name);
nearest_interp_node->set_input_image(*graph->GetNode(x_var_name));
nearest_interp_node->set_input_size(*out_size_node);
nearest_interp_node->set_attr_align_corners(align_corners);
interp_node = nearest_interp_node;
} else {
LOG(FATAL) << "[NPU] Unsupported interpolate method: " << interp_method;
LOG(WARNING) << "[NPU] Unsupported interpolate method: " << interp_method;
return FAILED;
}
lite::npu::OpList::Global().add(interp_node);
node_map_type outputs_map;
outputs_map[op_info->Output("Out").front()] = interp_node;
return outputs_map;
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace bridges
} // namespace npu
} // namespace kernels
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_NPU_BRIDGE(bilinear_interp,
paddle::lite::kernels::npu::bridges::InterpolateConverter);
REGISTER_NPU_BRIDGE(nearest_interp,
paddle::lite::kernels::npu::bridges::InterpolateConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
bilinear_interp,
paddle::lite::subgraph::npu::InterpolateConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
nearest_interp,
paddle::lite::subgraph::npu::InterpolateConverter);
......@@ -12,24 +12,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/graph.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace subgraph {
namespace npu {
namespace bridges {
// Note: inputs_map the var_name contains only the data, the weight should be
// handle in this converter
node_map_type MulConverter(const std::shared_ptr<lite::OpLite> mul_op,
const node_map_type& inputs_map) {
auto scope = mul_op->scope();
auto op_info = mul_op->op_info();
// Note: all of the input weight vars should be handled in this converter
int MulConverter(void* ctx, OpLite* op) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "[NPU] Converting " + op_type + "...";
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
auto x_var_name = op_info->Input("X").front();
auto y_var_name = op_info->Input("Y").front();
......@@ -37,6 +37,7 @@ node_map_type MulConverter(const std::shared_ptr<lite::OpLite> mul_op,
auto y = scope->FindVar(y_var_name)->GetMutable<lite::Tensor>();
auto x_dims = x->dims();
auto y_dims = y->dims();
auto out_var_name = op_info->Output("Out").front();
int x_num_col_dims = op_info->GetAttr<int>("x_num_col_dims");
int y_num_col_dims = op_info->GetAttr<int>("y_num_col_dims");
int m = x_dims.Slice(0, x_num_col_dims).production();
......@@ -44,61 +45,47 @@ node_map_type MulConverter(const std::shared_ptr<lite::OpLite> mul_op,
CHECK_EQ(k, y_dims.Slice(0, y_num_col_dims).production())
<< "[NPU] columns of X must be equal with rows of Y";
int n = y_dims.Slice(y_num_col_dims, y_dims.size()).production();
LOG(INFO) << "m:" << m << ",n:" << n << ",k:" << k;
LOG(INFO) << "x_var_name:" << x_var_name
<< ", is data: " << inputs_map.count(x_var_name);
LOG(INFO) << "y_var_name:" << y_var_name
<< ", is data: " << inputs_map.count(y_var_name);
CHECK(inputs_map.count(x_var_name))
VLOG(3) << "m:" << m << ",n:" << n << ",k:" << k;
VLOG(3) << "x_var_name:" << x_var_name
<< ", is data: " << graph->HasNode(x_var_name);
VLOG(3) << "y_var_name:" << y_var_name
<< ", is data: " << graph->HasNode(y_var_name);
CHECK(graph->HasNode(x_var_name))
<< "[NPU] MatMul in HiAI DDK only support X is data, Y is const yet.";
auto mul_node = std::make_shared<ge::op::MatMul>(unique_op_type);
// add input x node which supports persistable and non-persistable tensor, and
auto mul_node = graph->AddNode<ge::op::MatMul>(out_var_name);
// Add input x node which supports persistable and non-persistable tensor, and
// reshape to (m, k)
if (inputs_map.count(x_var_name)) {
if (graph->HasNode(x_var_name)) {
auto reshaped_x_node =
std::make_shared<ge::op::Reshape>(x_var_name + "_reshape");
reshaped_x_node->set_input_tensor(*inputs_map.at(x_var_name));
graph->AddNode<ge::op::Reshape>(x_var_name + "/reshape");
reshaped_x_node->set_input_tensor(*graph->GetNode(x_var_name));
reshaped_x_node->set_attr_shape({m, k});
reshaped_x_node->set_attr_axis(0);
mul_node->set_input_x1(*reshaped_x_node);
lite::npu::OpList::Global().add(inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(reshaped_x_node);
} else {
auto x_const_node = std::make_shared<ge::op::Const>(x_var_name);
x_const_node->set_attr_value(lite::npu::CvtTensor(x, {m, k}));
auto x_const_node = graph->AddNode(x_var_name, *x, {m, k});
mul_node->set_input_x1(*x_const_node);
lite::npu::OpList::Global().add(x_const_node);
}
// add input y node which only supports persistable tensor, and reshape to (k,
// n)
if (inputs_map.count(y_var_name)) {
// Add input y node which only supports persistable tensor, and reshape to
// (k,n)
if (graph->HasNode(y_var_name)) {
auto reshaped_y_node =
std::make_shared<ge::op::Reshape>(y_var_name + "_reshape");
reshaped_y_node->set_input_tensor(*inputs_map.at(y_var_name));
graph->AddNode<ge::op::Reshape>(y_var_name + "/reshape");
reshaped_y_node->set_input_tensor(*graph->GetNode(y_var_name));
reshaped_y_node->set_attr_shape({k, n});
reshaped_y_node->set_attr_axis(0);
mul_node->set_input_x2(*reshaped_y_node);
lite::npu::OpList::Global().add(inputs_map.at(y_var_name));
lite::npu::OpList::Global().add(reshaped_y_node);
} else {
auto y_const_node = std::make_shared<ge::op::Const>(y_var_name);
y_const_node->set_attr_value(lite::npu::CvtTensor(y, {k, n}));
auto y_const_node = graph->AddNode(y_var_name, *y, {k, n});
mul_node->set_input_x2(*y_const_node);
lite::npu::OpList::Global().add(y_const_node);
}
lite::npu::OpList::Global().add(mul_node);
node_map_type outputs_map;
outputs_map[op_info->Output("Out").front()] = mul_node;
return outputs_map;
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace bridges
} // namespace npu
} // namespace kernels
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_NPU_BRIDGE(mul, paddle::lite::kernels::npu::bridges::MulConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU, mul, paddle::lite::subgraph::npu::MulConverter);
......@@ -12,38 +12,39 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/graph.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace subgraph {
namespace npu {
namespace bridges {
node_map_type Pad2dConverter(const std::shared_ptr<lite::OpLite> pad2d_op,
const node_map_type& inputs_map) {
auto scope = pad2d_op->scope();
auto op_info = pad2d_op->op_info();
int Pad2dConverter(void* ctx, OpLite* op) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "[NPU] Converting " + op_type + "...";
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
std::shared_ptr<ge::op::Pad> pad2d_node =
std::make_shared<ge::op::Pad>(unique_op_type);
auto x_var_name = op_info->Input("X").front();
pad2d_node->set_input_x(*inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(pad2d_node);
auto out_var_name = op_info->Output("Out").front();
auto pad2d_node = graph->AddNode<ge::op::Pad>(out_var_name);
pad2d_node->set_input_x(*graph->GetNode(x_var_name));
auto mode = op_info->GetAttr<std::string>("mode");
if (mode == "constant") {
pad2d_node->set_attr_mode(0);
} else if (mode == "reflect") {
LOG(FATAL) << "[NPU] pad mode " << mode << " isn't supported in HiAI DDK";
LOG(WARNING) << "[NPU] pad mode " << mode << " isn't supported in HiAI DDK";
pad2d_node->set_attr_mode(1);
return FAILED;
} else {
LOG(FATAL) << "[NPU] pad mode " << mode << " isn't supported in HiAI DDK";
LOG(WARNING) << "[NPU] pad mode " << mode << " isn't supported in HiAI DDK";
return FAILED;
}
auto x_dims = scope->FindTensor(x_var_name)->dims();
......@@ -51,34 +52,25 @@ node_map_type Pad2dConverter(const std::shared_ptr<lite::OpLite> pad2d_op,
CHECK_EQ(padding.size(), 4);
int xds = x_dims.size();
padding.insert(padding.begin(), xds * 2 - 4, 0);
auto npu_padding =
std::make_shared<ge::op::Const>(unique_op_type + "/padding");
npu_padding->set_attr_value(
lite::npu::CreateTensorAndFillData<int>(padding, {xds, 2}));
pad2d_node->set_input_padding(*npu_padding);
lite::npu::OpList::Global().add(npu_padding);
auto padding_const_node =
graph->AddNode(out_var_name + "/padding", padding, {xds, 2});
pad2d_node->set_input_padding(*padding_const_node);
if (mode == "constant") {
auto pad_value = op_info->GetAttr<float>("pad_value");
auto npu_pad_value =
std::make_shared<ge::op::Const>(unique_op_type + "/pad_value");
npu_pad_value->set_attr_value(
lite::npu::CreateTensorAndFillData<float>({pad_value}));
pad2d_node->set_input_constant_values(*npu_pad_value);
lite::npu::OpList::Global().add(npu_pad_value);
auto pad_value_const_node =
graph->AddNode(out_var_name + "/pad_value", pad_value);
pad2d_node->set_input_constant_values(*pad_value_const_node);
pad2d_node->set_attr_T(0); // type of pad_value: 0:float 3:int32
}
node_map_type outputs_map;
outputs_map[op_info->Output("Out").front()] = pad2d_node;
return outputs_map;
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace bridges
} // namespace npu
} // namespace kernels
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_NPU_BRIDGE(pad2d, paddle::lite::kernels::npu::bridges::Pad2dConverter);
REGISTER_SUBGRAPH_BRIDGE(NPU,
pad2d,
paddle::lite::subgraph::npu::Pad2dConverter);
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -7,7 +7,7 @@ ARM_ABI="armv8" # armv8, armv7
ARM_LANG="gcc" # gcc only yet
ANDROID_STL="c++_shared" # c++_shared/c++_static, c++_shared is used by HiAI DDK 310
DDK_ROOT="$(pwd)/ai_ddk_lib/" # HiAI DDK 310 from https://developer.huawei.com/consumer/cn/hiai/
TARGET_NAME="test_npu_pass" # default target
TARGET_NAME="test_subgraph_pass" # default target
BUILD_EXTRA=OFF # ON(with sequence ops)/OFF
WITH_JAVA=ON # ON(build jar and jni so)/OFF
WITH_TESTING=ON # ON/OFF
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册