未验证 提交 c8bd5e26 编写于 作者: H HappyAngel 提交者: GitHub

Merge pull request #147 from PaddlePaddle/develop

pull
......@@ -10,7 +10,6 @@ class CxxPredictor
```python
from paddlelite.lite import *
from lite_core import *
# 1. 设置CxxConfig
config = CxxConfig()
......
......@@ -12,7 +12,6 @@ Tensor是Paddle-Lite的数据组织形式,用于对底层数据进行封装并
```python
from paddlelite.lite import *
from lite_core import *
# 1. 设置CxxConfig
config = CxxConfig()
......
......@@ -48,7 +48,7 @@ class Place{
示例:
```python
from lite_core import *
from paddlelite.lite import *
Place{TargetType(ARM), PrecisionType(FP32), DataLayoutType(NCHW)}
```
......@@ -2,29 +2,31 @@
X2Paddle可以将caffe、tensorflow、onnx模型转换成Paddle支持的模型。
[X2Paddle](https://github.com/PaddlePaddle/X2Paddle)支持将Caffe/TensorFlow模型转换为PaddlePaddle模型。目前X2Paddle支持的模型参考[x2paddle_model_zoo](https://github.com/PaddlePaddle/X2Paddle/blob/develop/x2paddle_model_zoo.md)
[X2Paddle](https://github.com/PaddlePaddle/X2Paddle)支持将Caffe/TensorFlow模型转换为PaddlePaddle模型。
支持的模型可参考**X2Paddle模型测试库:**
https://github.com/PaddlePaddle/X2Paddle/blob/develop/x2paddle_model_zoo.md
## 多框架支持
|模型 | caffe | tensorflow | onnx |
|模型 | caffe | tensorflow | onnx |
|---|---|---|---|
|mobilenetv1 | Y | Y | |
|mobilenetv2 | Y | Y | Y |
|resnet18 | Y | Y | |
|resnet50 | Y | Y | Y |
|mnasnet | Y | Y | |
|efficientnet | Y | Y | Y |
|squeezenetv1.1 | Y | Y | Y |
|shufflenet | Y | Y | |
|mobilenet_ssd | Y | Y | |
|mobilenet_yolov3 | | Y | |
|inceptionv4 | | | |
|mtcnn | Y | Y | |
|facedetection | Y | | |
|unet | Y | Y | |
|ocr_attention | | | |
|vgg16 | | | |
|mobilenetv1 | Y | Y | |
|mobilenetv2 | Y | Y | Y |
|resnet18 | Y | Y | |
|resnet50 | Y | Y | Y |
|mnasnet | Y | Y | |
|efficientnet | Y | Y | Y |
|squeezenetv1.1 | Y | Y | Y |
|shufflenet | Y | Y | |
|mobilenet_ssd | Y | Y | |
|mobilenet_yolov3 | | Y | |
|inceptionv4 | | | |
|mtcnn | Y | Y | |
|facedetection | Y | | |
|unet | Y | Y | |
|ocr_attention | | | |
|vgg16 | | | |
## 安装
......
......@@ -41,10 +41,26 @@ namespace lite_api {
bool IsOpenCLBackendValid() {
bool opencl_valid = false;
#ifdef LITE_WITH_OPENCL
bool opencl_lib_found = paddle::lite::CLWrapper::Global()->OpenclLibFound();
#ifdef LITE_WITH_LOG
LOG(INFO) << "opencl_lib_found:" << opencl_lib_found;
#endif
if (opencl_lib_found == false) return false;
bool dlsym_success = paddle::lite::CLWrapper::Global()->DlsymSuccess();
#ifdef LITE_WITH_LOG
LOG(INFO) << "dlsym_success:" << dlsym_success;
#endif
if (dlsym_success == false) return false;
opencl_valid = paddle::lite::CLRuntime::Global()->OpenCLAvaliableForDevice();
#endif
#ifdef LITE_WITH_LOG
LOG(INFO) << "opencl_valid:" << opencl_valid;
#endif
return opencl_valid;
}
......@@ -62,50 +78,28 @@ void Tensor::Resize(const shape_t &shape) {
tensor(raw_tensor_)->Resize(shape);
}
// Tensor::data
template <>
const float *Tensor::data() const {
return ctensor(raw_tensor_)->data<float>();
}
template <>
const int8_t *Tensor::data() const {
return ctensor(raw_tensor_)->data<int8_t>();
}
template <>
const uint8_t *Tensor::data() const {
return ctensor(raw_tensor_)->data<uint8_t>();
}
template <>
const int64_t *Tensor::data() const {
return ctensor(raw_tensor_)->data<int64_t>();
}
template <>
const int32_t *Tensor::data() const {
return ctensor(raw_tensor_)->data<int32_t>();
template <typename T>
const T *Tensor::data() const {
return ctensor(raw_tensor_)->data<T>();
}
// Tensor::mutable_data
template <>
int *Tensor::mutable_data(TargetType type) const {
return tensor(raw_tensor_)->mutable_data<int>(type);
}
template <>
float *Tensor::mutable_data(TargetType type) const {
return tensor(raw_tensor_)->mutable_data<float>(type);
}
template <>
int8_t *Tensor::mutable_data(TargetType type) const {
return tensor(raw_tensor_)->mutable_data<int8_t>(type);
}
template <>
uint8_t *Tensor::mutable_data(TargetType type) const {
return tensor(raw_tensor_)->mutable_data<uint8_t>(type);
}
template <>
int64_t *Tensor::mutable_data(TargetType type) const {
return tensor(raw_tensor_)->mutable_data<int64_t>(type);
template <typename T>
T *Tensor::mutable_data(TargetType type) const {
return tensor(raw_tensor_)->mutable_data<T>(type);
}
template const float *Tensor::data<float>() const;
template const int8_t *Tensor::data<int8_t>() const;
template const uint8_t *Tensor::data<uint8_t>() const;
template const int64_t *Tensor::data<int64_t>() const;
template const int32_t *Tensor::data<int32_t>() const;
template int *Tensor::mutable_data(TargetType type) const;
template float *Tensor::mutable_data(TargetType type) const;
template int8_t *Tensor::mutable_data(TargetType type) const;
template uint8_t *Tensor::mutable_data(TargetType type) const;
template int64_t *Tensor::mutable_data(TargetType type) const;
template <typename T, TargetType type>
void Tensor::CopyFromCpu(const T *src_data) {
T *data = tensor(raw_tensor_)->mutable_data<T>(type);
......@@ -244,6 +238,18 @@ ConfigBase::ConfigBase(PowerMode mode, int threads) {
#endif
}
void ConfigBase::set_opencl_tune(bool enable_tune) {
#ifdef LITE_WITH_OPENCL
if (paddle::lite_api::IsOpenCLBackendValid()) {
enable_opencl_tune_ = enable_tune;
paddle::lite::CLRuntime::Global()->set_auto_tune(enable_opencl_tune_);
#ifdef LITE_WITH_OPENCL
LOG(INFO) << "auto_tune:" << paddle::lite::CLRuntime::Global()->auto_tune();
#endif
}
#endif
}
void ConfigBase::set_power_mode(paddle::lite_api::PowerMode mode) {
#ifdef LITE_WITH_ARM
lite::DeviceInfo::Global().SetRunMode(mode, threads_);
......
......@@ -124,6 +124,8 @@ class LITE_API ConfigBase {
std::string model_dir_;
int threads_{1};
PowerMode mode_{LITE_POWER_NO_BIND};
// gpu
bool enable_opencl_tune_{false};
// to save subgraph model for npu/xpu/...
std::string subgraph_model_cache_dir_{""};
int device_id_{0};
......@@ -139,6 +141,9 @@ class LITE_API ConfigBase {
// set Power_mode
void set_power_mode(PowerMode mode);
PowerMode power_mode() const { return mode_; }
// set GPU opencl tune
void set_opencl_tune(bool enable_tune);
bool opencl_tune() const { return enable_opencl_tune_; }
// set subgraph_model_dir
void set_subgraph_model_cache_dir(std::string subgraph_model_cache_dir) {
subgraph_model_cache_dir_ = subgraph_model_cache_dir;
......
......@@ -67,6 +67,15 @@ bool Device::Build(std::vector<ge::Operator>& input_nodes, // NOLINT
std::lock_guard<std::mutex> lock(device_mutex_);
// Convert the HiAI IR graph to the HiAI om model
ge::Graph ir_graph("graph");
// set input node attr index is node size > 1
if (input_nodes.size() > 1) {
int idx = 0;
for (auto node : input_nodes) {
node.SetAttr("index", idx);
idx++;
}
}
VLOG(3) << "Getting input node size " << input_nodes.size();
ir_graph.SetInputs(input_nodes).SetOutputs(output_nodes);
// Build IR model
......
......@@ -24,50 +24,28 @@ bool AclModelClient::LoadFromMem(const void* data, uint32_t size) {
return true;
}
auto ret = aclmdlQuerySizeFromMem(
data, size, &model_memory_size_, &model_weight_size_);
if (ret != ACL_ERROR_NONE) {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] query model size from memory failed!";
return false;
}
ret = aclrtMalloc(
&model_memory_ptr_, model_memory_size_, ACL_MEM_MALLOC_HUGE_FIRST);
if (ret != ACL_ERROR_NONE) {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] malloc buffer for model memory "
"failed, require size is "
<< model_memory_size_;
return false;
}
ret = aclrtMalloc(
&model_weight_ptr_, model_weight_size_, ACL_MEM_MALLOC_HUGE_FIRST);
if (ret != ACL_ERROR_NONE) {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] malloc buffer for model weigth "
"failed, require size is "
<< model_weight_size_;
return false;
}
ret = aclmdlLoadFromMemWithMem(data,
size,
&model_id_,
model_memory_ptr_,
model_memory_size_,
model_weight_ptr_,
model_weight_size_);
if (ret != ACL_ERROR_NONE) {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] Load model from memory failed!";
return false;
}
ACL_CALL(aclmdlQuerySizeFromMem(
data, size, &model_memory_size_, &model_weight_size_));
ACL_CALL(aclrtMalloc(
&model_memory_ptr_, model_memory_size_, ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CALL(aclrtMalloc(
&model_weight_ptr_, model_weight_size_, ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CALL(aclmdlLoadFromMemWithMem(data,
size,
&model_id_,
model_memory_ptr_,
model_memory_size_,
model_weight_ptr_,
model_weight_size_));
model_desc_ = aclmdlCreateDesc();
if (model_desc_ == nullptr) {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] create model description failed!";
return false;
}
ret = aclmdlGetDesc(model_desc_, model_id_);
if (ret != ACL_ERROR_NONE) {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] get model description failed!";
return false;
}
VLOG(3) << "[HUAWEI_ASCEND_NPU] AclModelClient LoadFromMem success.";
ACL_CALL(aclmdlGetDesc(model_desc_, model_id_));
VLOG(3) << "[HUAWEI_ASCEND_NPU] Load model form memeory success.";
load_flag_ = true;
return true;
}
......@@ -77,49 +55,28 @@ bool AclModelClient::LoadFromFile(const char* model_path) {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] model is already loaded!";
return true;
}
auto ret =
aclmdlQuerySize(model_path, &model_memory_size_, &model_weight_size_);
if (ret != ACL_ERROR_NONE) {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] query model size from file failed!";
return false;
}
ret = aclrtMalloc(
&model_memory_ptr_, model_memory_size_, ACL_MEM_MALLOC_HUGE_FIRST);
if (ret != ACL_ERROR_NONE) {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] malloc buffer for model memory "
"failed, require size is "
<< model_memory_size_;
return false;
}
ret = aclrtMalloc(
&model_weight_ptr_, model_weight_size_, ACL_MEM_MALLOC_HUGE_FIRST);
if (ret != ACL_ERROR_NONE) {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] malloc buffer for model weigth "
"failed, require size is "
<< model_weight_size_;
return false;
}
ret = aclmdlLoadFromFileWithMem(model_path,
&model_id_,
model_memory_ptr_,
model_memory_size_,
model_weight_ptr_,
model_weight_size_);
if (ret != ACL_ERROR_NONE) {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] Load model from file failed!";
return false;
}
ACL_CALL(
aclmdlQuerySize(model_path, &model_memory_size_, &model_weight_size_));
ACL_CALL(aclrtMalloc(
&model_memory_ptr_, model_memory_size_, ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CALL(aclrtMalloc(
&model_weight_ptr_, model_weight_size_, ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CALL(aclmdlLoadFromFileWithMem(model_path,
&model_id_,
model_memory_ptr_,
model_memory_size_,
model_weight_ptr_,
model_weight_size_));
model_desc_ = aclmdlCreateDesc();
if (model_desc_ == nullptr) {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] create model description failed!";
return false;
}
ret = aclmdlGetDesc(model_desc_, model_id_);
if (ret != ACL_ERROR_NONE) {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] get model description failed!";
return false;
}
VLOG(3) << "[HUAWEI_ASCEND_NPU] Loading model file success:" << model_path;
ACL_CALL(aclmdlGetDesc(model_desc_, model_id_));
VLOG(3) << "[HUAWEI_ASCEND_NPU] Load model form file success: " << model_path;
load_flag_ = true;
return true;
}
......@@ -132,33 +89,25 @@ bool AclModelClient::GetModelIOTensorDim(
return false;
}
size_t input_num = aclmdlGetNumInputs(model_desc_);
VLOG(3) << "[HUAWEI_ASCEND_NPU] input numher is " << input_num;
VLOG(3) << "[HUAWEI_ASCEND_NPU] input number is " << input_num;
for (size_t i = 0; i < input_num; i++) {
VLOG(3) << "[HUAWEI_ASCEND_NPU] printing input [" << i << "] ....";
aclmdlIODims input_dim;
aclmdlGetInputDims(model_desc_, i, &input_dim);
ACL_CALL(aclmdlGetInputDims(model_desc_, i, &input_dim));
aclDataType data_type = aclmdlGetInputDataType(model_desc_, i);
VLOG(3) << "[HUAWEI_ASCEND_NPU] data_type of inputs[" << i << "] is "
<< data_type;
aclFormat data_format = aclmdlGetInputFormat(model_desc_, i);
VLOG(3) << "[HUAWEI_ASCEND_NPU] data_format of inputs[" << i << "] is "
<< data_format;
TensorDesc tensor_desc = TensorDesc(data_type, input_dim, data_format);
input_tensor->push_back(tensor_desc);
}
size_t output_num = aclmdlGetNumOutputs(model_desc_);
VLOG(3) << "[HUAWEI_ASCEND_NPU] output numher is " << output_num;
VLOG(3) << "[HUAWEI_ASCEND_NPU] output number is " << output_num;
for (size_t i = 0; i < output_num; i++) {
VLOG(3) << "[HUAWEI_ASCEND_NPU] printing output [" << i << "] ....";
aclmdlIODims output_dim;
aclmdlGetOutputDims(model_desc_, i, &output_dim);
ACL_CALL(aclmdlGetOutputDims(model_desc_, i, &output_dim));
aclDataType data_type = aclmdlGetOutputDataType(model_desc_, i);
VLOG(3) << "[HUAWEI_ASCEND_NPU] data_type of outputs[" << i << "] is "
<< data_type;
aclFormat data_format = aclmdlGetOutputFormat(model_desc_, i);
VLOG(3) << "[HUAWEI_ASCEND_NPU] data_format of outputs[" << i << "] is "
<< data_format;
TensorDesc tensor_desc = TensorDesc(data_type, output_dim, data_format);
output_tensor->push_back(tensor_desc);
}
......@@ -181,28 +130,16 @@ bool AclModelClient::GetTensorFromDataset(
uint32_t device_size = aclGetDataBufferSize(buffer_device);
void* tensor_data = nullptr;
aclError ret = aclrtMallocHost(&tensor_data, device_size);
if (ret != ACL_ERROR_NONE) {
LOG(ERROR) << "[HUAWEI_ASCEND_NPU] aclrtMallocHost failed, ret " << ret;
return false;
}
ret = aclrtMemcpy(tensor_data,
device_size,
device_data,
device_size,
ACL_MEMCPY_DEVICE_TO_HOST);
if (ret != ACL_ERROR_NONE) {
LOG(ERROR) << "[HUAWEI_ASCEND_NPU] aclrtMemcpy failed, ret " << ret;
return false;
}
if (output_tensor->at(i)->SetData(reinterpret_cast<uint8_t*>(tensor_data),
device_size) != ge::GRAPH_SUCCESS) {
LOG(ERROR) << "[HUAWEI_ASCEND_NPU] SetData to output tensor failed";
return false;
}
}
VLOG(3)
<< "[HUAWEI_ASCEND_NPU] Get output tensor from output dataset succeed.";
ACL_CALL(aclrtMallocHost(&tensor_data, device_size));
ACL_CALL(aclrtMemcpy(tensor_data,
device_size,
device_data,
device_size,
ACL_MEMCPY_DEVICE_TO_HOST));
ATC_CALL(output_tensor->at(i)->SetData(
reinterpret_cast<uint8_t*>(tensor_data), device_size));
}
VLOG(3) << "[HUAWEI_ASCEND_NPU] Get output tensor from dataset succeed.";
return true;
}
......@@ -218,37 +155,33 @@ void AclModelClient::CreateInputDataset(
auto item = input_tensor->at(i);
size_t buffer_size = item->GetSize();
void* buffer_device = nullptr;
aclError ret =
aclrtMalloc(&buffer_device, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY);
if (ret != ACL_ERROR_NONE) {
LOG(ERROR)
<< "[HUAWEI_ASCEND_NPU] input malloc device buffer failed. size is "
<< buffer_size;
return;
}
ACL_CALL(
aclrtMalloc(&buffer_device, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY));
void* buffer_data = reinterpret_cast<void*>(item->GetData());
ret = aclrtMemcpy(buffer_device,
buffer_size,
buffer_data,
buffer_size,
ACL_MEMCPY_HOST_TO_DEVICE);
auto ret = aclrtMemcpy(buffer_device,
buffer_size,
buffer_data,
buffer_size,
ACL_MEMCPY_HOST_TO_DEVICE);
if (ret != ACL_ERROR_NONE) {
LOG(ERROR) << "[HUAWEI_ASCEND_NPU] input memcpy failed, buffer size is "
<< buffer_size;
aclrtFree(buffer_device);
ACL_CALL(aclrtFree(buffer_device));
return;
}
aclDataBuffer* data_buffer =
aclCreateDataBuffer(buffer_device, buffer_size);
if (data_buffer == nullptr) {
LOG(ERROR) << "[HUAWEI_ASCEND_NPU] output aclCreateDataBuffer failed!";
aclrtFree(buffer_device);
ACL_CALL(aclrtFree(buffer_device));
return;
}
if (aclmdlAddDatasetBuffer(input_dataset_, data_buffer) != ACL_ERROR_NONE) {
LOG(ERROR) << "[HUAWEI_ASCEND_NPU] input aclmdlAddDatasetBuffer failed!";
aclrtFree(buffer_device);
aclDestroyDataBuffer(data_buffer);
ACL_CALL(aclrtFree(buffer_device));
ACL_CALL(aclDestroyDataBuffer(data_buffer));
return;
}
}
......@@ -266,26 +199,20 @@ void AclModelClient::CreateOutputDataset(
for (size_t i = 0; i < output_size; i++) {
size_t buffer_size = aclmdlGetOutputSizeByIndex(model_desc_, i);
void* buffer_device = nullptr;
aclError ret =
aclrtMalloc(&buffer_device, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY);
if (ret != ACL_ERROR_NONE) {
LOG(ERROR)
<< "[HUAWEI_ASCEND_NPU] output malloc device buffer failed. size is "
<< buffer_size;
return;
}
ACL_CALL(
aclrtMalloc(&buffer_device, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY));
aclDataBuffer* data_buffer =
aclCreateDataBuffer(buffer_device, buffer_size);
if (data_buffer == nullptr) {
LOG(ERROR) << "[HUAWEI_ASCEND_NPU] output aclCreateDataBuffer failed!";
aclrtFree(buffer_device);
ACL_CALL(aclrtFree(buffer_device));
return;
}
if (aclmdlAddDatasetBuffer(output_dataset_, data_buffer) !=
ACL_ERROR_NONE) {
LOG(ERROR) << "[HUAWEI_ASCEND_NPU] output aclmdlAddDatasetBuffer failed!";
aclrtFree(buffer_device);
aclDestroyDataBuffer(data_buffer);
ACL_CALL(aclrtFree(buffer_device));
ACL_CALL(aclDestroyDataBuffer(data_buffer));
return;
}
}
......@@ -332,21 +259,13 @@ void AclModelClient::DestroyDataset(aclmdlDataset** dataset) {
aclDataBuffer* buffer_device = aclmdlGetDatasetBuffer(*dataset, i);
void* device_data = aclGetDataBufferAddr(buffer_device);
if (device_data == nullptr) {
LOG(WARNING)
<< "[HUAWEI_ASCEND_NPU] failed to get data buffer of deivce data!";
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] failed to get data buffer!";
} else {
if (aclrtFree(device_data) != ACL_ERROR_NONE) {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] failed to free deivce data!";
}
}
if (aclDestroyDataBuffer(buffer_device) != ACL_ERROR_NONE) {
LOG(WARNING)
<< "[HUAWEI_ASCEND_NPU] failed to destroy deivce data buffer!";
ACL_CALL(aclrtFree(device_data));
}
ACL_CALL(aclDestroyDataBuffer(buffer_device));
}
if (aclmdlDestroyDataset(*dataset) != ACL_ERROR_NONE) {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] failed to destroy dataset!";
}
ACL_CALL(aclmdlDestroyDataset(*dataset));
*dataset = nullptr;
VLOG(3) << "[HUAWEI_ASCEND_NPU] Destroy dataset success.";
}
......@@ -361,24 +280,20 @@ bool AclModelClient::UnloadModel() {
DestroyDataset(&input_dataset_);
DestroyDataset(&output_dataset_);
aclError ret = aclmdlUnload(model_id_);
if (ret != ACL_ERROR_NONE) {
LOG(ERROR) << "unload model failed, model id is " << model_id_;
return false;
}
ACL_CALL(aclmdlUnload(model_id_));
if (model_desc_ != nullptr) {
(void)aclmdlDestroyDesc(model_desc_);
ACL_CALL(aclmdlDestroyDesc(model_desc_));
model_desc_ = nullptr;
}
if (model_memory_ptr_ != nullptr) {
aclrtFree(model_memory_ptr_);
ACL_CALL(aclrtFree(model_memory_ptr_));
model_memory_ptr_ = nullptr;
model_memory_size_ = 0;
}
if (model_weight_ptr_ != nullptr) {
aclrtFree(model_weight_ptr_);
ACL_CALL(aclrtFree(model_weight_ptr_));
model_weight_ptr_ = nullptr;
model_weight_size_ = 0;
}
......
......@@ -35,32 +35,39 @@ class TensorDesc {
ge_tensor_desc_ = new ge::TensorDesc(
GetGeShape(dims), GetGeFormat(format), GetGeDataType(data_type));
CHECK(ge_tensor_desc_ != nullptr);
VLOG(3) << "[HUAWEI_ASCEND_NPU] Getting data shape : " << repr();
}
~TensorDesc() { ge_tensor_desc_ = nullptr; }
int64_t GetNumber() const {
return ge_tensor_desc_->GetShape().GetDim(dim_order[0]);
}
int64_t GetChannel() const {
return ge_tensor_desc_->GetShape().GetDim(dim_order[1]);
}
int64_t GetHeight() const {
return ge_tensor_desc_->GetShape().GetDim(dim_order[2]);
const ge::TensorDesc& GetGeTensorDesc() const { return *ge_tensor_desc_; }
std::string repr() const {
STL::stringstream ss;
size_t dim_size = ge_tensor_desc_->GetShape().GetDimNum();
if (dim_size == 0) {
ss << "{}";
return ss.str();
}
ss << "{";
for (size_t i = 0; i < dim_size - 1; i++) {
ss << ge_tensor_desc_->GetShape().GetDim(i) << ",";
}
ss << ge_tensor_desc_->GetShape().GetDim(dim_size - 1);
ss << "}";
return ss.str();
}
int64_t GetWidth() const {
return ge_tensor_desc_->GetShape().GetDim(dim_order[3]);
int64_t production() const {
return ge_tensor_desc_->GetShape().GetShapeSize();
}
const ge::TensorDesc& GetGeTensorDesc() const { return *ge_tensor_desc_; }
private:
ge::Shape GetGeShape(aclmdlIODims dims) {
ge::Shape ge_shape({0, 0, 0, 0});
auto shape_data = std::vector<int64_t>({1L, 1L, 1L, 1L});
shape_data.resize(dims.dimCount);
ge::Shape ge_shape(shape_data);
for (size_t i = 0; i < dims.dimCount; i++) {
if (ge_shape.SetDim(i, dims.dims[i]) != ge::GRAPH_SUCCESS) {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] ge::Shape SetDim failed!";
} else {
VLOG(3) << "[HUAWEI_ASCEND_NPU] Setting Ge Shape[" << i << "] = <"
<< dims.dims[i] << ">";
}
ATC_CALL(ge_shape.SetDim(i, dims.dims[i]));
}
return ge_shape;
}
......@@ -80,6 +87,8 @@ class TensorDesc {
LOG(FATAL) << "[HUAWEI_ASCEND_NPU] format not supported:" << format;
break;
}
VLOG(3) << "[HUAWEI_ASCEND_NPU] Getting data format : "
<< CvtFormat(ge_format);
return ge_format;
}
ge::DataType GetGeDataType(aclDataType data_type) {
......@@ -110,6 +119,8 @@ class TensorDesc {
LOG(FATAL) << "[HUAWEI_ASCEND_NPU] data type not supported!";
break;
}
VLOG(3) << "[HUAWEI_ASCEND_NPU] Getting data type : "
<< CvtDataType(ge_datatype);
return ge_datatype;
}
......
......@@ -13,6 +13,8 @@
// limitations under the License.
#pragma once
#include <string>
#include "acl/acl.h"
#include "ge/ge_api_types.h"
#include "ge/ge_ir_build.h"
......@@ -21,11 +23,16 @@
#include "graph/tensor.h"
#include "graph/types.h"
#include "lite/utils/cp_logging.h"
#include "lite/utils/replace_stl/stream.h"
/*
* This file contains some Huawei Ascend NPU specific uitls.
*/
namespace paddle {
namespace lite {
namespace huawei_ascend_npu {
#define ACL_CALL(msg) \
CHECK_EQ(reinterpret_cast<aclError>(msg), ACL_ERROR_NONE) \
<< (msg) << " Huawei Ascend NPU ACL Error: " \
......@@ -38,10 +45,6 @@
<< ::paddle::lite::huawei_ascend_npu::AtcErrorInfo( \
reinterpret_cast<uint32_t>(msg))
namespace paddle {
namespace lite {
namespace huawei_ascend_npu {
static const char* AtcErrorInfo(uint32_t error) {
switch (error) {
#define LITE_ATC_ERROR_INFO(xx) \
......@@ -123,6 +126,61 @@ static const char* AclErrorInfo(int error) {
}
}
static const std::string& CvtFormat(ge::Format format) {
static const int MAX_FORMAT_LENGTH = 25;
static const std::string format2string[] = {
"FORMAT_NCHW = 0",
"FORMAT_NHWC = 1",
"FORMAT_ND = 2",
"FORMAT_NC1HWC0 = 3",
"FORMAT_FRACTAL_Z = 4",
"FORMAT_NC1C0HWPAD = 5",
"FORMAT_NHWC1C0 = 6",
"FORMAT_FSR_NCHW = 7",
"FORMAT_FRACTAL_DECONV = 8",
"FORMAT_C1HWNC0 = 9",
"FORMAT_FRACTAL_DECONV_TRANSPOSE = 10",
"FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS = 11",
"FORMAT_NC1HWC0_C04 = 12",
"FORMAT_FRACTAL_Z_C04 = 13",
"FORMAT_CHWN = 14",
"FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS = 15",
"FORMAT_HWCN = 16",
"FORMAT_NC1KHKWHWC0 = 17",
"FORMAT_BN_WEIGHT = 18",
"FORMAT_FILTER_HWCK = 19",
"FORMAT_HASHTABLE_LOOKUP_LOOKUPS = 20",
"FORMAT_HASHTABLE_LOOKUP_KEYS = 21",
"FORMAT_HASHTABLE_LOOKUP_VALUE = 22",
"FORMAT_HASHTABLE_LOOKUP_OUTPUT = 23",
"FORMAT_HASHTABLE_LOOKUP_HITS = 24"};
auto x = static_cast<int>(format);
CHECK_LT(x, MAX_FORMAT_LENGTH);
return format2string[x];
}
static const std::string& CvtDataType(ge::DataType data_type) {
static const int MAX_DATATYPE_LENGTH = 14;
static const std::string datatype2string[] = {"DT_FLOAT=0",
"DT_FLOAT16=1",
"DT_INT8=2",
"DT_INT32=3",
"DT_UINT8=4",
"Unknown=5",
"DT_INT16=6",
"DT_UINT16=7",
"DT_UINT32=8",
"DT_INT64=9",
"DT_UINT64=10",
"DT_DOUBLE=11",
"DT_BOOL=12",
"DT_STRING=13"};
auto x = static_cast<int>(data_type);
CHECK_LT(x, MAX_DATATYPE_LENGTH);
return datatype2string[x];
}
} // namespace huawei_ascend_npu
} // namespace lite
} // namespace paddle
......@@ -34,15 +34,20 @@ cl::Program &CLContext::GetProgram(const std::string &file_name,
std::string program_key = program_key_ss.str();
auto it = programs_.find(program_key);
if (it != programs_.end()) {
#ifdef LITE_WITH_LOG
VLOG(3) << " --- program -> " << program_key << " has been built --- ";
#endif
return *(it->second);
}
auto program = CLRuntime::Global()->CreateProgram(GetContext(), file_name);
#ifdef LITE_WITH_LOG
VLOG(3) << " --- begin build program -> " << program_key << " --- ";
#endif
CLRuntime::Global()->BuildProgram(program.get(), options);
#ifdef LITE_WITH_LOG
VLOG(3) << " --- end build program -> " << program_key << " --- ";
#endif
programs_[program_key] = std::move(program);
......@@ -54,14 +59,20 @@ void CLContext::AddKernel(const std::string &kernel_name,
const std::string &options,
const std::string &time_stamp) {
cl_int status{CL_SUCCESS};
#ifdef LITE_WITH_LOG
VLOG(3) << " --- to get program " << file_name << " --- ";
#endif
auto program = GetProgram(file_name, options);
#ifdef LITE_WITH_LOG
VLOG(3) << " --- end get program --- ";
VLOG(3) << " --- to create kernel: " << kernel_name << " --- ";
#endif
std::shared_ptr<cl::Kernel> kernel(
new cl::Kernel(program, kernel_name.c_str(), &status));
CL_CHECK_FATAL(status);
#ifdef LITE_WITH_LOG
VLOG(3) << " --- end create kernel --- ";
#endif
kernels_.emplace_back(std::move(kernel));
STL::stringstream kernel_key;
kernel_key << kernel_name << options << time_stamp;
......@@ -69,7 +80,9 @@ void CLContext::AddKernel(const std::string &kernel_name,
}
cl::Kernel &CLContext::GetKernel(const int index) {
#ifdef LITE_WITH_LOG
VLOG(3) << " --- kernel count: " << kernels_.size() << " --- ";
#endif
CHECK(static_cast<size_t>(index) < kernels_.size())
<< "The index must be less than the size of kernels.";
CHECK(kernels_[index] != nullptr)
......
......@@ -65,9 +65,11 @@ class CLContext {
cl::NDRange LocalWorkSizeTune(cl::NDRange global_work_size,
size_t max_work_size,
int divitor = 2);
cl::NDRange LocalWorkSizeTuneReverse(cl::NDRange global_work_size,
size_t max_work_size,
int divitor = 2);
bool IsArmMali();
// cl::NDRange LocalWorkSizeConv1x1(cl::NDRange global_work_size,
// size_t max_work_size);
......
......@@ -25,6 +25,13 @@ CLRuntime* CLRuntime::Global() {
}
CLRuntime::~CLRuntime() {
#ifdef LITE_WITH_LOG
LOG(INFO) << "is_cl_runtime_initialized_:" << is_cl_runtime_initialized_;
#endif
if (is_cl_runtime_initialized_ == false) {
return;
}
if (command_queue_ != nullptr) {
command_queue_->flush();
command_queue_->finish();
......@@ -38,18 +45,53 @@ CLRuntime::~CLRuntime() {
}
bool CLRuntime::Init() {
#ifdef LITE_WITH_LOG
LOG(INFO) << "is_cl_runtime_initialized_:" << is_cl_runtime_initialized_;
#endif
if (is_cl_runtime_initialized_) {
return true;
}
bool opencl_lib_found = paddle::lite::CLWrapper::Global()->OpenclLibFound();
#ifdef LITE_WITH_LOG
LOG(INFO) << "opencl_lib_found:" << opencl_lib_found;
#endif
if (opencl_lib_found == false) {
return false;
}
bool dlsym_success = paddle::lite::CLWrapper::Global()->DlsymSuccess();
#ifdef LITE_WITH_LOG
LOG(INFO) << "dlsym_success:" << dlsym_success;
#endif
if (dlsym_success == false) {
return false;
}
bool is_platform_init = InitializePlatform();
bool is_device_init = InitializeDevice();
#ifdef LITE_WITH_LOG
LOG(INFO) << "is_platform_init:" << is_platform_init;
#endif
if (is_platform_init == false) {
return false;
}
bool is_device_init = InitializeDevice();
#ifdef LITE_WITH_LOG
LOG(INFO) << "is_device_init:" << is_device_init;
#endif
if (is_device_init == false) {
return false;
}
if ((is_platform_init == true) && (is_device_init == true)) {
is_platform_device_init_success_ = true;
context_ = CreateContext();
command_queue_ = CreateCommandQueue(context());
is_cl_runtime_initialized_ = true;
#ifdef LITE_WITH_LOG
LOG(INFO) << "set is_cl_runtime_initialized_ = true";
#endif
}
return is_cl_runtime_initialized_;
}
......@@ -138,20 +180,24 @@ GpuType CLRuntime::ParseGpuTypeFromDeviceName(std::string device_name) {
const std::string kMALI_PATTERN_STR = "Mali";
const std::string kADRENO_PATTERN_STR = "QUALCOMM Adreno(TM)";
const std::string kPOWERVR_PATTERN_STR = "PowerVR";
std::string gpu_type_str = "";
if (device_name == kADRENO_PATTERN_STR) {
LOG(INFO) << "adreno gpu";
gpu_type_str = "adreno gpu";
return GpuType::QUALCOMM_ADRENO;
} else if (device_name.find(kMALI_PATTERN_STR) != std::string::npos) {
LOG(INFO) << "mali gpu";
gpu_type_str = "mali gpu";
return GpuType::ARM_MALI;
} else if (device_name.find(kPOWERVR_PATTERN_STR) != std::string::npos) {
LOG(INFO) << "powerVR gpu";
gpu_type_str = "powerVR gpu";
return GpuType::IMAGINATION_POWERVR;
} else {
LOG(INFO) << "others gpu";
gpu_type_str = "others gpu";
return GpuType::UNKNOWN;
}
#ifdef LITE_WITH_LOG
LOG(INFO) << "gpu_type_str:" << gpu_type_str;
#endif
}
bool CLRuntime::InitializeDevice() {
......
......@@ -70,27 +70,30 @@ class CLRuntime {
static CLRuntime* Global();
bool OpenCLAvaliableForDevice() {
bool opencl_lib_found = paddle::lite::CLWrapper::Global()->OpenclLibFound();
LOG(INFO) << "opencl_lib_found:" << opencl_lib_found;
if (opencl_lib_found == false) return false;
bool dlsym_success = paddle::lite::CLWrapper::Global()->DlsymSuccess();
LOG(INFO) << "dlsym_success:" << dlsym_success;
if (opencl_lib_found == false) return false;
// note(ysh329): entered this func means:
// 1. opencl_lib_found must be true
// 2. dlsym_success must be true
InitializeDevice();
bool support_fp16 =
static_cast<bool>(device_info_["CL_DEVICE_EXTENSIONS_FP16"]);
#ifdef LITE_WITH_LOG
LOG(INFO) << "support_fp16:" << support_fp16;
#endif
if (support_fp16 == false) return false;
is_device_avaliable_for_opencl_ =
dlsym_success && opencl_lib_found && support_fp16;
is_device_avaliable_for_opencl_ = support_fp16;
#ifdef LITE_WITH_LOG
LOG(INFO) << "is_device_avaliable_for_opencl_:"
<< is_device_avaliable_for_opencl_;
#endif
return is_device_avaliable_for_opencl_;
}
void set_auto_tune(bool enable_tune) { auto_tune_ = enable_tune; }
bool auto_tune() { return auto_tune_; }
bool Init();
cl::Platform& platform();
......@@ -195,6 +198,8 @@ class CLRuntime {
bool is_cl_runtime_initialized_{false};
bool is_platform_device_init_success_{false};
bool auto_tune_{false};
};
} // namespace lite
......
......@@ -347,18 +347,23 @@ class Context<TargetType::kX86> {
#ifdef LITE_WITH_OPENCL
template <>
class Context<TargetType::kOpenCL> {
std::shared_ptr<CLContext> cl_context_;
std::shared_ptr<CLContext> cl_context_{nullptr};
public:
CLContext* cl_context() { return cl_context_.get(); }
void InitOnce() {
// Init cl runtime.
CHECK(CLRuntime::Global()->IsInitSuccess()) << "OpenCL runtime init failed";
if (CLRuntime::Global()->IsInitSuccess() == false) {
LOG(ERROR) << "OpenCL runtime init failed";
}
cl_context_ = std::make_shared<CLContext>();
}
void CopySharedTo(OpenCLContext* ctx) { ctx->cl_context_ = cl_context_; }
void CopySharedTo(OpenCLContext* ctx) {
if (ctx && cl_context_) {
ctx->cl_context_ = cl_context_;
}
}
};
#endif
......
......@@ -117,11 +117,11 @@ void ConvConvFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
<< " must be 1";
}
for (int i = 0; i < paddings1.size(); i++) {
CHECK_EQ(paddings1[i], 1) << "paddings[" << i << "]: " << paddings1[i]
<< " must be 1";
CHECK_EQ(paddings1[i], 0) << "paddings1[" << i << "]: " << paddings1[i]
<< " must be 0";
}
for (int i = 0; i < dilations1.size(); i++) {
CHECK_EQ(dilations1[i], 1) << "dilations[" << i << "]: " << dilations1[i]
CHECK_EQ(dilations1[i], 1) << "dilations1[" << i << "]: " << dilations1[i]
<< " must be 1";
}
// comupte new_wight and new bias
......
......@@ -159,9 +159,12 @@ RuntimeProgram::RuntimeProgram(
int block_idx)
: exec_scope_(exec_scope) {
#ifdef LITE_WITH_OPENCL
bool opencl_valid = CLRuntime::Global()->OpenCLAvaliableForDevice();
using OpenCLContext = Context<TargetType::kOpenCL>;
std::unique_ptr<KernelContext> local_ctx(new KernelContext());
local_ctx->As<OpenCLContext>().InitOnce();
std::unique_ptr<KernelContext> unique_opencl_ctx(new KernelContext());
if (opencl_valid) {
unique_opencl_ctx->As<OpenCLContext>().InitOnce();
}
#endif
CHECK(program_desc);
auto block_size = program_desc->BlocksSize();
......@@ -227,15 +230,24 @@ RuntimeProgram::RuntimeProgram(
}
#ifdef LITE_WITH_OPENCL
if (kernel->target() == TARGET(kOpenCL)) {
std::unique_ptr<KernelContext> ctx(new KernelContext());
(*local_ctx).As<OpenCLContext>().CopySharedTo(&ctx->As<OpenCLContext>());
kernel->SetContext(std::move(ctx));
if (opencl_valid) {
std::unique_ptr<KernelContext> ctx(new KernelContext());
(*unique_opencl_ctx)
.As<OpenCLContext>()
.CopySharedTo(&ctx->As<OpenCLContext>());
kernel->SetContext(std::move(ctx));
} else {
LOG(ERROR) << "opencl_valid:" << opencl_valid;
}
} else {
kernel->SetContext(
ContextScheduler::Global().NewContext(kernel->target()));
}
#else
kernel->SetContext(ContextScheduler::Global().NewContext(kernel->target()));
if (kernel != nullptr) {
kernel->SetContext(
ContextScheduler::Global().NewContext(kernel->target()));
}
#endif
instructions_[kRootBlockIdx].emplace_back(std::move(op), std::move(kernel));
}
......
......@@ -92,6 +92,7 @@ void RunModel(std::string model_dir,
if (is_opencl_backend_valid) {
// give opencl nb model dir
config.set_model_from_file(model_dir);
config.set_opencl_tune(false); // default is false
} else {
std::cout << "Unsupport opencl nb model." << std::endl;
exit(1);
......
......@@ -208,6 +208,8 @@ REGISTER_LITE_KERNEL(lstm,
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Weight", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("C0", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("H0", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Hidden", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Cell", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("BatchGate", {LiteType::GetTensorTy(TARGET(kARM))})
......
......@@ -13,9 +13,11 @@
// limitations under the License.
#include "lite/kernels/arm/prior_box_compute.h"
#include <algorithm>
#include <string>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/target_wrapper.h"
namespace paddle {
namespace lite {
......@@ -46,9 +48,301 @@ inline void ExpandAspectRatios(const std::vector<float>& input_aspect_ratior,
}
}
void PriorBoxCompute::Run() {
auto& param = Param<operators::PriorBoxParam>();
inline void fast_free(void* ptr) {
if (ptr) {
free(static_cast<void**>(ptr)[-1]);
}
}
void density_prior_box(const lite::Tensor* input,
const lite::Tensor* image,
lite::Tensor* boxes,
lite::Tensor* variances,
const std::vector<float>& min_size_,
const std::vector<float>& fixed_size_,
const std::vector<float>& fixed_ratio_,
const std::vector<int>& density_size_,
const std::vector<float>& max_size_,
const std::vector<float>& aspect_ratio_,
const std::vector<float>& variance_,
int img_w_,
int img_h_,
float step_w_,
float step_h_,
float offset_,
int prior_num_,
bool is_flip_,
bool is_clip_,
const std::vector<std::string>& order_,
bool min_max_aspect_ratios_order) {
// compute output shape
int win1 = input->dims()[3];
int hin1 = input->dims()[2];
DDim shape_out({hin1, win1, prior_num_, 4});
boxes->Resize(shape_out);
variances->Resize(shape_out);
float* _cpu_data = boxes->mutable_data<float>();
float* _variance_data = variances->mutable_data<float>();
const int width = win1;
const int height = hin1;
int img_width = img_w_;
int img_height = img_h_;
if (img_width == 0 || img_height == 0) {
img_width = image->dims()[3];
img_height = image->dims()[2];
}
float step_w = step_w_;
float step_h = step_h_;
if (step_w == 0 || step_h == 0) {
step_w = static_cast<float>(img_width) / width;
step_h = static_cast<float>(img_height) / height;
}
float offset = offset_;
int step_average = static_cast<int>((step_w + step_h) * 0.5); // add
int channel_size = height * width * prior_num_ * 4;
int idx = 0;
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
float center_x = (w + offset) * step_w;
float center_y = (h + offset) * step_h;
float box_width;
float box_height;
if (fixed_size_.size() > 0) {
// add
for (int s = 0; s < fixed_size_.size(); ++s) {
int fixed_size = fixed_size_[s];
int com_idx = 0;
box_width = fixed_size;
box_height = fixed_size;
if (fixed_ratio_.size() > 0) {
for (int r = 0; r < fixed_ratio_.size(); ++r) {
float ar = fixed_ratio_[r];
int density = density_size_[s];
int shift = step_average / density;
float box_width_ratio = fixed_size_[s] * sqrt(ar);
float box_height_ratio = fixed_size_[s] / sqrt(ar);
for (int p = 0; p < density; ++p) {
for (int c = 0; c < density; ++c) {
float center_x_temp =
center_x - step_average / 2.0f + shift / 2.f + c * shift;
float center_y_temp =
center_y - step_average / 2.0f + shift / 2.f + p * shift;
// xmin
_cpu_data[idx++] =
(center_x_temp - box_width_ratio / 2.f) / img_width >= 0
? (center_x_temp - box_width_ratio / 2.f) / img_width
: 0;
// ymin
_cpu_data[idx++] =
(center_y_temp - box_height_ratio / 2.f) / img_height >= 0
? (center_y_temp - box_height_ratio / 2.f) /
img_height
: 0;
// xmax
_cpu_data[idx++] =
(center_x_temp + box_width_ratio / 2.f) / img_width <= 1
? (center_x_temp + box_width_ratio / 2.f) / img_width
: 1;
// ymax
_cpu_data[idx++] =
(center_y_temp + box_height_ratio / 2.f) / img_height <= 1
? (center_y_temp + box_height_ratio / 2.f) /
img_height
: 1;
}
}
}
} else {
// this code for density anchor box
if (density_size_.size() > 0) {
CHECK_EQ(fixed_size_.size(), density_size_.size())
<< "fixed_size_ should be same with density_size_";
int density = density_size_[s];
int shift = fixed_size_[s] / density;
for (int r = 0; r < density; ++r) {
for (int c = 0; c < density; ++c) {
float center_x_temp =
center_x - fixed_size / 2.f + shift / 2.f + c * shift;
float center_y_temp =
center_y - fixed_size / 2.f + shift / 2.f + r * shift;
// xmin
_cpu_data[idx++] =
(center_x_temp - box_width / 2.f) / img_width >= 0
? (center_x_temp - box_width / 2.f) / img_width
: 0;
// ymin
_cpu_data[idx++] =
(center_y_temp - box_height / 2.f) / img_height >= 0
? (center_y_temp - box_height / 2.f) / img_height
: 0;
// xmax
_cpu_data[idx++] =
(center_x_temp + box_width / 2.f) / img_width <= 1
? (center_x_temp + box_width / 2.f) / img_width
: 1;
// ymax
_cpu_data[idx++] =
(center_y_temp + box_height / 2.f) / img_height <= 1
? (center_y_temp + box_height / 2.f) / img_height
: 1;
}
}
}
// rest of priors: will never come here!!!
for (int r = 0; r < aspect_ratio_.size(); ++r) {
float ar = aspect_ratio_[r];
if (fabs(ar - 1.) < 1e-6) {
continue;
}
int density = density_size_[s];
int shift = fixed_size_[s] / density;
float box_width_ratio = fixed_size_[s] * sqrt(ar);
float box_height_ratio = fixed_size_[s] / sqrt(ar);
for (int p = 0; p < density; ++p) {
for (int c = 0; c < density; ++c) {
float center_x_temp =
center_x - fixed_size / 2.f + shift / 2.f + c * shift;
float center_y_temp =
center_y - fixed_size / 2.f + shift / 2.f + p * shift;
// xmin
_cpu_data[idx++] =
(center_x_temp - box_width_ratio / 2.f) / img_width >= 0
? (center_x_temp - box_width_ratio / 2.f) / img_width
: 0;
// ymin
_cpu_data[idx++] =
(center_y_temp - box_height_ratio / 2.f) / img_height >= 0
? (center_y_temp - box_height_ratio / 2.f) /
img_height
: 0;
// xmax
_cpu_data[idx++] =
(center_x_temp + box_width_ratio / 2.f) / img_width <= 1
? (center_x_temp + box_width_ratio / 2.f) / img_width
: 1;
// ymax
_cpu_data[idx++] =
(center_y_temp + box_height_ratio / 2.f) / img_height <= 1
? (center_y_temp + box_height_ratio / 2.f) /
img_height
: 1;
}
}
}
}
}
} else {
float* min_buf = reinterpret_cast<float*>(
TargetWrapper<TARGET(kHost)>::Malloc(sizeof(float) * 4));
float* max_buf = reinterpret_cast<float*>(
TargetWrapper<TARGET(kHost)>::Malloc(sizeof(float) * 4));
float* com_buf =
reinterpret_cast<float*>(TargetWrapper<TARGET(kHost)>::Malloc(
sizeof(float) * aspect_ratio_.size() * 4));
for (int s = 0; s < min_size_.size(); ++s) {
int min_idx = 0;
int max_idx = 0;
int com_idx = 0;
int min_size = min_size_[s];
// first prior: aspect_ratio = 1, size = min_size
box_width = box_height = min_size;
//! xmin
min_buf[min_idx++] = (center_x - box_width / 2.f) / img_width;
//! ymin
min_buf[min_idx++] = (center_y - box_height / 2.f) / img_height;
//! xmax
min_buf[min_idx++] = (center_x + box_width / 2.f) / img_width;
//! ymax
min_buf[min_idx++] = (center_y + box_height / 2.f) / img_height;
if (max_size_.size() > 0) {
int max_size = max_size_[s];
//! second prior: aspect_ratio = 1, size = sqrt(min_size * max_size)
box_width = box_height = sqrtf(min_size * max_size);
//! xmin
max_buf[max_idx++] = (center_x - box_width / 2.f) / img_width;
//! ymin
max_buf[max_idx++] = (center_y - box_height / 2.f) / img_height;
//! xmax
max_buf[max_idx++] = (center_x + box_width / 2.f) / img_width;
//! ymax
max_buf[max_idx++] = (center_y + box_height / 2.f) / img_height;
}
//! rest of priors
for (int r = 0; r < aspect_ratio_.size(); ++r) {
float ar = aspect_ratio_[r];
if (fabs(ar - 1.) < 1e-6) {
continue;
}
box_width = min_size * sqrt(ar);
box_height = min_size / sqrt(ar);
//! xmin
com_buf[com_idx++] = (center_x - box_width / 2.f) / img_width;
//! ymin
com_buf[com_idx++] = (center_y - box_height / 2.f) / img_height;
//! xmax
com_buf[com_idx++] = (center_x + box_width / 2.f) / img_width;
//! ymax
com_buf[com_idx++] = (center_y + box_height / 2.f) / img_height;
}
if (min_max_aspect_ratios_order) {
memcpy(_cpu_data + idx, min_buf, sizeof(float) * min_idx);
idx += min_idx;
memcpy(_cpu_data + idx, max_buf, sizeof(float) * max_idx);
idx += max_idx;
memcpy(_cpu_data + idx, com_buf, sizeof(float) * com_idx);
idx += com_idx;
} else {
memcpy(_cpu_data + idx, min_buf, sizeof(float) * min_idx);
idx += min_idx;
memcpy(_cpu_data + idx, com_buf, sizeof(float) * com_idx);
idx += com_idx;
memcpy(_cpu_data + idx, max_buf, sizeof(float) * max_idx);
idx += max_idx;
}
}
TargetWrapper<TARGET(kHost)>::Free(min_buf);
TargetWrapper<TARGET(kHost)>::Free(max_buf);
TargetWrapper<TARGET(kHost)>::Free(com_buf);
}
}
}
//! clip the prior's coordinate such that it is within [0, 1]
if (is_clip_) {
for (int d = 0; d < channel_size; ++d) {
_cpu_data[d] = std::min(std::max(_cpu_data[d], 0.f), 1.f);
}
}
//! set the variance.
int count = 0;
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
for (int i = 0; i < prior_num_; ++i) {
for (int j = 0; j < 4; ++j) {
_variance_data[count] = variance_[j];
++count;
}
}
}
}
}
void PriorBoxCompute::ReInitWhenNeeded() {
auto& param = this->template Param<param_t>();
auto input_dims = param.input->dims();
auto image_dims = param.image->dims();
if (last_input_shape_ == input_dims && last_image_shape_ == image_dims) {
return;
}
bool is_flip = param.flip;
bool is_clip = param.clip;
std::vector<float> min_size = param.min_sizes;
......@@ -66,25 +360,35 @@ void PriorBoxCompute::Run() {
prior_num += max_size.size();
std::vector<std::string> order = param.order;
bool min_max_aspect_ratios_order = param.min_max_aspect_ratios_order;
density_prior_box(param.input,
param.image,
&boxes_tmp_,
&variances_tmp_,
min_size,
std::vector<float>(),
std::vector<float>(),
std::vector<int>(),
max_size,
aspect_ratios_vec,
variance,
img_w,
img_h,
step_w,
step_h,
offset,
prior_num,
is_flip,
is_clip,
order,
min_max_aspect_ratios_order);
last_input_shape_ = input_dims;
last_image_shape_ = image_dims;
}
lite::arm::math::prior_box(param.input,
param.image,
&param.boxes,
&param.variances,
min_size,
max_size,
aspect_ratios_vec,
variance,
img_w,
img_h,
step_w,
step_h,
offset,
prior_num,
is_flip,
is_clip,
order,
min_max_aspect_ratios_order);
void PriorBoxCompute::Run() {
auto& param = this->template Param<param_t>();
param.boxes->CopyDataFrom(boxes_tmp_);
param.variances->CopyDataFrom(variances_tmp_);
}
} // namespace arm
......
......@@ -26,8 +26,14 @@ class PriorBoxCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
using param_t = operators::PriorBoxParam;
void Run() override;
void ReInitWhenNeeded() override;
virtual ~PriorBoxCompute() = default;
private:
Tensor boxes_tmp_;
Tensor variances_tmp_;
DDim last_input_shape_;
DDim last_image_shape_;
};
} // namespace arm
......
......@@ -9,6 +9,8 @@ set(huawei_ascend_npu_subgraph_bridge_deps subgraph_bridge_registry subgraph_bri
lite_cc_library(subgraph_bridge_act_op_huawei_ascend_npu SRCS act_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_conv_op_huawei_ascend_npu SRCS conv_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_interpolate_op_huawei_ascend_npu SRCS interpolate_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_concat_op_huawei_ascend_npu SRCS concat_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
set(huawei_ascend_npu_subgraph_bridges
subgraph_bridge_registry
......@@ -16,4 +18,6 @@ set(huawei_ascend_npu_subgraph_bridges
subgraph_bridge_graph_huawei_ascend_npu
subgraph_bridge_act_op_huawei_ascend_npu
subgraph_bridge_conv_op_huawei_ascend_npu
subgraph_bridge_interpolate_op_huawei_ascend_npu
subgraph_bridge_concat_op_huawei_ascend_npu
CACHE INTERNAL "huawei_ascend_npu_subgraph_bridges")
......@@ -49,6 +49,10 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto act_node = graph->template Add<ActType>(out_name);
auto act_op = act_node->template data<ActType>();
act_op->set_input_x(*x_node->data());
TENSOR_UPDATE_INPUT(
act_op, x, ge::FORMAT_NCHW, CvtPrecisionType(x_node->precision()));
TENSOR_UPDATE_OUTPUT(
act_op, y, ge::FORMAT_NCHW, CvtPrecisionType(act_node->precision()));
return SUCCESS;
}
......@@ -84,6 +88,10 @@ int ActConverter<ge::op::LeakyRelu>(void* ctx, OpLite* op, KernelBase* kernel) {
// only for leaky_relu
auto alpha = op_info->GetAttr<float>("alpha");
act_op->set_attr_negative_slope(alpha);
TENSOR_UPDATE_INPUT(
act_op, x, ge::FORMAT_NCHW, CvtPrecisionType(x_node->precision()));
TENSOR_UPDATE_OUTPUT(
act_op, y, ge::FORMAT_NCHW, CvtPrecisionType(act_node->precision()));
return SUCCESS;
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/huawei_ascend_npu/bridges/graph.h"
#include "lite/kernels/huawei_ascend_npu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace huawei_ascend_npu {
int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " << op_type << " ... ";
// Get input and output vars and op attributes
auto x_names = op_info->Input("X");
auto axis = op_info->GetAttr<int>("axis");
auto out_name = op_info->Output("Out").front();
auto num = x_names.size();
if (op_info->HasInput("AxisTensor")) {
// axis node
auto axis_name = op_info->Input("AxisTensor").front();
auto axis_tensor = scope->FindMutableTensor(axis_name);
std::shared_ptr<Node> axis_node = nullptr;
if (graph->Has(axis_name)) {
axis_node = graph->Get(axis_name);
} else {
axis_node = graph->Add(axis_name, *axis_tensor);
}
// concat node
auto concat_node = graph->Add<ge::op::Concat>(out_name);
auto concat_op = concat_node->data<ge::op::Concat>();
// set axis input
concat_op->set_input_concat_dim(*axis_node->data());
TENSOR_UPDATE_INPUT(concat_op,
concat_dim,
ge::FORMAT_NCHW,
CvtPrecisionType(axis_node->precision()));
// set dynamic input
concat_op->set_attr_N(num);
concat_op->create_dynamic_input_x(num);
int idx = 0;
for (auto& x_name : x_names) {
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x);
}
concat_op->set_dynamic_input_x(idx, *x_node->data());
TENSOR_UPDATE_DYNAMIC_INPUT(concat_op,
x,
idx,
ge::FORMAT_NCHW,
CvtPrecisionType(x_node->precision()));
idx++;
}
TENSOR_UPDATE_OUTPUT(concat_op,
y,
ge::FORMAT_NCHW,
CvtPrecisionType(concat_node->precision()));
} else {
auto concat_node = graph->Add<ge::op::ConcatD>(out_name);
auto concat_op = concat_node->data<ge::op::ConcatD>();
concat_op->set_attr_concat_dim(axis);
concat_op->set_attr_N(num);
concat_op->create_dynamic_input_x(num);
int idx = 0;
for (auto& x_name : x_names) {
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x);
}
concat_op->set_dynamic_input_x(idx, *x_node->data());
TENSOR_UPDATE_DYNAMIC_INPUT(concat_op,
x,
idx,
ge::FORMAT_NCHW,
CvtPrecisionType(x_node->precision()));
idx++;
}
TENSOR_UPDATE_OUTPUT(concat_op,
y,
ge::FORMAT_NCHW,
CvtPrecisionType(concat_node->precision()));
}
return SUCCESS;
}
} // namespace huawei_ascend_npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(
concat,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::ConcatConverter);
......@@ -35,7 +35,6 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto input_name = op_info->Input("Input").front();
auto input = scope->FindMutableTensor(input_name);
auto input_dims = input->dims();
ge::DataType ge_data_type = CvtPrecisionType(input->precision());
auto filter_name = op_info->Input("Filter").front();
auto filter = scope->FindMutableTensor(filter_name);
......@@ -99,6 +98,22 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
input_dims,
filter_dims);
// Check Restrictions: HxW(input) == HxW(filter) if output feature h*w = 1*1
if (output_dims[2] == 1 && output_dims[3] == 1) {
int input_h = input_dims[2] + paddings[0] + paddings[1];
int input_w = input_dims[3] + paddings[2] + paddings[3];
int filter_h = (filter_dims[2] - 1) * dilations[0] + 1;
int filter_w = (filter_dims[3] - 1) * dilations[1] + 1;
CHECK_EQ(input_h, filter_h) << "[HUAWEI_ASCEND_NPU] Huawei Ascend NPU DDK "
"restriction: if output HxW = 1x1, then "
"input height after padding should equal to "
"filter height after dilation";
CHECK_EQ(input_w, filter_w) << "[HUAWEI_ASCEND_NPU] Huawei Ascend NPU DDK "
"restriction: if output HxW = 1x1, then "
"input width after padding should equal to "
"filter width after dilation";
}
// Check depthwise mode, and decide whether use DepthwiseConv2D Op
bool use_depthwise_conv = false;
bool is_depthwise_mode = (ic == groups && oc == groups && groups != 1);
......@@ -148,20 +163,6 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
}
}
// Ascend must update convop desc, or IR model build will fail
ge::TensorDesc conv2d_input_desc_x(
ge::Shape(CvtShape(input_dims)), ge::FORMAT_NCHW, ge_data_type);
ge::TensorDesc conv2d_input_desc_filter(
ge::Shape(CvtShape(filter_dims)), ge::FORMAT_NCHW, ge_data_type);
ge::TensorDesc conv2d_input_desc_bias(
ge::Shape(bias_shape), ge::FORMAT_ND, ge_data_type);
ge::TensorDesc conv2d_output_desc_y(
ge::Shape(CvtShape(output_dims)), ge::FORMAT_NCHW, ge_data_type);
// Setting desc name
conv2d_input_desc_x.SetName("conv2d_input_desc_x");
conv2d_input_desc_filter.SetName("conv2d_input_desc_filter");
conv2d_input_desc_bias.SetName("conv2d_input_desc_bias");
conv2d_output_desc_y.SetName("conv2d_output_desc_y");
// Conv node
std::shared_ptr<Node> conv_node = nullptr;
if (use_depthwise_conv && is_depthwise_mode) {
......@@ -177,12 +178,19 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
conv_op->set_attr_data_format("NCHW");
if (bias_node != nullptr && is_channel_bias) {
conv_op->set_input_bias(*bias_node->data());
conv_op->update_input_desc_bias(conv2d_input_desc_bias);
TENSOR_UPDATE_INPUT(conv_op,
bias,
ge::FORMAT_NCHW,
CvtPrecisionType(bias_node->precision()));
}
// update tensor desc to conv2d
conv_op->update_input_desc_x(conv2d_input_desc_x);
conv_op->update_input_desc_filter(conv2d_input_desc_filter);
conv_op->update_output_desc_y(conv2d_output_desc_y);
TENSOR_UPDATE_INPUT(
conv_op, x, ge::FORMAT_NCHW, CvtPrecisionType(input_node->precision()));
TENSOR_UPDATE_INPUT(conv_op,
filter,
ge::FORMAT_NCHW,
CvtPrecisionType(filter_node->precision()));
TENSOR_UPDATE_OUTPUT(
conv_op, y, ge::FORMAT_NCHW, CvtPrecisionType(conv_node->precision()));
} else {
conv_node = graph->Add<ge::op::Conv2D>(output_name);
auto conv_op = conv_node->data<ge::op::Conv2D>();
......@@ -198,12 +206,19 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
conv_op->set_attr_data_format("NCHW");
if (bias_node != nullptr && is_channel_bias) {
conv_op->set_input_bias(*bias_node->data());
conv_op->update_input_desc_bias(conv2d_input_desc_bias);
TENSOR_UPDATE_INPUT(conv_op,
bias,
ge::FORMAT_NCHW,
CvtPrecisionType(bias_node->precision()));
}
// update tensor desc to conv2d
conv_op->update_input_desc_x(conv2d_input_desc_x);
conv_op->update_input_desc_filter(conv2d_input_desc_filter);
conv_op->update_output_desc_y(conv2d_output_desc_y);
TENSOR_UPDATE_INPUT(
conv_op, x, ge::FORMAT_NCHW, CvtPrecisionType(input_node->precision()));
TENSOR_UPDATE_INPUT(conv_op,
filter,
ge::FORMAT_NCHW,
CvtPrecisionType(filter_node->precision()));
TENSOR_UPDATE_OUTPUT(
conv_op, y, ge::FORMAT_NCHW, CvtPrecisionType(conv_node->precision()));
}
// append Add node to support bias
if (bias_node != nullptr && !is_channel_bias) {
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// #include "lite/operators/interpolate_op.h"
#include "lite/kernels/huawei_ascend_npu/bridges/graph.h"
#include "lite/kernels/huawei_ascend_npu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace huawei_ascend_npu {
int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[HUAWEI_ASCEND_NPU] Converting " + op_type + "...";
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto x_h = x_dims[2];
auto x_w = x_dims[3];
CHECK_EQ(x_dims.size(), 4);
auto out_name = op_info->Output("Out").front();
auto scale = op_info->GetAttr<float>("scale");
auto out_w = op_info->GetAttr<int>("out_w");
auto out_h = op_info->GetAttr<int>("out_h");
auto align_corners = op_info->GetAttr<bool>("align_corners");
int align_mode =
op_info->HasAttr("align_mode") ? op_info->GetAttr<int>("align_mode") : 1;
auto interp_method = op_info->GetAttr<std::string>("interp_method");
if (align_mode == 0 && !align_corners) {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] align_mode = 0 && "
"align_corners = false isn't "
"supported in Huawei Ascend NPU DDK";
return FAILED;
}
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x);
}
// Priority: OutSize > scale > out_h/out_w
if (scale > 0) {
out_h = static_cast<int>(x_h * scale);
out_w = static_cast<int>(x_w * scale);
out_h = out_h > 0 ? out_h : -1;
out_w = out_w > 0 ? out_w : -1;
}
// Update out_h and out_w and create out_size node if has OutSize
std::shared_ptr<Node> out_size_node = nullptr;
if (HasInputArg(op_info, scope, "OutSize")) {
auto out_size_name = op_info->Input("OutSize").front();
if (graph->Has(out_size_name)) {
out_size_node = graph->Get(out_size_name);
} else {
auto out_size = scope->FindMutableTensor(out_size_name);
CHECK_EQ(out_size->numel(), 2);
CHECK(out_size->persistable());
auto out_size_data = out_size->mutable_data<int>();
// Update out_h and out_w if has OutSize
out_h = out_size_data[0];
out_w = out_size_data[1];
}
}
if (out_size_node == nullptr) {
out_size_node =
graph->Add(out_name + "/out_size", std::vector<int>({out_h, out_w}));
}
if (interp_method == "bilinear") {
auto bilinear_interp_node = graph->Add<ge::op::ResizeBilinearV2>(out_name);
auto bilinear_interp_op =
bilinear_interp_node->data<ge::op::ResizeBilinearV2>();
bilinear_interp_op->set_input_x(*x_node->data());
bilinear_interp_op->set_input_size(*out_size_node->data());
bilinear_interp_op->set_attr_align_corners(align_corners);
TENSOR_UPDATE_INPUT(bilinear_interp_op,
x,
ge::FORMAT_NCHW,
CvtPrecisionType(x_node->precision()));
TENSOR_UPDATE_INPUT(bilinear_interp_op,
size,
ge::FORMAT_NCHW,
CvtPrecisionType(out_size_node->precision()));
TENSOR_UPDATE_OUTPUT(bilinear_interp_op,
y,
ge::FORMAT_NCHW,
CvtPrecisionType(bilinear_interp_node->precision()));
} else if (interp_method == "nearest") {
auto nearest_interp_node =
graph->Add<ge::op::ResizeNearestNeighborV2>(out_name);
auto nearest_interp_op =
nearest_interp_node->data<ge::op::ResizeNearestNeighborV2>();
nearest_interp_op->set_input_x(*x_node->data());
nearest_interp_op->set_input_size(*out_size_node->data());
nearest_interp_op->set_attr_align_corners(align_corners);
TENSOR_UPDATE_INPUT(nearest_interp_op,
x,
ge::FORMAT_NCHW,
CvtPrecisionType(x_node->precision()));
TENSOR_UPDATE_INPUT(nearest_interp_op,
size,
ge::FORMAT_NCHW,
CvtPrecisionType(out_size_node->precision()));
TENSOR_UPDATE_OUTPUT(nearest_interp_op,
y,
ge::FORMAT_NCHW,
CvtPrecisionType(nearest_interp_node->precision()));
} else {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] Unsupported interpolate method: "
<< interp_method;
return FAILED;
}
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace huawei_ascend_npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(
bilinear_interp,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::InterpolateConverter);
REGISTER_SUBGRAPH_BRIDGE(
nearest_interp,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::InterpolateConverter);
......@@ -25,3 +25,6 @@ USE_SUBGRAPH_BRIDGE(softplus, kHuaweiAscendNPU);
// conv
USE_SUBGRAPH_BRIDGE(conv2d, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(depthwise_conv2d, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(bilinear_interp, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(nearest_interp, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(concat, kHuaweiAscendNPU);
......@@ -156,61 +156,6 @@ int CvtActMode(std::string act_type) {
return act_mode;
}
const std::string& CvtFormat(ge::Format format) {
static const int MAX_FORMAT_LENGTH = 25;
static const std::string format2string[] = {
"FORMAT_NCHW = 0",
"FORMAT_NHWC = 1",
"FORMAT_ND = 2",
"FORMAT_NC1HWC0 = 3",
"FORMAT_FRACTAL_Z = 4",
"FORMAT_NC1C0HWPAD = 5",
"FORMAT_NHWC1C0 = 6",
"FORMAT_FSR_NCHW = 7",
"FORMAT_FRACTAL_DECONV = 8",
"FORMAT_C1HWNC0 = 9",
"FORMAT_FRACTAL_DECONV_TRANSPOSE = 10",
"FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS = 11",
"FORMAT_NC1HWC0_C04 = 12",
"FORMAT_FRACTAL_Z_C04 = 13",
"FORMAT_CHWN = 14",
"FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS = 15",
"FORMAT_HWCN = 16",
"FORMAT_NC1KHKWHWC0 = 17",
"FORMAT_BN_WEIGHT = 18",
"FORMAT_FILTER_HWCK = 19",
"FORMAT_HASHTABLE_LOOKUP_LOOKUPS = 20",
"FORMAT_HASHTABLE_LOOKUP_KEYS = 21",
"FORMAT_HASHTABLE_LOOKUP_VALUE = 22",
"FORMAT_HASHTABLE_LOOKUP_OUTPUT = 23",
"FORMAT_HASHTABLE_LOOKUP_HITS = 24"};
auto x = static_cast<int>(format);
CHECK_LT(x, MAX_FORMAT_LENGTH);
return format2string[x];
}
const std::string& CvtDataType(ge::DataType data_type) {
static const int MAX_DATATYPE_LENGTH = 14;
static const std::string datatype2string[] = {"DT_FLOAT=0",
"DT_FLOAT16=1",
"DT_INT8=2",
"DT_INT32=3",
"DT_UINT8=4",
"Unknown=5",
"DT_INT16=6",
"DT_UINT16=7",
"DT_UINT32=8",
"DT_INT64=9",
"DT_UINT64=10",
"DT_DOUBLE=11",
"DT_BOOL=12",
"DT_STRING=13"};
auto x = static_cast<int>(data_type);
CHECK_LT(x, MAX_DATATYPE_LENGTH);
return datatype2string[x];
}
} // namespace huawei_ascend_npu
} // namespace subgraph
} // namespace lite
......
......@@ -30,6 +30,17 @@ namespace lite {
namespace subgraph {
namespace huawei_ascend_npu {
#define TENSOR_UPDATE_INPUT(op, attr, format, dtype) \
ge::TensorDesc _##op##_input_desc_##attr(ge::Shape(), format, dtype); \
op->update_input_desc_##attr(_##op##_input_desc_##attr);
#define TENSOR_UPDATE_OUTPUT(op, attr, format, dtype) \
ge::TensorDesc _##op##_output_desc_##attr(ge::Shape(), format, dtype); \
op->update_output_desc_##attr(_##op##_output_desc_##attr);
#define TENSOR_UPDATE_DYNAMIC_INPUT(op, attr, idx, format, dtype) \
ge::TensorDesc _##op##_input_desc_##attr##_##idx( \
ge::Shape(), format, dtype); \
op->update_dynamic_input_desc_##attr(idx, _##op##_input_desc_##attr##_##idx);
// Type/tensor converters for converting Paddle type/tensor to HiAI type/tensor
bool HasInputArg(const OpInfo* op_info,
const Scope* scope,
......@@ -50,9 +61,6 @@ ge::Tensor CvtTensor(const Tensor& in_tensor,
int CvtActMode(std::string act_type);
const std::string& CvtFormat(ge::Format format);
const std::string& CvtDataType(ge::DataType data_type);
} // namespace huawei_ascend_npu
} // namespace subgraph
} // namespace lite
......
......@@ -241,32 +241,18 @@ bool DeviceProgram::ShareBufferWithOriginTensors(
VLOG(3) << "[HUAWEI_ASCEND_NPU] Inputs[" << i
<< "] name: " << input_names[i]
<< " origin dims:" << (*origin_itensors)[i]->dims().repr()
<< " device dims: {" << device_idims_[i].GetNumber() << ","
<< device_idims_[i].GetChannel() << ","
<< device_idims_[i].GetHeight() << ","
<< device_idims_[i].GetWidth() << "}";
<< " device dims:" << device_idims_[i].repr();
CHECK_EQ((*origin_itensors)[i]->dims().production(),
device_idims_[i].GetNumber() * device_idims_[i].GetChannel() *
device_idims_[i].GetHeight() * device_idims_[i].GetWidth());
device_idims_[i].production());
// reset tensor desc
if ((*device_itensors)[i]->SetTensorDesc(
device_idims_[i].GetGeTensorDesc()) != ge::GRAPH_SUCCESS) {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] ge::Tensor input tensor "
"SetTensorDesc failed!";
} else {
VLOG(3) << "[HUAWEI_ASCEND_NPU] ge::Tensor input tensor SetTensorDesc "
"success.";
}
ATC_CALL((*device_itensors)[i]->SetTensorDesc(
device_idims_[i].GetGeTensorDesc()));
// copy data from origin to device
if ((*device_itensors)[i]->SetData(
reinterpret_cast<uint8_t*>((*origin_itensors)[i]->raw_data()),
(*origin_itensors)[i]->memory_size()) != ge::GRAPH_SUCCESS) {
LOG(WARNING)
<< "[HUAWEI_ASCEND_NPU] ge::Tensor input tensor SetData failed!";
} else {
VLOG(3) << "[HUAWEI_ASCEND_NPU] ge::Tensor input tensor SetData success.";
}
ATC_CALL((*device_itensors)[i]->SetData(
reinterpret_cast<uint8_t*>((*origin_itensors)[i]->raw_data()),
(*origin_itensors)[i]->memory_size()));
VLOG(3)
<< "[HUAWEI_ASCEND_NPU] Init the input tensors for the device program "
"and share their buffers with the origin input tensors";
......@@ -285,26 +271,13 @@ bool DeviceProgram::ShareBufferWithOriginTensors(
VLOG(3) << "[HUAWEI_ASCEND_NPU] Outputs[" << i
<< "] name: " << output_names[i]
<< " origin dims:" << (*origin_otensors)[i]->dims().repr()
<< " device dims: {" << device_odims_[i].GetNumber() << ","
<< device_odims_[i].GetChannel() << ","
<< device_odims_[i].GetHeight() << ","
<< device_odims_[i].GetWidth() << "}";
<< " device dims:" << device_odims_[i].repr();
CHECK_EQ((*origin_otensors)[i]->dims().production(),
device_odims_[i].GetNumber() * device_odims_[i].GetChannel() *
device_odims_[i].GetHeight() * device_odims_[i].GetWidth());
device_odims_[i].production());
// reset tensor desc
if ((*device_otensors)[i]->SetTensorDesc(
device_odims_[i].GetGeTensorDesc()) != ge::GRAPH_SUCCESS) {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] ge::Tensor output tensor "
"SetTensorDesc failed!";
} else {
VLOG(3) << "[HUAWEI_ASCEND_NPU] ge::Tensor output tensor SetTensorDesc "
"success.";
}
VLOG(3)
<< "[HUAWEI_ASCEND_NPU] Init the output tensors for the device program "
"and share their buffers with the origin output tensors";
ATC_CALL((*device_otensors)[i]->SetTensorDesc(
device_odims_[i].GetGeTensorDesc()));
}
return true;
}
......@@ -321,8 +294,7 @@ bool DeviceProgram::SharedBufferWithOutputTensors(
for (size_t i = 0; i < output_names.size(); i++) {
CHECK_EQ((*origin_otensors)[i]->dims().production(),
device_odims_[i].GetNumber() * device_odims_[i].GetChannel() *
device_odims_[i].GetHeight() * device_odims_[i].GetWidth());
device_odims_[i].production());
// Share data buf between device_itensor and origin_itensor
std::shared_ptr<Buffer> buffer = std::make_shared<Buffer>(
......
......@@ -20,6 +20,7 @@
#include <utility>
#include <vector>
#include "graph/compatible/all_ops.h"
#include "graph/op/all_ops.h"
#include "lite/core/op_lite.h"
#include "lite/core/tensor.h"
......
......@@ -127,12 +127,11 @@ int LayerNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
}
// LayerNorm node
auto layer_norm_node = graph->Add<ge::op::InstanceNorm>(y_name);
auto layer_norm_op = layer_norm_node->data<ge::op::InstanceNorm>();
auto layer_norm_node = graph->Add<hiai::op::LayerNorm>(y_name);
auto layer_norm_op = layer_norm_node->data<hiai::op::LayerNorm>();
layer_norm_op->set_input_x(*x_node->data());
layer_norm_op->set_input_scale(*scale_node->data());
layer_norm_op->set_input_bias(*bias_node->data());
layer_norm_op->set_attr_reduction_indices(ge::AttrValue::LIST_INT({3}));
layer_norm_op->set_input_gamma(*scale_node->data());
layer_norm_op->set_input_beta(*bias_node->data());
layer_norm_op->set_attr_epsilon(epsilon);
// Reshaped Y node if needs
......
......@@ -32,16 +32,24 @@ namespace opencl {
void ConvImageCompute::PrepareForRun() {
ReInitWhenNeeded();
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
const bool is_mali = context.cl_context()->IsArmMali();
use_tune_ = CLRuntime::Global()->auto_tune();
if (!is_mali) {
use_tune_ = false;
}
#ifdef LITE_WITH_LOG
LOG(INFO) << "use_tune_" << use_tune_;
#endif
auto filter_dims = conv_param_->filter->dims();
filter_tensor_n_ = filter_dims[0];
filter_tensor_c_ = filter_dims[1];
filter_tensor_h_ = filter_dims[2];
filter_tensor_w_ = filter_dims[3];
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
const bool is_mali = context.cl_context()->IsArmMali();
auto paddings = *conv_param_->paddings;
pad_up_ = paddings[0];
pad_down_ = paddings[1];
......@@ -65,6 +73,7 @@ void ConvImageCompute::PrepareForRun() {
bool stride_equal = stride_h_ == stride_w_;
bool dilation_equal = dilation_h_ == dilation_w_;
#ifdef LITE_WITH_LOG
VLOG(3) << "Is arm mali / " << (is_mali ? "Yes" : "No");
VLOG(3) << "Is relu fused? / " << (relu_fused_ ? "Yes" : "No");
VLOG(3) << "groups:" << groups_ << " stride_h_:" << stride_h_
......@@ -83,6 +92,8 @@ void ConvImageCompute::PrepareForRun() {
VLOG(3) << "dilation_equal:" << dilation_equal;
VLOG(3) << "padding :" << pad_up_ << " " << pad_down_ << " " << pad_left_
<< " " << pad_right_;
#endif
CHECK(pad_equal && stride_equal && dilation_equal);
CHECK_GE(conv_param_->dilations->size(), 2);
CHECK(dilation_h_ == dilation_w_);
......@@ -91,10 +102,6 @@ void ConvImageCompute::PrepareForRun() {
CHECK_GE(conv_param_->strides.size(), 2);
CHECK(stride_h_ == stride_w_);
if (!is_mali) {
use_tune_ = false;
}
/*********************************************
* Upload filter, bias to opencl device
*********************************************/
......
......@@ -152,7 +152,7 @@ class ConvImageCompute : public KernelLite<TARGET(kOpenCL),
cl::NDRange local_work_size_ = cl::NDRange{
static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)};
bool use_lws_{true};
bool use_tune_{true};
bool use_tune_{false};
};
} // namespace opencl
......
......@@ -147,6 +147,8 @@ TEST(Concat, precision) {
#if defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 1e-2; // use fp16 in npu
#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU)
place = TARGET(kHuaweiAscendNPU);
#elif defined(LITE_WITH_ARM)
place = TARGET(kARM);
#elif defined(LITE_WITH_X86)
......@@ -157,6 +159,10 @@ TEST(Concat, precision) {
for (int axis : {1, 2}) {
for (bool is_use_axis_tensor : {false, true}) {
// is_use_axis_tensor = true has bugs in Huawei Ascend NPU DDK
if (place == TARGET(kHuaweiAscendNPU) && is_use_axis_tensor) {
continue;
}
LOG(INFO) << "axis:" << axis
<< ", is_use_axis_tensor:" << is_use_axis_tensor;
std::unique_ptr<arena::TestCase> tester(
......
......@@ -296,6 +296,11 @@ void TestConvStrides(Place place, float abs_error = 2e-5) {
for (auto out_channels : {1, 3}) {
for (auto strides :
std::vector<std::vector<int>>{{2, 2}, {3, 3}, {1, 2}, {3, 1}}) {
// Check Huawei Ascend NPU restriction if output HxW = 1x1
// input_w after padding = 4 should equal to fitler_w after dilation = 3
if (place == TARGET(kHuaweiAscendNPU) && dims[3] == 4) {
continue;
}
std::unique_ptr<arena::TestCase> tester(new ConvComputeTester(
place, "def", DDim(dims), out_channels, 3, strides));
arena::Arena arena(std::move(tester), place, abs_error);
......@@ -415,13 +420,16 @@ TEST(Conv2d, precision) {
abs_error = 5e-2; // Using fp16 in NPU
#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU)
place = TARGET(kHuaweiAscendNPU);
abs_error = 5e-2; // Using fp16 in NPU
abs_error = 1e-2; // Using fp16 in NPU
#else
return;
#endif
TestConvKsize(place, abs_error);
// Huawei Ascend NPU DDK not support groups > 1
#if !defined(LITE_WITH_HUAWEI_ASCEND_NPU)
TestConvGroups(place, abs_error);
#endif
TestConvDilations(place, abs_error);
TestConvStrides(place, abs_error);
TestConvPaddings(place, abs_error);
......
......@@ -420,6 +420,12 @@ void TestInterpAlignMode(Place place, float abs_error = 2e-5) {
if (place == TARGET(kARM) && align_mode == 1 && !align_corners) {
continue;
}
// align_mode = 0 && align_corners = false NOT supported in Huawei
// Ascend NPU DDK
if (place == TARGET(kHuaweiAscendNPU) && align_mode == 0 &&
!align_corners) {
continue;
}
std::unique_ptr<arena::TestCase> tester(
new NearestInterpComputeTester(place,
"def",
......@@ -443,6 +449,9 @@ TEST(Interp, precision) {
#if defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 1e-2; // use fp16 in npu
#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU)
place = TARGET(kHuaweiAscendNPU);
abs_error = 1e-2; // use fp16 in npu
#elif defined(LITE_WITH_ARM)
place = TARGET(kARM);
#else
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册